diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs index 094cb3ca13..5ecaaa0814 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs @@ -155,20 +155,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } // ForZeroContentLength does not complete the reader nor the writer - if (!messageBody.IsEmpty) + if (!messageBody.IsEmpty && _keepAlive) { - if (_keepAlive) - { - // Finish reading the request body in case the app did not. - TimeoutControl.SetTimeout(Constants.RequestBodyDrainTimeout.Ticks, TimeoutAction.SendTimeoutResponse); - await messageBody.ConsumeAsync(); - TimeoutControl.CancelTimeout(); - } - else - { - messageBody.Cancel(); - Input.CancelPendingRead(); - } + // Finish reading the request body in case the app did not. + TimeoutControl.SetTimeout(Constants.RequestBodyDrainTimeout.Ticks, TimeoutAction.SendTimeoutResponse); + await messageBody.ConsumeAsync(); + TimeoutControl.CancelTimeout(); } if (!HasResponseStarted) @@ -202,6 +194,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { RequestBodyPipe.Reader.Complete(); + // Wait for MessageBody.PumpAsync() to call RequestBodyPipe.Writer.Complete(). + await messageBody.StopAsync(); + // At this point both the request body pipe reader and writer should be completed. RequestBodyPipe.Reset(); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs index 2bc4ff6e5a..407dc15ec3 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs @@ -21,6 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private bool _send100Continue = true; private volatile bool _canceled; + private Task _pumpTask; protected MessageBody(Frame context) { @@ -132,11 +133,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } - public void Cancel() - { - _canceled = true; - } - public virtual async Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) { TryInit(); @@ -213,6 +209,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } while (!result.IsCompleted); } + public virtual Task StopAsync() + { + if (!_context.HasStartedConsumingRequestBody) + { + return Task.CompletedTask; + } + + _canceled = true; + _context.Input.CancelPendingRead(); + return _pumpTask; + } + protected void Copy(ReadableBuffer readableBuffer, WritableBuffer writableBuffer) { _context.TimeoutControl.BytesRead(readableBuffer.Length); @@ -245,7 +253,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { OnReadStart(); _context.HasStartedConsumingRequestBody = true; - _ = PumpAsync(); + _pumpTask = PumpAsync(); } } @@ -411,6 +419,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { return Task.CompletedTask; } + + public override Task StopAsync() + { + return Task.CompletedTask; + } } private class ForContentLength : MessageBody diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs index 3b1cf10a92..655582dcf1 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs @@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void CanReadFromContentLength(HttpVersion httpVersion) + public async Task CanReadFromContentLength(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -41,6 +41,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await body.StopAsync(); } } @@ -65,11 +67,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await body.StopAsync(); } } [Fact] - public void CanReadFromChunkedEncoding() + public async Task CanReadFromChunkedEncoding() { using (var input = new TestInput()) { @@ -89,6 +93,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await body.StopAsync(); } } @@ -113,6 +119,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await body.StopAsync(); } } @@ -136,6 +144,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(5, await readTask.TimeoutAfter(TimeSpan.FromSeconds(10))); Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + + await body.StopAsync(); } } @@ -155,6 +165,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await stream.ReadAsync(buffer, 0, buffer.Length)); Assert.IsType(ex.InnerException); Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); + + await body.StopAsync(); } } @@ -174,13 +186,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await stream.ReadAsync(buffer, 0, buffer.Length)); Assert.Equal(CoreStrings.BadRequest_BadChunkSizeData, ex.Message); + + await body.StopAsync(); } } [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void CanReadFromRemainingData(HttpVersion httpVersion) + public async Task CanReadFromRemainingData(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -197,6 +211,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests AssertASCII("Hello", new ArraySegment(buffer, 0, count)); input.Fin(); + + await body.StopAsync(); } } @@ -220,13 +236,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests AssertASCII("Hello", new ArraySegment(buffer, 0, count)); input.Fin(); + + await body.StopAsync(); } } [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) + public async Task ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -238,6 +256,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var buffer = new byte[1024]; Assert.Equal(0, stream.Read(buffer, 0, buffer.Length)); + + await body.StopAsync(); } } @@ -256,6 +276,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var buffer = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + + await body.StopAsync(); } } @@ -282,6 +304,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var requestArray = ms.ToArray(); Assert.Equal(8197, requestArray.Length); AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); + + await body.StopAsync(); } } @@ -345,6 +369,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + await body.StopAsync(); } } @@ -360,6 +386,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await body.ConsumeAsync(); Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + await body.StopAsync(); } } @@ -436,6 +464,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await copyToAsyncTask; Assert.Equal(2, writeCount); + + await body.StopAsync(); } } @@ -444,7 +474,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [InlineData("Keep-Alive, Upgrade")] [InlineData("upgrade, keep-alive")] [InlineData("Upgrade, Keep-Alive")] - public void ConnectionUpgradeKeepAlive(string headerConnection) + public async Task ConnectionUpgradeKeepAlive(string headerConnection) { using (var input = new TestInput()) { @@ -459,6 +489,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests AssertASCII("Hello", new ArraySegment(buffer, 0, 5)); input.Fin(); + + await body.StopAsync(); } } @@ -481,11 +513,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests input.Add("b"); Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + await body.StopAsync(); } } [Fact] - public async Task PumpAsyncReturnsAfterCanceling() + public async Task StopAsyncPreventsFurtherDataConsumption() { using (var input = new TestInput()) { @@ -497,15 +530,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests input.Add("a"); Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); - body.Cancel(); + await body.StopAsync(); // Add some more data. Checking for cancelation and exiting the loop // should take priority over reading this data. input.Add("b"); - // Unblock the loop - input.Pipe.Reader.CancelPendingRead(); - // There shouldn't be any additional data available Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1)); } @@ -535,6 +565,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var exception = await Assert.ThrowsAsync(() => body.ReadAsync(new ArraySegment(new byte[1]))); Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); + + await body.StopAsync(); } } @@ -562,6 +594,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var exception = await Assert.ThrowsAsync(() => body.ConsumeAsync()); Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); + + await body.StopAsync(); } } @@ -592,6 +626,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var exception = await Assert.ThrowsAsync(() => body.CopyToAsync(ms)); Assert.Equal(StatusCodes.Status408RequestTimeout, exception.StatusCode); } + + await body.StopAsync(); } } @@ -616,6 +652,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests mockLogger.Verify(logger => logger.RequestBodyStart("ConnectionId", "RequestId")); input.Fin(); + + await body.StopAsync(); } } @@ -644,6 +682,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests input.Fin(); Assert.True(logEvent.Wait(TimeSpan.FromSeconds(10))); + + await body.StopAsync(); } } @@ -700,6 +740,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests input.Add("a"); await readTask; + + await body.StopAsync(); } } @@ -729,6 +771,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests // requests. mockTimeoutControl.Verify(timeoutControl => timeoutControl.PauseTimingReads(), Times.Never); mockTimeoutControl.Verify(timeoutControl => timeoutControl.ResumeTimingReads(), Times.Never); + + await body.StopAsync(); } }