diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs index d5225447d1..5cd4f8c378 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs @@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private bool _requestProcessingStarted; private Task _requestProcessingTask; protected volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx - protected volatile bool _requestAborted; + protected int _requestAborted; protected CancellationTokenSource _abortedCts; protected CancellationToken? _manuallySetRequestAbortToken; @@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http var cts = _abortedCts; return cts != null ? cts.Token : - _requestAborted ? new CancellationToken(true) : + (Volatile.Read(ref _requestAborted) == 1) ? new CancellationToken(true) : RequestAbortedSource.Token; } set @@ -185,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http // Get the abort token, lazily-initializing it if necessary. // Make sure it's canceled if an abort request already came in. var cts = LazyInitializer.EnsureInitialized(ref _abortedCts, () => new CancellationTokenSource()); - if (_requestAborted) + if (Volatile.Read(ref _requestAborted) == 1) { cts.Cancel(); } @@ -288,24 +288,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http /// public void Abort() { - _requestProcessingStopping = true; - _requestAborted = true; + if (Interlocked.CompareExchange(ref _requestAborted, 1, 0) == 0) + { + _requestProcessingStopping = true; - _requestBody?.Abort(); - _responseBody?.Abort(); + _requestBody?.Abort(); + _responseBody?.Abort(); - try - { - ConnectionControl.End(ProduceEndType.SocketDisconnect); - SocketInput.AbortAwaiting(); - RequestAbortedSource.Cancel(); - } - catch (Exception ex) - { - Log.LogError("Abort", ex); - } - finally - { + try + { + ConnectionControl.End(ProduceEndType.SocketDisconnect); + SocketInput.AbortAwaiting(); + } + catch (Exception ex) + { + Log.LogError("Abort", ex); + } + + try + { + RequestAbortedSource.Cancel(); + } + catch (Exception ex) + { + Log.LogError("Abort", ex); + } _abortedCts = null; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs index 0f3d8d48d6..e9498b1d40 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs @@ -3,6 +3,7 @@ using System; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http.Features; @@ -111,7 +112,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _application.DisposeContext(context, _applicationException); // If _requestAbort is set, the connection has already been closed. - if (!_requestAborted) + if (Volatile.Read(ref _requestAborted) == 0) { _responseBody.ResumeAcceptingWrites(); await ProduceEnd(); @@ -148,7 +149,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _abortedCts = null; // If _requestAborted is set, the connection has already been closed. - if (!_requestAborted) + if (Volatile.Read(ref _requestAborted) == 0) { // Inform client no more data will ever arrive ConnectionControl.End(ProduceEndType.SocketShutdownSend); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs index dc70690b32..efd3669104 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameRequestStream.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http { @@ -51,8 +52,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public override int Read(byte[] buffer, int offset, int count) { - ValidateState(); - // ValueTask uses .GetAwaiter().GetResult() if necessary return ReadAsync(buffer, offset, count).Result; } @@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http #if NET451 public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) { - ValidateState(); + ValidateState(CancellationToken.None); var task = ReadAsync(buffer, offset, count, CancellationToken.None, state); if (callback != null) @@ -77,7 +76,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken, object state) { - ValidateState(); + ValidateState(cancellationToken); var tcs = new TaskCompletionSource(state); var task = _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken); @@ -103,10 +102,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - ValidateState(); - - // Needs .AsTask to match Stream's Async method return types - return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken).AsTask(); + var task = ValidateState(cancellationToken); + if (task == null) + { + // Needs .AsTask to match Stream's Async method return types + return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken).AsTask(); + } + return task; } public override void Write(byte[] buffer, int offset, int count) @@ -149,24 +151,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void Abort() { // We don't want to throw an ODE until the app func actually completes. - // If the request is aborted, we throw an IOException instead. + // If the request is aborted, we throw an TaskCanceledException instead. if (_state != FrameStreamState.Closed) { _state = FrameStreamState.Aborted; } } - private void ValidateState() + private Task ValidateState(CancellationToken cancellationToken) { switch (_state) { case FrameStreamState.Open: - return; + if (cancellationToken.IsCancellationRequested) + { + return TaskUtilities.GetCancelledZeroTask(); + } + break; case FrameStreamState.Closed: throw new ObjectDisposedException(nameof(FrameRequestStream)); case FrameStreamState.Aborted: - throw new IOException("The request has been aborted."); + return TaskUtilities.GetCancelledZeroTask(); } + return null; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs index f45470f555..d5b4f41aea 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameResponseStream.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNet.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http { @@ -37,16 +38,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public override void Flush() { - ValidateState(); + ValidateState(CancellationToken.None); _context.FrameControl.Flush(); } public override Task FlushAsync(CancellationToken cancellationToken) { - ValidateState(); - - return _context.FrameControl.FlushAsync(cancellationToken); + var task = ValidateState(cancellationToken); + if (task == null) + { + return _context.FrameControl.FlushAsync(cancellationToken); + } + return task; } public override long Seek(long offset, SeekOrigin origin) @@ -66,16 +70,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public override void Write(byte[] buffer, int offset, int count) { - ValidateState(); + ValidateState(CancellationToken.None); _context.FrameControl.Write(new ArraySegment(buffer, offset, count)); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { - ValidateState(); - - return _context.FrameControl.WriteAsync(new ArraySegment(buffer, offset, count), cancellationToken); + var task = ValidateState(cancellationToken); + if (task == null) + { + return _context.FrameControl.WriteAsync(new ArraySegment(buffer, offset, count), cancellationToken); + } + return task; } public Stream StartAcceptingWrites() @@ -112,24 +119,36 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void Abort() { // We don't want to throw an ODE until the app func actually completes. - // If the request is aborted, we throw an IOException instead. if (_state != FrameStreamState.Closed) { _state = FrameStreamState.Aborted; } } - private void ValidateState() + private Task ValidateState(CancellationToken cancellationToken) { switch (_state) { case FrameStreamState.Open: - return; + if (cancellationToken.IsCancellationRequested) + { + return TaskUtilities.GetCancelledTask(cancellationToken); + } + break; case FrameStreamState.Closed: throw new ObjectDisposedException(nameof(FrameResponseStream)); case FrameStreamState.Aborted: - throw new IOException("The request has been aborted."); + if (cancellationToken.CanBeCanceled) + { + // Aborted state only throws on write if cancellationToken requests it + return TaskUtilities.GetCancelledTask( + cancellationToken.IsCancellationRequested ? + cancellationToken : + new CancellationToken(true)); + } + break; } + return null; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs index ce9f4aa9cc..daeeeaa9e1 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs @@ -2,11 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.IO; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Server.Kestrel.Http { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs index 94e335b345..83213cbead 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Runtime.CompilerServices; using System.Threading; +using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Http @@ -184,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void AbortAwaiting() { - _awaitableError = new ObjectDisposedException(nameof(SocketInput), "The request was aborted"); + _awaitableError = new TaskCanceledException("The request was aborted"); Complete(); } @@ -238,6 +239,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http var error = _awaitableError; if (error != null) { + if (error is TaskCanceledException || error is InvalidOperationException) + { + throw error; + } throw new IOException(error.Message, error); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs index 15a56f9fc8..38533f447b 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketOutput.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; @@ -22,6 +21,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private const int _maxPooledWriteContexts = 32; private static readonly WaitCallback _returnBlocks = (state) => ReturnBlocks((MemoryPoolBlock2)state); + private static readonly Action _connectionCancellation = (state) => ((SocketOutput)state).CancellationTriggered(); private readonly KestrelThread _thread; private readonly UvStreamHandle _socket; @@ -78,6 +78,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public Task WriteAsync( ArraySegment buffer, + CancellationToken cancellationToken, bool immediate = true, bool chunk = false, bool socketShutdownSend = false, @@ -89,9 +90,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http lock (_contextLock) { + if (_lastWriteError != null || _socket.IsClosed) + { + _log.ConnectionDisconnectedWrite(_connectionId, buffer.Count, _lastWriteError); + + return TaskUtilities.CompletedTask; + } + if (buffer.Count > 0) { var tail = ProducingStart(); + if (tail.IsDefault) + { + return TaskUtilities.CompletedTask; + } + if (chunk) { _numBytesPreCompleted += ChunkWriter.WriteBeginChunkBytes(ref tail, buffer.Count); @@ -146,13 +159,36 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } else { - // immediate write, which is not eligable for instant completion above - tcs = new TaskCompletionSource(buffer.Count); - _tasksPending.Enqueue(new WaitingTask() { - CompletionSource = tcs, - BytesToWrite = buffer.Count, - IsSync = isSync - }); + if (cancellationToken.CanBeCanceled) + { + if (cancellationToken.IsCancellationRequested) + { + _connection.Abort(); + + return TaskUtilities.GetCancelledTask(cancellationToken); + } + else + { + // immediate write, which is not eligable for instant completion above + tcs = new TaskCompletionSource(); + _tasksPending.Enqueue(new WaitingTask() + { + CancellationToken = cancellationToken, + CancellationRegistration = cancellationToken.Register(_connectionCancellation, this), + BytesToWrite = buffer.Count, + CompletionSource = tcs + }); + } + } + else + { + tcs = new TaskCompletionSource(); + _tasksPending.Enqueue(new WaitingTask() { + IsSync = isSync, + BytesToWrite = buffer.Count, + CompletionSource = tcs + }); + } } if (!_writePending && immediate) @@ -177,12 +213,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { case ProduceEndType.SocketShutdownSend: WriteAsync(default(ArraySegment), + default(CancellationToken), immediate: true, socketShutdownSend: true, socketDisconnect: false); break; case ProduceEndType.SocketDisconnect: WriteAsync(default(ArraySegment), + default(CancellationToken), immediate: true, socketShutdownSend: false, socketDisconnect: true); @@ -198,7 +236,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http if (_tail == null) { - throw new IOException("The socket has been closed."); + return default(MemoryPoolIterator2); } _lastStart = new MemoryPoolIterator2(_tail, _tail.End); @@ -251,6 +289,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } } + private void CancellationTriggered() + { + lock (_contextLock) + { + // Abort the connection for any failed write + // Queued on threadpool so get it in as first op. + _connection?.Abort(); + + CompleteAllWrites(); + } + } + private static void ReturnBlocks(MemoryPoolBlock2 block) { while (block != null) @@ -305,10 +355,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http if (error != null) { - _lastWriteError = new IOException(error.Message, error); - - // Abort the connection for any failed write. + // Abort the connection for any failed write + // Queued on threadpool so get it in as first op. _connection.Abort(); + _lastWriteError = error; } PoolWriteContext(writeContext); @@ -317,43 +367,78 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http // completed writes that we haven't triggered callbacks for yet. _numBytesPreCompleted -= bytesWritten; + CompleteFinishedWrites(status); + + if (error != null) + { + _log.ConnectionError(_connectionId, error); + } + else + { + _log.ConnectionWriteCallback(_connectionId, status); + } + } + + private void CompleteNextWrite(ref int bytesLeftToBuffer) + { + var waitingTask = _tasksPending.Dequeue(); + var bytesToWrite = waitingTask.BytesToWrite; + + _numBytesPreCompleted += bytesToWrite; + bytesLeftToBuffer -= bytesToWrite; + + // Dispose registration if there is one + waitingTask.CancellationRegistration?.Dispose(); + + if (waitingTask.CancellationToken.IsCancellationRequested) + { + if (waitingTask.IsSync) + { + waitingTask.CompletionSource.TrySetCanceled(); + } + else + { + _threadPool.Cancel(waitingTask.CompletionSource); + } + } + else + { + if (waitingTask.IsSync) + { + waitingTask.CompletionSource.TrySetResult(null); + } + else + { + _threadPool.Complete(waitingTask.CompletionSource); + } + } + } + + private void CompleteFinishedWrites(int status) + { // bytesLeftToBuffer can be greater than _maxBytesPreCompleted // This allows large writes to complete once they've actually finished. var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; while (_tasksPending.Count > 0 && (_tasksPending.Peek().BytesToWrite) <= bytesLeftToBuffer) { - var waitingTask = _tasksPending.Dequeue(); - var bytesToWrite = waitingTask.BytesToWrite; + CompleteNextWrite(ref bytesLeftToBuffer); + } + } - _numBytesPreCompleted += bytesToWrite; - bytesLeftToBuffer -= bytesToWrite; - - if (_lastWriteError == null) - { - if (waitingTask.IsSync) - { - waitingTask.CompletionSource.TrySetResult(null); - } - else - { - _threadPool.Complete(waitingTask.CompletionSource); - } - } - else - { - if (waitingTask.IsSync) - { - waitingTask.CompletionSource.TrySetException(_lastWriteError); - } - else - { - _threadPool.Error(waitingTask.CompletionSource, _lastWriteError); - } - } + private void CompleteAllWrites() + { + var writesToComplete = _tasksPending.Count > 0; + var bytesLeftToBuffer = _maxBytesPreCompleted - _numBytesPreCompleted; + while (_tasksPending.Count > 0) + { + CompleteNextWrite(ref bytesLeftToBuffer); } - _log.ConnectionWriteCallback(_connectionId, status); + if (writesToComplete) + { + _log.ConnectionError(_connectionId, new TaskCanceledException("Connetcion")); + } } // This is called on the libuv event loop @@ -393,12 +478,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http void ISocketOutput.Write(ArraySegment buffer, bool immediate, bool chunk) { - WriteAsync(buffer, immediate, chunk, isSync: true).GetAwaiter().GetResult(); + WriteAsync(buffer, CancellationToken.None, immediate, chunk, isSync: true).GetAwaiter().GetResult(); } Task ISocketOutput.WriteAsync(ArraySegment buffer, bool immediate, bool chunk, CancellationToken cancellationToken) { - return WriteAsync(buffer, immediate, chunk); + if (cancellationToken.IsCancellationRequested) + { + _connection?.Abort(); + return TaskUtilities.GetCancelledTask(cancellationToken); + } + + return WriteAsync(buffer, cancellationToken, immediate, chunk); } private static void BytesBetween(MemoryPoolIterator2 start, MemoryPoolIterator2 end, out int bytes, out int buffers) @@ -649,6 +740,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { public bool IsSync; public int BytesToWrite; + public CancellationToken CancellationToken; + public IDisposable CancellationRegistration; public TaskCompletionSource CompletionSource; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs index 1cab425247..0ed3c6565a 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs @@ -29,6 +29,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure void ConnectionWriteCallback(long connectionId, int status); + void ConnectionError(long connectionId, Exception ex); + + void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex); + void ApplicationError(Exception ex); } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs index 404bc01a55..f9217bd992 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IThreadPool.cs @@ -9,6 +9,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public interface IThreadPool { void Complete(TaskCompletionSource tcs); + void Cancel(TaskCompletionSource tcs); void Error(TaskCompletionSource tcs, Exception ex); void Run(Action action); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs index f9b82ce25b..5e8a2ee28f 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs @@ -21,6 +21,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel private static readonly Action _connectionWroteFin; private static readonly Action _connectionKeepAlive; private static readonly Action _connectionDisconnect; + private static readonly Action _connectionError; + private static readonly Action _connectionDisconnectedWrite; protected readonly ILogger _logger; @@ -39,6 +41,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel // ConnectionWrite: Reserved: 11 // ConnectionWriteCallback: Reserved: 12 // ApplicationError: Reserved: 13 - LoggerMessage.Define overload not present + _connectionError = LoggerMessage.Define(LogLevel.Information, 14, @"Connection id ""{ConnectionId}"" communication error"); + _connectionDisconnectedWrite = LoggerMessage.Define(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client."); } public KestrelTrace(ILogger logger) @@ -114,6 +118,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel _logger.LogError(13, "An unhandled exception was thrown by the application.", ex); } + public virtual void ConnectionError(long connectionId, Exception ex) + { + _connectionError(_logger, connectionId, ex); + } + + public virtual void ConnectionDisconnectedWrite(long connectionId, int count, Exception ex) + { + _connectionDisconnectedWrite(_logger, connectionId, count, ex); + } + public virtual void Log(LogLevel logLevel, int eventId, object state, Exception exception, Func formatter) { _logger.Log(logLevel, eventId, state, exception, formatter); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs index 70f142f536..a5f41987d4 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/LoggingThreadPool.cs @@ -12,6 +12,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure private readonly IKestrelTrace _log; private readonly WaitCallback _runAction; + private readonly WaitCallback _cancelTcs; private readonly WaitCallback _completeTcs; public LoggingThreadPool(IKestrelTrace log) @@ -42,6 +43,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure _log.ApplicationError(e); } }; + + _cancelTcs = (o) => + { + try + { + ((TaskCompletionSource)o).TrySetCanceled(); + } + catch (Exception e) + { + _log.ApplicationError(e); + } + }; } public void Run(Action action) @@ -54,6 +67,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure ThreadPool.QueueUserWorkItem(_completeTcs, tcs); } + public void Cancel(TaskCompletionSource tcs) + { + ThreadPool.QueueUserWorkItem(_cancelTcs, tcs); + } + public void Error(TaskCompletionSource tcs, Exception ex) { // ex ang _log are closure captured diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs index 62b19ced94..8e22ee4018 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs @@ -724,6 +724,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public void CopyFrom(byte[] data, int offset, int count) { + if (IsDefault) + { + return; + } + Debug.Assert(_block != null); Debug.Assert(_block.Next == null); Debug.Assert(_block.End == _index); @@ -766,6 +771,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public unsafe void CopyFromAscii(string data) { + if (IsDefault) + { + return; + } + Debug.Assert(_block != null); Debug.Assert(_block.Next == null); Debug.Assert(_block.End == _index); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs index a59713eeaa..5e52222d3d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure @@ -13,5 +14,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public static Task CompletedTask = Task.FromResult(null); #endif public static Task ZeroTask = Task.FromResult(0); + + public static Task GetCancelledTask(CancellationToken cancellationToken) + { +#if DOTNET5_4 + return Task.FromCanceled(cancellationToken); +#else + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; +#endif + } + + public static Task GetCancelledZeroTask() + { + // Task.FromCanceled doesn't return Task + var tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index a44f096129..85858ddcf5 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -1084,7 +1084,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - await Assert.ThrowsAsync(async () => await readTcs.Task); + await Assert.ThrowsAsync(async () => await readTcs.Task); // The cancellation token for only the last request should be triggered. var abortedRequestId = await registrationTcs.Task; @@ -1096,6 +1096,12 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [FrameworkSkipCondition(RuntimeFrameworks.Mono, SkipReason = "Test hangs after execution on Mono.")] public async Task FailedWritesResultInAbortedRequest(ServiceContext testContext) { + const int resetEventTimeout = 2000; + // This should match _maxBytesPreCompleted in SocketOutput + const int maxBytesPreCompleted = 65536; + // Ensure string is long enough to disable write-behind buffering + var largeString = new string('a', maxBytesPreCompleted + 1); + var writeTcs = new TaskCompletionSource(); var registrationWh = new ManualResetEventSlim(); var connectionCloseWh = new ManualResetEventSlim(); @@ -1119,7 +1125,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // Ensure write is long enough to disable write-behind buffering for (int i = 0; i < 10; i++) { - await response.WriteAsync(new string('a', 65537)); + await response.WriteAsync(largeString).ConfigureAwait(false); } } catch (Exception ex) @@ -1127,7 +1133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests writeTcs.SetException(ex); // Give a chance for RequestAborted to trip before the app completes - registrationWh.Wait(1000); + registrationWh.Wait(resetEventTimeout); throw; } @@ -1141,16 +1147,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "POST / HTTP/1.1", "Content-Length: 5", "", - "Hello"); + "Hello").ConfigureAwait(false); // Don't wait to receive the response. Just close the socket. } connectionCloseWh.Set(); // Write failed - await Assert.ThrowsAsync(async () => await writeTcs.Task); + await Assert.ThrowsAsync(async () => await writeTcs.Task); // RequestAborted tripped - Assert.True(registrationWh.Wait(200)); + Assert.True(registrationWh.Wait(resetEventTimeout)); } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs index f9d1f7ac22..47d9eb2a87 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs @@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests private static void TestConcurrentFaultedTask(Task t) { Assert.True(t.IsFaulted); - Assert.IsType(typeof(System.IO.IOException), t.Exception.InnerException); + Assert.IsType(typeof(System.InvalidOperationException), t.Exception.InnerException); Assert.Equal(t.Exception.InnerException.Message, "Concurrent reads are not supported."); } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs index 01c42c5f18..ec0fe0c9a0 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketOutputTests.cs @@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var completedWh = new ManualResetEventSlim(); // Act - socketOutput.WriteAsync(buffer).ContinueWith( + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith( (t) => { Assert.Null(t.Exception); @@ -101,14 +101,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests }; // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // Too many bytes are already pre-completed for the second write to pre-complete. Assert.False(completedWh.Wait(1000)); @@ -162,28 +162,28 @@ namespace Microsoft.AspNetCore.Server.KestrelTests }; // Act - socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is not immediate. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The second write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer, false).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken), false).ContinueWith(onCompleted); // Assert // The third write should pre-complete since it is not immediate, even though too many. Assert.True(completedWh.Wait(1000)); // Arrange completedWh.Reset(); // Act - socketOutput.WriteAsync(halfBuffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(halfBuffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // Too many bytes are already pre-completed for the fourth write to pre-complete. Assert.False(completedWh.Wait(1000)); @@ -198,6 +198,116 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } + [Fact] + public async Task OnlyWritesRequestingCancellationAreErroredOnCancellation() + { + // This should match _maxBytesPreCompleted in SocketOutput + var maxBytesPreCompleted = 65536; + var completeQueue = new Queue>(); + + // Arrange + var mockLibuv = new MockLibuv + { + OnWrite = (socket, buffers, triggerCompleted) => + { + completeQueue.Enqueue(triggerCompleted); + return 0; + } + }; + + using (var kestrelEngine = new KestrelEngine(mockLibuv, new TestServiceContext())) + using (var memory = new MemoryPool2()) + { + kestrelEngine.Start(count: 1); + + var kestrelThread = kestrelEngine.Threads[0]; + var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); + var trace = new KestrelTrace(new TestKestrelTrace()); + var ltp = new LoggingThreadPool(trace); + ISocketOutput socketOutput = new SocketOutput(kestrelThread, socket, memory, null, 0, trace, ltp, new Queue()); + + var bufferSize = maxBytesPreCompleted; + + var data = new byte[bufferSize]; + var fullBuffer = new ArraySegment(data, 0, bufferSize); + + var cts = new CancellationTokenSource(); + + // Act + var task1Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + // task1 should complete sucessfully as < _maxBytesPreCompleted + + // First task is completed and sucessful + Assert.True(task1Success.IsCompleted); + Assert.False(task1Success.IsCanceled); + Assert.False(task1Success.IsFaulted); + + task1Success.GetAwaiter().GetResult(); + + // following tasks should wait. + + var task2Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + var task3Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken)); + + // Give time for tasks to perculate + await Task.Delay(2000).ConfigureAwait(false); + + // Second task is not completed + Assert.False(task2Throw.IsCompleted); + Assert.False(task2Throw.IsCanceled); + Assert.False(task2Throw.IsFaulted); + + // Third task is not completed + Assert.False(task3Success.IsCompleted); + Assert.False(task3Success.IsCanceled); + Assert.False(task3Success.IsFaulted); + + cts.Cancel(); + + // Give time for tasks to perculate + await Task.Delay(2000).ConfigureAwait(false); + + // Second task is now cancelled + Assert.True(task2Throw.IsCompleted); + Assert.True(task2Throw.IsCanceled); + Assert.False(task2Throw.IsFaulted); + + // Third task is now completed + Assert.True(task3Success.IsCompleted); + Assert.False(task3Success.IsCanceled); + Assert.False(task3Success.IsFaulted); + + // Fourth task immediately cancels as the token is cancelled + var task4Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + + Assert.True(task4Throw.IsCompleted); + Assert.True(task4Throw.IsCanceled); + Assert.False(task4Throw.IsFaulted); + + Assert.Throws(() => task4Throw.GetAwaiter().GetResult()); + + var task5Success = socketOutput.WriteAsync(fullBuffer, cancellationToken: default(CancellationToken)); + // task5 should complete immedately + + Assert.True(task5Success.IsCompleted); + Assert.False(task5Success.IsCanceled); + Assert.False(task5Success.IsFaulted); + + cts = new CancellationTokenSource(); + + var task6Throw = socketOutput.WriteAsync(fullBuffer, cancellationToken: cts.Token); + // task6 should complete immedately but not cancel as its cancelation token isn't set + + Assert.True(task6Throw.IsCompleted); + Assert.False(task6Throw.IsCanceled); + Assert.False(task6Throw.IsFaulted); + + Assert.Throws(() => task6Throw.GetAwaiter().GetResult()); + + Assert.True(true); + } + } + [Fact] public void WritesDontGetCompletedTooQuickly() { @@ -247,7 +357,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests }; // Act (Pre-complete the maximum number of bytes in preparation for the rest of the test) - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); // Assert // The first write should pre-complete since it is <= _maxBytesPreCompleted. Assert.True(completedWh.Wait(1000)); @@ -257,8 +367,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests onWriteWh.Reset(); // Act - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted); - socketOutput.WriteAsync(buffer).ContinueWith(onCompleted2); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted); + socketOutput.WriteAsync(buffer, default(CancellationToken)).ContinueWith(onCompleted2); Assert.True(onWriteWh.Wait(1000)); completeQueue.Dequeue()(0); @@ -320,7 +430,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests socketOutput.ProducingComplete(end); // A call to Write is required to ensure a write is scheduled - socketOutput.WriteAsync(default(ArraySegment)); + socketOutput.WriteAsync(default(ArraySegment), default(CancellationToken)); Assert.True(nBufferWh.Wait(1000)); Assert.Equal(2, nBuffers);