// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Diagnostics; using System.Globalization; using System.IO; using System.Linq; using System.Net; using System.Net.Http; using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; using Moq; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { public class RequestTests : LoggedTest { private const int _connectionStartedEventId = 1; private const int _connectionResetEventId = 19; private static readonly int _semaphoreWaitTimeout = Debugger.IsAttached ? 10000 : 2500; public static TheoryData ConnectionAdapterData => new TheoryData { new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)), new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0)) { ConnectionAdapters = { new PassThroughConnectionAdapter() } } }; [Theory] [InlineData(10 * 1024 * 1024, true)] // In the following dataset, send at least 2GB. // Never change to a lower value, otherwise regression testing for // https://github.com/aspnet/KestrelHttpServer/issues/520#issuecomment-188591242 // will be lost. [InlineData((long)int.MaxValue + 1, false)] public void LargeUpload(long contentLength, bool checkBytes) { const int bufferLength = 1024 * 1024; Assert.True(contentLength % bufferLength == 0, $"{nameof(contentLength)} sent must be evenly divisible by {bufferLength}."); Assert.True(bufferLength % 256 == 0, $"{nameof(bufferLength)} must be evenly divisible by 256"); var builder = TransportSelector.GetWebHostBuilder() .ConfigureServices(AddTestLogging) .UseKestrel(options => { options.Limits.MaxRequestBodySize = contentLength; options.Limits.MinRequestBodyDataRate = null; }) .UseUrls("http://127.0.0.1:0/") .Configure(app => { app.Run(async context => { // Read the full request body long total = 0; var receivedBytes = new byte[bufferLength]; var received = 0; while ((received = await context.Request.Body.ReadAsync(receivedBytes, 0, receivedBytes.Length)) > 0) { if (checkBytes) { for (var i = 0; i < received; i++) { // Do not use Assert.Equal here, it is to slow for this hot path Assert.True((byte)((total + i) % 256) == receivedBytes[i], "Data received is incorrect"); } } total += received; } await context.Response.WriteAsync(total.ToString(CultureInfo.InvariantCulture)); }); }); using (var host = builder.Build()) { host.Start(); using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); socket.Send(Encoding.ASCII.GetBytes($"POST / HTTP/1.0\r\nContent-Length: {contentLength}\r\n\r\n")); var contentBytes = new byte[bufferLength]; if (checkBytes) { for (var i = 0; i < contentBytes.Length; i++) { contentBytes[i] = (byte)i; } } for (var i = 0; i < contentLength / contentBytes.Length; i++) { socket.Send(contentBytes); } var response = new StringBuilder(); var responseBytes = new byte[4096]; var received = 0; while ((received = socket.Receive(responseBytes)) > 0) { response.Append(Encoding.ASCII.GetString(responseBytes, 0, received)); } Assert.Contains(contentLength.ToString(CultureInfo.InvariantCulture), response.ToString()); } } } [Fact] public Task RemoteIPv4Address() { return TestRemoteIPAddress("127.0.0.1", "127.0.0.1", "127.0.0.1"); } [ConditionalFact] [IPv6SupportedCondition] public Task RemoteIPv6Address() { return TestRemoteIPAddress("[::1]", "[::1]", "::1"); } [Fact] public async Task DoesNotHangOnConnectionCloseRequest() { var builder = TransportSelector.GetWebHostBuilder() .UseKestrel() .UseUrls("http://127.0.0.1:0") .ConfigureServices(AddTestLogging) .Configure(app => { app.Run(async context => { await context.Response.WriteAsync("hello, world"); }); }); using (var host = builder.Build()) using (var client = new HttpClient()) { host.Start(); client.DefaultRequestHeaders.Connection.Clear(); client.DefaultRequestHeaders.Connection.Add("close"); var response = await client.GetAsync($"http://127.0.0.1:{host.GetPort()}/"); response.EnsureSuccessStatusCode(); } } [Fact] public async Task ConnectionResetPriorToRequestIsLoggedAsDebug() { var connectionStarted = new SemaphoreSlim(0); var connectionReset = new SemaphoreSlim(0); var loggedHigherThanDebug = false; var mockLogger = new Mock(); mockLogger .Setup(logger => logger.IsEnabled(It.IsAny())) .Returns(true); mockLogger .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) .Callback>((logLevel, eventId, state, exception, formatter) => { Logger.Log(logLevel, eventId, state, exception, formatter); if (eventId.Id == _connectionStartedEventId) { connectionStarted.Release(); } else if (eventId.Id == _connectionResetEventId) { connectionReset.Release(); } if (logLevel > LogLevel.Debug) { loggedHigherThanDebug = true; } }); var mockLoggerFactory = new Mock(); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsAny())) .Returns(Logger); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) .Returns(mockLogger.Object); using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(mockLoggerFactory.Object))) { using (var connection = server.CreateConnection()) { // Wait until connection is established Assert.True(await connectionStarted.WaitAsync(TestConstants.DefaultTimeout)); connection.Reset(); } // If the reset is correctly logged as Debug, the wait below should complete shortly. // This check MUST come before disposing the server, otherwise there's a race where the RST // is still in flight when the connection is aborted, leading to the reset never being received // and therefore not logged. Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout)); } Assert.False(loggedHigherThanDebug); } [Fact] public async Task ConnectionResetBetweenRequestsIsLoggedAsDebug() { var connectionReset = new SemaphoreSlim(0); var loggedHigherThanDebug = false; var mockLogger = new Mock(); mockLogger .Setup(logger => logger.IsEnabled(It.IsAny())) .Returns(true); mockLogger .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) .Callback>((logLevel, eventId, state, exception, formatter) => { Logger.Log(logLevel, eventId, state, exception, formatter); if (eventId.Id == _connectionResetEventId) { connectionReset.Release(); } if (logLevel > LogLevel.Debug) { loggedHigherThanDebug = true; } }); var mockLoggerFactory = new Mock(); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsAny())) .Returns(Logger); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) .Returns(mockLogger.Object); using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(mockLoggerFactory.Object))) { using (var connection = server.CreateConnection()) { await connection.Send( "GET / HTTP/1.1", "Host:", "", ""); // Make sure the response is fully received, so a write failure (e.g. EPIPE) doesn't cause // a more critical log message. await connection.Receive( "HTTP/1.1 200 OK", $"Date: {server.Context.DateHeaderValue}", "Content-Length: 0", "", ""); connection.Reset(); // Force a reset } // If the reset is correctly logged as Debug, the wait below should complete shortly. // This check MUST come before disposing the server, otherwise there's a race where the RST // is still in flight when the connection is aborted, leading to the reset never being received // and therefore not logged. Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout)); } Assert.False(loggedHigherThanDebug); } [Fact] public async Task ConnectionResetMidRequestIsLoggedAsDebug() { var requestStarted = new SemaphoreSlim(0); var connectionReset = new SemaphoreSlim(0); var connectionClosing = new SemaphoreSlim(0); var loggedHigherThanDebug = false; var mockLogger = new Mock(); mockLogger .Setup(logger => logger.IsEnabled(It.IsAny())) .Returns(true); mockLogger .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) .Callback>((logLevel, eventId, state, exception, formatter) => { Logger.Log(logLevel, eventId, state, exception, formatter); var log = $"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception}"; TestOutputHelper.WriteLine(log); if (eventId.Id == _connectionResetEventId) { connectionReset.Release(); } if (logLevel > LogLevel.Debug) { loggedHigherThanDebug = true; } }); var mockLoggerFactory = new Mock(); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsAny())) .Returns(Logger); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel", "Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) .Returns(mockLogger.Object); using (var server = new TestServer(async context => { requestStarted.Release(); await connectionClosing.WaitAsync(); }, new TestServiceContext(mockLoggerFactory.Object))) { using (var connection = server.CreateConnection()) { await connection.SendEmptyGet(); // Wait until connection is established Assert.True(await requestStarted.WaitAsync(TestConstants.DefaultTimeout), "request should have started"); connection.Reset(); } // If the reset is correctly logged as Debug, the wait below should complete shortly. // This check MUST come before disposing the server, otherwise there's a race where the RST // is still in flight when the connection is aborted, leading to the reset never being received // and therefore not logged. Assert.True(await connectionReset.WaitAsync(TestConstants.DefaultTimeout), "Connection reset event should have been logged"); connectionClosing.Release(); } Assert.False(loggedHigherThanDebug, "Logged event should not have been higher than debug."); } [Fact] public async Task ThrowsOnReadAfterConnectionError() { var requestStarted = new SemaphoreSlim(0); var connectionReset = new SemaphoreSlim(0); var appDone = new SemaphoreSlim(0); var expectedExceptionThrown = false; var builder = TransportSelector.GetWebHostBuilder() .ConfigureServices(AddTestLogging) .UseKestrel() .UseUrls("http://127.0.0.1:0") .Configure(app => app.Run(async context => { requestStarted.Release(); Assert.True(await connectionReset.WaitAsync(_semaphoreWaitTimeout)); try { await context.Request.Body.ReadAsync(new byte[1], 0, 1); } catch (ConnectionResetException) { expectedExceptionThrown = true; } appDone.Release(); })); using (var host = builder.Build()) { host.Start(); using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); socket.LingerState = new LingerOption(true, 0); socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\nContent-Length: 1\r\n\r\n")); Assert.True(await requestStarted.WaitAsync(_semaphoreWaitTimeout)); } connectionReset.Release(); Assert.True(await appDone.WaitAsync(_semaphoreWaitTimeout)); Assert.True(expectedExceptionThrown); } } [Fact] public async Task RequestAbortedTokenFiredOnClientFIN() { var appStarted = new SemaphoreSlim(0); var requestAborted = new SemaphoreSlim(0); var builder = TransportSelector.GetWebHostBuilder() .UseKestrel() .UseUrls("http://127.0.0.1:0") .ConfigureServices(AddTestLogging) .Configure(app => app.Run(async context => { appStarted.Release(); var token = context.RequestAborted; token.Register(() => requestAborted.Release(2)); await requestAborted.WaitAsync().DefaultTimeout(); })); using (var host = builder.Build()) { host.Start(); using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n")); await appStarted.WaitAsync(); socket.Shutdown(SocketShutdown.Send); await requestAborted.WaitAsync().DefaultTimeout(); } } } [Fact] public void AbortingTheConnectionSendsFIN() { var builder = TransportSelector.GetWebHostBuilder() .UseKestrel() .UseUrls("http://127.0.0.1:0") .ConfigureServices(AddTestLogging) .Configure(app => app.Run(context => { context.Abort(); return Task.CompletedTask; })); using (var host = builder.Build()) { host.Start(); using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) { socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\nHost:\r\n\r\n")); int result = socket.Receive(new byte[32]); Assert.Equal(0, result); } } } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task ConnectionClosedTokenFiresOnClientFIN(ListenOptions listenOptions) { var testContext = new TestServiceContext(LoggerFactory); var appStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using (var server = new TestServer(context => { appStartedTcs.SetResult(null); var connectionLifetimeFeature = context.Features.Get(); connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); return Task.CompletedTask; }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { await connection.Send( "GET / HTTP/1.1", "Host:", "", ""); await appStartedTcs.Task.DefaultTimeout(); connection.Shutdown(SocketShutdown.Send); await connectionClosedTcs.Task.DefaultTimeout(); } } } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task ConnectionClosedTokenFiresOnServerFIN(ListenOptions listenOptions) { var testContext = new TestServiceContext(LoggerFactory); var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using (var server = new TestServer(context => { var connectionLifetimeFeature = context.Features.Get(); connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); return Task.CompletedTask; }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { await connection.Send( "GET / HTTP/1.1", "Host:", "Connection: close", "", ""); await connectionClosedTcs.Task.DefaultTimeout(); await connection.ReceiveEnd($"HTTP/1.1 200 OK", "Connection: close", $"Date: {server.Context.DateHeaderValue}", "Content-Length: 0", "", ""); } } } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task ConnectionClosedTokenFiresOnServerAbort(ListenOptions listenOptions) { var testContext = new TestServiceContext(LoggerFactory); var connectionClosedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); using (var server = new TestServer(context => { var connectionLifetimeFeature = context.Features.Get(); connectionLifetimeFeature.ConnectionClosed.Register(() => connectionClosedTcs.SetResult(null)); context.Abort(); return Task.CompletedTask; }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { await connection.Send( "GET / HTTP/1.1", "Host:", "", ""); await connectionClosedTcs.Task.DefaultTimeout(); await connection.ReceiveForcedEnd(); } } } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task RequestsCanBeAbortedMidRead(ListenOptions listenOptions) { const int applicationAbortedConnectionId = 34; var testContext = new TestServiceContext(LoggerFactory); var readTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var registrationTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var requestId = 0; using (var server = new TestServer(async httpContext => { requestId++; var response = httpContext.Response; var request = httpContext.Request; var lifetime = httpContext.Features.Get(); lifetime.RequestAborted.Register(() => registrationTcs.TrySetResult(requestId)); if (requestId == 1) { response.Headers["Content-Length"] = new[] { "5" }; await response.WriteAsync("World"); } else { var readTask = request.Body.CopyToAsync(Stream.Null); lifetime.Abort(); try { await readTask; } catch (Exception ex) { readTcs.SetException(ex); throw; } readTcs.SetException(new Exception("This shouldn't be reached.")); } }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { // Full request and response await connection.Send( "POST / HTTP/1.1", "Host:", "Content-Length: 5", "", "Hello"); await connection.Receive( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", "Content-Length: 5", "", "World"); // Never send the body so CopyToAsync always fails. await connection.Send("POST / HTTP/1.1", "Host:", "Content-Length: 5", "", ""); await connection.WaitForConnectionClose(); } } await Assert.ThrowsAsync(async () => await readTcs.Task); // The cancellation token for only the last request should be triggered. var abortedRequestId = await registrationTcs.Task; Assert.Equal(2, abortedRequestId); Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel" && w.EventId == applicationAbortedConnectionId)); } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task ServerCanAbortConnectionAfterUnobservedClose(ListenOptions listenOptions) { const int connectionPausedEventId = 4; const int connectionFinSentEventId = 7; const int maxRequestBufferSize = 4096; var readCallbackUnwired = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var clientClosedConnection = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var serverClosedConnection = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var appFuncCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var mockLogger = new Mock(); mockLogger .Setup(logger => logger.IsEnabled(It.IsAny())) .Returns(true); mockLogger .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) .Callback>((logLevel, eventId, state, exception, formatter) => { if (eventId.Id == connectionPausedEventId) { readCallbackUnwired.TrySetResult(null); } else if (eventId.Id == connectionFinSentEventId) { serverClosedConnection.SetResult(null); } Logger.Log(logLevel, eventId, state, exception, formatter); }); var mockLoggerFactory = new Mock(); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsAny())) .Returns(Logger); mockLoggerFactory .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) .Returns(mockLogger.Object); var mockKestrelTrace = new Mock(Logger) { CallBase = true }; var testContext = new TestServiceContext(mockLoggerFactory.Object) { Log = mockKestrelTrace.Object, ServerOptions = { Limits = { MaxRequestBufferSize = maxRequestBufferSize, MaxRequestLineSize = maxRequestBufferSize, MaxRequestHeadersTotalSize = maxRequestBufferSize, } } }; var scratchBuffer = new byte[maxRequestBufferSize * 8]; using (var server = new TestServer(async context => { await clientClosedConnection.Task; context.Abort(); await serverClosedConnection.Task; appFuncCompleted.SetResult(null); }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { await connection.Send( "POST / HTTP/1.1", "Host:", $"Content-Length: {scratchBuffer.Length}", "", ""); var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed. await readCallbackUnwired.Task.DefaultTimeout(); } clientClosedConnection.SetResult(null); await appFuncCompleted.Task.DefaultTimeout(); } mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); } [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task AppCanHandleClientAbortingConnectionMidRequest(ListenOptions listenOptions) { var readTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var appStartedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); var mockKestrelTrace = new Mock(Logger) { CallBase = true }; var testContext = new TestServiceContext() { Log = mockKestrelTrace.Object, }; var scratchBuffer = new byte[4096]; using (var server = new TestServer(async context => { appStartedTcs.SetResult(null); try { await context.Request.Body.CopyToAsync(Stream.Null);; } catch (Exception ex) { readTcs.SetException(ex); throw; } readTcs.SetException(new Exception("This shouldn't be reached.")); }, testContext, listenOptions)) { using (var connection = server.CreateConnection()) { await connection.Send( "POST / HTTP/1.1", "Host:", $"Content-Length: {scratchBuffer.Length * 2}", "", ""); await appStartedTcs.Task.DefaultTimeout(); await connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); connection.Reset(); } await Assert.ThrowsAnyAsync(() => readTcs.Task).DefaultTimeout(); } mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); } private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress) { var builder = TransportSelector.GetWebHostBuilder() .UseKestrel() .UseUrls($"http://{registerAddress}:0") .ConfigureServices(AddTestLogging) .Configure(app => { app.Run(async context => { var connection = context.Connection; await context.Response.WriteAsync(JsonConvert.SerializeObject(new { RemoteIPAddress = connection.RemoteIpAddress?.ToString(), RemotePort = connection.RemotePort, LocalIPAddress = connection.LocalIpAddress?.ToString(), LocalPort = connection.LocalPort })); }); }); using (var host = builder.Build()) using (var client = new HttpClient()) { host.Start(); var response = await client.GetAsync($"http://{requestAddress}:{host.GetPort()}/"); response.EnsureSuccessStatusCode(); var connectionFacts = await response.Content.ReadAsStringAsync(); Assert.NotEmpty(connectionFacts); var facts = JsonConvert.DeserializeObject(connectionFacts); Assert.Equal(expectAddress, facts["RemoteIPAddress"].Value()); Assert.NotEmpty(facts["RemotePort"].Value()); } } } }