diff --git a/KestrelHttpServer.sln b/KestrelHttpServer.sln index 4280909874..fd5f9367d1 100644 --- a/KestrelHttpServer.sln +++ b/KestrelHttpServer.sln @@ -42,6 +42,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "shared", "shared", "{0EF2AC test\shared\MockSocketOutput.cs = test\shared\MockSocketOutput.cs test\shared\MockSystemClock.cs = test\shared\MockSystemClock.cs test\shared\SocketInputExtensions.cs = test\shared\SocketInputExtensions.cs + test\shared\TestApp.cs = test\shared\TestApp.cs test\shared\TestApplicationErrorLogger.cs = test\shared\TestApplicationErrorLogger.cs test\shared\TestConnection.cs = test\shared\TestConnection.cs test\shared\TestFrame.cs = test\shared\TestFrame.cs diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs index eb9fae5648..1b7eaa8e97 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Filter; using Microsoft.AspNetCore.Server.Kestrel.Filter.Internal; using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; +using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http @@ -200,7 +201,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _frame.Input = _filteredStreamAdapter.SocketInput; _frame.Output = _filteredStreamAdapter.SocketOutput; - _readInputTask = _filteredStreamAdapter.ReadInputAsync(); + // Don't attempt to read input if connection has already closed. + // This can happen if a client opens a connection and immediately closes it. + _readInputTask = _socketClosedTcs.Task.Status == TaskStatus.WaitingForActivation ? + _filteredStreamAdapter.ReadInputAsync() : + TaskCache.CompletedTask; } _frame.PrepareRequest = _filterContext.PrepareRequest; @@ -278,7 +283,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http Input.IncomingComplete(readCount, error); - if (errorDone) + if (!normalRead) { Abort(error); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index 0b37975bc3..2cb16d13bb 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -564,6 +564,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } else { + CheckLastWrite(); Output.Write(data); } } @@ -599,6 +600,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } else { + CheckLastWrite(); return Output.WriteAsync(data, cancellationToken: cancellationToken); } } @@ -631,6 +633,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } else { + CheckLastWrite(); await Output.WriteAsync(data, cancellationToken: cancellationToken); } } @@ -658,6 +661,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _responseBytesWritten += count; } + private void CheckLastWrite() + { + var responseHeaders = FrameResponseHeaders; + + // Prevent firing request aborted token if this is the last write, to avoid + // aborting the request if the app is still running when the client receives + // the final bytes of the response and gracefully closes the connection. + // + // Called after VerifyAndUpdateWrite(), so _responseBytesWritten has already been updated. + if (responseHeaders != null && + !responseHeaders.HasTransferEncoding && + responseHeaders.HasContentLength && + _responseBytesWritten == responseHeaders.HeaderContentLengthValue.Value) + { + _abortedCts = null; + } + } + protected void VerifyResponseContentLength() { var responseHeaders = FrameResponseHeaders; @@ -838,6 +859,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http private async Task WriteAutoChunkSuffixAwaited() { + // For the same reason we call CheckLastWrite() in Content-Length responses. + _abortedCts = null; + await WriteChunkedResponseSuffix(); if (_keepAlive) diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs index d78f60a7cd..4ffab99136 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestBufferSizeTests.cs @@ -103,7 +103,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } Assert.Equal(data.Length, bytesWritten); - socket.Shutdown(SocketShutdown.Send); clientFinishedSendingRequestBody.Set(); }; diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestLineSizeTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestLineSizeTests.cs index 1118f28d07..1f6e95e767 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestLineSizeTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/MaxRequestLineSizeTests.cs @@ -25,13 +25,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests [InlineData("DELETE /a%20b%20c/d%20e?f=ghi HTTP/1.1\r\n\r\n", 1027)] public async Task ServerAcceptsRequestLineWithinLimit(string request, int limit) { - var maxRequestLineSize = limit; - using (var server = CreateServer(limit)) { using (var connection = new TestConnection(server.Port)) { - await connection.SendEnd(request); + await connection.Send(request); await connection.ReceiveEnd( "HTTP/1.1 200 OK", $"Date: {server.Context.DateHeaderValue}", @@ -57,8 +55,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = new TestConnection(server.Port)) { - await connection.SendAllTryEnd($"{requestLine}\r\n"); - await connection.Receive( + await connection.SendAll(requestLine); + await connection.ReceiveForcedEnd( "HTTP/1.1 414 URI Too Long", "Connection: close", $"Date: {server.Context.DateHeaderValue}", diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs index 712407f76c..7ce1faf4bc 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs @@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = new TestConnection(server.Port)) { - await connection.SendEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.ReceiveEnd( "HTTP/1.1 200 OK", $"Date: {server.Context.DateHeaderValue}", @@ -60,7 +60,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = new TestConnection(server.Port)) { - await connection.SendEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.Send($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.ReceiveEnd( "HTTP/1.1 200 OK", $"Date: {server.Context.DateHeaderValue}", @@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = new TestConnection(server.Port)) { - await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.ReceiveForcedEnd( "HTTP/1.1 431 Request Header Fields Too Large", "Connection: close", @@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = new TestConnection(server.Port)) { - await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{headers}\r\n"); + await connection.SendAll($"GET / HTTP/1.1\r\n{headers}\r\n"); await connection.ReceiveForcedEnd( "HTTP/1.1 431 Request Header Fields Too Large", "Connection: close", diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs index 645f7c5205..bc1362023b 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs @@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Internal.Networking; +using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -420,6 +421,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task RequestAbortedTokenFiredOnClientFIN() + { + var appStarted = new SemaphoreSlim(0); + var requestAborted = new SemaphoreSlim(0); + var builder = new WebHostBuilder() + .UseKestrel() + .UseUrls($"http://127.0.0.1:0") + .Configure(app => app.Run(async context => + { + appStarted.Release(); + + var token = context.RequestAborted; + token.Register(() => requestAborted.Release(2)); + await requestAborted.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10)); + })); + + using (var host = builder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + socket.Send(Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n")); + await appStarted.WaitAsync(); + socket.Shutdown(SocketShutdown.Send); + await requestAborted.WaitAsync().TimeoutAfter(TimeSpan.FromSeconds(10)); + } + } + } + private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress) { var builder = new WebHostBuilder() diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index dbf6f5bdd4..49bdc6e99c 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -5,7 +5,6 @@ using System; using System.Linq; using System.Net; using System.Net.Http; -using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -288,26 +287,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests disposedTcs.TrySetResult(c.Response.StatusCode); }); - var hostBuilder = new WebHostBuilder() - .UseKestrel() - .UseUrls("http://127.0.0.1:0") - .ConfigureServices(services => services.AddSingleton(mockHttpContextFactory.Object)) - .Configure(app => - { - app.Run(handler); - }); - - using (var host = hostBuilder.Build()) + using (var server = new TestServer(handler, new TestServiceContext(), "http://127.0.0.1:0", mockHttpContextFactory.Object)) { - host.Start(); - if (!sendMalformedRequest) { using (var client = new HttpClient()) { try { - var response = await client.GetAsync($"http://localhost:{host.GetPort()}/"); + var response = await client.GetAsync($"http://127.0.0.1:{server.Port}/"); Assert.Equal(expectedClientStatusCode, response.StatusCode); } catch @@ -321,14 +309,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } else { - using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + using (var connection = new TestConnection(server.Port)) { - socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); - socket.Send(Encoding.ASCII.GetBytes( - "POST / HTTP/1.1\r\n" + - "Transfer-Encoding: chunked\r\n" + - "\r\n" + - "wrong")); + await connection.Send( + "POST / HTTP/1.1", + "Transfer-Encoding: chunked", + "", + "gg"); + await connection.ReceiveForcedEnd( + "HTTP/1.1 400 Bad Request", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); } } @@ -453,13 +447,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests [Fact] public async Task ResponseBodyNotWrittenOnHeadResponseAndLoggedOnlyOnce() { - var mockKestrelTrace = new Mock(); + const string response = "hello, world"; - using (var server = new TestServer(async httpContext => - { - await httpContext.Response.WriteAsync("hello, world"); - await httpContext.Response.Body.FlushAsync(); - }, new TestServiceContext { Log = mockKestrelTrace.Object })) + var logTcs = new TaskCompletionSource(); + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.ConnectionHeadResponseBodyWrite(It.IsAny(), response.Length)) + .Callback((connectionId, count) => logTcs.SetResult(null)); + + using (var server = new TestServer(async httpContext => + { + await httpContext.Response.WriteAsync(response); + await httpContext.Response.Body.FlushAsync(); + }, new TestServiceContext { Log = mockKestrelTrace.Object })) { using (var connection = server.CreateConnection()) { @@ -472,11 +472,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests $"Date: {server.Context.DateHeaderValue}", "", ""); + + // Wait for message to be logged before disposing the socket. + // Disposing the socket will abort the connection and Frame._requestAborted + // might be 1 by the time ProduceEnd() gets called and the message is logged. + await logTcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); } } mockKestrelTrace.Verify(kestrelTrace => - kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), "hello, world".Length), Times.Once); + kestrelTrace.ConnectionHeadResponseBodyWrite(It.IsAny(), response.Length), Times.Once); } [Fact] @@ -533,7 +538,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( $"HTTP/1.1 200 OK", $"Date: {server.Context.DateHeaderValue}", "Content-Length: 11", @@ -566,7 +571,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( $"HTTP/1.1 500 Internal Server Error", "Connection: close", $"Date: {server.Context.DateHeaderValue}", @@ -633,7 +638,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( $"HTTP/1.1 500 Internal Server Error", "Connection: close", $"Date: {server.Context.DateHeaderValue}", @@ -858,7 +863,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -880,7 +885,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests public async Task AppCanWriteOwnBadRequestResponse() { var expectedResponse = string.Empty; - var responseWrittenTcs = new TaskCompletionSource(); + var responseWritten = new SemaphoreSlim(0); using (var server = new TestServer(async httpContext => { @@ -894,18 +899,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests httpContext.Response.StatusCode = 400; httpContext.Response.ContentLength = ex.Message.Length; await httpContext.Response.WriteAsync(ex.Message); - responseWrittenTcs.SetResult(null); + responseWritten.Release(); } }, new TestServiceContext())) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.1", "Transfer-Encoding: chunked", "", - "bad"); - await responseWrittenTcs.Task; + "wrong"); + await responseWritten.WaitAsync(); await connection.ReceiveEnd( "HTTP/1.1 400 Bad Request", $"Date: {server.Context.DateHeaderValue}", @@ -935,7 +940,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: close", $"Date: {server.Context.DateHeaderValue}", @@ -951,7 +956,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "Connection: keep-alive", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: close", $"Date: {server.Context.DateHeaderValue}", @@ -982,7 +987,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: keep-alive", $"Date: {server.Context.DateHeaderValue}", @@ -998,7 +1003,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "Connection: keep-alive", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: keep-alive", $"Date: {server.Context.DateHeaderValue}", @@ -1036,7 +1041,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "hello, world"); // Make sure connection was kept open - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs index 7b8b6a0114..b211578da5 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd(request); + await connection.SendAll(request); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); } } @@ -88,15 +88,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd(request); + await connection.SendAll(request); await ReceiveBadRequestResponse(connection, "505 HTTP Version Not Supported", server.Context.DateHeaderValue); } } } [Theory] - // Missing final CRLF - [InlineData("Header-1: value1\r\nHeader-2: value2\r\n")] // Leading whitespace [InlineData(" Header-1: value1\r\nHeader-2: value2\r\n\r\n")] [InlineData("\tHeader-1: value1\r\nHeader-2: value2\r\n\r\n")] @@ -124,7 +122,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd($"GET / HTTP/1.1\r\n{rawHeaders}"); + await connection.SendAll($"GET / HTTP/1.1\r\n{rawHeaders}"); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); } } @@ -137,7 +135,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd( + await connection.SendAll( "GET / HTTP/1.1", "H\u00eb\u00e4d\u00ebr: value", "", @@ -168,7 +166,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd($"GET {path} HTTP/1.1\r\n"); + await connection.SendAll($"GET {path} HTTP/1.1\r\n"); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); } } @@ -183,7 +181,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd($"{method} / HTTP/1.1\r\n\r\n"); + await connection.Send($"{method} / HTTP/1.1\r\n\r\n"); await ReceiveBadRequestResponse(connection, "411 Length Required", server.Context.DateHeaderValue); } } @@ -198,7 +196,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd($"{method} / HTTP/1.0\r\n\r\n"); + await connection.Send($"{method} / HTTP/1.0\r\n\r\n"); await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs index 403e7a0208..6fe3bc1d6d 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs @@ -68,13 +68,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Transfer-Encoding: chunked", "", "5", "Hello", "6", " World", "0", + "", ""); await connection.ReceiveEnd( "HTTP/1.1 200 OK", @@ -94,7 +95,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Transfer-Encoding: chunked", "Connection: keep-alive", @@ -143,7 +144,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.1", "Content-Length: 5", "", @@ -254,8 +255,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var connection = server.CreateConnection()) { - await connection.SendEnd(fullRequest); - + await connection.Send(fullRequest); await connection.ReceiveEnd(expectedFullResponse); } } @@ -282,7 +282,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd( + await connection.SendAll( "POST / HTTP/1.1", $"{transferEncodingHeaderLine}", $"{headerLine}", @@ -322,7 +322,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendAllTryEnd( + await connection.SendAll( "POST / HTTP/1.1", $"{transferEncodingHeaderLine}", $"{headerLine}", @@ -423,8 +423,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var connection = server.CreateConnection()) { - await connection.SendEnd(fullRequest); - + await connection.Send(fullRequest); await connection.ReceiveEnd(expectedFullResponse); } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs index 7ce2297b00..9bbe7c88a7 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.0", "Connection: keep-alive", "", @@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "Connection: close", "", @@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -172,7 +172,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -210,7 +210,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -246,7 +246,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -264,7 +264,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task ConnectionClosedIfExeptionThrownAfterWrite(TestServiceContext testContext) + public async Task ConnectionClosedIfExceptionThrownAfterWrite(TestServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -295,7 +295,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(TestServiceContext testContext) + public async Task ConnectionClosedIfExceptionThrownAfterZeroLengthWrite(TestServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -342,7 +342,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -383,7 +383,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionFilterTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionFilterTests.cs index 6ef4b9de24..a27e8a40bd 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionFilterTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ConnectionFilterTests.cs @@ -3,32 +3,17 @@ using System; using System.IO; +using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Filter; using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Internal; using Xunit; namespace Microsoft.AspNetCore.Server.KestrelTests { public class ConnectionFilterTests { - private async Task App(HttpContext httpContext) - { - var request = httpContext.Request; - var response = httpContext.Response; - while (true) - { - var buffer = new byte[8192]; - var count = await request.Body.ReadAsync(buffer, 0, buffer.Length); - if (count == 0) - { - break; - } - await response.Body.WriteAsync(buffer, 0, count); - } - } - [Fact] public async Task CanReadAndWriteWithRewritingConnectionFilter() { @@ -37,12 +22,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var sendString = "POST / HTTP/1.0\r\nContent-Length: 12\r\n\r\nHello World?"; - using (var server = new TestServer(App, serviceContext)) + using (var server = new TestServer(TestApp.EchoApp, serviceContext)) { using (var connection = server.CreateConnection()) { // "?" changes to "!" - await connection.SendEnd(sendString); + await connection.Send(sendString); await connection.ReceiveEnd( "HTTP/1.1 200 OK", "Connection: close", @@ -60,11 +45,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { var serviceContext = new TestServiceContext(new AsyncConnectionFilter()); - using (var server = new TestServer(App, serviceContext)) + using (var server = new TestServer(TestApp.EchoApp, serviceContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Content-Length: 12", "", @@ -84,7 +69,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { var serviceContext = new TestServiceContext(new ThrowingConnectionFilter()); - using (var server = new TestServer(App, serviceContext)) + using (var server = new TestServer(TestApp.EchoApp, serviceContext)) { using (var connection = server.CreateConnection()) { @@ -108,15 +93,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests private class RewritingConnectionFilter : IConnectionFilter { - private static Task _empty = Task.FromResult(null); - private RewritingStream _rewritingStream; public Task OnConnectionAsync(ConnectionFilterContext context) { _rewritingStream = new RewritingStream(context.Connection); context.Connection = _rewritingStream; - return _empty; + return TaskCache.CompletedTask; } public int BytesRead => _rewritingStream.BytesRead; @@ -189,6 +172,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests return actual; } + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + var actual = await _innerStream.ReadAsync(buffer, offset, count); + + BytesRead += actual; + + return actual; + } + public override long Seek(long offset, SeekOrigin origin) { return _innerStream.Seek(offset, origin); @@ -211,6 +203,19 @@ namespace Microsoft.AspNetCore.Server.KestrelTests _innerStream.Write(buffer, offset, count); } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + for (int i = 0; i < buffer.Length; i++) + { + if (buffer[i] == '?') + { + buffer[i] = (byte)'!'; + } + } + + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/DefaultHeaderTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/DefaultHeaderTests.cs index 3070332da9..b8a843f738 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/DefaultHeaderTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/DefaultHeaderTests.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.0", diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index 650fcedbe1..d5997be036 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -42,39 +42,6 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - private async Task App(HttpContext httpContext) - { - var request = httpContext.Request; - var response = httpContext.Response; - while (true) - { - var buffer = new byte[8192]; - var count = await request.Body.ReadAsync(buffer, 0, buffer.Length); - if (count == 0) - { - break; - } - await response.Body.WriteAsync(buffer, 0, count); - } - } - - private async Task AppChunked(HttpContext httpContext) - { - var request = httpContext.Request; - var response = httpContext.Response; - var data = new MemoryStream(); - await request.Body.CopyToAsync(data); - var bytes = data.ToArray(); - - response.Headers["Content-Length"] = bytes.Length.ToString(); - await response.Body.WriteAsync(bytes, 0, bytes.Length); - } - - private Task EmptyApp(HttpContext httpContext) - { - return Task.FromResult(null); - } - [Theory] [MemberData(nameof(ConnectionFilterData))] public void EngineCanStartAndStop(TestServiceContext testContext) @@ -88,7 +55,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public void ListenerCanCreateAndDispose(TestServiceContext testContext) { - testContext.App = App; + testContext.App = TestApp.EchoApp; var engine = new KestrelEngine(testContext); engine.Start(1); var address = ServerAddress.FromUrl("http://127.0.0.1:0/"); @@ -101,22 +68,22 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public void ConnectionCanReadAndWrite(TestServiceContext testContext) { - testContext.App = App; + testContext.App = TestApp.EchoApp; var engine = new KestrelEngine(testContext); engine.Start(1); var address = ServerAddress.FromUrl("http://127.0.0.1:0/"); var started = engine.CreateServer(address); var socket = TestConnection.CreateConnectedLoopbackSocket(address.Port); - socket.Send(Encoding.ASCII.GetBytes("POST / HTTP/1.0\r\nContent-Length: 11\r\n\r\nHello World")); - socket.Shutdown(SocketShutdown.Send); - var buffer = new byte[8192]; - while (true) + var data = "Hello World"; + socket.Send(Encoding.ASCII.GetBytes($"POST / HTTP/1.0\r\nContent-Length: 11\r\n\r\n{data}")); + var buffer = new byte[data.Length]; + var read = 0; + while (read < data.Length) { - var length = socket.Receive(buffer); - if (length == 0) { break; } - var text = Encoding.ASCII.GetString(buffer, 0, length); + read += socket.Receive(buffer, read, buffer.Length - read, SocketFlags.None); } + socket.Dispose(); started.Dispose(); engine.Dispose(); } @@ -125,11 +92,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http10RequestReceivesHttp11Response(TestServiceContext testContext) { - using (var server = new TestServer(App, testContext)) + using (var server = new TestServer(TestApp.EchoApp, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Content-Length: 11", "", @@ -148,11 +115,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http11(TestServiceContext testContext) { - using (var server = new TestServer(AppChunked, testContext)) + using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.1", @@ -243,7 +210,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Enumerable.Repeat(response, loopCount) .Concat(new[] { lastResponse }); - await connection.SendEnd(requestData.ToArray()); + await connection.Send(requestData.ToArray()); await connection.ReceiveEnd(responseData.ToArray()); } @@ -258,11 +225,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http10ContentLength(TestServiceContext testContext) { - using (var server = new TestServer(App, testContext)) + using (var server = new TestServer(TestApp.EchoApp, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Content-Length: 11", "", @@ -281,11 +248,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http10KeepAlive(TestServiceContext testContext) { - using (var server = new TestServer(AppChunked, testContext)) + using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.0", "Connection: keep-alive", "", @@ -314,11 +281,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http10KeepAliveNotUsedIfResponseContentLengthNotSet(TestServiceContext testContext) { - using (var server = new TestServer(App, testContext)) + using (var server = new TestServer(TestApp.EchoApp, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.0", "Connection: keep-alive", "", @@ -347,11 +314,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Http10KeepAliveContentLength(TestServiceContext testContext) { - using (var server = new TestServer(AppChunked, testContext)) + using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.0", "Content-Length: 11", "Connection: keep-alive", @@ -382,7 +349,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task Expect100ContinueForBody(TestServiceContext testContext) { - using (var server = new TestServer(AppChunked, testContext)) + using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { using (var connection = server.CreateConnection()) { @@ -392,8 +359,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Connection: close", "Content-Length: 11", "\r\n"); - await connection.Receive("HTTP/1.1 100 Continue", "\r\n"); - await connection.SendEnd("Hello World"); + await connection.Receive( + "HTTP/1.1 100 Continue", + "", + ""); + await connection.Send("Hello World"); await connection.Receive( "HTTP/1.1 200 OK", "Connection: close", @@ -409,7 +379,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task DisconnectingClient(TestServiceContext testContext) { - using (var server = new TestServer(App, testContext)) + using (var server = new TestServer(TestApp.EchoApp, testContext)) { var socket = TestConnection.CreateConnectedLoopbackSocket(server.Port); await Task.Delay(200); @@ -418,15 +388,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests await Task.Delay(200); using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.0", - "\r\n"); + "", + ""); await connection.ReceiveEnd( "HTTP/1.1 200 OK", "Connection: close", $"Date: {testContext.DateHeaderValue}", "Content-Length: 0", - "\r\n"); + "", + ""); } } } @@ -435,11 +407,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task ZeroContentLengthSetAutomaticallyAfterNoWrites(TestServiceContext testContext) { - using (var server = new TestServer(EmptyApp, testContext)) + using (var server = new TestServer(TestApp.EmptyApp, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.0", @@ -472,7 +444,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "Connection: close", "", @@ -488,7 +460,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.0", "", ""); @@ -507,11 +479,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task ZeroContentLengthNotSetAutomaticallyForHeadRequests(TestServiceContext testContext) { - using (var server = new TestServer(EmptyApp, testContext)) + using (var server = new TestServer(TestApp.EmptyApp, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "HEAD / HTTP/1.1", "", ""); @@ -542,7 +514,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "POST / HTTP/1.1", "Content-Length: 3", "", @@ -639,7 +611,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 101 Switching Protocols", "Connection: Upgrade", $"Date: {testContext.DateHeaderValue}", @@ -654,7 +626,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Connection: keep-alive", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 101 Switching Protocols", "Connection: Upgrade", $"Date: {testContext.DateHeaderValue}", @@ -679,7 +651,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests response.OnStarting(_ => { onStartingCalled = true; - return Task.FromResult(null); + return TaskCache.CompletedTask; }, null); // Anything added to the ResponseHeaders dictionary is ignored @@ -689,25 +661,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.1", "Connection: close", "", ""); - await connection.Receive( + await connection.ReceiveForcedEnd( "HTTP/1.1 500 Internal Server Error", - ""); - await connection.Receive( $"Date: {testContext.DateHeaderValue}", "Content-Length: 0", "", "HTTP/1.1 500 Internal Server Error", - ""); - await connection.Receive("Connection: close", - ""); - await connection.ReceiveEnd( + "Connection: close", $"Date: {testContext.DateHeaderValue}", "Content-Length: 0", "", @@ -802,48 +769,19 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.True(onStartingCalled); Assert.Equal(1, testLogger.ApplicationErrorsLogged); } - - [Theory] - [MemberData(nameof(ConnectionFilterData))] - public async Task ConnectionClosesWhenFinReceived(TestServiceContext testContext) - { - using (var server = new TestServer(AppChunked, testContext)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendEnd( - "GET / HTTP/1.1", - "", - "Post / HTTP/1.1", - "Content-Length: 7", - "", - "Goodbye"); - await connection.ReceiveEnd( - "HTTP/1.1 200 OK", - $"Date: {testContext.DateHeaderValue}", - "Content-Length: 0", - "", - "HTTP/1.1 200 OK", - $"Date: {testContext.DateHeaderValue}", - "Content-Length: 7", - "", - "Goodbye"); - } - } - } - - [Theory] + [MemberData(nameof(ConnectionFilterData))] public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes(TestServiceContext testContext) { - using (var server = new TestServer(AppChunked, testContext)) + using (var server = new TestServer(TestApp.EchoAppChunked, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "POST / HTTP/1.1"); + connection.Shutdown(SocketShutdown.Send); await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", @@ -859,11 +797,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "POST / HTTP/1.1", "Content-Length: 7"); + connection.Shutdown(SocketShutdown.Send); await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", @@ -918,24 +857,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.1", - "Connection: close", "", ""); - await connection.Receive( + await connection.ReceiveEnd( "HTTP/1.1 500 Internal Server Error", - ""); - await connection.Receive( $"Date: {testContext.DateHeaderValue}", "Content-Length: 0", "", "HTTP/1.1 500 Internal Server Error", - "Connection: close", - ""); - await connection.ReceiveEnd( $"Date: {testContext.DateHeaderValue}", "Content-Length: 0", "", @@ -1072,7 +1005,6 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(2, abortedRequestId); } - [Theory] [MemberData(nameof(ConnectionFilterData))] public async Task FailedWritesResultInAbortedRequest(TestServiceContext testContext) { @@ -1215,7 +1147,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.1", @@ -1262,7 +1194,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", "GET / HTTP/1.1", @@ -1290,13 +1222,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var server = new TestServer(async httpContext => { var path = httpContext.Request.Path.Value; - httpContext.Response.Headers["Content-Length"] = new[] {path.Length.ToString() }; + httpContext.Response.Headers["Content-Length"] = new[] { path.Length.ToString() }; await httpContext.Response.WriteAsync(path); })) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( $"GET {inputPath} HTTP/1.1", "", ""); @@ -1330,14 +1262,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests callOrder.Push(2); return TaskCache.CompletedTask; }, null); - + context.Response.ContentLength = response.Length; await context.Response.WriteAsync(response); }, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -1348,10 +1280,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "", "hello, world"); - Assert.Equal(1, callOrder.Pop()); - Assert.Equal(2, callOrder.Pop()); + } } + + Assert.Equal(1, callOrder.Pop()); + Assert.Equal(2, callOrder.Pop()); } [Theory] @@ -1381,7 +1315,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "", ""); @@ -1391,47 +1325,47 @@ namespace Microsoft.AspNetCore.Server.KestrelTests $"Content-Length: {response.Length}", "", "hello, world"); - - Assert.Equal(1, callOrder.Pop()); - Assert.Equal(2, callOrder.Pop()); } } + + Assert.Equal(1, callOrder.Pop()); + Assert.Equal(2, callOrder.Pop()); } [Theory] [MemberData(nameof(ConnectionFilterData))] public async Task UpgradeRequestIsNotKeptAliveOrChunked(TestServiceContext testContext) { + const string message = "Hello World"; + using (var server = new TestServer(async context => { var upgradeFeature = context.Features.Get(); var duplexStream = await upgradeFeature.UpgradeAsync(); - while (true) + var buffer = new byte[message.Length]; + var read = 0; + while (read < message.Length) { - var buffer = new byte[8192]; - var count = await duplexStream.ReadAsync(buffer, 0, buffer.Length); - if (count == 0) - { - break; - } - await duplexStream.WriteAsync(buffer, 0, count); + read += await duplexStream.ReadAsync(buffer, read, buffer.Length - read).TimeoutAfter(TimeSpan.FromSeconds(10)); } + + await duplexStream.WriteAsync(buffer, 0, read); }, testContext)) { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET / HTTP/1.1", "Connection: Upgrade", "", - "Hello World"); - await connection.ReceiveEnd( + message); + await connection.ReceiveForcedEnd( "HTTP/1.1 101 Switching Protocols", "Connection: Upgrade", $"Date: {testContext.DateHeaderValue}", "", - "Hello World"); + message); } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs index b9444c0f79..318061b832 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs @@ -6,6 +6,7 @@ using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel; using Microsoft.AspNetCore.Server.Kestrel.Internal; @@ -22,10 +23,23 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { private readonly SocketInput _socketInput; private readonly MemoryPool _pool; - private readonly Frame _frame; + private readonly TestFrame _frame; private readonly ServiceContext _serviceContext; private readonly ConnectionContext _connectionContext; + private class TestFrame : Frame + { + public TestFrame(IHttpApplication application, ConnectionContext context) + : base(application, context) + { + } + + public Task ProduceEndAsync() + { + return ProduceEnd(); + } + } + public FrameTests() { var trace = new KestrelTrace(new TestKestrelTrace()); @@ -50,7 +64,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests ConnectionControl = Mock.Of() }; - _frame = new Frame(application: null, context: _connectionContext); + _frame = new TestFrame(application: null, context: _connectionContext); _frame.Reset(); _frame.InitializeHeaders(); } @@ -713,5 +727,73 @@ namespace Microsoft.AspNetCore.Server.KestrelTests await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); _socketInput.IncomingFin(); } + + [Fact] + public void RequestAbortedTokenIsResetBeforeLastWriteWithContentLength() + { + _frame.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _frame.RequestAborted.WaitHandle; + + foreach (var ch in "hello, worl") + { + _frame.Write(new ArraySegment(new[] { (byte)ch })); + Assert.Same(original, _frame.RequestAborted.WaitHandle); + } + + _frame.Write(new ArraySegment(new[] { (byte)'d' })); + Assert.NotSame(original, _frame.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncWithContentLength() + { + _frame.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _frame.RequestAborted.WaitHandle; + + foreach (var ch in "hello, worl") + { + await _frame.WriteAsync(new ArraySegment(new[] { (byte)ch }), default(CancellationToken)); + Assert.Same(original, _frame.RequestAborted.WaitHandle); + } + + await _frame.WriteAsync(new ArraySegment(new[] { (byte)'d' }), default(CancellationToken)); + Assert.NotSame(original, _frame.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteAsyncAwaitedWithContentLength() + { + _frame.ResponseHeaders["Content-Length"] = "12"; + + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _frame.RequestAborted.WaitHandle; + + foreach (var ch in "hello, worl") + { + await _frame.WriteAsyncAwaited(new ArraySegment(new[] { (byte)ch }), default(CancellationToken)); + Assert.Same(original, _frame.RequestAborted.WaitHandle); + } + + await _frame.WriteAsyncAwaited(new ArraySegment(new[] { (byte)'d' }), default(CancellationToken)); + Assert.NotSame(original, _frame.RequestAborted.WaitHandle); + } + + [Fact] + public async Task RequestAbortedTokenIsResetBeforeLastWriteWithChunkedEncoding() + { + // Need to compare WaitHandle ref since CancellationToken is struct + var original = _frame.RequestAborted.WaitHandle; + + _frame.HttpVersion = "HTTP/1.1"; + await _frame.WriteAsync(new ArraySegment(Encoding.ASCII.GetBytes("hello, world")), default(CancellationToken)); + Assert.Same(original, _frame.RequestAborted.WaitHandle); + + await _frame.ProduceEndAsync(); + Assert.NotSame(original, _frame.RequestAborted.WaitHandle); + } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/RequestTargetProcessingTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/RequestTargetProcessingTests.cs index e01610a93d..a793cf0269 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/RequestTargetProcessingTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/RequestTargetProcessingTests.cs @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( "GET /%41%CC%8A/A/../B/%41%CC%8A HTTP/1.1", "", ""); @@ -71,7 +71,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( $"GET {requestTarget} HTTP/1.1", "", ""); @@ -115,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { using (var connection = server.CreateConnection()) { - await connection.SendEnd( + await connection.Send( $"GET {requestTarget} HTTP/1.1", "", ""); diff --git a/test/shared/DummyApplication.cs b/test/shared/DummyApplication.cs index a4b3eb4b1a..e944bf3562 100644 --- a/test/shared/DummyApplication.cs +++ b/test/shared/DummyApplication.cs @@ -12,20 +12,27 @@ namespace Microsoft.AspNetCore.Testing public class DummyApplication : IHttpApplication { private readonly RequestDelegate _requestDelegate; + private readonly IHttpContextFactory _httpContextFactory; public DummyApplication(RequestDelegate requestDelegate) + : this(requestDelegate, null) + { + } + + public DummyApplication(RequestDelegate requestDelegate, IHttpContextFactory httpContextFactory) { _requestDelegate = requestDelegate; + _httpContextFactory = httpContextFactory; } public HttpContext CreateContext(IFeatureCollection contextFeatures) { - return new DefaultHttpContext(contextFeatures); + return _httpContextFactory?.Create(contextFeatures) ?? new DefaultHttpContext(contextFeatures); } public void DisposeContext(HttpContext context, Exception exception) { - + _httpContextFactory?.Dispose(context); } public async Task ProcessRequestAsync(HttpContext context) diff --git a/test/shared/TestApp.cs b/test/shared/TestApp.cs new file mode 100644 index 0000000000..54a4f33a90 --- /dev/null +++ b/test/shared/TestApp.cs @@ -0,0 +1,49 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.Testing +{ + public static class TestApp + { + public static async Task EchoApp(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + var buffer = new byte[httpContext.Request.ContentLength ?? 0]; + var bytesRead = 0; + + while (bytesRead < buffer.Length) + { + var count = await request.Body.ReadAsync(buffer, bytesRead, buffer.Length - bytesRead); + bytesRead += count; + } + + if (buffer.Length > 0) + { + await response.Body.WriteAsync(buffer, 0, buffer.Length); + } + } + + public static async Task EchoAppChunked(HttpContext httpContext) + { + var request = httpContext.Request; + var response = httpContext.Response; + var data = new MemoryStream(); + await request.Body.CopyToAsync(data); + var bytes = data.ToArray(); + + response.Headers["Content-Length"] = bytes.Length.ToString(); + await response.Body.WriteAsync(bytes, 0, bytes.Length); + } + + public static Task EmptyApp(HttpContext httpContext) + { + return TaskCache.CompletedTask; + } + } +} \ No newline at end of file diff --git a/test/shared/TestConnection.cs b/test/shared/TestConnection.cs index 56e6c53ce2..13b9d67fa8 100644 --- a/test/shared/TestConnection.cs +++ b/test/shared/TestConnection.cs @@ -50,22 +50,6 @@ namespace Microsoft.AspNetCore.Testing _stream.Flush(); } - public async Task SendAllTryEnd(params string[] lines) - { - await SendAll(lines); - - try - { - _socket.Shutdown(SocketShutdown.Send); - } - catch (IOException) - { - // The server may forcefully close the connection (usually due to a bad request), - // so an IOException: "An existing connection was forcibly closed by the remote host" - // isn't guaranteed but not unexpected. - } - } - public async Task Send(params string[] lines) { var text = string.Join("\r\n", lines); @@ -82,12 +66,6 @@ namespace Microsoft.AspNetCore.Testing _stream.Flush(); } - public async Task SendEnd(params string[] lines) - { - await Send(lines); - _socket.Shutdown(SocketShutdown.Send); - } - public async Task Receive(params string[] lines) { var expected = string.Join("\r\n", lines); @@ -95,6 +73,7 @@ namespace Microsoft.AspNetCore.Testing var offset = 0; while (offset < expected.Length) { + var data = new byte[expected.Length]; var task = _reader.ReadAsync(actual, offset, actual.Length - offset); if (!Debugger.IsAttached) { @@ -108,12 +87,13 @@ namespace Microsoft.AspNetCore.Testing offset += count; } - Assert.Equal(expected, new String(actual, 0, offset)); + Assert.Equal(expected, new string(actual, 0, offset)); } public async Task ReceiveEnd(params string[] lines) { await Receive(lines); + _socket.Shutdown(SocketShutdown.Send); var ch = new char[128]; var count = await _reader.ReadAsync(ch, 0, 128).TimeoutAfter(TimeSpan.FromMinutes(1)); var text = new string(ch, 0, count); @@ -139,11 +119,16 @@ namespace Microsoft.AspNetCore.Testing } } + public void Shutdown(SocketShutdown how) + { + _socket.Shutdown(how); + } + public Task WaitForConnectionClose() { var tcs = new TaskCompletionSource(); var eventArgs = new SocketAsyncEventArgs(); - eventArgs.SetBuffer(new byte[1], 0, 1); + eventArgs.SetBuffer(new byte[128], 0, 128); eventArgs.Completed += ReceiveAsyncCompleted; eventArgs.UserToken = tcs; @@ -157,11 +142,16 @@ namespace Microsoft.AspNetCore.Testing private void ReceiveAsyncCompleted(object sender, SocketAsyncEventArgs e) { + var tcs = (TaskCompletionSource)e.UserToken; if (e.BytesTransferred == 0) { - var tcs = (TaskCompletionSource)e.UserToken; tcs.SetResult(null); } + else + { + tcs.SetException(new IOException( + $"Expected connection close, received data instead: \"{_reader.CurrentEncoding.GetString(e.Buffer, 0, e.BytesTransferred)}\"")); + } } public static Socket CreateConnectedLoopbackSocket(int port) diff --git a/test/shared/TestServer.cs b/test/shared/TestServer.cs index 738b0b10df..f52cc190d3 100644 --- a/test/shared/TestServer.cs +++ b/test/shared/TestServer.cs @@ -29,12 +29,17 @@ namespace Microsoft.AspNetCore.Testing } public TestServer(RequestDelegate app, TestServiceContext context, string serverAddress) + : this(app, context, serverAddress, null) + { + } + + public TestServer(RequestDelegate app, TestServiceContext context, string serverAddress, IHttpContextFactory httpContextFactory) { Context = context; context.FrameFactory = connectionContext => { - return new Frame(new DummyApplication(app), connectionContext); + return new Frame(new DummyApplication(app, httpContextFactory), connectionContext); }; try