Use examined rather than consumed for content length of body. (#8223)
This commit is contained in:
parent
f3eaa73c1a
commit
26c487b0c0
|
|
@ -11,11 +11,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
{
|
||||
internal class Http1ContentLengthMessageBody : Http1MessageBody
|
||||
{
|
||||
private ReadResult _readResult;
|
||||
private readonly long _contentLength;
|
||||
private long _inputLength;
|
||||
private ReadResult _readResult;
|
||||
private bool _readCompleted;
|
||||
private bool _completed;
|
||||
private bool _isReading;
|
||||
private int _userCanceled;
|
||||
private long _totalExaminedInPreviousReadResult;
|
||||
private bool _finalAdvanceCalled;
|
||||
|
||||
public Http1ContentLengthMessageBody(bool keepAlive, long contentLength, Http1Connection context)
|
||||
: base(context)
|
||||
|
|
@ -29,9 +33,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
{
|
||||
ThrowIfCompleted();
|
||||
|
||||
if (_inputLength == 0)
|
||||
if (_isReading)
|
||||
{
|
||||
_readResult = new ReadResult(default, isCanceled: false, isCompleted: true);
|
||||
throw new InvalidOperationException("Reading is already in progress.");
|
||||
}
|
||||
|
||||
if (_readCompleted)
|
||||
{
|
||||
_isReading = true;
|
||||
return _readResult;
|
||||
}
|
||||
|
||||
|
|
@ -53,6 +62,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
try
|
||||
{
|
||||
var readAwaitable = _context.Input.ReadAsync(cancellationToken);
|
||||
|
||||
_isReading = true;
|
||||
_readResult = await StartTimingReadAsync(readAwaitable, cancellationToken);
|
||||
}
|
||||
catch (ConnectionAbortedException ex)
|
||||
|
|
@ -102,9 +113,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
{
|
||||
ThrowIfCompleted();
|
||||
|
||||
if (_inputLength == 0)
|
||||
if (_isReading)
|
||||
{
|
||||
readResult = new ReadResult(default, isCanceled: false, isCompleted: true);
|
||||
throw new InvalidOperationException("Reading is already in progress.");
|
||||
}
|
||||
|
||||
if (_readCompleted)
|
||||
{
|
||||
_isReading = true;
|
||||
readResult = _readResult;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
@ -126,6 +143,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
}
|
||||
}
|
||||
|
||||
// Only set _isReading if we are returing true.
|
||||
_isReading = true;
|
||||
|
||||
CreateReadResultFromConnectionReadResult();
|
||||
|
||||
readResult = _readResult;
|
||||
|
|
@ -133,6 +153,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
return true;
|
||||
}
|
||||
|
||||
public override Task ConsumeAsync()
|
||||
{
|
||||
TryStart();
|
||||
|
||||
if (!_readResult.Buffer.IsEmpty && _inputLength == 0)
|
||||
{
|
||||
_context.Input.AdvanceTo(_readResult.Buffer.End);
|
||||
}
|
||||
|
||||
return OnConsumeAsync();
|
||||
}
|
||||
|
||||
private void ThrowIfCompleted()
|
||||
{
|
||||
if (_completed)
|
||||
|
|
@ -143,13 +175,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
|
||||
private void CreateReadResultFromConnectionReadResult()
|
||||
{
|
||||
if (_readResult.Buffer.Length > _inputLength)
|
||||
if (_readResult.Buffer.Length >= _inputLength + _totalExaminedInPreviousReadResult)
|
||||
{
|
||||
_readResult = new ReadResult(_readResult.Buffer.Slice(0, _inputLength), _readResult.IsCanceled, isCompleted: true);
|
||||
}
|
||||
else if (_readResult.Buffer.Length == _inputLength)
|
||||
{
|
||||
_readResult = new ReadResult(_readResult.Buffer, _readResult.IsCanceled, isCompleted: true);
|
||||
_readCompleted = true;
|
||||
_readResult = new ReadResult(
|
||||
_readResult.Buffer.Slice(0, _inputLength + _totalExaminedInPreviousReadResult),
|
||||
_readResult.IsCanceled && Interlocked.Exchange(ref _userCanceled, 0) == 1,
|
||||
_readCompleted);
|
||||
}
|
||||
|
||||
if (_readResult.IsCompleted)
|
||||
|
|
@ -165,18 +197,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
|
|||
|
||||
public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
|
||||
{
|
||||
if (_inputLength == 0)
|
||||
if (!_isReading)
|
||||
{
|
||||
throw new InvalidOperationException("No reading operation to complete.");
|
||||
}
|
||||
|
||||
_isReading = false;
|
||||
|
||||
if (_readCompleted)
|
||||
{
|
||||
_readResult = new ReadResult(_readResult.Buffer.Slice(consumed, _readResult.Buffer.End), Interlocked.Exchange(ref _userCanceled, 0) == 1, _readCompleted);
|
||||
|
||||
if (_readResult.Buffer.Length == 0 && !_finalAdvanceCalled)
|
||||
{
|
||||
_context.Input.AdvanceTo(consumed);
|
||||
_finalAdvanceCalled = true;
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
var dataLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length;
|
||||
|
||||
_inputLength -= dataLength;
|
||||
var consumedLength = _readResult.Buffer.Slice(_readResult.Buffer.Start, consumed).Length;
|
||||
var examinedLength = consumedLength + _readResult.Buffer.Slice(consumed, examined).Length;
|
||||
|
||||
_context.Input.AdvanceTo(consumed, examined);
|
||||
|
||||
OnDataRead(dataLength);
|
||||
var newlyExamined = examinedLength - _totalExaminedInPreviousReadResult;
|
||||
|
||||
OnDataRead(newlyExamined);
|
||||
_totalExaminedInPreviousReadResult += newlyExamined;
|
||||
_inputLength -= newlyExamined;
|
||||
|
||||
_totalExaminedInPreviousReadResult -= consumedLength;
|
||||
}
|
||||
|
||||
protected override void OnReadStarting()
|
||||
|
|
|
|||
|
|
@ -790,7 +790,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
|
|||
|
||||
// Add some input and read it to start PumpAsync
|
||||
input.Add("a");
|
||||
Assert.Equal(1, (await body.ReadAsync()).Buffer.Length);
|
||||
|
||||
// Time out on the next read
|
||||
input.Http1Connection.SendTimeoutResponse();
|
||||
|
|
|
|||
|
|
@ -727,6 +727,170 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ContentLengthReadAsyncPipeReaderBufferRequestBody()
|
||||
{
|
||||
var testContext = new TestServiceContext(LoggerFactory);
|
||||
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
// This will hang if 0 content length is not assumed by the server
|
||||
Assert.Equal(5, readResult.Buffer.Length);
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
Assert.Equal(5, readResult.Buffer.Length);
|
||||
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = server.CreateConnection())
|
||||
{
|
||||
await connection.SendAll(
|
||||
"POST / HTTP/1.0",
|
||||
"Host:",
|
||||
"Content-Length: 5",
|
||||
"",
|
||||
"hello");
|
||||
await connection.ReceiveEnd(
|
||||
"HTTP/1.1 200 OK",
|
||||
"Connection: close",
|
||||
$"Date: {testContext.DateHeaderValue}",
|
||||
"Content-Length: 0",
|
||||
"",
|
||||
"");
|
||||
}
|
||||
|
||||
await server.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ContentLengthReadAsyncPipeReaderBufferRequestBodyMultipleTimes()
|
||||
{
|
||||
var testContext = new TestServiceContext(LoggerFactory);
|
||||
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
// This will hang if 0 content length is not assumed by the server
|
||||
Assert.Equal(5, readResult.Buffer.Length);
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
|
||||
for (var i = 0; i < 2; i++)
|
||||
{
|
||||
readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
Assert.Equal(5, readResult.Buffer.Length);
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
}
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = server.CreateConnection())
|
||||
{
|
||||
await connection.SendAll(
|
||||
"POST / HTTP/1.0",
|
||||
"Host:",
|
||||
"Content-Length: 5",
|
||||
"",
|
||||
"hello");
|
||||
await connection.ReceiveEnd(
|
||||
"HTTP/1.1 200 OK",
|
||||
"Connection: close",
|
||||
$"Date: {testContext.DateHeaderValue}",
|
||||
"Content-Length: 0",
|
||||
"",
|
||||
"");
|
||||
}
|
||||
|
||||
await server.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ContentLengthReadAsyncSingleBytesAtATime()
|
||||
{
|
||||
var testContext = new TestServiceContext(LoggerFactory);
|
||||
var tcs = new TaskCompletionSource<object>();
|
||||
var tcs2 = new TaskCompletionSource<object>();
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
Assert.Equal(3, readResult.Buffer.Length);
|
||||
tcs.SetResult(null);
|
||||
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
|
||||
readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
tcs2.SetResult(null);
|
||||
|
||||
readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
Assert.Equal(5, readResult.Buffer.Length);
|
||||
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = server.CreateConnection())
|
||||
{
|
||||
await connection.Send(
|
||||
"POST / HTTP/1.0",
|
||||
"Host:",
|
||||
"Content-Length: 5",
|
||||
"",
|
||||
"fun");
|
||||
await tcs.Task;
|
||||
await connection.Send(
|
||||
"n");
|
||||
await tcs2.Task;
|
||||
await connection.Send(
|
||||
"y");
|
||||
await connection.ReceiveEnd(
|
||||
"HTTP/1.1 200 OK",
|
||||
"Connection: close",
|
||||
$"Date: {testContext.DateHeaderValue}",
|
||||
"Content-Length: 0",
|
||||
"",
|
||||
"");
|
||||
}
|
||||
|
||||
await server.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ContentLengthDoesNotConsumeEntireBufferDoesNotThrow()
|
||||
{
|
||||
var testContext = new TestServiceContext(LoggerFactory);
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
|
||||
|
||||
readResult = await httpContext.Request.BodyReader.ReadAsync();
|
||||
httpContext.Request.BodyReader.AdvanceTo(readResult.Buffer.Slice(1).Start, readResult.Buffer.End);
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = server.CreateConnection())
|
||||
{
|
||||
await connection.SendAll(
|
||||
"POST / HTTP/1.0",
|
||||
"Host:",
|
||||
"Content-Length: 5",
|
||||
"",
|
||||
"funny");
|
||||
|
||||
await connection.ReceiveEnd(
|
||||
"HTTP/1.1 200 OK",
|
||||
"Connection: close",
|
||||
$"Date: {testContext.DateHeaderValue}",
|
||||
"Content-Length: 0",
|
||||
"",
|
||||
"");
|
||||
}
|
||||
|
||||
await server.StopAsync();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ConnectionClosesWhenFinReceivedBeforeRequestCompletes()
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue