diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs index b71cacae7c..dd2049d0ae 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1ContentLengthMessageBody.cs @@ -11,11 +11,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { internal class Http1ContentLengthMessageBody : Http1MessageBody { + private ReadResult _readResult; private readonly long _contentLength; private long _inputLength; - private ReadResult _readResult; + private bool _readCompleted; private bool _completed; + private bool _isReading; private int _userCanceled; + private long _totalExaminedInPreviousReadResult; + private bool _finalAdvanceCalled; public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context) : base(context) @@ -29,9 +33,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { ThrowIfCompleted(); - if (_inputLength == 0) + if (_isReading) { - _readResult = new ReadResult(default, isCanceled: false, isCompleted: true); + throw new InvalidOperationException("Reading is already in progress."); + } + + if (_readCompleted) + { + _isReading = true; return _readResult; } @@ -53,6 +62,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http try { var readAwaitable = _context.Input.ReadAsync(cancellationToken); + + _isReading = true; _readResult = await StartTimingReadAsync(readAwaitable, cancellationToken); } catch (ConnectionAbortedException ex) @@ -102,9 +113,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { ThrowIfCompleted(); - if (_inputLength == 0) + if (_isReading) { - readResult = new ReadResult(default, isCanceled: false, isCompleted: true); + throw new InvalidOperationException("Reading is already in progress."); + } + + if (_readCompleted) + { + _isReading = true; + readResult = _readResult; return true; } @@ -126,6 +143,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } + // Only set _isReading if we are returing true. + _isReading = true; + CreateReadResultFromConnectionReadResult(); readResult = _readResult; @@ -133,6 +153,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return true; } + public override Task ConsumeAsync() + { + TryStart(); + + if (!_readResult.Buffer.IsEmpty && _inputLength == 0) + { + _context.Input.AdvanceTo(_readResult.Buffer.End); + } + + return OnConsumeAsync(); + } + private void ThrowIfCompleted() { if (_completed) @@ -143,13 +175,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private void CreateReadResultFromConnectionReadResult() { - if (_readResult.Buffer.Length > _inputLength) + if (_readResult.Buffer.Length >= _inputLength + _totalExaminedInPreviousReadResult) { - _readResult = new ReadResult(_readResult.Buffer.Slice(0, _inputLength), _readResult.IsCanceled, isCompleted: true); - } - else if (_readResult.Buffer.Length == _inputLength) - { - _readResult = new ReadResult(_readResult.Buffer, _readResult.IsCanceled, isCompleted: true); + _readCompleted = true; + _readResult = new ReadResult( + _readResult.Buffer.Slice(0, _inputLength + _totalExaminedInPreviousReadResult), + _readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 1, + _readCompleted); } if (_readResult.IsCompleted) @@ -165,18 +197,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) { - if (_inputLength == 0) + if (!_isReading) { + throw new InvalidOperationException("No reading operation to complete."); + } + + _isReading = false; + + if (_readCompleted) + { + _readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), Interlocked.Exchange(ref _userCanceled, 0) == 1, _readCompleted); + + if (_readResult.Buffer.Length == 0 && !_finalAdvanceCalled) + { + _context.Input.AdvanceTo(consumed); + _finalAdvanceCalled = true; + } + return; } - var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length; - - _inputLength -= dataLength; + var consumedLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length; + var examinedLength = consumedLength + _readResult.Buffer.Slice(consumed, examined).Length; _context.Input.AdvanceTo(consumed, examined); - OnDataRead(dataLength); + var newlyExamined = examinedLength - _totalExaminedInPreviousReadResult; + + OnDataRead(newlyExamined); + _totalExaminedInPreviousReadResult += newlyExamined; + _inputLength -= newlyExamined; + + _totalExaminedInPreviousReadResult -= consumedLength; } protected override void OnReadStarting() diff --git a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs index e4e44978af..94d4ff0b74 100644 --- a/src/Servers/Kestrel/Core/test/MessageBodyTests.cs +++ b/src/Servers/Kestrel/Core/test/MessageBodyTests.cs @@ -790,7 +790,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests // Add some input and read it to start PumpAsync input.Add("a"); - Assert.Equal(1, (await body.ReadAsync()).Buffer.Length); // Time out on the next read input.Http1Connection.SendTimeoutResponse(); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs index 16cd2dee2d..bca9b9d1b1 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/RequestTests.cs @@ -727,6 +727,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests } } + [Fact] + public async Task ContentLengthReadAsyncPipeReaderBufferRequestBody() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyReader.ReadAsync(); + // This will hang if 0 content length is not assumed by the server + Assert.Equal(5, readResult.Buffer.Length); + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + readResult = await httpContext.Request.BodyReader.ReadAsync(); + Assert.Equal(5, readResult.Buffer.Length); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "hello"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthReadAsyncPipeReaderBufferRequestBodyMultipleTimes() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyReader.ReadAsync(); + // This will hang if 0 content length is not assumed by the server + Assert.Equal(5, readResult.Buffer.Length); + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + + for (var i = 0; i < 2; i++) + { + readResult = await httpContext.Request.BodyReader.ReadAsync(); + Assert.Equal(5, readResult.Buffer.Length); + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + } + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "hello"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthReadAsyncSingleBytesAtATime() + { + var testContext = new TestServiceContext(LoggerFactory); + var tcs = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyReader.ReadAsync(); + Assert.Equal(3, readResult.Buffer.Length); + tcs.SetResult(null); + + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + + readResult = await httpContext.Request.BodyReader.ReadAsync(); + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + tcs2.SetResult(null); + + readResult = await httpContext.Request.BodyReader.ReadAsync(); + Assert.Equal(5, readResult.Buffer.Length); + + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "fun"); + await tcs.Task; + await connection.Send( + "n"); + await tcs2.Task; + await connection.Send( + "y"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + + [Fact] + public async Task ContentLengthDoesNotConsumeEntireBufferDoesNotThrow() + { + var testContext = new TestServiceContext(LoggerFactory); + using (var server = new TestServer(async httpContext => + { + var readResult = await httpContext.Request.BodyReader.ReadAsync(); + + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End); + + readResult = await httpContext.Request.BodyReader.ReadAsync(); + httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Slice(1).Start, readResult.Buffer.End); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "funny"); + + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + [Fact] public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes() {