Reduce code duplication
This commit is contained in:
parent
1647432ef6
commit
411f44f263
|
|
@ -1,5 +1,4 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
|
|
@ -42,8 +41,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
else
|
||||
{
|
||||
// REVIEW: Errors?
|
||||
|
||||
// Get the end point mapped to this http connection
|
||||
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
|
||||
var format =
|
||||
|
|
@ -55,129 +52,72 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
if (context.Request.Path.StartsWithSegments(path + "/sse"))
|
||||
{
|
||||
// Get the connection state for the current http context
|
||||
var connectionState = GetOrCreateConnection(context);
|
||||
connectionState.Connection.User = context.User;
|
||||
connectionState.Connection.Metadata["transport"] = "sse";
|
||||
connectionState.Connection.Metadata.Format = format;
|
||||
var sse = new ServerSentEvents(connectionState.Connection);
|
||||
var state = GetOrCreateConnection(context);
|
||||
state.Connection.User = context.User;
|
||||
state.Connection.Metadata["transport"] = "sse";
|
||||
state.Connection.Metadata.Format = format;
|
||||
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
var sse = new ServerSentEvents(state.Connection);
|
||||
|
||||
// Add the connection to the list
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
await DoPersistentConnection(endpoint, sse, context, state.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
// Start the transport
|
||||
var transportTask = sse.ProcessRequest(context);
|
||||
|
||||
// Wait for any of them to end
|
||||
await Task.WhenAny(endpointTask, transportTask);
|
||||
|
||||
// Transport has ended so kill the channel
|
||||
connectionState.Connection.Channel.Dispose();
|
||||
|
||||
// Wait on both to end
|
||||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
_manager.RemoveConnection(state.Connection.ConnectionId);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
|
||||
{
|
||||
// Get the connection state for the current http context
|
||||
var connectionState = GetOrCreateConnection(context);
|
||||
connectionState.Connection.User = context.User;
|
||||
connectionState.Connection.Metadata["transport"] = "websockets";
|
||||
connectionState.Connection.Metadata.Format = format;
|
||||
var ws = new WebSockets(connectionState.Connection);
|
||||
var state = GetOrCreateConnection(context);
|
||||
state.Connection.User = context.User;
|
||||
state.Connection.Metadata["transport"] = "websockets";
|
||||
state.Connection.Metadata.Format = format;
|
||||
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
var ws = new WebSockets(state.Connection);
|
||||
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
await DoPersistentConnection(endpoint, ws, context, state.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
// Start the transport
|
||||
var transportTask = ws.ProcessRequest(context);
|
||||
|
||||
// Wait for any of them to end
|
||||
await Task.WhenAny(endpointTask, transportTask);
|
||||
|
||||
// Kill the channel
|
||||
connectionState.Connection.Channel.Dispose();
|
||||
|
||||
// Wait for both
|
||||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
_manager.RemoveConnection(state.Connection.ConnectionId);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
|
||||
{
|
||||
var connectionId = context.Request.Query["id"];
|
||||
ConnectionState connectionState;
|
||||
|
||||
bool isNewConnection = false;
|
||||
if (_manager.TryGetConnection(connectionId, out connectionState))
|
||||
{
|
||||
// Treat reserved connections like new ones
|
||||
if (connectionState.Connection.Channel == null)
|
||||
{
|
||||
// REVIEW: The connection manager should encapsulate this...
|
||||
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
|
||||
isNewConnection = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Add new connection
|
||||
connectionState = _manager.AddNewConnection(new HttpChannel(_channelFactory));
|
||||
isNewConnection = true;
|
||||
}
|
||||
bool isNewConnection;
|
||||
var state = GetOrCreateConnection(context, out isNewConnection);
|
||||
|
||||
// Mark the connection as active
|
||||
connectionState.Active = true;
|
||||
state.Active = true;
|
||||
|
||||
Task endpointTask = null;
|
||||
|
||||
// Raise OnConnected for new connections only since polls happen all the time
|
||||
if (isNewConnection)
|
||||
{
|
||||
connectionState.Connection.Metadata["transport"] = "poll";
|
||||
connectionState.Connection.Metadata.Format = format;
|
||||
connectionState.Connection.User = context.User;
|
||||
state.Connection.Metadata["transport"] = "poll";
|
||||
state.Connection.Metadata.Format = format;
|
||||
state.Connection.User = context.User;
|
||||
|
||||
// REVIEW: This is super gross, this all needs to be cleaned up...
|
||||
connectionState.Close = async () =>
|
||||
state.Close = async () =>
|
||||
{
|
||||
connectionState.Connection.Channel.Dispose();
|
||||
state.Connection.Channel.Dispose();
|
||||
|
||||
await endpointTask;
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
endpoint.Connections.Remove(state.Connection);
|
||||
};
|
||||
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
endpoint.Connections.Add(state.Connection);
|
||||
|
||||
endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
connectionState.Connection.Metadata["endpoint"] = endpointTask;
|
||||
endpointTask = endpoint.OnConnected(state.Connection);
|
||||
state.Connection.Metadata["endpoint"] = endpointTask;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Get the endpoint task from connection state
|
||||
endpointTask = (Task)connectionState.Connection.Metadata["endpoint"];
|
||||
endpointTask = (Task)state.Connection.Metadata["endpoint"];
|
||||
}
|
||||
|
||||
RegisterLongPollingDisconnect(context, connectionState.Connection);
|
||||
RegisterLongPollingDisconnect(context, state.Connection);
|
||||
|
||||
var longPolling = new LongPolling(connectionState.Connection);
|
||||
var longPolling = new LongPolling(state.Connection);
|
||||
|
||||
// Start the transport
|
||||
var transportTask = longPolling.ProcessRequest(context);
|
||||
|
|
@ -187,20 +127,48 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
if (resultTask == endpointTask)
|
||||
{
|
||||
// Notify the long polling transport to end
|
||||
connectionState.Connection.Channel.Dispose();
|
||||
state.Connection.Channel.Dispose();
|
||||
|
||||
await transportTask;
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
endpoint.Connections.Remove(state.Connection);
|
||||
}
|
||||
|
||||
// Mark the connection as inactive
|
||||
connectionState.LastSeen = DateTimeOffset.UtcNow;
|
||||
connectionState.Active = false;
|
||||
state.LastSeen = DateTimeOffset.UtcNow;
|
||||
state.Active = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task DoPersistentConnection(EndPoint endpoint,
|
||||
IHttpTransport transport,
|
||||
HttpContext context,
|
||||
Connection connection)
|
||||
{
|
||||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connection);
|
||||
|
||||
endpoint.Connections.Add(connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connection);
|
||||
|
||||
// Start the transport
|
||||
var transportTask = transport.ProcessRequest(context);
|
||||
|
||||
// Wait for any of them to end
|
||||
await Task.WhenAny(endpointTask, transportTask);
|
||||
|
||||
// Kill the channel
|
||||
connection.Channel.Dispose();
|
||||
|
||||
// Wait for both
|
||||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
endpoint.Connections.Remove(connection);
|
||||
}
|
||||
|
||||
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
|
||||
{
|
||||
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
|
||||
|
|
@ -253,13 +221,21 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
}
|
||||
|
||||
private ConnectionState GetOrCreateConnection(HttpContext context)
|
||||
{
|
||||
bool isNewConnection;
|
||||
return GetOrCreateConnection(context, out isNewConnection);
|
||||
}
|
||||
|
||||
private ConnectionState GetOrCreateConnection(HttpContext context, out bool isNewConnection)
|
||||
{
|
||||
var connectionId = context.Request.Query["id"];
|
||||
ConnectionState connectionState;
|
||||
isNewConnection = false;
|
||||
|
||||
// There's no connection id so this is a branch new connection
|
||||
if (StringValues.IsNullOrEmpty(connectionId))
|
||||
{
|
||||
isNewConnection = true;
|
||||
var channel = new HttpChannel(_channelFactory);
|
||||
connectionState = _manager.AddNewConnection(channel);
|
||||
}
|
||||
|
|
@ -276,7 +252,9 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Reserved connection, we need to provide a channel
|
||||
if (connectionState.Connection.Channel == null)
|
||||
{
|
||||
isNewConnection = true;
|
||||
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
|
||||
connectionState.Active = true;
|
||||
connectionState.LastSeen = DateTimeOffset.UtcNow;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue