diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs index f11619d7f3..7aeea58f34 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1MessageBody.cs @@ -148,6 +148,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } } + catch (OperationCanceledException) + { + // TryRead can throw OperationCanceledException https://github.com/dotnet/corefx/issues/32029 + // beacuse of buggy logic, this works around that for now + } catch (BadHttpRequestException ex) { // At this point, the response has already been written, so this won't result in a 4XX response; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs index 685980d1c6..cf0443fbe7 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs @@ -100,7 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } - public Task WriteAsync(Func callback, T state) + public Task WriteAsync(Func callback, T state, CancellationToken cancellationToken) { lock (_contextLock) { @@ -115,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _totalBytesCommitted += bytesCommitted; } - return FlushAsync(); + return FlushAsync(cancellationToken); } public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index b0ae93147d..14871734dd 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -928,7 +928,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private Task WriteChunkedAsync(ReadOnlyMemory data, CancellationToken cancellationToken) { - return Output.WriteAsync(_writeChunk, data); + return Output.WriteAsync(_writeChunk, data, cancellationToken); } private static long WriteChunk(PipeWriter writableBuffer, ReadOnlyMemory buffer) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs index 41dfdbbbec..5e7a1d5b75 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs @@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public interface IHttpOutputProducer : IDisposable { void Abort(ConnectionAbortedException abortReason); - Task WriteAsync(Func callback, T state); + Task WriteAsync(Func callback, T state, CancellationToken cancellationToken); Task FlushAsync(CancellationToken cancellationToken); Task Write100ContinueAsync(CancellationToken cancellationToken); void WriteResponseHeaders(int statusCode, string ReasonPhrase, HttpResponseHeaders responseHeaders); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs index 33bd8ebfb5..5d0dee8db9 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs @@ -42,7 +42,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http while (true) { - var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); var readableBuffer = result.Buffer; var consumed = readableBuffer.End; @@ -76,7 +76,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http while (true) { - var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var result = await _context.RequestBodyPipe.Reader.ReadAsync(cancellationToken); var readableBuffer = result.Buffer; var consumed = readableBuffer.End; @@ -90,7 +90,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http // - The WriteAsync(ReadOnlyMemory) isn't overridden on the destination // - We change the Kestrel Memory Pool to not use pinned arrays but instead use native memory #if NETCOREAPP2_1 - await destination.WriteAsync(memory); + await destination.WriteAsync(memory, cancellationToken); #else var array = memory.GetArray(); await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs index e701654d1e..24ab8db7e6 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 // TODO: RST_STREAM? } - public Task WriteAsync(Func callback, T state) + public Task WriteAsync(Func callback, T state, CancellationToken cancellationToken) { throw new NotImplementedException(); } diff --git a/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs index b528e80b3a..c41212972d 100644 --- a/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs +++ b/src/Servers/Kestrel/test/FunctionalTests/RequestTests.cs @@ -222,6 +222,74 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task RequestBodyReadAsyncCanBeCancelled() + { + var helloTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var readTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var cts = new CancellationTokenSource(); + + using (var server = new TestServer(async context => + { + var buffer = new byte[1024]; + try + { + await context.Request.Body.ReadUntilLengthAsync(buffer, 6, cts.Token).DefaultTimeout(); + + Assert.Equal("Hello ", Encoding.ASCII.GetString(buffer, 0, 6)); + + helloTcs.TrySetResult(null); + } + catch (Exception ex) + { + // This shouldn't fail + helloTcs.TrySetException(ex); + } + + try + { + var task = context.Request.Body.ReadAsync(buffer, 0, buffer.Length, cts.Token); + readTcs.TrySetResult(null); + await task; + + context.Response.ContentLength = 12; + await context.Response.WriteAsync("Read success"); + } + catch (OperationCanceledException) + { + context.Response.ContentLength = 14; + await context.Response.WriteAsync("Read cancelled"); + } + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Connection: keep-alive", + "Content-Length: 11", + "", + ""); + + await connection.Send("Hello "); + + await helloTcs.Task; + await readTcs.Task; + + // Cancel the body after hello is read + cts.Cancel(); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 14", + "", + "Read cancelled"); + } + } + } + [Fact] public void CanUpgradeRequestWithConnectionKeepAliveUpgradeHeader() { diff --git a/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs b/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs index 39b6725ee3..6cf324d36d 100644 --- a/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs +++ b/src/Servers/Kestrel/test/FunctionalTests/ResponseTests.cs @@ -216,6 +216,73 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task ResponseBodyWriteAsyncCanBeCancelled() + { + var serviceContext = new TestServiceContext(LoggerFactory); + var cts = new CancellationTokenSource(); + var appTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var writeBlockedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (var server = new TestServer(async context => + { + try + { + await context.Response.WriteAsync("hello", cts.Token).DefaultTimeout(); + + var data = new byte[1024 * 1024 * 10]; + + var timerTask = Task.Delay(TimeSpan.FromSeconds(1)); + var writeTask = context.Response.Body.WriteAsync(data, 0, data.Length, cts.Token).DefaultTimeout(); + var completedTask = await Task.WhenAny(writeTask, timerTask); + + while (completedTask == writeTask) + { + await writeTask; + timerTask = Task.Delay(TimeSpan.FromSeconds(1)); + writeTask = context.Response.Body.WriteAsync(data, 0, data.Length, cts.Token).DefaultTimeout(); + completedTask = await Task.WhenAny(writeTask, timerTask); + } + + writeBlockedTcs.TrySetResult(null); + + await writeTask; + } + catch (Exception ex) + { + appTcs.TrySetException(ex); + writeBlockedTcs.TrySetException(ex); + } + finally + { + appTcs.TrySetResult(null); + } + }, serviceContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + await connection.Receive($"HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "5", + "hello"); + + await writeBlockedTcs.Task.DefaultTimeout(); + + cts.Cancel(); + + await Assert.ThrowsAsync(() => appTcs.Task).DefaultTimeout(); + } + } + } + [Fact] public Task ResponseStatusCodeSetBeforeHttpContextDisposeAppException() { diff --git a/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/StreamExtensions.cs b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/StreamExtensions.cs new file mode 100644 index 0000000000..8fb041b7ef --- /dev/null +++ b/src/Servers/Kestrel/test/FunctionalTests/TestHelpers/StreamExtensions.cs @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace System.IO +{ + public static class StreamFillBufferExtensions + { + public static async Task ReadUntilEndAsync(this Stream stream, byte[] buffer, CancellationToken cancellationToken = default) + { + var offset = 0; + + while (offset < buffer.Length) + { + var read = await stream.ReadAsync(buffer, offset, buffer.Length - offset, cancellationToken); + offset += read; + + if (read == 0) + { + return offset; + } + } + + Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1, cancellationToken)); + + return offset; + } + + public static async Task ReadUntilLengthAsync(this Stream stream, byte[] buffer, int length, CancellationToken cancellationToken = default) + { + var offset = 0; + + while (offset < length) + { + var read = await stream.ReadAsync(buffer, offset, length - offset, cancellationToken); + offset += read; + + Assert.NotEqual(0, read); + } + } + } +} diff --git a/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj b/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj index 96bd15c6ac..742b913e55 100644 --- a/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj +++ b/src/Servers/Kestrel/test/Libuv.FunctionalTests/Libuv.FunctionalTests.csproj @@ -23,6 +23,7 @@ + diff --git a/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj b/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj index b4e5e24498..8600f1086c 100644 --- a/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj +++ b/src/Servers/Kestrel/test/Sockets.FunctionalTests/Sockets.FunctionalTests.csproj @@ -22,6 +22,7 @@ +