aspnetcore/test/Kestrel.FunctionalTests/ConnectionAdapterTests.cs

357 lines
12 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{
public class ConnectionAdapterTests : LoggedTest
{
[Fact]
public async Task CanReadAndWriteWithRewritingConnectionAdapter()
{
var adapter = new RewritingConnectionAdapter();
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { adapter }
};
var serviceContext = new TestServiceContext(LoggerFactory);
var sendString = "POST / HTTP/1.0\r\nContent-Length: 12\r\n\r\nHello World?";
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
// "?" changes to "!"
await connection.Send(sendString);
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {serviceContext.DateHeaderValue}",
"",
"Hello World!");
}
}
Assert.Equal(sendString.Length, adapter.BytesRead);
}
[Fact]
public async Task CanReadAndWriteWithAsyncConnectionAdapter()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new AsyncConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.0",
"Content-Length: 12",
"",
"Hello World?");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {serviceContext.DateHeaderValue}",
"",
"Hello World!");
}
}
}
[Fact]
public async Task ImmediateFinAfterOnConnectionAsyncClosesGracefully()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new AsyncConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
// FIN
connection.Shutdown(SocketShutdown.Send);
await connection.WaitForConnectionClose();
}
}
}
[Fact]
public async Task ImmediateFinAfterThrowingClosesGracefully()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new ThrowingConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
// FIN
connection.Shutdown(SocketShutdown.Send);
await connection.WaitForConnectionClose();
}
}
}
[Fact]
public async Task ImmediateShutdownAfterOnConnectionAsyncDoesNotCrash()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new AsyncConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
var stopTask = Task.CompletedTask;
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
stopTask = server.StopAsync();
}
await stopTask;
}
}
[Fact]
public async Task ThrowingSynchronousConnectionAdapterDoesNotCrashServer()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new ThrowingConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(TestApp.EchoApp, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
// Will throw because the exception in the connection adapter will close the connection.
await Assert.ThrowsAsync<IOException>(async () =>
{
await connection.Send(
"POST / HTTP/1.0",
"Content-Length: 1000",
"\r\n");
for (var i = 0; i < 1000; i++)
{
await connection.Send("a");
await Task.Delay(5);
}
});
}
}
}
[Fact]
public async Task CanFlushAsyncWithConnectionAdapter()
{
var listenOptions = new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
{
ConnectionAdapters = { new PassThroughConnectionAdapter() }
};
var serviceContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(async context =>
{
await context.Response.WriteAsync("Hello ");
await context.Response.Body.FlushAsync();
await context.Response.WriteAsync("World!");
}, serviceContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.0",
"",
"");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {serviceContext.DateHeaderValue}",
"",
"Hello World!");
}
}
}
private class RewritingConnectionAdapter : IConnectionAdapter
{
private RewritingStream _rewritingStream;
public bool IsHttps => false;
public Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context)
{
_rewritingStream = new RewritingStream(context.ConnectionStream);
return Task.FromResult<IAdaptedConnection>(new AdaptedConnection(_rewritingStream));
}
public int BytesRead => _rewritingStream.BytesRead;
}
private class AsyncConnectionAdapter : IConnectionAdapter
{
public bool IsHttps => false;
public async Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context)
{
await Task.Delay(100);
return new AdaptedConnection(new RewritingStream(context.ConnectionStream));
}
}
private class ThrowingConnectionAdapter : IConnectionAdapter
{
public bool IsHttps => false;
public Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context)
{
throw new Exception();
}
}
private class AdaptedConnection : IAdaptedConnection
{
public AdaptedConnection(Stream adaptedStream)
{
ConnectionStream = adaptedStream;
}
public Stream ConnectionStream { get; }
public void Dispose()
{
}
}
private class RewritingStream : Stream
{
private readonly Stream _innerStream;
public RewritingStream(Stream innerStream)
{
_innerStream = innerStream;
}
public int BytesRead { get; private set; }
public override bool CanRead => _innerStream.CanRead;
public override bool CanSeek => _innerStream.CanSeek;
public override bool CanWrite => _innerStream.CanWrite;
public override long Length => _innerStream.Length;
public override long Position
{
get
{
return _innerStream.Position;
}
set
{
_innerStream.Position = value;
}
}
public override void Flush()
{
_innerStream.Flush();
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _innerStream.FlushAsync(cancellationToken);
}
public override int Read(byte[] buffer, int offset, int count)
{
var actual = _innerStream.Read(buffer, offset, count);
BytesRead += actual;
return actual;
}
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
var actual = await _innerStream.ReadAsync(buffer, offset, count);
BytesRead += actual;
return actual;
}
public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_innerStream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
for (int i = 0; i < buffer.Length; i++)
{
if (buffer[i] == '?')
{
buffer[i] = (byte)'!';
}
}
_innerStream.Write(buffer, offset, count);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
for (int i = 0; i < buffer.Length; i++)
{
if (buffer[i] == '?')
{
buffer[i] = (byte)'!';
}
}
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}
}
}
}