From 1647432ef6a00d6d2cf4ff323de1abd3189e17c5 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Mon, 3 Oct 2016 22:58:56 -0700 Subject: [PATCH] Added Connection property to EndPoint - Exposes a list of connections for user code to act on - The connection list is thread safe (uses a concurrent dictionary under the hood) - Removed the Bus and just used the connection list in the samples --- samples/SocketsSample/Bus.cs | 71 -------------- .../SocketsSample/EndPoints/ChatEndPoint.cs | 61 ++++++------ .../SocketsSample/EndPoints/HubEndpoint.cs | 94 ++++++++++++------- .../EndPoints/JsonRpcEndpoint.cs | 16 +--- samples/SocketsSample/Hubs/Hub.cs | 35 ------- .../ConnectionList.cs | 51 ++++++++++ .../ConnectionManager.cs | 2 +- .../ConnectionState.cs | 5 +- src/Microsoft.AspNetCore.Sockets/EndPoint.cs | 12 ++- .../HttpConnectionDispatcher.cs | 24 +++++ 10 files changed, 182 insertions(+), 189 deletions(-) delete mode 100644 samples/SocketsSample/Bus.cs create mode 100644 src/Microsoft.AspNetCore.Sockets/ConnectionList.cs diff --git a/samples/SocketsSample/Bus.cs b/samples/SocketsSample/Bus.cs deleted file mode 100644 index 95fb8f2fed..0000000000 --- a/samples/SocketsSample/Bus.cs +++ /dev/null @@ -1,71 +0,0 @@ -using System; -using System.Collections.Concurrent; -using System.Collections.Generic; -using System.Linq; -using System.Threading; -using System.Threading.Tasks; -using Channels; - -namespace Microsoft.AspNetCore.Sockets -{ - public class Message - { - public byte[] Payload { get; set; } - } - - public class Bus - { - private readonly ConcurrentDictionary>> _subscriptions = new ConcurrentDictionary>>(); - - public IDisposable Subscribe(string key, Func observer) - { - var connections = _subscriptions.GetOrAdd(key, _ => new List>()); - lock (connections) - { - connections.Add(observer); - } - - return new DisposableAction(() => - { - lock (connections) - { - connections.Remove(observer); - } - }); - } - - public async Task Publish(string key, Message message) - { - List> connections; - if (_subscriptions.TryGetValue(key, out connections)) - { - Task[] tasks = null; - lock (connections) - { - tasks = new Task[connections.Count]; - for (int i = 0; i < connections.Count; i++) - { - tasks[i] = connections[i](message); - } - } - - await Task.WhenAll(tasks); - } - } - - private class DisposableAction : IDisposable - { - private Action _action; - - public DisposableAction(Action action) - { - _action = action; - } - - public void Dispose() - { - Interlocked.Exchange(ref _action, () => { }).Invoke(); - } - } - } -} diff --git a/samples/SocketsSample/EndPoints/ChatEndPoint.cs b/samples/SocketsSample/EndPoints/ChatEndPoint.cs index 6478e5d7c1..5308d793d0 100644 --- a/samples/SocketsSample/EndPoints/ChatEndPoint.cs +++ b/samples/SocketsSample/EndPoints/ChatEndPoint.cs @@ -10,51 +10,48 @@ namespace SocketsSample { public class ChatEndPoint : EndPoint { - private Bus bus = new Bus(); - public override async Task OnConnected(Connection connection) { - await bus.Publish(nameof(ChatEndPoint), new Message - { - Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})") - }); + await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})"); - using (bus.Subscribe(nameof(ChatEndPoint), message => OnMessage(message, connection))) + + while (true) { - while (true) + var input = await connection.Channel.Input.ReadAsync(); + try { - var input = await connection.Channel.Input.ReadAsync(); - try + if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted) { - if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted) - { - break; - } + break; + } - await bus.Publish(nameof(ChatEndPoint), new Message() - { - Payload = input.ToArray() - }); - } - finally - { - connection.Channel.Input.Advance(input.End); - } + // We can avoid the copy here but we'll deal with that later + await Broadcast(input.ToArray()); + } + finally + { + connection.Channel.Input.Advance(input.End); } } - await bus.Publish(nameof(ChatEndPoint), new Message - { - Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})") - }); + await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})"); } - private async Task OnMessage(Message message, Connection connection) + private Task Broadcast(string text) { - var buffer = connection.Channel.Output.Alloc(); - var payload = message.Payload; - buffer.Write(payload); - await buffer.FlushAsync(); + return Broadcast(Encoding.UTF8.GetBytes(text)); + } + + private Task Broadcast(byte[] payload) + { + var tasks = new List(Connections.Count); + + foreach (var c in Connections) + { + tasks.Add(c.Channel.Output.WriteAsync(payload)); + } + + return Task.WhenAll(tasks); } } diff --git a/samples/SocketsSample/EndPoints/HubEndpoint.cs b/samples/SocketsSample/EndPoints/HubEndpoint.cs index eaf2f1a5ff..eb1bab8533 100644 --- a/samples/SocketsSample/EndPoints/HubEndpoint.cs +++ b/samples/SocketsSample/EndPoints/HubEndpoint.cs @@ -1,55 +1,34 @@ using System; +using System.Collections.Generic; using System.Linq; using System.Text; using System.Threading.Tasks; using Channels; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; using Newtonsoft.Json.Linq; using SocketsSample.Hubs; namespace SocketsSample { - public class HubEndpoint : JsonRpcEndpoint + public class HubEndpoint : JsonRpcEndpoint, IHubConnectionContext { private readonly ILogger _logger; - private readonly Bus _bus = new Bus(); public HubEndpoint(ILogger logger, ILogger jsonRpcLogger, IServiceProvider serviceProvider) : base(jsonRpcLogger, serviceProvider) { _logger = logger; + All = new AllClientProxy(this); } - public override Task OnConnected(Connection connection) + public IClientProxy All { get; } + + public IClientProxy Client(string connectionId) { - // TODO: Get the list of hubs and signals over the connection - - // Subscribe to the hub - _bus.Subscribe(nameof(Chat), message => OnMessage(connection, message)); - - // Subscribe to the connection id - _bus.Subscribe(connection.ConnectionId, message => OnMessage(connection, message)); - - return base.OnConnected(connection); + return new SingleClientProxy(this, connectionId); } - protected override bool HandleResponse(string connectionId, JObject response) - { - var ignore = _bus.Publish(connectionId, new Message - { - Payload = Encoding.UTF8.GetBytes(response.ToString()) - }); - - return true; - } - - private Task OnMessage(Connection connection, Message message) - { - return connection.Channel.Output.WriteAsync(message.Payload); - } - - public Task Invoke(string key, string method, object[] args) + private byte[] Pack(string method, object[] args) { var obj = new JObject(); obj["method"] = method; @@ -57,18 +36,15 @@ namespace SocketsSample if (_logger.IsEnabled(LogLevel.Debug)) { - _logger.LogDebug("Outgoing RPC invocation method '{methodName}' to {signal}", method, key); + _logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method); } - return _bus.Publish(key, new Message - { - Payload = Encoding.UTF8.GetBytes(obj.ToString()) - }); + return Encoding.UTF8.GetBytes(obj.ToString()); } protected override void Initialize(object endpoint) { - ((Hub)endpoint).Clients = new HubConnectionContext(endpoint.GetType().Name, this); + ((Hub)endpoint).Clients = this; base.Initialize(endpoint); } @@ -77,5 +53,53 @@ namespace SocketsSample // Register the chat hub RegisterJsonRPCEndPoint(typeof(Chat)); } + + private class AllClientProxy : IClientProxy + { + private readonly HubEndpoint _endPoint; + + public AllClientProxy(HubEndpoint endPoint) + { + _endPoint = endPoint; + } + + public Task Invoke(string method, params object[] args) + { + // REVIEW: Thread safety + var tasks = new List(_endPoint.Connections.Count); + + byte[] message = null; + + foreach (var connection in _endPoint.Connections) + { + if (message == null) + { + message = _endPoint.Pack(method, args); + } + + tasks.Add(connection.Channel.Output.WriteAsync(message)); + } + + return Task.WhenAll(tasks); + } + } + + private class SingleClientProxy : IClientProxy + { + private readonly string _connectionId; + private readonly HubEndpoint _endPoint; + + public SingleClientProxy(HubEndpoint endPoint, string connectionId) + { + _endPoint = endPoint; + _connectionId = connectionId; + } + + public Task Invoke(string method, params object[] args) + { + var connection = _endPoint.Connections[_connectionId]; + return connection?.Channel.Output.WriteAsync(_endPoint.Pack(method, args)); + } + } } } diff --git a/samples/SocketsSample/EndPoints/JsonRpcEndpoint.cs b/samples/SocketsSample/EndPoints/JsonRpcEndpoint.cs index 6628b5007e..8f58de9343 100644 --- a/samples/SocketsSample/EndPoints/JsonRpcEndpoint.cs +++ b/samples/SocketsSample/EndPoints/JsonRpcEndpoint.cs @@ -86,22 +86,14 @@ namespace SocketsSample response["error"] = string.Format("Unknown method '{0}'", request.Value("method")); } - if (!HandleResponse(connection.ConnectionId, response)) - { - _logger.LogDebug("Sending JSON RPC response: {data}", response); + _logger.LogDebug("Sending JSON RPC response: {data}", response); - var writer = new JsonTextWriter(new StreamWriter(stream)); - response.WriteTo(writer); - writer.Flush(); - } + var writer = new JsonTextWriter(new StreamWriter(stream)); + response.WriteTo(writer); + writer.Flush(); } } - protected virtual bool HandleResponse(string connectionId, JObject response) - { - return false; - } - protected virtual void Initialize(object endpoint) { diff --git a/samples/SocketsSample/Hubs/Hub.cs b/samples/SocketsSample/Hubs/Hub.cs index accf7e9b82..b43977fd1f 100644 --- a/samples/SocketsSample/Hubs/Hub.cs +++ b/samples/SocketsSample/Hubs/Hub.cs @@ -16,41 +16,6 @@ namespace SocketsSample.Hubs IClientProxy Client(string connectionId); } - public class HubConnectionContext : IHubConnectionContext - { - private readonly HubEndpoint _endPoint; - - public HubConnectionContext(string hubName, HubEndpoint endpoint) - { - _endPoint = endpoint; - All = new HubClientProxy(endpoint, hubName); - } - - public IClientProxy All { get; } - - public IClientProxy Client(string connectionId) - { - return new HubClientProxy(_endPoint, connectionId); - } - } - - public class HubClientProxy : IClientProxy - { - private readonly HubEndpoint _endPoint; - private readonly string _key; - - public HubClientProxy(HubEndpoint endPoint, string key) - { - _endPoint = endPoint; - _key = key; - } - - public Task Invoke(string method, params object[] args) - { - return _endPoint.Invoke(_key, method, args); - } - } - public interface IClientProxy { /// diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs new file mode 100644 index 0000000000..1550d20834 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs @@ -0,0 +1,51 @@ +using System; +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.Sockets +{ + public class ConnectionList : IReadOnlyCollection + { + private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); + + public Connection this[string connectionId] + { + get + { + Connection connection; + if (_connections.TryGetValue(connectionId, out connection)) + { + return connection; + } + return null; + } + } + + public int Count => _connections.Count; + + public void Add(Connection connection) + { + _connections.TryAdd(connection.ConnectionId, connection); + } + + public void Remove(Connection connection) + { + Connection dummy; + _connections.TryRemove(connection.ConnectionId, out dummy); + } + + public IEnumerator GetEnumerator() + { + foreach (var item in _connections) + { + yield return item.Value; + } + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index e4d696a962..72944f272d 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets ConnectionState s; if (_connections.TryRemove(c.Key, out s)) { - s.Connection.Channel.Dispose(); + s?.Close(); } else { diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs index 3386f910ec..0b21533a7d 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionState.cs @@ -4,8 +4,11 @@ namespace Microsoft.AspNetCore.Sockets { public class ConnectionState { + public Connection Connection { get; set; } + + // These are used for long polling mostly + public Action Close { get; set; } public DateTimeOffset LastSeen { get; set; } public bool Active { get; set; } = true; - public Connection Connection { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs index 918723e62a..41afc43841 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs @@ -7,8 +7,16 @@ namespace Microsoft.AspNetCore.Sockets /// public class EndPoint { - // This is a stream based API, we might just want to change to a message based API or invent framing - // over this stream based API to do a message based API + /// + /// Live list of connections for this + /// + public ConnectionList Connections { get; } = new ConnectionList(); + + /// + /// Called when a new connection is accepted to the endpoint + /// + /// The new + /// A that represents the connection lifetime. When the task completes, the connection is complete. public virtual Task OnConnected(Connection connection) { return Task.CompletedTask; diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index e60f45ff5f..e6ef50c0f3 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -64,6 +64,9 @@ namespace Microsoft.AspNetCore.Sockets // Register this transport for disconnect RegisterDisconnect(context, connectionState.Connection); + // Add the connection to the list + endpoint.Connections.Add(connectionState.Connection); + // Call into the end point passing the connection var endpointTask = endpoint.OnConnected(connectionState.Connection); @@ -80,6 +83,8 @@ namespace Microsoft.AspNetCore.Sockets await Task.WhenAll(endpointTask, transportTask); _manager.RemoveConnection(connectionState.Connection.ConnectionId); + + endpoint.Connections.Remove(connectionState.Connection); } else if (context.Request.Path.StartsWithSegments(path + "/ws")) { @@ -93,6 +98,8 @@ namespace Microsoft.AspNetCore.Sockets // Register this transport for disconnect RegisterDisconnect(context, connectionState.Connection); + endpoint.Connections.Add(connectionState.Connection); + // Call into the end point passing the connection var endpointTask = endpoint.OnConnected(connectionState.Connection); @@ -109,6 +116,8 @@ namespace Microsoft.AspNetCore.Sockets await Task.WhenAll(endpointTask, transportTask); _manager.RemoveConnection(connectionState.Connection.ConnectionId); + + endpoint.Connections.Remove(connectionState.Connection); } else if (context.Request.Path.StartsWithSegments(path + "/poll")) { @@ -144,6 +153,19 @@ namespace Microsoft.AspNetCore.Sockets connectionState.Connection.Metadata["transport"] = "poll"; connectionState.Connection.Metadata.Format = format; connectionState.Connection.User = context.User; + + // REVIEW: This is super gross, this all needs to be cleaned up... + connectionState.Close = async () => + { + connectionState.Connection.Channel.Dispose(); + + await endpointTask; + + endpoint.Connections.Remove(connectionState.Connection); + }; + + endpoint.Connections.Add(connectionState.Connection); + endpointTask = endpoint.OnConnected(connectionState.Connection); connectionState.Connection.Metadata["endpoint"] = endpointTask; } @@ -168,6 +190,8 @@ namespace Microsoft.AspNetCore.Sockets connectionState.Connection.Channel.Dispose(); await transportTask; + + endpoint.Connections.Remove(connectionState.Connection); } // Mark the connection as inactive