diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index e22a171cc2..888677f3a7 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -3,7 +3,9 @@ using System; using System.Collections.Concurrent; +using System.Collections.Generic; using System.Threading; +using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Sockets.Internal; @@ -71,11 +73,8 @@ namespace Microsoft.AspNetCore.Sockets ConnectionState s; if (_connections.TryRemove(c.Key, out s)) { - s?.Close(); - } - else - { - + // REVIEW: Should we keep firing and forgetting this? + var ignore = s.DisposeAsync(); } } } @@ -86,22 +85,18 @@ namespace Microsoft.AspNetCore.Sockets // Stop firing the timer _timer.Dispose(); + var tasks = new List(); + foreach (var c in _connections) { ConnectionState s; if (_connections.TryRemove(c.Key, out s)) { - // Longpolling connections should do this - if (s.Close != null) - { - s.Close(); - } - else - { - s.Dispose(); - } + tasks.Add(s.DisposeAsync()); } } + + Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5)); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index e01b2be6e4..5dc231eaf3 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -114,57 +114,31 @@ namespace Microsoft.AspNetCore.Sockets // Mark the connection as active state.Active = true; - var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory); - - // Start the transport - var transportTask = longPolling.ProcessRequestAsync(context); - // Raise OnConnected for new connections only since polls happen all the time - var endpointTask = state.Connection.Metadata.Get("endpoint"); - if (endpointTask == null) + if (state.ApplicationTask == null) { _logger.LogDebug("Establishing new Long Polling connection: {0}", state.Connection.ConnectionId); // This will re-initialize formatType metadata, but meh... state.Connection.Metadata["transport"] = LongPollingTransport.Name; - // REVIEW: This is super gross, this all needs to be cleaned up... - state.Close = async () => - { - // Close the end point's connection - state.Connection.Dispose(); - - try - { - await endpointTask; - } - catch - { - // possibly invoked on a ThreadPool thread - } - }; - - endpointTask = endpoint.OnConnectedAsync(state.Connection); - state.Connection.Metadata["endpoint"] = endpointTask; + state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); } else { _logger.LogDebug("Resuming existing Long Polling connection: {0}", state.Connection.ConnectionId); } - var resultTask = await Task.WhenAny(endpointTask, transportTask); + var longPolling = new LongPollingTransport(state.Application.Input, _loggerFactory); - if (resultTask == endpointTask) + // Start the transport + state.TransportTask = longPolling.ProcessRequestAsync(context); + + var resultTask = await Task.WhenAny(state.ApplicationTask, state.TransportTask); + + if (resultTask == state.ApplicationTask) { - // Notify the long polling transport to end - if (endpointTask.IsFaulted) - { - state.Connection.Transport.Output.TryComplete(endpointTask.Exception.InnerException); - } - - state.Connection.Dispose(); - - await transportTask; + await state.DisposeAsync(); } // Mark the connection as inactive @@ -194,20 +168,17 @@ namespace Microsoft.AspNetCore.Sockets HttpContext context, ConnectionState state) { - // Start the transport - var transportTask = transport.ProcessRequestAsync(context); - // Call into the end point passing the connection - var endpointTask = endpoint.OnConnectedAsync(state.Connection); + state.ApplicationTask = endpoint.OnConnectedAsync(state.Connection); + + // Start the transport + state.TransportTask = transport.ProcessRequestAsync(context); // Wait for any of them to end - await Task.WhenAny(endpointTask, transportTask); + await Task.WhenAny(state.ApplicationTask, state.TransportTask); // Kill the channel - state.Dispose(); - - // Wait for both - await Task.WhenAll(endpointTask, transportTask); + await state.DisposeAsync(); } private Task ProcessNegotiate(HttpContext context) diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index d360cb2deb..fa25a232ff 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -2,16 +2,18 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Sockets.Internal { - public class ConnectionState : IDisposable + public class ConnectionState { public Connection Connection { get; set; } public IChannelConnection Application { get; } - // These are used for long polling mostly - public Action Close { get; set; } + public Task TransportTask { get; set; } + public Task ApplicationTask { get; set; } + public DateTime LastSeenUtc { get; set; } public bool Active { get; set; } = true; @@ -22,10 +24,25 @@ namespace Microsoft.AspNetCore.Sockets.Internal LastSeenUtc = DateTime.UtcNow; } - public void Dispose() + public async Task DisposeAsync() { + // If the application task is faulted, propagate the error to the transport + if (ApplicationTask.IsFaulted) + { + Connection.Transport.Output.TryComplete(ApplicationTask.Exception.InnerException); + } + + // If the transport task is faulted, propagate the error to the application + if (TransportTask.IsFaulted) + { + Application.Output.TryComplete(TransportTask.Exception.InnerException); + } + Connection.Dispose(); Application.Dispose(); + + // REVIEW: Add a timeout so we don't wait forever + await Task.WhenAll(ApplicationTask, TransportTask); } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index 7ac86dcc19..8295793ac3 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -13,14 +13,14 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void NewConnectionsHaveConnectionId() { - var connectionManager = new ConnectionManager(); var state = connectionManager.CreateConnection(); Assert.NotNull(state.Connection); Assert.NotNull(state.Connection.ConnectionId); Assert.True(state.Active); - Assert.Null(state.Close); + Assert.Null(state.ApplicationTask); + Assert.Null(state.TransportTask); Assert.NotNull(state.Connection.Transport); } @@ -83,17 +83,19 @@ namespace Microsoft.AspNetCore.Sockets.Tests var connectionManager = new ConnectionManager(); var state = connectionManager.CreateConnection(); - var task = Task.Run(async () => + state.ApplicationTask = Task.Run(async () => { - var connection = state.Connection; + Assert.False(await state.Connection.Transport.Input.WaitToReadAsync()); + }); - Assert.False(await connection.Transport.Input.WaitToReadAsync()); - Assert.True(connection.Transport.Input.Completion.IsCompleted); + state.TransportTask = Task.Run(async () => + { + Assert.False(await state.Application.Input.WaitToReadAsync()); }); connectionManager.CloseConnections(); - await task; + await state.DisposeAsync(); } } }