diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 2b1b30ca92..242db513f6 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -42,38 +42,58 @@ namespace Microsoft.AspNetCore.Sockets } else { + // REVIEW: Errors? + // Get the end point mapped to this http connection var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); // Server sent events transport if (context.Request.Path.StartsWithSegments(path + "/sse")) { + // Get the connection state for the current http context var connectionState = GetOrCreateConnection(context); - var sse = new ServerSentEvents((HttpChannel)connectionState.Connection.Channel); + var channel = (HttpChannel)connectionState.Connection.Channel; + var sse = new ServerSentEvents(channel); + // Register this transport for disconnect RegisterDisconnect(context, sse); - var ignore = endpoint.OnConnected(connectionState.Connection); + // Call into the end point passing the connection + var endpointTask = endpoint.OnConnected(connectionState.Connection); + // Start the transport await sse.ProcessRequest(context); + // Transport has ended so kill the channel connectionState.Connection.Channel.Dispose(); + // Wait for the user code to unwind + await endpointTask; + _manager.RemoveConnection(connectionState.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/ws")) { + // Get the connection state for the current http context var connectionState = GetOrCreateConnection(context); - var ws = new WebSockets((HttpChannel)connectionState.Connection.Channel); + var channel = (HttpChannel)connectionState.Connection.Channel; + var ws = new WebSockets(channel); + // Register this transport for disconnect RegisterDisconnect(context, ws); - var ignore = endpoint.OnConnected(connectionState.Connection); + // Call into the end point passing the connection + var endpointTask = endpoint.OnConnected(connectionState.Connection); + // Start the transport await ws.ProcessRequest(context); + // Transport has ended so kill the channel connectionState.Connection.Channel.Dispose(); + // Wait for the user code to unwind + await endpointTask; + _manager.RemoveConnection(connectionState.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/poll")) @@ -87,12 +107,10 @@ namespace Microsoft.AspNetCore.Sockets // Treat reserved connections like new ones if (connectionState.Connection.Channel == null) { - var channel = new HttpChannel(_channelFactory); - // REVIEW: The connection manager should encapsulate this... connectionState.Active = true; connectionState.LastSeen = DateTimeOffset.UtcNow; - connectionState.Connection.Channel = channel; + connectionState.Connection.Channel = new HttpChannel(_channelFactory); isNewConnection = true; } } @@ -106,13 +124,17 @@ namespace Microsoft.AspNetCore.Sockets // Raise OnConnected for new connections only since polls happen all the time if (isNewConnection) { + // REVIEW: We should await this task after disposing the connection var ignore = endpoint.OnConnected(connectionState.Connection); } - var longPolling = new LongPolling((HttpChannel)connectionState.Connection.Channel); + var channel = (HttpChannel)connectionState.Connection.Channel; + var longPolling = new LongPolling(channel); + // Register this transport for disconnect RegisterDisconnect(context, longPolling); + // Start the transport await longPolling.ProcessRequest(context); _manager.MarkConnectionInactive(connectionState.Connection.ConnectionId);