Fix Complete in Message Bodies (#11066)

This commit is contained in:
Justin Kotalik 2019-06-25 15:34:08 -07:00 committed by GitHub
parent 7003fb53d5
commit 73a2603aa6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 197 additions and 16 deletions

View File

@ -50,6 +50,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
public override bool TryRead(out ReadResult readResult)
{
ThrowIfCompleted();
return TryReadInternal(out readResult);
}
public override bool TryReadInternal(out ReadResult readResult)
{
TryStart();
@ -65,7 +72,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return boolResult;
}
public override async ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
{
ThrowIfCompleted();
return ReadAsyncInternal(cancellationToken);
}
public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken cancellationToken = default)
{
TryStart();
@ -92,7 +105,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override void Complete(Exception exception)
{
_requestBodyPipe.Reader.Complete();
_completed = true;
_context.ReportApplicationError(exception);
}

View File

@ -15,7 +15,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private readonly long _contentLength;
private long _inputLength;
private bool _readCompleted;
private bool _completed;
private bool _isReading;
private int _userCanceled;
private long _totalExaminedInPreviousReadResult;
@ -29,10 +28,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_inputLength = _contentLength;
}
public override async ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = default)
{
ThrowIfCompleted();
return ReadAsyncInternal(cancellationToken);
}
public override async ValueTask<ReadResult> ReadAsyncInternal(CancellationToken cancellationToken = default)
{
if (_isReading)
{
throw new InvalidOperationException("Reading is already in progress.");
@ -112,7 +115,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override bool TryRead(out ReadResult readResult)
{
ThrowIfCompleted();
return TryReadInternal(out readResult);
}
public override bool TryReadInternal(out ReadResult readResult)
{
if (_isReading)
{
throw new InvalidOperationException("Reading is already in progress.");
@ -164,14 +171,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return OnConsumeAsync();
}
private void ThrowIfCompleted()
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
}
private void CreateReadResultFromConnectionReadResult()
{

View File

@ -3,6 +3,7 @@
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
@ -12,6 +13,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
internal abstract class Http1MessageBody : MessageBody
{
protected readonly Http1Connection _context;
protected bool _completed;
protected Http1MessageBody(Http1Connection context)
: base(context)
@ -34,11 +36,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
public abstract bool TryReadInternal(out ReadResult readResult);
public abstract ValueTask<ReadResult> ReadAsyncInternal(CancellationToken cancellationToken = default);
protected override Task OnConsumeAsync()
{
try
{
if (TryRead(out var readResult))
while (TryReadInternal(out var readResult))
{
AdvanceTo(readResult.Buffer.End);
@ -79,7 +85,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ReadResult result;
do
{
result = await ReadAsync();
result = await ReadAsyncInternal();
AdvanceTo(result.Buffer.End);
} while (!result.IsCompleted);
}
@ -177,5 +183,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
context.OnTrailersComplete(); // No trailers for these.
return keepAlive ? MessageBody.ZeroContentLengthKeepAlive : MessageBody.ZeroContentLengthClose;
}
protected void ThrowIfCompleted()
{
if (_completed)
{
throw new InvalidOperationException("Reading is not allowed after the reader was completed.");
}
}
}
}

View File

@ -14,7 +14,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
/// </summary>
internal sealed class Http1UpgradeMessageBody : Http1MessageBody
{
public bool _completed;
public Http1UpgradeMessageBody(Http1Connection context)
: base(context)
{
@ -78,5 +77,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
return Task.CompletedTask;
}
public override bool TryReadInternal(out ReadResult readResult)
{
return _context.Input.TryRead(out readResult);
}
public override ValueTask<ReadResult> ReadAsyncInternal(CancellationToken cancellationToken = default)
{
return _context.Input.ReadAsync(cancellationToken);
}
}
}

View File

@ -1235,6 +1235,30 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}
}
[Fact]
public async Task CompleteForContentLengthAllowsConsumeToWork()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderContentLength = "5" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);
input.Add("a");
Assert.True(reader.TryRead(out var readResult));
Assert.False(readResult.IsCompleted);
input.Add("asdf");
reader.AdvanceTo(readResult.Buffer.End);
reader.Complete();
await body.ConsumeAsync();
}
}
[Fact]
public async Task CompleteForContentLengthDoesNotCompleteConnectionPipeMakesReadReturnThrow()
{
@ -1261,6 +1285,30 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}
}
[Fact]
public async Task CompleteForChunkedAllowsConsumeToWork()
{
using (var input = new TestInput())
{
var body = Http1MessageBody.For(HttpVersion.Http11, new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, input.Http1Connection);
var reader = new HttpRequestPipeReader();
reader.StartAcceptingReads(body);
input.Add("5\r\nHello\r\n");
Assert.True(reader.TryRead(out var readResult));
Assert.False(readResult.IsCompleted);
reader.AdvanceTo(readResult.Buffer.End);
input.Add("1\r\nH\r\n0\r\n\r\n");
reader.Complete();
await body.ConsumeAsync();
}
}
[Fact]
public async Task CompleteForChunkedDoesNotCompleteConnectionPipeMakesReadThrow()
{
@ -1313,7 +1361,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}
}
[Fact]
public async Task CompleteForZeroByteBodyDoesNotCompleteConnectionPipeNoopsReads()
{

View File

@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
@ -991,6 +992,62 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
}
}
[Fact]
public async Task ChunkedRequestCallCompleteDoesNotCauseException()
{
var testContext = new TestServiceContext(LoggerFactory);
await using (var server = new TestServer(async httpContext =>
{
var request = httpContext.Request;
// This read may receive all data, but what we care about
// is that ConsumeAsync is called and doesn't error. Calling
// TryRead before would always fail.
var readResult = await request.BodyReader.ReadAsync();
request.BodyReader.AdvanceTo(readResult.Buffer.End);
request.BodyReader.Complete();
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Transfer-Encoding: chunked",
"",
"1",
"H",
"4",
"ello",
"0",
"",
"");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
// start another request to make sure OnComsumeAsync is hit
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Transfer-Encoding: chunked",
"",
"0",
"",
"");
}
}
Assert.All(TestSink.Writes, w => Assert.InRange(w.LogLevel, LogLevel.Trace, LogLevel.Information));
}
[Fact]
public async Task ChunkedRequestCallCompleteWithExceptionCauses500()
{

View File

@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
using Microsoft.AspNetCore.Testing;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
@ -1535,6 +1536,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
}
}
[Fact]
public async Task ContentLengthRequestCallCompleteDoesNotCauseException()
{
var testContext = new TestServiceContext(LoggerFactory);
var tcs = new TaskCompletionSource<object>();
await using (var server = new TestServer(async httpContext =>
{
var request = httpContext.Request;
var readResult = await request.BodyReader.ReadAsync();
request.BodyReader.AdvanceTo(readResult.Buffer.End);
httpContext.Request.BodyReader.Complete();
tcs.SetResult(null);
}, testContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Content-Length: 5",
"",
"He");
await tcs.Task;
await connection.Send("llo");
await connection.Receive(
"HTTP/1.1 200 OK",
$"Date: {testContext.DateHeaderValue}",
"Content-Length: 0",
"",
"");
}
}
Assert.All(TestSink.Writes, w => Assert.InRange(w.LogLevel, LogLevel.Trace, LogLevel.Information));
}
[Fact]
public async Task ContentLengthCallCompleteWithExceptionCauses500()
{