From 8da2dddd4951d0d621bbf8cde08cb8baa2bf650c Mon Sep 17 00:00:00 2001 From: David Fowler Date: Mon, 3 Apr 2017 15:25:45 -0700 Subject: [PATCH] Fix issue where multiple calls to dispose don't wait properly (#360) * Fix issue where multiple calls to dispose don't wait properly - DisposeAsync returned immediately to anyone but the first caller. This means that it was possible to end the request before properly waiting on the transport task which means writing after dispose was possible. - Added a test --- .../Internal/ConnectionState.cs | 58 +++++++++++-------- .../ConnectionManagerTests.cs | 21 +++++++ .../HttpConnectionDispatcherTests.cs | 15 ++++- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index addee46904..a8f24bc808 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -10,6 +10,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal { public class ConnectionState { + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously + // on the same task + private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); + public Connection Connection { get; set; } public IChannelConnection Application { get; } @@ -34,8 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal public async Task DisposeAsync() { - Task applicationTask = TaskCache.CompletedTask; - Task transportTask = TaskCache.CompletedTask; + Task disposeTask = TaskCache.CompletedTask; try { @@ -43,30 +46,34 @@ namespace Microsoft.AspNetCore.Sockets.Internal if (Status == ConnectionStatus.Disposed) { - return; + disposeTask = _disposeTcs.Task; } - - Status = ConnectionStatus.Disposed; - - RequestId = null; - - // If the application task is faulted, propagate the error to the transport - if (ApplicationTask?.IsFaulted == true) + else { - Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + Status = ConnectionStatus.Disposed; + + RequestId = null; + + // If the application task is faulted, propagate the error to the transport + if (ApplicationTask?.IsFaulted == true) + { + Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + } + + // If the transport task is faulted, propagate the error to the application + if (TransportTask?.IsFaulted == true) + { + Application.Output.TryComplete(TransportTask.Exception.InnerException); + } + + Connection.Dispose(); + Application.Dispose(); + + var applicationTask = ApplicationTask ?? TaskCache.CompletedTask; + var transportTask = TransportTask ?? TaskCache.CompletedTask; + + disposeTask = Task.WhenAll(applicationTask, transportTask); } - - // If the transport task is faulted, propagate the error to the application - if (TransportTask?.IsFaulted == true) - { - Application.Output.TryComplete(TransportTask.Exception.InnerException); - } - - Connection.Dispose(); - Application.Dispose(); - - applicationTask = ApplicationTask ?? applicationTask; - transportTask = TransportTask ?? transportTask; } finally { @@ -74,7 +81,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal } // REVIEW: Add a timeout so we don't wait forever - await Task.WhenAll(applicationTask, transportTask); + await disposeTask; + + // Notify all waiters that we're done disposing + _disposeTcs.TrySetResult(null); } public enum ConnectionStatus diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 76bb9d5932..46549a578c 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Xunit; @@ -100,6 +101,26 @@ namespace Microsoft.AspNetCore.Sockets.Tests await state.DisposeAsync(); } + [Fact] + public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() + { + var connectionManager = CreateConnectionManager(); + var state = connectionManager.CreateConnection(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + state.ApplicationTask = tcs.Task; + state.TransportTask = tcs.Task; + + var firstTask = state.DisposeAsync(); + var secondTask = state.DisposeAsync(); + Assert.False(firstTask.IsCompleted); + Assert.False(secondTask.IsCompleted); + + tcs.TrySetResult(null); + + await Task.WhenAll(firstTask, secondTask).OrTimeout(); + } + [Fact] public async Task DisposeInactiveConnection() { diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 7bb7d4ebbe..7b80480a84 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Internal; @@ -14,6 +15,7 @@ using Microsoft.AspNetCore.WebSockets.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; +using Microsoft.Extensions.WebSockets.Internal; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -234,9 +236,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); + var webSocketTask = Task.CompletedTask; + + if (isWebSocketRequest) + { + var ws = (TestWebSocketConnectionFeature)context1.Features.Get(); + webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask); + await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None); + } + manager.CloseConnections(); - await request1; + await webSocketTask.OrTimeout(); + + await request1.OrTimeout(); } [Fact]