From b63dd40efbb1a57ea1fd102d97e42650a24f6987 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 6 Jan 2016 17:08:25 -0800 Subject: [PATCH] Protect SocketInput against concurrent consumption --- .../Http/SocketInput.cs | 20 ++++++++++++++---- .../SocketInputTests.cs | 21 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs index efe431f943..0afad29f09 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs @@ -25,6 +25,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http private MemoryPoolBlock2 _tail; private MemoryPoolBlock2 _pinned; + private int _consumingState; + public SocketInput(MemoryPool2 memory, IThreadPool threadPool) { _memory = memory; @@ -81,10 +83,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public void IncomingComplete(int count, Exception error) { - // Unpin may called without an earlier Pin if (_pinned != null) { - _pinned.End += count; if (_head == null) @@ -133,6 +133,11 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public MemoryPoolIterator2 ConsumingStart() { + if (Interlocked.CompareExchange(ref _consumingState, 1, 0) != 0) + { + throw new InvalidOperationException("Already consuming input."); + } + return new MemoryPoolIterator2(_head); } @@ -142,6 +147,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { MemoryPoolBlock2 returnStart = null; MemoryPoolBlock2 returnEnd = null; + if (!consumed.IsDefault) { returnStart = _head; @@ -149,6 +155,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http _head = consumed.Block; _head.Start = consumed.Index; } + if (!examined.IsDefault && examined.IsEnd && RemoteIntakeFin == false && @@ -156,7 +163,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { _manualResetEvent.Reset(); - var awaitableState = Interlocked.CompareExchange( + Interlocked.CompareExchange( ref _awaitableState, _awaitableIsNotCompleted, _awaitableIsCompleted); @@ -168,6 +175,11 @@ namespace Microsoft.AspNet.Server.Kestrel.Http returnStart = returnStart.Next; returnBlock.Pool.Return(returnBlock); } + + if (Interlocked.CompareExchange(ref _consumingState, 0, 1) != 1) + { + throw new InvalidOperationException("No ongoing consuming operation to complete."); + } } public void AbortAwaiting() @@ -201,7 +213,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { _awaitableError = new InvalidOperationException("Concurrent reads are not supported."); - awaitableState = Interlocked.Exchange( + Interlocked.Exchange( ref _awaitableState, _awaitableIsCompleted); diff --git a/test/Microsoft.AspNet.Server.KestrelTests/SocketInputTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/SocketInputTests.cs index 5bb02f5c0d..5a38bf87ef 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/SocketInputTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/SocketInputTests.cs @@ -78,6 +78,27 @@ namespace Microsoft.AspNet.Server.KestrelTests } } + [Fact] + public void ConsumingOutOfOrderFailsGracefully() + { + var defultIter = new MemoryPoolIterator2(); + + // Calling ConsumingComplete without a preceding calling to ConsumingStart fails + var socketInput = new SocketInput(null, null); + Assert.Throws(() => socketInput.ConsumingComplete(defultIter, defultIter)); + + // Calling ConsumingStart twice in a row fails + socketInput = new SocketInput(null, null); + socketInput.ConsumingStart(); + Assert.Throws(() => socketInput.ConsumingStart()); + + // Calling ConsumingComplete twice in a row fails + socketInput = new SocketInput(null, null); + socketInput.ConsumingStart(); + socketInput.ConsumingComplete(defultIter, defultIter); + Assert.Throws(() => socketInput.ConsumingComplete(defultIter, defultIter)); + } + private static void TestConcurrentFaultedTask(Task t) { Assert.True(t.IsFaulted);