Reduce code duplication

This commit is contained in:
David Fowler 2016-10-03 23:52:51 -07:00
parent 1647432ef6
commit 411f44f263
1 changed files with 71 additions and 93 deletions

View File

@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Tasks;
using Channels;
@ -42,8 +41,6 @@ namespace Microsoft.AspNetCore.Sockets
}
else
{
// REVIEW: Errors?
// Get the end point mapped to this http connection
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
var format =
@ -55,129 +52,72 @@ namespace Microsoft.AspNetCore.Sockets
if (context.Request.Path.StartsWithSegments(path + "/sse"))
{
// Get the connection state for the current http context
var connectionState = GetOrCreateConnection(context);
connectionState.Connection.User = context.User;
connectionState.Connection.Metadata["transport"] = "sse";
connectionState.Connection.Metadata.Format = format;
var sse = new ServerSentEvents(connectionState.Connection);
var state = GetOrCreateConnection(context);
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = "sse";
state.Connection.Metadata.Format = format;
// Register this transport for disconnect
RegisterDisconnect(context, connectionState.Connection);
var sse = new ServerSentEvents(state.Connection);
// Add the connection to the list
endpoint.Connections.Add(connectionState.Connection);
await DoPersistentConnection(endpoint, sse, context, state.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
// Start the transport
var transportTask = sse.ProcessRequest(context);
// Wait for any of them to end
await Task.WhenAny(endpointTask, transportTask);
// Transport has ended so kill the channel
connectionState.Connection.Channel.Dispose();
// Wait on both to end
await Task.WhenAll(endpointTask, transportTask);
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
endpoint.Connections.Remove(connectionState.Connection);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
// Get the connection state for the current http context
var connectionState = GetOrCreateConnection(context);
connectionState.Connection.User = context.User;
connectionState.Connection.Metadata["transport"] = "websockets";
connectionState.Connection.Metadata.Format = format;
var ws = new WebSockets(connectionState.Connection);
var state = GetOrCreateConnection(context);
state.Connection.User = context.User;
state.Connection.Metadata["transport"] = "websockets";
state.Connection.Metadata.Format = format;
// Register this transport for disconnect
RegisterDisconnect(context, connectionState.Connection);
var ws = new WebSockets(state.Connection);
endpoint.Connections.Add(connectionState.Connection);
await DoPersistentConnection(endpoint, ws, context, state.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
// Start the transport
var transportTask = ws.ProcessRequest(context);
// Wait for any of them to end
await Task.WhenAny(endpointTask, transportTask);
// Kill the channel
connectionState.Connection.Channel.Dispose();
// Wait for both
await Task.WhenAll(endpointTask, transportTask);
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
endpoint.Connections.Remove(connectionState.Connection);
_manager.RemoveConnection(state.Connection.ConnectionId);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
bool isNewConnection = false;
if (_manager.TryGetConnection(connectionId, out connectionState))
{
// Treat reserved connections like new ones
if (connectionState.Connection.Channel == null)
{
// REVIEW: The connection manager should encapsulate this...
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
isNewConnection = true;
}
}
else
{
// Add new connection
connectionState = _manager.AddNewConnection(new HttpChannel(_channelFactory));
isNewConnection = true;
}
bool isNewConnection;
var state = GetOrCreateConnection(context, out isNewConnection);
// Mark the connection as active
connectionState.Active = true;
state.Active = true;
Task endpointTask = null;
// Raise OnConnected for new connections only since polls happen all the time
if (isNewConnection)
{
connectionState.Connection.Metadata["transport"] = "poll";
connectionState.Connection.Metadata.Format = format;
connectionState.Connection.User = context.User;
state.Connection.Metadata["transport"] = "poll";
state.Connection.Metadata.Format = format;
state.Connection.User = context.User;
// REVIEW: This is super gross, this all needs to be cleaned up...
connectionState.Close = async () =>
state.Close = async () =>
{
connectionState.Connection.Channel.Dispose();
state.Connection.Channel.Dispose();
await endpointTask;
endpoint.Connections.Remove(connectionState.Connection);
endpoint.Connections.Remove(state.Connection);
};
endpoint.Connections.Add(connectionState.Connection);
endpoint.Connections.Add(state.Connection);
endpointTask = endpoint.OnConnected(connectionState.Connection);
connectionState.Connection.Metadata["endpoint"] = endpointTask;
endpointTask = endpoint.OnConnected(state.Connection);
state.Connection.Metadata["endpoint"] = endpointTask;
}
else
{
// Get the endpoint task from connection state
endpointTask = (Task)connectionState.Connection.Metadata["endpoint"];
endpointTask = (Task)state.Connection.Metadata["endpoint"];
}
RegisterLongPollingDisconnect(context, connectionState.Connection);
RegisterLongPollingDisconnect(context, state.Connection);
var longPolling = new LongPolling(connectionState.Connection);
var longPolling = new LongPolling(state.Connection);
// Start the transport
var transportTask = longPolling.ProcessRequest(context);
@ -187,20 +127,48 @@ namespace Microsoft.AspNetCore.Sockets
if (resultTask == endpointTask)
{
// Notify the long polling transport to end
connectionState.Connection.Channel.Dispose();
state.Connection.Channel.Dispose();
await transportTask;
endpoint.Connections.Remove(connectionState.Connection);
endpoint.Connections.Remove(state.Connection);
}
// Mark the connection as inactive
connectionState.LastSeen = DateTimeOffset.UtcNow;
connectionState.Active = false;
state.LastSeen = DateTimeOffset.UtcNow;
state.Active = false;
}
}
}
private static async Task DoPersistentConnection(EndPoint endpoint,
IHttpTransport transport,
HttpContext context,
Connection connection)
{
// Register this transport for disconnect
RegisterDisconnect(context, connection);
endpoint.Connections.Add(connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connection);
// Start the transport
var transportTask = transport.ProcessRequest(context);
// Wait for any of them to end
await Task.WhenAny(endpointTask, transportTask);
// Kill the channel
connection.Channel.Dispose();
// Wait for both
await Task.WhenAll(endpointTask, transportTask);
endpoint.Connections.Remove(connection);
}
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
{
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
@ -253,13 +221,21 @@ namespace Microsoft.AspNetCore.Sockets
}
private ConnectionState GetOrCreateConnection(HttpContext context)
{
bool isNewConnection;
return GetOrCreateConnection(context, out isNewConnection);
}
private ConnectionState GetOrCreateConnection(HttpContext context, out bool isNewConnection)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
isNewConnection = false;
// There's no connection id so this is a branch new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
isNewConnection = true;
var channel = new HttpChannel(_channelFactory);
connectionState = _manager.AddNewConnection(channel);
}
@ -276,7 +252,9 @@ namespace Microsoft.AspNetCore.Sockets
// Reserved connection, we need to provide a channel
if (connectionState.Connection.Channel == null)
{
isNewConnection = true;
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
connectionState.Active = true;
connectionState.LastSeen = DateTimeOffset.UtcNow;
}
}