diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index addee46904..a8f24bc808 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -10,6 +10,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal { public class ConnectionState { + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously + // on the same task + private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); + public Connection Connection { get; set; } public IChannelConnection Application { get; } @@ -34,8 +38,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal public async Task DisposeAsync() { - Task applicationTask = TaskCache.CompletedTask; - Task transportTask = TaskCache.CompletedTask; + Task disposeTask = TaskCache.CompletedTask; try { @@ -43,30 +46,34 @@ namespace Microsoft.AspNetCore.Sockets.Internal if (Status == ConnectionStatus.Disposed) { - return; + disposeTask = _disposeTcs.Task; } - - Status = ConnectionStatus.Disposed; - - RequestId = null; - - // If the application task is faulted, propagate the error to the transport - if (ApplicationTask?.IsFaulted == true) + else { - Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + Status = ConnectionStatus.Disposed; + + RequestId = null; + + // If the application task is faulted, propagate the error to the transport + if (ApplicationTask?.IsFaulted == true) + { + Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + } + + // If the transport task is faulted, propagate the error to the application + if (TransportTask?.IsFaulted == true) + { + Application.Output.TryComplete(TransportTask.Exception.InnerException); + } + + Connection.Dispose(); + Application.Dispose(); + + var applicationTask = ApplicationTask ?? TaskCache.CompletedTask; + var transportTask = TransportTask ?? TaskCache.CompletedTask; + + disposeTask = Task.WhenAll(applicationTask, transportTask); } - - // If the transport task is faulted, propagate the error to the application - if (TransportTask?.IsFaulted == true) - { - Application.Output.TryComplete(TransportTask.Exception.InnerException); - } - - Connection.Dispose(); - Application.Dispose(); - - applicationTask = ApplicationTask ?? applicationTask; - transportTask = TransportTask ?? transportTask; } finally { @@ -74,7 +81,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal } // REVIEW: Add a timeout so we don't wait forever - await Task.WhenAll(applicationTask, transportTask); + await disposeTask; + + // Notify all waiters that we're done disposing + _disposeTcs.TrySetResult(null); } public enum ConnectionStatus diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 76bb9d5932..46549a578c 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Xunit; @@ -100,6 +101,26 @@ namespace Microsoft.AspNetCore.Sockets.Tests await state.DisposeAsync(); } + [Fact] + public async Task DisposingConnectionMultipleTimesWaitsOnConnectionClose() + { + var connectionManager = CreateConnectionManager(); + var state = connectionManager.CreateConnection(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + state.ApplicationTask = tcs.Task; + state.TransportTask = tcs.Task; + + var firstTask = state.DisposeAsync(); + var secondTask = state.DisposeAsync(); + Assert.False(firstTask.IsCompleted); + Assert.False(secondTask.IsCompleted); + + tcs.TrySetResult(null); + + await Task.WhenAll(firstTask, secondTask).OrTimeout(); + } + [Fact] public async Task DisposeInactiveConnection() { diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 7bb7d4ebbe..7b80480a84 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Internal; @@ -14,6 +15,7 @@ using Microsoft.AspNetCore.WebSockets.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; +using Microsoft.Extensions.WebSockets.Internal; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -234,9 +236,20 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status409Conflict, context2.Response.StatusCode); + var webSocketTask = Task.CompletedTask; + + if (isWebSocketRequest) + { + var ws = (TestWebSocketConnectionFeature)context1.Features.Get(); + webSocketTask = ws.Client.ExecuteAsync(frame => Task.CompletedTask); + await ws.Client.CloseAsync(new WebSocketCloseResult(WebSocketCloseStatus.NormalClosure), CancellationToken.None); + } + manager.CloseConnections(); - await request1; + await webSocketTask.OrTimeout(); + + await request1.OrTimeout(); } [Fact]