diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index fc4ba2b202..d829c68c18 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -533,23 +533,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public void Flush() { - ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); + InitializeResponse(0).GetAwaiter().GetResult(); Output.Flush(); } public async Task FlushAsync(CancellationToken cancellationToken) { - await ProduceStartAndFireOnStarting(); + await InitializeResponse(0); await Output.FlushAsync(cancellationToken); } public void Write(ArraySegment data) { - // For the first write, ensure headers are flushed if Write(Chunked)isn't called. + // For the first write, ensure headers are flushed if Write(Chunked) isn't called. var firstWrite = !HasResponseStarted; - VerifyAndUpdateWrite(data.Count); - ProduceStartAndFireOnStarting().GetAwaiter().GetResult(); + if (firstWrite) + { + InitializeResponse(data.Count).GetAwaiter().GetResult(); + } + else + { + VerifyAndUpdateWrite(data.Count); + } if (_canHaveBody) { @@ -616,9 +622,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public async Task WriteAsyncAwaited(ArraySegment data, CancellationToken cancellationToken) { - VerifyAndUpdateWrite(data.Count); - - await ProduceStartAndFireOnStarting(); + await InitializeResponseAwaited(data.Count); // WriteAsyncAwaited is only called for the first write to the body. // Ensure headers are flushed if Write(Chunked)Async isn't called. @@ -734,7 +738,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } - public Task ProduceStartAndFireOnStarting() + public Task InitializeResponse(int firstWriteByteCount) { if (HasResponseStarted) { @@ -743,7 +747,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http if (_onStarting != null) { - return ProduceStartAndFireOnStartingAwaited(); + return InitializeResponseAwaited(firstWriteByteCount); } if (_applicationException != null) @@ -751,11 +755,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ThrowResponseAbortedException(); } + VerifyAndUpdateWrite(firstWriteByteCount); ProduceStart(appCompleted: false); + return TaskCache.CompletedTask; } - private async Task ProduceStartAndFireOnStartingAwaited() + private async Task InitializeResponseAwaited(int firstWriteByteCount) { await FireOnStarting(); @@ -764,6 +770,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http ThrowResponseAbortedException(); } + VerifyAndUpdateWrite(firstWriteByteCount); ProduceStart(appCompleted: false); } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs index 31ffb191ec..c84b3a4663 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ResponseTests.cs @@ -485,7 +485,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthWriteThrowsAndConnectionCloses() + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWrite() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; @@ -520,7 +520,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthWriteAsyncThrowsAndConnectionCloses() + public async Task ThrowsAndClosesConnectionWhenAppWritesMoreThanContentLengthWriteAsync() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; @@ -554,15 +554,52 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } [Fact] - public async Task WhenAppWritesMoreThanContentLengthAndResponseNotStarted500ResponseSentAndConnectionCloses() + public async Task InternalServerErrorAndConnectionClosedOnWriteWithMoreThanContentLengthAndResponseNotStarted() { var testLogger = new TestApplicationErrorLogger(); var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; - using (var server = new TestServer(async httpContext => + using (var server = new TestServer(httpContext => { + var response = Encoding.ASCII.GetBytes("hello, world"); httpContext.Response.ContentLength = 5; - await httpContext.Response.WriteAsync("hello, world"); + httpContext.Response.Body.Write(response, 0, response.Length); + return TaskCache.CompletedTask; + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.ReceiveForcedEnd( + $"HTTP/1.1 500 Internal Server Error", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + var logMessage = Assert.Single(testLogger.Messages, message => message.LogLevel == LogLevel.Error); + Assert.Equal( + $"Response Content-Length mismatch: too many bytes written (12 of 5).", + logMessage.Exception.Message); + } + + [Fact] + public async Task InternalServerErrorAndConnectionClosedOnWriteAsyncWithMoreThanContentLengthAndResponseNotStarted() + { + var testLogger = new TestApplicationErrorLogger(); + var serviceContext = new TestServiceContext { Log = new TestKestrelTrace(testLogger) }; + + using (var server = new TestServer(httpContext => + { + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = 5; + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); }, serviceContext)) { using (var connection = server.CreateConnection()) @@ -1065,6 +1102,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task FirstWriteVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length); + return TaskCache.CompletedTask; + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + httpContext.Response.Body.Write(response, 0, response.Length / 2); + httpContext.Response.Body.Write(response, response.Length / 2, response.Length - response.Length / 2); + return TaskCache.CompletedTask; + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task FirstWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + return httpContext.Response.Body.WriteAsync(response, 0, response.Length); + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Fact] + public async Task SubsequentWriteAsyncVerifiedAfterOnStarting() + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.OnStarting(() => + { + // Change response to chunked + httpContext.Response.ContentLength = null; + return TaskCache.CompletedTask; + }); + + var response = Encoding.ASCII.GetBytes("hello, world"); + httpContext.Response.ContentLength = response.Length - 1; + + // If OnStarting is not run before verifying writes, an error response will be sent. + await httpContext.Response.Body.WriteAsync(response, 0, response.Length / 2); + await httpContext.Response.Body.WriteAsync(response, response.Length / 2, response.Length - response.Length / 2); + }, new TestServiceContext())) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + $"Transfer-Encoding: chunked", + "", + "6", + "hello,", + "6", + " world", + "0", + "", + ""); + } + } + } + public static TheoryData NullHeaderData { get diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index dff37141cb..f4965d908f 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -878,11 +878,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests throw onStartingException; }, null); - response.Headers["Content-Length"] = new[] { "11" }; - - var writeException = await Assert.ThrowsAsync(async () => - await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); - + var writeException = await Assert.ThrowsAsync(async () => await response.Body.FlushAsync()); Assert.Same(onStartingException, writeException.InnerException); failedWriteCount++; @@ -896,7 +892,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "GET / HTTP/1.1", "", ""); - await connection.ReceiveEnd( + await connection.Receive( "HTTP/1.1 500 Internal Server Error", $"Date: {testContext.DateHeaderValue}", "Content-Length: 0",