diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index 307ce7bc49..fb5e82df2a 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -508,23 +508,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public void Flush() { - ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); + InitializeResponse(0).GetAwaiter().GetResult(); SocketOutput.Flush(); } public async Task FlushAsync(CancellationToken cancellationToken) { - await ProduceStartAndFireOnStarting(); + await InitializeResponse(0); await SocketOutput.FlushAsync(cancellationToken); } public void Write(ArraySegment data) { - // For the first write, ensure headers are flushed if Write(Chunked)isn't called. + // For the first write, ensure headers are flushed if Write(Chunked) isn't called. var firstWrite = !HasResponseStarted; - VerifyAndUpdateWrite(data.Count); - ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); + if (firstWrite) + { + InitializeResponse(data.Count).GetAwaiter().GetResult(); + } + else + { + VerifyAndUpdateWrite(data.Count); + } if (_canHaveBody) { @@ -589,9 +595,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken cancellationToken) { - VerifyAndUpdateWrite(data.Count); - - await ProduceStartAndFireOnStarting(); + await InitializeResponseAwaited(data.Count); // WriteAsyncAwaited is only called for the first write to the body. // Ensure headers are flushed if Write(Chunked)Async isn't called. @@ -645,7 +649,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http responseHeaders.HeaderContentLengthValue.HasValue && _responseBytesWritten < responseHeaders.HeaderContentLengthValue.Value) { - _keepAlive = false; + // We need to close the connection if any bytes were written since the client + // cannot be certain of how many bytes it will receive. + if (_responseBytesWritten > 0) + { + _keepAlive = false; + } + ReportApplicationError(new InvalidOperationException( $"Response Content-Length mismatch: too few bytes written ({_responseBytesWritten} of {responseHeaders.HeaderContentLengthValue.Value}).")); } @@ -688,7 +698,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } - public Task ProduceStartAndFireOnStarting() + public Task InitializeResponse(int firstWriteByteCount) { if (HasResponseStarted) { @@ -697,7 +707,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if (_onStarting != null) { - return ProduceStartAndFireOnStartingAwaited(); + return InitializeResponseAwaited(firstWriteByteCount); } if (_applicationException != null) @@ -705,11 +715,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ThrowResponseAbortedException(); } + VerifyAndUpdateWrite(firstWriteByteCount); ProduceStart(appCompleted: false); + return TaskCache.CompletedTask; } - private async Task ProduceStartAndFireOnStartingAwaited() + private async Task InitializeResponseAwaited(int firstWriteByteCount) { await FireOnStarting(); @@ -718,6 +730,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ThrowResponseAbortedException(); } + VerifyAndUpdateWrite(firstWriteByteCount); ProduceStart(appCompleted: false); } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index dbf6f5bdd4..af166d1ecf 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -480,7 +480,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses() + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWrite() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; @@ -515,7 +515,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses() + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWriteAsync() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; @@ -549,15 +549,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses() + public async Task InternalServerErrorAndConnectionClosedOnWriteWithMoreThanContentLengthAndResponseNotStarted() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; - using (var server = new TestServer(async httpContext => + using (var server = new TestServer(httpContext => { + var response = Encoding.ASCII.GetBytes("hello, world"); httpContext.Response.ContentLength = 5; - await httpContext.Response.WriteAsync("hello, world"); + httpContext.Response.Body.Write(response, 0, response.Length); + return TaskCache.CompletedTask; }, serviceContext)) { using (var connection = server.CreateConnection()) @@ -566,7 +568,42 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task InternalServerErrorAndConnectionClosedOnWriteAsyncWithMoreThanContentLengthAndResponseNotStarted() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = 5; + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveForcedEnd( $"HTTP/1.1 500 Internal Server Error", "Connection: close", $"Date: {server.Context.DateHeaderValue}", @@ -616,7 +653,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionCloses() + public async Task WhenAppSetsContentLengthButDoesNotWriteBody500ResponseSentAndConnectionDoesNotClose() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; @@ -630,12 +667,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests using (var connection = server.CreateConnection()) { await connection.Send( + "GET / HTTP/1.1", + "", "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( - $"HTTP/1.1 500 Internal Server Error", - "Connection: close", + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + "HTTP/1.1 500 Internal Server Error", $"Date: {server.Context.DateHeaderValue}", "Content-Length: 0", "", @@ -643,10 +685,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } - var errorMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); - Assert.Equal( - $"Response Content-Length mismatch: too few bytes written (0 of 5).", - errorMessage.Exception.Message); + var error = testLogger.Messages.Where(message => message.LogLevel == LogLevel.Error); + Assert.Equal(2, error.Count()); + Assert.All(error, message => message.Equals("Response Content-Length mismatch: too few bytes written (0 of 5).")); } [Theory] @@ -1050,6 +1091,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task FirstWriteVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length); + return TaskCache.CompletedTask; + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length / 2); + httpContext.Response.Body.Write(response, response.Length / 2, response.Length - response.Length / 2); + return TaskCache.CompletedTask; + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task FirstWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + await httpContext.Response.Body.WriteAsync(response, 0, response.Length / 2); + await httpContext.Response.Body.WriteAsync(response, response.Length / 2, response.Length - response.Length / 2); + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + public static TheoryData NullHeaderData { get diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index 650fcedbe1..35144e1f26 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -883,9 +883,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task ThrowingInOnStartingResultsInFailedWritesAnd500Response(TestServiceContext testContext) { - var onStartingCallCount1 = 0; - var onStartingCallCount2 = 0; - var failedWriteCount = 0; + var callback1Called = false; + var callback2CallCount = 0; var testLogger = new TestApplicationErrorLogger(); testContext.Log = new KestrelTrace(testLogger); @@ -897,23 +896,17 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var response = httpContext.Response; response.OnStarting(_ => { - onStartingCallCount1++; + callback1Called = true; throw onStartingException; }, null); response.OnStarting(_ => { - onStartingCallCount2++; + callback2CallCount++; throw onStartingException; }, null); - response.Headers["Content-Length"] = new[] { "11" }; - - var writeException = await Assert.ThrowsAsync(async () => - await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); - + var writeException = await Assert.ThrowsAsync(async () => await response.Body.FlushAsync()); Assert.Same(onStartingException, writeException.InnerException); - - failedWriteCount++; }, testContext)) { using (var connection = server.CreateConnection()) @@ -943,10 +936,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - // The first registered OnStarting callback should not be called, + // The first registered OnStarting callback should not have been called, // since they are called LIFO and the other one failed. - Assert.Equal(0, onStartingCallCount1); - Assert.Equal(2, onStartingCallCount2); + Assert.False(callback1Called); + Assert.Equal(2, callback2CallCount); Assert.Equal(2, testLogger.ApplicationErrorsLogged); }