From edc19354751dd43dc48e71cdfe192147e40122c8 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 6 Sep 2018 09:04:39 -0700 Subject: [PATCH] Flow the cancellation tokens to ReadAsync and WriteAsync (#2865) --- .../Internal/Http/Http1MessageBody.cs | 5 ++ .../Internal/Http/Http1OutputProducer.cs | 20 +----- .../Internal/Http/HttpProtocol.cs | 2 +- .../Internal/Http/IHttpOutputProducer.cs | 2 +- src/Kestrel.Core/Internal/Http/MessageBody.cs | 6 +- .../Internal/Http2/Http2OutputProducer.cs | 2 +- .../Kestrel.Core.Tests/OutputProducerTests.cs | 8 ++- .../RequestTests.cs | 68 +++++++++++++++++++ .../ResponseTests.cs | 56 +++++++++++++++ .../RequestTests.cs | 2 +- .../LibuvOutputConsumerTests.cs | 5 +- 11 files changed, 146 insertions(+), 30 deletions(-) diff --git a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs index 602df5100a..8440927229 100644 --- a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs @@ -155,6 +155,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } } + catch (OperationCanceledException) + { + // TryRead can throw OperationCanceledException https://github.com/dotnet/corefx/issues/32029 + // beacuse of buggy logic, this works around that for now + } catch (BadHttpRequestException ex) { // At this point, the response has already been written, so this won't result in a 4XX response; diff --git a/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs b/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs index 1337b6b38b..5acff95b16 100644 --- a/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http/Http1OutputProducer.cs @@ -73,23 +73,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return WriteAsync(Constants.EmptyData, cancellationToken); } - public void Write(Func callback, T state) - { - lock (_contextLock) - { - if (_completed) - { - return; - } - - var buffer = _pipeWriter; - var bytesCommitted = callback(buffer, state); - _unflushedBytes += bytesCommitted; - _totalBytesCommitted += bytesCommitted; - } - } - - public Task WriteAsync(Func callback, T state) + public Task WriteAsync(Func callback, T state, CancellationToken cancellationToken) { lock (_contextLock) { @@ -104,7 +88,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _totalBytesCommitted += bytesCommitted; } - return FlushAsync(); + return FlushAsync(cancellationToken); } public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders) diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs index 0c27877c4a..2e80c87bae 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs @@ -915,7 +915,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private Task WriteChunkedAsync(ReadOnlyMemory data, CancellationToken cancellationToken) { - return Output.WriteAsync(_writeChunk, data); + return Output.WriteAsync(_writeChunk, data, cancellationToken); } private static long WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory buffer) diff --git a/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs b/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs index 6dbdaca7f4..25f3af7012 100644 --- a/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http/IHttpOutputProducer.cs @@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public interface IHttpOutputProducer : IDisposable { void Abort(ConnectionAbortedException abortReason); - Task WriteAsync(Func callback, T state); + Task WriteAsync(Func callback, T state, CancellationToken cancellationToken); Task FlushAsync(CancellationToken cancellationToken); Task Write100ContinueAsync(); void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders); diff --git a/src/Kestrel.Core/Internal/Http/MessageBody.cs b/src/Kestrel.Core/Internal/Http/MessageBody.cs index ed0bab633b..2ef88d5f33 100644 --- a/src/Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/MessageBody.cs @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http while (true) { - var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); var readableBuffer = result.Buffer; var consumed = readableBuffer.End; var actual = 0; @@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http while (true) { - var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); var readableBuffer = result.Buffer; var consumed = readableBuffer.End; var bytesRead = 0; @@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http bytesRead += memory.Length; #if NETCOREAPP2_1 - await destination.WriteAsync(memory); + await destination.WriteAsync(memory, cancellationToken); #elif NETSTANDARD2_0 var array = memory.GetArray(); await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken); diff --git a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs index 1a9f88b4c7..405b339436 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs @@ -77,7 +77,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 Dispose(); } - public Task WriteAsync(Func callback, T state) + public Task WriteAsync(Func callback, T state, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/test/Kestrel.Core.Tests/OutputProducerTests.cs b/test/Kestrel.Core.Tests/OutputProducerTests.cs index 7f3d566ff5..22102de0be 100644 --- a/test/Kestrel.Core.Tests/OutputProducerTests.cs +++ b/test/Kestrel.Core.Tests/OutputProducerTests.cs @@ -5,6 +5,7 @@ using System; using System.Buffers; using System.IO.Pipelines; using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -31,7 +32,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } [Fact] - public void WritesNoopAfterConnectionCloses() + public async Task WritesNoopAfterConnectionCloses() { var pipeOptions = new PipeOptions ( @@ -48,12 +49,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var called = false; - socketOutput.Write((buffer, state) => + await socketOutput.WriteAsync((buffer, state) => { called = true; return 0; }, - 0); + 0, + default); Assert.False(called); } diff --git a/test/Kestrel.InMemory.FunctionalTests/RequestTests.cs b/test/Kestrel.InMemory.FunctionalTests/RequestTests.cs index 0edb95b273..6760c4b0b5 100644 --- a/test/Kestrel.InMemory.FunctionalTests/RequestTests.cs +++ b/test/Kestrel.InMemory.FunctionalTests/RequestTests.cs @@ -7,6 +7,7 @@ using System.IO; using System.IO.Pipelines; using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -55,6 +56,73 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests } } + [Fact] + public async Task RequestBodyReadAsyncCanBeCancelled() + { + var helloTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var cts = new CancellationTokenSource(); + + using (var server = new TestServer(async context => + { + var buffer = new byte[1024]; + try + { + + int read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length, cts.Token); + + Assert.Equal("Hello ", Encoding.UTF8.GetString(buffer, 0, read)); + + helloTcs.TrySetResult(null); + } + catch (Exception ex) + { + // This shouldn't fail + helloTcs.TrySetException(ex); + } + + try + { + await context.Request.Body.ReadAsync(buffer, 0, buffer.Length, cts.Token); + + context.Response.ContentLength = 12; + await context.Response.WriteAsync("Read success"); + } + catch (OperationCanceledException) + { + context.Response.ContentLength = 14; + await context.Response.WriteAsync("Read cancelled"); + } + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Connection: keep-alive", + "Content-Length: 11", + "", + ""); + + await connection.Send("Hello "); + + await helloTcs.Task; + + // Cancel the body after hello is read + cts.Cancel(); + + await connection.Send("World"); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 14", + "", + "Read cancelled"); + } + } + } + [Fact] public async Task CanUpgradeRequestWithConnectionKeepAliveUpgradeHeader() { diff --git a/test/Kestrel.InMemory.FunctionalTests/ResponseTests.cs b/test/Kestrel.InMemory.FunctionalTests/ResponseTests.cs index e0e63fd0d5..26c1433b16 100644 --- a/test/Kestrel.InMemory.FunctionalTests/ResponseTests.cs +++ b/test/Kestrel.InMemory.FunctionalTests/ResponseTests.cs @@ -7,6 +7,7 @@ using System.IO; using System.Linq; using System.Net; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -99,6 +100,61 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests } } + [Fact] + public async Task ResponseBodyWriteAsyncCanBeCancelled() + { + var serviceContext = new TestServiceContext(LoggerFactory); + serviceContext.ServerOptions.Limits.MaxResponseBufferSize = 5; + var cts = new CancellationTokenSource(); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var writeReturnedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + try + { + await context.Response.WriteAsync("hello", cts.Token).DefaultTimeout(); + writeReturnedTcs.TrySetResult(null); + + var task = context.Response.WriteAsync("world", cts.Token); + Assert.False(task.IsCompleted); + await task.DefaultTimeout(); + } + catch (Exception ex) + { + appTcs.TrySetException(ex); + } + finally + { + appTcs.TrySetResult(null); + writeReturnedTcs.TrySetCanceled(); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "5", + "hello"); + + await writeReturnedTcs.Task.DefaultTimeout(); + + cts.Cancel(); + + await Assert.ThrowsAsync(() => appTcs.Task).DefaultTimeout(); + } + } + } + [Fact] public Task ResponseStatusCodeSetBeforeHttpContextDisposeAppException() { diff --git a/test/Kestrel.Transport.FunctionalTests/RequestTests.cs b/test/Kestrel.Transport.FunctionalTests/RequestTests.cs index e194331df1..ce253762df 100644 --- a/test/Kestrel.Transport.FunctionalTests/RequestTests.cs +++ b/test/Kestrel.Transport.FunctionalTests/RequestTests.cs @@ -658,7 +658,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests await Assert.ThrowsAsync(async () => await readTcs.Task); // The cancellation token for only the last request should be triggered. - var abortedRequestId = await registrationTcs.Task; + var abortedRequestId = await registrationTcs.Task.DefaultTimeout(); Assert.Equal(2, abortedRequestId); Assert.Single(TestSink.Writes.Where(w => w.LoggerName == "Microsoft.AspNetCore.Server.Kestrel" && diff --git a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs index 466f68ac40..48c2cee506 100644 --- a/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs +++ b/test/Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs @@ -303,12 +303,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests Assert.NotEmpty(completeQueue); // Add more bytes to the write-behind buffer to prevent the next write from - outputProducer.Write((writableBuffer, state) => + _ = outputProducer.WriteAsync((writableBuffer, state) => { writableBuffer.Write(state); return state.Count; }, - halfWriteBehindBuffer); + halfWriteBehindBuffer, + default); // Act var writeTask2 = outputProducer.WriteDataAsync(halfWriteBehindBuffer);