From 411f44f263d9134a0e729deb02d280c0c1c8d467 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Mon, 3 Oct 2016 23:52:51 -0700 Subject: [PATCH] Reduce code duplication --- .../HttpConnectionDispatcher.cs | 164 ++++++++---------- 1 file changed, 71 insertions(+), 93 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index e6ef50c0f3..f5aea8747c 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; using System.Text; using System.Threading.Tasks; using Channels; @@ -42,8 +41,6 @@ namespace Microsoft.AspNetCore.Sockets } else { - // REVIEW: Errors? - // Get the end point mapped to this http connection var endpoint = (EndPoint)context.RequestServices.GetRequiredService(); var format = @@ -55,129 +52,72 @@ namespace Microsoft.AspNetCore.Sockets if (context.Request.Path.StartsWithSegments(path + "/sse")) { // Get the connection state for the current http context - var connectionState = GetOrCreateConnection(context); - connectionState.Connection.User = context.User; - connectionState.Connection.Metadata["transport"] = "sse"; - connectionState.Connection.Metadata.Format = format; - var sse = new ServerSentEvents(connectionState.Connection); + var state = GetOrCreateConnection(context); + state.Connection.User = context.User; + state.Connection.Metadata["transport"] = "sse"; + state.Connection.Metadata.Format = format; - // Register this transport for disconnect - RegisterDisconnect(context, connectionState.Connection); + var sse = new ServerSentEvents(state.Connection); - // Add the connection to the list - endpoint.Connections.Add(connectionState.Connection); + await DoPersistentConnection(endpoint, sse, context, state.Connection); - // Call into the end point passing the connection - var endpointTask = endpoint.OnConnected(connectionState.Connection); - - // Start the transport - var transportTask = sse.ProcessRequest(context); - - // Wait for any of them to end - await Task.WhenAny(endpointTask, transportTask); - - // Transport has ended so kill the channel - connectionState.Connection.Channel.Dispose(); - - // Wait on both to end - await Task.WhenAll(endpointTask, transportTask); - - _manager.RemoveConnection(connectionState.Connection.ConnectionId); - - endpoint.Connections.Remove(connectionState.Connection); + _manager.RemoveConnection(state.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/ws")) { // Get the connection state for the current http context - var connectionState = GetOrCreateConnection(context); - connectionState.Connection.User = context.User; - connectionState.Connection.Metadata["transport"] = "websockets"; - connectionState.Connection.Metadata.Format = format; - var ws = new WebSockets(connectionState.Connection); + var state = GetOrCreateConnection(context); + state.Connection.User = context.User; + state.Connection.Metadata["transport"] = "websockets"; + state.Connection.Metadata.Format = format; - // Register this transport for disconnect - RegisterDisconnect(context, connectionState.Connection); + var ws = new WebSockets(state.Connection); - endpoint.Connections.Add(connectionState.Connection); + await DoPersistentConnection(endpoint, ws, context, state.Connection); - // Call into the end point passing the connection - var endpointTask = endpoint.OnConnected(connectionState.Connection); - - // Start the transport - var transportTask = ws.ProcessRequest(context); - - // Wait for any of them to end - await Task.WhenAny(endpointTask, transportTask); - - // Kill the channel - connectionState.Connection.Channel.Dispose(); - - // Wait for both - await Task.WhenAll(endpointTask, transportTask); - - _manager.RemoveConnection(connectionState.Connection.ConnectionId); - - endpoint.Connections.Remove(connectionState.Connection); + _manager.RemoveConnection(state.Connection.ConnectionId); } else if (context.Request.Path.StartsWithSegments(path + "/poll")) { - var connectionId = context.Request.Query["id"]; - ConnectionState connectionState; - - bool isNewConnection = false; - if (_manager.TryGetConnection(connectionId, out connectionState)) - { - // Treat reserved connections like new ones - if (connectionState.Connection.Channel == null) - { - // REVIEW: The connection manager should encapsulate this... - connectionState.Connection.Channel = new HttpChannel(_channelFactory); - isNewConnection = true; - } - } - else - { - // Add new connection - connectionState = _manager.AddNewConnection(new HttpChannel(_channelFactory)); - isNewConnection = true; - } + bool isNewConnection; + var state = GetOrCreateConnection(context, out isNewConnection); // Mark the connection as active - connectionState.Active = true; + state.Active = true; Task endpointTask = null; // Raise OnConnected for new connections only since polls happen all the time if (isNewConnection) { - connectionState.Connection.Metadata["transport"] = "poll"; - connectionState.Connection.Metadata.Format = format; - connectionState.Connection.User = context.User; + state.Connection.Metadata["transport"] = "poll"; + state.Connection.Metadata.Format = format; + state.Connection.User = context.User; // REVIEW: This is super gross, this all needs to be cleaned up... - connectionState.Close = async () => + state.Close = async () => { - connectionState.Connection.Channel.Dispose(); + state.Connection.Channel.Dispose(); await endpointTask; - endpoint.Connections.Remove(connectionState.Connection); + endpoint.Connections.Remove(state.Connection); }; - endpoint.Connections.Add(connectionState.Connection); + endpoint.Connections.Add(state.Connection); - endpointTask = endpoint.OnConnected(connectionState.Connection); - connectionState.Connection.Metadata["endpoint"] = endpointTask; + endpointTask = endpoint.OnConnected(state.Connection); + state.Connection.Metadata["endpoint"] = endpointTask; } else { // Get the endpoint task from connection state - endpointTask = (Task)connectionState.Connection.Metadata["endpoint"]; + endpointTask = (Task)state.Connection.Metadata["endpoint"]; } - RegisterLongPollingDisconnect(context, connectionState.Connection); + RegisterLongPollingDisconnect(context, state.Connection); - var longPolling = new LongPolling(connectionState.Connection); + var longPolling = new LongPolling(state.Connection); // Start the transport var transportTask = longPolling.ProcessRequest(context); @@ -187,20 +127,48 @@ namespace Microsoft.AspNetCore.Sockets if (resultTask == endpointTask) { // Notify the long polling transport to end - connectionState.Connection.Channel.Dispose(); + state.Connection.Channel.Dispose(); await transportTask; - endpoint.Connections.Remove(connectionState.Connection); + endpoint.Connections.Remove(state.Connection); } // Mark the connection as inactive - connectionState.LastSeen = DateTimeOffset.UtcNow; - connectionState.Active = false; + state.LastSeen = DateTimeOffset.UtcNow; + state.Active = false; } } } + private static async Task DoPersistentConnection(EndPoint endpoint, + IHttpTransport transport, + HttpContext context, + Connection connection) + { + // Register this transport for disconnect + RegisterDisconnect(context, connection); + + endpoint.Connections.Add(connection); + + // Call into the end point passing the connection + var endpointTask = endpoint.OnConnected(connection); + + // Start the transport + var transportTask = transport.ProcessRequest(context); + + // Wait for any of them to end + await Task.WhenAny(endpointTask, transportTask); + + // Kill the channel + connection.Channel.Dispose(); + + // Wait for both + await Task.WhenAll(endpointTask, transportTask); + + endpoint.Connections.Remove(connection); + } + private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection) { // For long polling, we need to end the transport but not the overall connection so we write 0 bytes @@ -253,13 +221,21 @@ namespace Microsoft.AspNetCore.Sockets } private ConnectionState GetOrCreateConnection(HttpContext context) + { + bool isNewConnection; + return GetOrCreateConnection(context, out isNewConnection); + } + + private ConnectionState GetOrCreateConnection(HttpContext context, out bool isNewConnection) { var connectionId = context.Request.Query["id"]; ConnectionState connectionState; + isNewConnection = false; // There's no connection id so this is a branch new connection if (StringValues.IsNullOrEmpty(connectionId)) { + isNewConnection = true; var channel = new HttpChannel(_channelFactory); connectionState = _manager.AddNewConnection(channel); } @@ -276,7 +252,9 @@ namespace Microsoft.AspNetCore.Sockets // Reserved connection, we need to provide a channel if (connectionState.Connection.Channel == null) { + isNewConnection = true; connectionState.Connection.Channel = new HttpChannel(_channelFactory); + connectionState.Active = true; connectionState.LastSeen = DateTimeOffset.UtcNow; } }