diff --git a/src/Microsoft.AspNetCore.Sockets/Connection.cs b/src/Microsoft.AspNetCore.Sockets/Connection.cs index f1c0c759e1..10746ae64d 100644 --- a/src/Microsoft.AspNetCore.Sockets/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets/Connection.cs @@ -12,6 +12,6 @@ namespace Microsoft.AspNetCore.Sockets public string ConnectionId { get; set; } public ClaimsPrincipal User { get; set; } public IChannel Channel { get; set; } - public IDictionary Metadata { get; } = new Dictionary(); + public IDictionary Metadata { get; } = new Dictionary(); } } diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 0843a9bc62..08febfc486 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -63,13 +63,16 @@ namespace Microsoft.AspNetCore.Sockets var endpointTask = endpoint.OnConnected(connectionState.Connection); // Start the transport - await sse.ProcessRequest(context); + var transportTask = sse.ProcessRequest(context); + + // Wait for any of them to end + await Task.WhenAny(endpointTask, transportTask); // Transport has ended so kill the channel connectionState.Connection.Channel.Dispose(); - // Wait for the user code to unwind - await endpointTask; + // Wait on both to end + await Task.WhenAll(endpointTask, transportTask); _manager.RemoveConnection(connectionState.Connection.ConnectionId); } @@ -88,13 +91,16 @@ namespace Microsoft.AspNetCore.Sockets var endpointTask = endpoint.OnConnected(connectionState.Connection); // Start the transport - await ws.ProcessRequest(context); + var transportTask = ws.ProcessRequest(context); - // Transport has ended so kill the channel + // Wait for any of them to end + await Task.WhenAny(endpointTask, transportTask); + + // Kill the channel connectionState.Connection.Channel.Dispose(); - // Wait for the user code to unwind - await endpointTask; + // Wait for both + await Task.WhenAll(endpointTask, transportTask); _manager.RemoveConnection(connectionState.Connection.ConnectionId); } @@ -124,13 +130,20 @@ namespace Microsoft.AspNetCore.Sockets // Mark the connection as active connectionState.Active = true; + Task endpointTask = null; + // Raise OnConnected for new connections only since polls happen all the time if (isNewConnection) { connectionState.Connection.Metadata["transport"] = "poll"; connectionState.Connection.User = context.User; - // REVIEW: We should await this task after disposing the connection - var ignore = endpoint.OnConnected(connectionState.Connection); + endpointTask = endpoint.OnConnected(connectionState.Connection); + connectionState.Connection.Metadata["endpoint"] = endpointTask; + } + else + { + // Get the endpoint task from connection state + endpointTask = (Task)connectionState.Connection.Metadata["endpoint"]; } RegisterLongPollingDisconnect(context, connectionState.Connection); @@ -138,7 +151,17 @@ namespace Microsoft.AspNetCore.Sockets var longPolling = new LongPolling(connectionState.Connection); // Start the transport - await longPolling.ProcessRequest(context); + var transportTask = longPolling.ProcessRequest(context); + + var resultTask = await Task.WhenAny(endpointTask, transportTask); + + if (resultTask == endpointTask) + { + // Notify the long polling transport to end + connectionState.Connection.Channel.Dispose(); + + await transportTask; + } // Mark the connection as inactive connectionState.LastSeen = DateTimeOffset.UtcNow; diff --git a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs index a61f11a90b..618ccb1b63 100644 --- a/src/Microsoft.AspNetCore.Sockets/LongPolling.cs +++ b/src/Microsoft.AspNetCore.Sockets/LongPolling.cs @@ -22,7 +22,8 @@ namespace Microsoft.AspNetCore.Sockets if (buffer.IsEmpty && _channel.Output.Reading.IsCompleted) { - // REVIEW: Set the status code here so the client doesn't reconnect + // Client should stop if it receives a 204 + context.Response.StatusCode = 204; return; }