250 lines
10 KiB
C#
250 lines
10 KiB
C#
using System;
|
|
using System.Text;
|
|
using System.Threading.Tasks;
|
|
using Channels;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.Extensions.DependencyInjection;
|
|
using Microsoft.Extensions.Primitives;
|
|
|
|
namespace Microsoft.AspNetCore.Sockets
|
|
{
|
|
public class HttpConnectionDispatcher
|
|
{
|
|
private readonly ConnectionManager _manager;
|
|
private readonly ChannelFactory _channelFactory;
|
|
|
|
public HttpConnectionDispatcher(ConnectionManager manager, ChannelFactory factory)
|
|
{
|
|
_manager = manager;
|
|
_channelFactory = factory;
|
|
}
|
|
|
|
public async Task Execute<TEndPoint>(string path, HttpContext context) where TEndPoint : EndPoint
|
|
{
|
|
if (context.Request.Path.StartsWithSegments(path + "/getid"))
|
|
{
|
|
await ProcessGetId(context);
|
|
}
|
|
else if (context.Request.Path.StartsWithSegments(path + "/send"))
|
|
{
|
|
await ProcessSend(context);
|
|
}
|
|
else
|
|
{
|
|
// Get the end point mapped to this http connection
|
|
var endpoint = (EndPoint)context.RequestServices.GetRequiredService<TEndPoint>();
|
|
var format =
|
|
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
|
|
? Format.Binary
|
|
: Format.Text;
|
|
|
|
// Server sent events transport
|
|
if (context.Request.Path.StartsWithSegments(path + "/sse"))
|
|
{
|
|
// Get the connection state for the current http context
|
|
var state = GetOrCreateConnection(context);
|
|
state.Connection.User = context.User;
|
|
state.Connection.Metadata["transport"] = "sse";
|
|
state.Connection.Metadata.Format = format;
|
|
|
|
var sse = new ServerSentEvents(state.Connection);
|
|
|
|
await DoPersistentConnection(endpoint, sse, context, state.Connection);
|
|
|
|
_manager.RemoveConnection(state.Connection.ConnectionId);
|
|
}
|
|
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
|
|
{
|
|
// Get the connection state for the current http context
|
|
var state = GetOrCreateConnection(context);
|
|
state.Connection.User = context.User;
|
|
state.Connection.Metadata["transport"] = "websockets";
|
|
state.Connection.Metadata.Format = format;
|
|
|
|
// TODO: this is wrong. + how does the user add their own metadata based on HttpContext
|
|
var formatType = (string)context.Request.Query["formatType"];
|
|
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
|
|
|
|
var ws = new WebSockets(state.Connection, format);
|
|
|
|
await DoPersistentConnection(endpoint, ws, context, state.Connection);
|
|
|
|
_manager.RemoveConnection(state.Connection.ConnectionId);
|
|
}
|
|
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
|
|
{
|
|
bool isNewConnection;
|
|
var state = GetOrCreateConnection(context, out isNewConnection);
|
|
|
|
// Mark the connection as active
|
|
state.Active = true;
|
|
|
|
Task endpointTask = null;
|
|
|
|
// Raise OnConnected for new connections only since polls happen all the time
|
|
if (isNewConnection)
|
|
{
|
|
state.Connection.Metadata["transport"] = "poll";
|
|
state.Connection.Metadata.Format = format;
|
|
state.Connection.User = context.User;
|
|
|
|
// REVIEW: This is super gross, this all needs to be cleaned up...
|
|
state.Close = async () =>
|
|
{
|
|
state.Connection.Channel.Dispose();
|
|
|
|
await endpointTask;
|
|
};
|
|
|
|
endpointTask = endpoint.OnConnected(state.Connection);
|
|
state.Connection.Metadata["endpoint"] = endpointTask;
|
|
}
|
|
else
|
|
{
|
|
// Get the endpoint task from connection state
|
|
endpointTask = (Task)state.Connection.Metadata["endpoint"];
|
|
}
|
|
|
|
RegisterLongPollingDisconnect(context, state.Connection);
|
|
|
|
var longPolling = new LongPolling(state.Connection);
|
|
|
|
// Start the transport
|
|
var transportTask = longPolling.ProcessRequest(context);
|
|
|
|
var resultTask = await Task.WhenAny(endpointTask, transportTask);
|
|
|
|
if (resultTask == endpointTask)
|
|
{
|
|
// Notify the long polling transport to end
|
|
state.Connection.Channel.Dispose();
|
|
|
|
await transportTask;
|
|
}
|
|
|
|
// Mark the connection as inactive
|
|
state.LastSeen = DateTimeOffset.UtcNow;
|
|
state.Active = false;
|
|
}
|
|
}
|
|
}
|
|
|
|
private static async Task DoPersistentConnection(EndPoint endpoint,
|
|
IHttpTransport transport,
|
|
HttpContext context,
|
|
Connection connection)
|
|
{
|
|
// Register this transport for disconnect
|
|
RegisterDisconnect(context, connection);
|
|
|
|
// Call into the end point passing the connection
|
|
var endpointTask = endpoint.OnConnected(connection);
|
|
|
|
// Start the transport
|
|
var transportTask = transport.ProcessRequest(context);
|
|
|
|
// Wait for any of them to end
|
|
await Task.WhenAny(endpointTask, transportTask);
|
|
|
|
// Kill the channel
|
|
connection.Channel.Dispose();
|
|
|
|
// Wait for both
|
|
await Task.WhenAll(endpointTask, transportTask);
|
|
}
|
|
|
|
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
|
|
{
|
|
// For long polling, we need to end the transport but not the overall connection so we write 0 bytes
|
|
context.RequestAborted.Register(state => ((HttpChannel)state).Output.WriteAsync(Span<byte>.Empty), connection.Channel);
|
|
}
|
|
|
|
private static void RegisterDisconnect(HttpContext context, Connection connection)
|
|
{
|
|
// We just kill the output writing as a signal to the transport that it is done
|
|
context.RequestAborted.Register(state => ((HttpChannel)state).Output.CompleteWriter(), connection.Channel);
|
|
}
|
|
|
|
private Task ProcessGetId(HttpContext context)
|
|
{
|
|
// Reserve an id for this connection
|
|
var state = _manager.ReserveConnection();
|
|
|
|
// Get the bytes for the connection id
|
|
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
|
|
|
|
// Write it out to the response with the right content length
|
|
context.Response.ContentLength = connectionIdBuffer.Length;
|
|
return context.Response.Body.WriteAsync(connectionIdBuffer, 0, connectionIdBuffer.Length);
|
|
}
|
|
|
|
private Task ProcessSend(HttpContext context)
|
|
{
|
|
var connectionId = context.Request.Query["id"];
|
|
if (StringValues.IsNullOrEmpty(connectionId))
|
|
{
|
|
throw new InvalidOperationException("Missing connection id");
|
|
}
|
|
|
|
ConnectionState state;
|
|
if (_manager.TryGetConnection(connectionId, out state))
|
|
{
|
|
// If we received an HTTP POST for the connection id and it's not an HttpChannel then fail.
|
|
// You can't write to a TCP channel directly from here.
|
|
var httpChannel = state.Connection.Channel as HttpChannel;
|
|
|
|
if (httpChannel == null)
|
|
{
|
|
throw new InvalidOperationException("No channel");
|
|
}
|
|
|
|
return context.Request.Body.CopyToAsync(httpChannel.Input);
|
|
}
|
|
|
|
throw new InvalidOperationException("Unknown connection id");
|
|
}
|
|
|
|
private ConnectionState GetOrCreateConnection(HttpContext context)
|
|
{
|
|
bool isNewConnection;
|
|
return GetOrCreateConnection(context, out isNewConnection);
|
|
}
|
|
|
|
private ConnectionState GetOrCreateConnection(HttpContext context, out bool isNewConnection)
|
|
{
|
|
var connectionId = context.Request.Query["id"];
|
|
ConnectionState connectionState;
|
|
isNewConnection = false;
|
|
|
|
// There's no connection id so this is a branch new connection
|
|
if (StringValues.IsNullOrEmpty(connectionId))
|
|
{
|
|
isNewConnection = true;
|
|
var channel = new HttpChannel(_channelFactory);
|
|
connectionState = _manager.AddNewConnection(channel);
|
|
}
|
|
else
|
|
{
|
|
// REVIEW: Fail if not reserved? Reused an existing connection id?
|
|
|
|
// There's a connection id
|
|
if (!_manager.TryGetConnection(connectionId, out connectionState))
|
|
{
|
|
throw new InvalidOperationException("Unknown connection id");
|
|
}
|
|
|
|
// Reserved connection, we need to provide a channel
|
|
if (connectionState.Connection.Channel == null)
|
|
{
|
|
isNewConnection = true;
|
|
connectionState.Connection.Channel = new HttpChannel(_channelFactory);
|
|
connectionState.Active = true;
|
|
connectionState.LastSeen = DateTimeOffset.UtcNow;
|
|
}
|
|
}
|
|
|
|
return connectionState;
|
|
}
|
|
}
|
|
}
|