Flow the cancellation tokens to ReadAsync and WriteAsync (#2865)

This commit is contained in:
Chris Ross (ASP.NET) 2018-11-14 16:20:31 -08:00
parent c2685813ab
commit 88273412b8
11 changed files with 195 additions and 8 deletions

View File

@ -148,6 +148,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
}
catch (OperationCanceledException)
{
// TryRead can throw OperationCanceledException https://github.com/dotnet/corefx/issues/32029
// beacuse of buggy logic, this works around that for now
}
catch (BadHttpRequestException ex)
{
// At this point, the response has already been written, so this won't result in a 4XX response;

View File

@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
}
}
public Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state)
public Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state, CancellationToken cancellationToken)
{
lock (_contextLock)
{
@ -115,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
_totalBytesCommitted += bytesCommitted;
}
return FlushAsync();
return FlushAsync(cancellationToken);
}
public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders)

View File

@ -928,7 +928,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private Task WriteChunkedAsync(ReadOnlyMemory<byte> data, CancellationToken cancellationToken)
{
return Output.WriteAsync(_writeChunk, data);
return Output.WriteAsync(_writeChunk, data, cancellationToken);
}
private static long WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory<byte> buffer)

View File

@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public interface IHttpOutputProducer : IDisposable
{
void Abort(ConnectionAbortedException abortReason);
Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state);
Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state, CancellationToken cancellationToken);
Task FlushAsync(CancellationToken cancellationToken);
Task Write100ContinueAsync(CancellationToken cancellationToken);
void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders);

View File

@ -42,7 +42,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
while (true)
{
var result = await _context.RequestBodyPipe.Reader.ReadAsync();
var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken);
var readableBuffer = result.Buffer;
var consumed = readableBuffer.End;
@ -76,7 +76,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
while (true)
{
var result = await _context.RequestBodyPipe.Reader.ReadAsync();
var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken);
var readableBuffer = result.Buffer;
var consumed = readableBuffer.End;
@ -90,7 +90,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
// - The WriteAsync(ReadOnlyMemory<byte>) isn't overridden on the destination
// - We change the Kestrel Memory Pool to not use pinned arrays but instead use native memory
#if NETCOREAPP2_1
await destination.WriteAsync(memory);
await destination.WriteAsync(memory, cancellationToken);
#else
var array = memory.GetArray();
await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken);

View File

@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
// TODO: RST_STREAM?
}
public Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state)
public Task WriteAsync<T>(Func<PipeWriter, T, long> callback, T state, CancellationToken cancellationToken)
{
throw new NotImplementedException();
}

View File

@ -222,6 +222,74 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
}
}
[Fact]
public async Task RequestBodyReadAsyncCanBeCancelled()
{
var helloTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var readTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var cts = new CancellationTokenSource();
using (var server = new TestServer(async context =>
{
var buffer = new byte[1024];
try
{
await context.Request.Body.ReadUntilLengthAsync(buffer, 6, cts.Token).DefaultTimeout();
Assert.Equal("Hello ", Encoding.ASCII.GetString(buffer, 0, 6));
helloTcs.TrySetResult(null);
}
catch (Exception ex)
{
// This shouldn't fail
helloTcs.TrySetException(ex);
}
try
{
var task = context.Request.Body.ReadAsync(buffer, 0, buffer.Length, cts.Token);
readTcs.TrySetResult(null);
await task;
context.Response.ContentLength = 12;
await context.Response.WriteAsync("Read success");
}
catch (OperationCanceledException)
{
context.Response.ContentLength = 14;
await context.Response.WriteAsync("Read cancelled");
}
}, new TestServiceContext(LoggerFactory)))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
"Connection: keep-alive",
"Content-Length: 11",
"",
"");
await connection.Send("Hello ");
await helloTcs.Task;
await readTcs.Task;
// Cancel the body after hello is read
cts.Cancel();
await connection.Receive($"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Content-Length: 14",
"",
"Read cancelled");
}
}
}
[Fact]
public void CanUpgradeRequestWithConnectionKeepAliveUpgradeHeader()
{

View File

@ -216,6 +216,73 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
}
}
[Fact]
public async Task ResponseBodyWriteAsyncCanBeCancelled()
{
var serviceContext = new TestServiceContext(LoggerFactory);
var cts = new CancellationTokenSource();
var appTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
var writeBlockedTcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
using (var server = new TestServer(async context =>
{
try
{
await context.Response.WriteAsync("hello", cts.Token).DefaultTimeout();
var data = new byte[1024 * 1024 * 10];
var timerTask = Task.Delay(TimeSpan.FromSeconds(1));
var writeTask = context.Response.Body.WriteAsync(data, 0, data.Length, cts.Token).DefaultTimeout();
var completedTask = await Task.WhenAny(writeTask, timerTask);
while (completedTask == writeTask)
{
await writeTask;
timerTask = Task.Delay(TimeSpan.FromSeconds(1));
writeTask = context.Response.Body.WriteAsync(data, 0, data.Length, cts.Token).DefaultTimeout();
completedTask = await Task.WhenAny(writeTask, timerTask);
}
writeBlockedTcs.TrySetResult(null);
await writeTask;
}
catch (Exception ex)
{
appTcs.TrySetException(ex);
writeBlockedTcs.TrySetException(ex);
}
finally
{
appTcs.TrySetResult(null);
}
}, serviceContext))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"Host:",
"",
"");
await connection.Receive($"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Transfer-Encoding: chunked",
"",
"5",
"hello");
await writeBlockedTcs.Task.DefaultTimeout();
cts.Cancel();
await Assert.ThrowsAsync<OperationCanceledException>(() => appTcs.Task).DefaultTimeout();
}
}
}
[Fact]
public Task ResponseStatusCodeSetBeforeHttpContextDisposeAppException()
{

View File

@ -0,0 +1,45 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace System.IO
{
public static class StreamFillBufferExtensions
{
public static async Task<int> ReadUntilEndAsync(this Stream stream, byte[] buffer, CancellationToken cancellationToken = default)
{
var offset = 0;
while (offset < buffer.Length)
{
var read = await stream.ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken);
offset += read;
if (read == 0)
{
return offset;
}
}
Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1, cancellationToken));
return offset;
}
public static async Task ReadUntilLengthAsync(this Stream stream, byte[] buffer, int length, CancellationToken cancellationToken = default)
{
var offset = 0;
while (offset < length)
{
var read = await stream.ReadAsync(buffer, offset, length - offset, cancellationToken);
offset += read;
Assert.NotEqual(0, read);
}
}
}
}

View File

@ -23,6 +23,7 @@
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Https" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Core" />
<Reference Include="Newtonsoft.Json" />
</ItemGroup>

View File

@ -22,6 +22,7 @@
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Https" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel.Core" />
<Reference Include="Newtonsoft.Json" />
</ItemGroup>