Use examined rather than consumed for content length of body. (#8223)

This commit is contained in:
Justin Kotalik 2019-03-19 08:25:57 -07:00 committed by GitHub
parent f3eaa73c1a
commit 26c487b0c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 232 additions and 17 deletions

View File

@ -11,11 +11,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
internal class Http1ContentLengthMessageBody : Http1MessageBody
{
private ReadResult _readResult;
private readonly long _contentLength;
private long _inputLength;
private ReadResult _readResult;
private bool _readCompleted;
private bool _completed;
private bool _isReading;
private int _userCanceled;
private long _totalExaminedInPreviousReadResult;
private bool _finalAdvanceCalled;
public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context)
: base(context)
@ -29,9 +33,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
ThrowIfCompleted();
if (_inputLength == 0)
if (_isReading)
{
_readResult = new ReadResult(default, isCanceled: false, isCompleted: true);
throw new InvalidOperationException("Reading is already in progress.");
}
if (_readCompleted)
{
_isReading = true;
return _readResult;
}
@ -53,6 +62,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
try
{
var readAwaitable = _context.Input.ReadAsync(cancellationToken);
_isReading = true;
_readResult = await StartTimingReadAsync(readAwaitable, cancellationToken);
}
catch (ConnectionAbortedException ex)
@ -102,9 +113,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
ThrowIfCompleted();
if (_inputLength == 0)
if (_isReading)
{
readResult = new ReadResult(default, isCanceled: false, isCompleted: true);
throw new InvalidOperationException("Reading is already in progress.");
}
if (_readCompleted)
{
_isReading = true;
readResult = _readResult;
return true;
}
@ -126,6 +143,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
// Only set _isReading if we are returing true.
_isReading = true;
CreateReadResultFromConnectionReadResult();
readResult = _readResult;
@ -133,6 +153,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return true;
}
public override Task ConsumeAsync()
{
TryStart();
if (!_readResult.Buffer.IsEmpty && _inputLength == 0)
{
_context.Input.AdvanceTo(_readResult.Buffer.End);
}
return OnConsumeAsync();
}
private void ThrowIfCompleted()
{
if (_completed)
@ -143,13 +175,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private void CreateReadResultFromConnectionReadResult()
{
if (_readResult.Buffer.Length > _inputLength)
if (_readResult.Buffer.Length >= _inputLength + _totalExaminedInPreviousReadResult)
{
_readResult = new ReadResult(_readResult.Buffer.Slice(0, _inputLength), _readResult.IsCanceled, isCompleted: true);
}
else if (_readResult.Buffer.Length == _inputLength)
{
_readResult = new ReadResult(_readResult.Buffer, _readResult.IsCanceled, isCompleted: true);
_readCompleted = true;
_readResult = new ReadResult(
_readResult.Buffer.Slice(0, _inputLength + _totalExaminedInPreviousReadResult),
_readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 1,
_readCompleted);
}
if (_readResult.IsCompleted)
@ -165,18 +197,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
{
if (_inputLength == 0)
if (!_isReading)
{
throw new InvalidOperationException("No reading operation to complete.");
}
_isReading = false;
if (_readCompleted)
{
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), Interlocked.Exchange(ref _userCanceled, 0) == 1, _readCompleted);
if (_readResult.Buffer.Length == 0 && !_finalAdvanceCalled)
{
_context.Input.AdvanceTo(consumed);
_finalAdvanceCalled = true;
}
return;
}
var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length;
_inputLength -= dataLength;
var consumedLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length;
var examinedLength = consumedLength + _readResult.Buffer.Slice(consumed, examined).Length;
_context.Input.AdvanceTo(consumed, examined);
OnDataRead(dataLength);
var newlyExamined = examinedLength - _totalExaminedInPreviousReadResult;
OnDataRead(newlyExamined);
_totalExaminedInPreviousReadResult += newlyExamined;
_inputLength -= newlyExamined;
_totalExaminedInPreviousReadResult -= consumedLength;
}
protected override void OnReadStarting()

View File

@ -790,7 +790,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
// Add some input and read it to start PumpAsync
input.Add("a");
Assert.Equal(1, (await body.ReadAsync()).Buffer.Length);
// Time out on the next read
input.Http1Connection.SendTimeoutResponse();

View File

@ -727,6 +727,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
}
}
[Fact]
public async Task ContentLengthReadAsyncPipeReaderBufferRequestBody()
{
var testContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(async httpContext =>
{
var readResult = await httpContext.Request.BodyReader.ReadAsync();
// This will hang if 0 content length is not assumed by the server
Assert.Equal(5, readResult.Buffer.Length);
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
readResult = await httpContext.Request.BodyReader.ReadAsync();
Assert.Equal(5, readResult.Buffer.Length);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendAll(
"POST / HTTP/1.0",
"Host:",
"Content-Length: 5",
"",
"hello");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
await server.StopAsync();
}
}
[Fact]
public async Task ContentLengthReadAsyncPipeReaderBufferRequestBodyMultipleTimes()
{
var testContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(async httpContext =>
{
var readResult = await httpContext.Request.BodyReader.ReadAsync();
// This will hang if 0 content length is not assumed by the server
Assert.Equal(5, readResult.Buffer.Length);
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
for (var i = 0; i < 2; i++)
{
readResult = await httpContext.Request.BodyReader.ReadAsync();
Assert.Equal(5, readResult.Buffer.Length);
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
}
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendAll(
"POST / HTTP/1.0",
"Host:",
"Content-Length: 5",
"",
"hello");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
await server.StopAsync();
}
}
[Fact]
public async Task ContentLengthReadAsyncSingleBytesAtATime()
{
var testContext = new TestServiceContext(LoggerFactory);
var tcs = new TaskCompletionSource<object>();
var tcs2 = new TaskCompletionSource<object>();
using (var server = new TestServer(async httpContext =>
{
var readResult = await httpContext.Request.BodyReader.ReadAsync();
Assert.Equal(3, readResult.Buffer.Length);
tcs.SetResult(null);
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
readResult = await httpContext.Request.BodyReader.ReadAsync();
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
tcs2.SetResult(null);
readResult = await httpContext.Request.BodyReader.ReadAsync();
Assert.Equal(5, readResult.Buffer.Length);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.0",
"Host:",
"Content-Length: 5",
"",
"fun");
await tcs.Task;
await connection.Send(
"n");
await tcs2.Task;
await connection.Send(
"y");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
await server.StopAsync();
}
}
[Fact]
public async Task ContentLengthDoesNotConsumeEntireBufferDoesNotThrow()
{
var testContext = new TestServiceContext(LoggerFactory);
using (var server = new TestServer(async httpContext =>
{
var readResult = await httpContext.Request.BodyReader.ReadAsync();
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
readResult = await httpContext.Request.BodyReader.ReadAsync();
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Slice(1).Start, readResult.Buffer.End);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.SendAll(
"POST / HTTP/1.0",
"Host:",
"Content-Length: 5",
"",
"funny");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
"Connection: close",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
await server.StopAsync();
}
}
[Fact]
public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes()
{