diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs index db9e78aa2f..fabf8950a4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ChunkedEncodingMessageBody.cs @@ -50,6 +50,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } public override bool TryRead(out ReadResult readResult) + { + ThrowIfCompleted(); + + return TryReadInternal(out readResult); + } + + public override bool TryReadInternal(out ReadResult readResult) { TryStart(); @@ -65,7 +72,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return boolResult; } - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) + { + ThrowIfCompleted(); + return ReadAsyncInternal(cancellationToken); + } + + public override async ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default) { TryStart(); @@ -92,7 +105,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public override void Complete(Exception exception) { - _requestBodyPipe.Reader.Complete(); + _completed = true; _context.ReportApplicationError(exception); } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs index 2d13e68680..7b4875aaa0 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs @@ -15,7 +15,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private readonly long _contentLength; private long _inputLength; private bool _readCompleted; - private bool _completed; private bool _isReading; private int _userCanceled; private long _totalExaminedInPreviousReadResult; @@ -29,10 +28,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _inputLength = _contentLength; } - public override async ValueTask ReadAsync(CancellationToken cancellationToken = default) + public override ValueTask ReadAsync(CancellationToken cancellationToken = default) { ThrowIfCompleted(); + return ReadAsyncInternal(cancellationToken); + } + public override async ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default) + { if (_isReading) { throw new InvalidOperationException("Reading is already in progress."); @@ -112,7 +115,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public override bool TryRead(out ReadResult readResult) { ThrowIfCompleted(); + return TryReadInternal(out readResult); + } + public override bool TryReadInternal(out ReadResult readResult) + { if (_isReading) { throw new InvalidOperationException("Reading is already in progress."); @@ -164,14 +171,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return OnConsumeAsync(); } - - private void ThrowIfCompleted() - { - if (_completed) - { - throw new InvalidOperationException("Reading is not allowed after the reader was completed."); - } - } private void CreateReadResultFromConnectionReadResult() { diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs index 0691c842d7..f2f8ae6434 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs @@ -3,6 +3,7 @@ using System; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; @@ -12,6 +13,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http internal abstract class Http1MessageBody : MessageBody { protected readonly Http1Connection _context; + protected bool _completed; protected Http1MessageBody(Http1Connection context) : base(context) @@ -34,11 +36,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } + public abstract bool TryReadInternal(out ReadResult readResult); + + public abstract ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default); + protected override Task OnConsumeAsync() { try { - if (TryRead(out var readResult)) + while (TryReadInternal(out var readResult)) { AdvanceTo(readResult.Buffer.End); @@ -79,7 +85,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http ReadResult result; do { - result = await ReadAsync(); + result = await ReadAsyncInternal(); AdvanceTo(result.Buffer.End); } while (!result.IsCompleted); } @@ -177,5 +183,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http context.OnTrailersComplete(); // No trailers for these. return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose; } + + protected void ThrowIfCompleted() + { + if (_completed) + { + throw new InvalidOperationException("Reading is not allowed after the reader was completed."); + } + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs index 3958fad0ba..8d01c9232f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1UpgradeMessageBody.cs @@ -14,7 +14,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http /// internal sealed class Http1UpgradeMessageBody : Http1MessageBody { - public bool _completed; public Http1UpgradeMessageBody(Http1Connection context) : base(context) { @@ -78,5 +77,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { return Task.CompletedTask; } + + public override bool TryReadInternal(out ReadResult readResult) + { + return _context.Input.TryRead(out readResult); + } + + public override ValueTask ReadAsyncInternal(CancellationToken cancellationToken = default) + { + return _context.Input.ReadAsync(cancellationToken); + } } } diff --git a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs index 94d4ff0b74..a361e2d1c4 100644 --- a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs +++ b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs @@ -1235,6 +1235,30 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } + [Fact] + public async Task CompleteForContentLengthAllowsConsumeToWork() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("a"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.False(readResult.IsCompleted); + + input.Add("asdf"); + + reader.AdvanceTo(readResult.Buffer.End); + reader.Complete(); + + await body.ConsumeAsync(); + } + } + [Fact] public async Task CompleteForContentLengthDoesNotCompleteConnectionPipeMakesReadReturnThrow() { @@ -1261,6 +1285,30 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } + [Fact] + public async Task CompleteForChunkedAllowsConsumeToWork() + { + using (var input = new TestInput()) + { + var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection); + var reader = new HttpRequestPipeReader(); + reader.StartAcceptingReads(body); + + input.Add("5\r\nHello\r\n"); + + Assert.True(reader.TryRead(out var readResult)); + + Assert.False(readResult.IsCompleted); + reader.AdvanceTo(readResult.Buffer.End); + + input.Add("1\r\nH\r\n0\r\n\r\n"); + + reader.Complete(); + + await body.ConsumeAsync(); + } + } + [Fact] public async Task CompleteForChunkedDoesNotCompleteConnectionPipeMakesReadThrow() { @@ -1313,7 +1361,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } - [Fact] public async Task CompleteForZeroByteBodyDoesNotCompleteConnectionPipeNoopsReads() { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs index 12ac5791ff..72d7c2cc7c 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedRequestTests.cs @@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; using Xunit; @@ -991,6 +992,62 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests } } + [Fact] + public async Task ChunkedRequestCallCompleteDoesNotCauseException() + { + var testContext = new TestServiceContext(LoggerFactory); + + await using (var server = new TestServer(async httpContext => + { + var request = httpContext.Request; + + // This read may receive all data, but what we care about + // is that ConsumeAsync is called and doesn't error. Calling + // TryRead before would always fail. + var readResult = await request.BodyReader.ReadAsync(); + request.BodyReader.AdvanceTo(readResult.Buffer.End); + + request.BodyReader.Complete(); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "1", + "H", + "4", + "ello", + "0", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // start another request to make sure OnComsumeAsync is hit + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "0", + "", + ""); + } + } + + Assert.All(TestSink.Writes, w => Assert.InRange(w.LogLevel, LogLevel.Trace, LogLevel.Information)); + } + [Fact] public async Task ChunkedRequestCallCompleteWithExceptionCauses500() { diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs index 7ed5fb5934..8e9414f519 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport; using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; using Xunit; @@ -1535,6 +1536,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests } } + [Fact] + public async Task ContentLengthRequestCallCompleteDoesNotCauseException() + { + var testContext = new TestServiceContext(LoggerFactory); + + var tcs = new TaskCompletionSource(); + await using (var server = new TestServer(async httpContext => + { + var request = httpContext.Request; + + var readResult = await request.BodyReader.ReadAsync(); + request.BodyReader.AdvanceTo(readResult.Buffer.End); + + httpContext.Request.BodyReader.Complete(); + + tcs.SetResult(null); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "He"); + await tcs.Task; + await connection.Send("llo"); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + + Assert.All(TestSink.Writes, w => Assert.InRange(w.LogLevel, LogLevel.Trace, LogLevel.Information)); + } + [Fact] public async Task ContentLengthCallCompleteWithExceptionCauses500() {