From bb92cc1c291cfe6e612b2ed1fb1f940745b60125 Mon Sep 17 00:00:00 2001 From: Pavel Krymets Date: Tue, 31 May 2016 09:24:46 -0700 Subject: [PATCH] Fix NRE when aborting connection or client disconects --- .../Http/SocketInput.cs | 96 +++++++++++-------- .../Infrastructure/TaskUtilities.cs | 5 + .../FrameRequestStreamTests.cs | 10 ++ 3 files changed, 73 insertions(+), 38 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs index 1b49246b92..87f82474c8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs @@ -27,9 +27,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private MemoryPoolBlock _tail; private MemoryPoolBlock _pinned; - private int _consumingState; private object _sync = new object(); + private bool _consuming; + private bool _disposed; + public SocketInput(MemoryPool memory, IThreadPool threadPool) { _memory = memory; @@ -163,12 +165,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public MemoryPoolIterator ConsumingStart() { - if (Interlocked.CompareExchange(ref _consumingState, 1, 0) != 0) + lock (_sync) { - throw new InvalidOperationException("Already consuming input."); + if (_consuming) + { + throw new InvalidOperationException("Already consuming input."); + } + _consuming = true; + return new MemoryPoolIterator(_head); } - - return new MemoryPoolIterator(_head); } public void ConsumingComplete( @@ -180,38 +185,44 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http lock (_sync) { - if (!consumed.IsDefault) + if (!_disposed) + { + if (!consumed.IsDefault) + { + returnStart = _head; + returnEnd = consumed.Block; + _head = consumed.Block; + _head.Start = consumed.Index; + } + + if (!examined.IsDefault && + examined.IsEnd && + RemoteIntakeFin == false && + _awaitableError == null) + { + _manualResetEvent.Reset(); + + Interlocked.CompareExchange( + ref _awaitableState, + _awaitableIsNotCompleted, + _awaitableIsCompleted); + } + } + else { returnStart = _head; - returnEnd = consumed.Block; - _head = consumed.Block; - _head.Start = consumed.Index; + returnEnd = null; + _head = null; + _tail = null; } - if (!examined.IsDefault && - examined.IsEnd && - RemoteIntakeFin == false && - _awaitableError == null) + ReturnBlocks(returnStart, returnEnd); + + if (!_consuming) { - _manualResetEvent.Reset(); - - Interlocked.CompareExchange( - ref _awaitableState, - _awaitableIsNotCompleted, - _awaitableIsCompleted); + throw new InvalidOperationException("No ongoing consuming operation to complete."); } - } - - while (returnStart != returnEnd) - { - var returnBlock = returnStart; - returnStart = returnStart.Next; - returnBlock.Pool.Return(returnBlock); - } - - if (Interlocked.CompareExchange(ref _consumingState, 0, 1) != 1) - { - throw new InvalidOperationException("No ongoing consuming operation to complete."); + _consuming = false; } } @@ -286,20 +297,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void Dispose() { - AbortAwaiting(); + lock (_sync) + { + AbortAwaiting(); - // Return all blocks - var block = _head; - while (block != null) + if (!_consuming) + { + ReturnBlocks(_head, null); + _head = null; + _tail = null; + } + _disposed = true; + } + } + + private static void ReturnBlocks(MemoryPoolBlock block, MemoryPoolBlock end) + { + while (block != end) { var returnBlock = block; block = block.Next; returnBlock.Pool.Return(returnBlock); } - - _head = null; - _tail = null; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs index 97f64e3e21..328c8ecb05 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/TaskUtilities.cs @@ -29,6 +29,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure public static Task GetCancelledZeroTask(CancellationToken cancellationToken = default(CancellationToken)) { #if NETSTANDARD1_3 + // Make sure cancellationToken is cancelled before passing to Task.FromCanceled + if (!cancellationToken.IsCancellationRequested) + { + cancellationToken = new CancellationToken(true); + } return Task.FromCanceled(cancellationToken); #else var tcs = new TaskCompletionSource(); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs index 10bdbf452c..b201db9b99 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs @@ -104,5 +104,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var stream = new FrameRequestStream(); await stream.FlushAsync(); } + + [Fact] + public void AbortCausesReadToCancel() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + stream.Abort(); + var task = stream.ReadAsync(new byte[1], 0, 1); + Assert.True(task.IsCanceled); + } } }