diff --git a/samples/SocketsSample/EndPoints/ChatEndPoint.cs b/samples/SocketsSample/EndPoints/ChatEndPoint.cs index 18218b025a..f78f7a33ba 100644 --- a/samples/SocketsSample/EndPoints/ChatEndPoint.cs +++ b/samples/SocketsSample/EndPoints/ChatEndPoint.cs @@ -8,9 +8,12 @@ namespace SocketsSample { public class ChatEndPoint : EndPoint { + public ConnectionList Connections { get; } = new ConnectionList(); public override async Task OnConnected(Connection connection) { + Connections.Add(connection); + await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})"); while (true) @@ -33,6 +36,8 @@ namespace SocketsSample } } + Connections.Remove(connection); + await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})"); } diff --git a/samples/SocketsSample/EndPoints/HubEndPoint.cs b/samples/SocketsSample/EndPoints/HubEndPoint.cs new file mode 100644 index 0000000000..7e89efde20 --- /dev/null +++ b/samples/SocketsSample/EndPoints/HubEndPoint.cs @@ -0,0 +1,89 @@ +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using SocketsSample.EndPoints.Hubs; +using SocketsSample.Hubs; + +namespace SocketsSample +{ + public class HubEndPoint : RpcEndpoint, IHubConnectionContext where THub : Hub + { + private readonly AllClientProxy _all; + private readonly HubLifetimeManager _lifetimeManager; + + public HubEndPoint(HubLifetimeManager lifetimeManager, + InvocationAdapterRegistry registry, + ILoggerFactory loggerFactory, + IServiceScopeFactory serviceScopeFactory) + : base(registry, loggerFactory, serviceScopeFactory) + { + _lifetimeManager = lifetimeManager; + _all = new AllClientProxy(_lifetimeManager); + } + + public virtual IClientProxy All => _all; + + public virtual IClientProxy Client(string connectionId) + { + return new SingleClientProxy(_lifetimeManager, connectionId); + } + + public virtual IClientProxy Group(string groupName) + { + return new GroupProxy(_lifetimeManager, groupName); + } + + public virtual IClientProxy User(string userId) + { + return new UserProxy(_lifetimeManager, userId); + } + + public override async Task OnConnected(Connection connection) + { + try + { + await _lifetimeManager.OnConnectedAsync(connection); + + using (var scope = _serviceScopeFactory.CreateScope()) + { + var hub = scope.ServiceProvider.GetService() ?? Activator.CreateInstance(); + Initialize(connection, hub); + await hub.OnConnectedAsync(); + } + + await base.OnConnected(connection); + } + finally + { + using (var scope = _serviceScopeFactory.CreateScope()) + { + var hub = scope.ServiceProvider.GetService() ?? Activator.CreateInstance(); + Initialize(connection, hub); + await hub.OnDisconnectedAsync(); + } + + await _lifetimeManager.OnDisconnectedAsync(connection); + } + } + + protected override void BeforeInvoke(Connection connection, THub endpoint) + { + Initialize(connection, endpoint); + } + + private void Initialize(Connection connection, THub endpoint) + { + var hub = endpoint; + hub.Clients = this; + hub.Context = new HubCallerContext(connection); + hub.Groups = new GroupManager(connection, _lifetimeManager); + } + + protected override void AfterInvoke(Connection connection, THub endpoint) + { + // Poison the hub make sure it can't be used after invocation + } + } +} diff --git a/samples/SocketsSample/EndPoints/HubEndpoint.cs b/samples/SocketsSample/EndPoints/HubEndpoint.cs deleted file mode 100644 index 69b1429c67..0000000000 --- a/samples/SocketsSample/EndPoints/HubEndpoint.cs +++ /dev/null @@ -1,117 +0,0 @@ -using System; -using System.Collections.Generic; -using System.Threading.Tasks; -using Channels; -using Microsoft.AspNetCore.Sockets; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; -using SocketsSample.Hubs; - -namespace SocketsSample -{ - public class HubEndpoint : RpcEndpoint, IHubConnectionContext - { - private readonly ILogger _logger; - private readonly IServiceProvider _serviceProvider; - - public HubEndpoint(ILogger logger, ILogger jsonRpcLogger, IServiceProvider serviceProvider) - : base(jsonRpcLogger, serviceProvider) - { - _logger = logger; - _serviceProvider = serviceProvider; - All = new AllClientProxy(this); - } - - public IClientProxy All { get; } - - public IClientProxy Client(string connectionId) - { - return new SingleClientProxy(this, connectionId); - } - - protected override void Initialize(Connection connection, object endpoint) - { - var hub = (Hub)endpoint; - hub.Clients = this; - hub.Context = new HubCallerContext(connection.ConnectionId, connection.User); - - base.Initialize(connection, endpoint); - } - - protected override void DiscoverEndpoints() - { - // Register the chat hub - RegisterRPCEndPoint(typeof(Chat)); - } - - private class AllClientProxy : IClientProxy - { - private readonly HubEndpoint _endPoint; - - public AllClientProxy(HubEndpoint endPoint) - { - _endPoint = endPoint; - } - - public Task Invoke(string method, params object[] args) - { - // REVIEW: Thread safety - var tasks = new List(_endPoint.Connections.Count); - var message = new InvocationDescriptor - { - Method = method, - Arguments = args - }; - - // TODO: serialize once per format by providing a different stream? - foreach (var connection in _endPoint.Connections) - { - - var invocationAdapter = - _endPoint._serviceProvider - .GetRequiredService() - .GetInvocationAdapter((string)connection.Metadata["formatType"]); - - tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream())); - } - - return Task.WhenAll(tasks); - } - } - - private class SingleClientProxy : IClientProxy - { - private readonly string _connectionId; - private readonly HubEndpoint _endPoint; - - public SingleClientProxy(HubEndpoint endPoint, string connectionId) - { - _endPoint = endPoint; - _connectionId = connectionId; - } - - public Task Invoke(string method, params object[] args) - { - var connection = _endPoint.Connections[_connectionId]; - - var invocationAdapter = - _endPoint._serviceProvider - .GetRequiredService() - .GetInvocationAdapter((string)connection.Metadata["formatType"]); - - if (_endPoint._logger.IsEnabled(LogLevel.Debug)) - { - _endPoint._logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method); - } - - var message = new InvocationDescriptor - { - Method = method, - Arguments = args - }; - - return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()); - } - } - } -} diff --git a/samples/SocketsSample/EndPoints/Hubs/DefaultHubLifetimeManager.cs b/samples/SocketsSample/EndPoints/Hubs/DefaultHubLifetimeManager.cs new file mode 100644 index 0000000000..081884dd07 --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/DefaultHubLifetimeManager.cs @@ -0,0 +1,115 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Channels; +using Microsoft.AspNetCore.Sockets; + +namespace SocketsSample.EndPoints.Hubs +{ + public class DefaultHubLifetimeManager : HubLifetimeManager + { + private readonly ConnectionList _connections = new ConnectionList(); + private readonly InvocationAdapterRegistry _registry; + + public DefaultHubLifetimeManager(InvocationAdapterRegistry registry) + { + _registry = registry; + } + + public override void AddGroup(Connection connection, string groupName) + { + var groups = connection.Metadata.GetOrAdd("groups", k => new HashSet()); + + lock (groups) + { + groups.Add(groupName); + } + } + + public override void RemoveGroup(Connection connection, string groupName) + { + var groups = connection.Metadata.Get>("groups"); + + lock (groups) + { + groups.Remove(groupName); + } + } + + public override Task InvokeAll(string methodName, params object[] args) + { + return InvokeAllWhere(methodName, args, c => true); + } + + private Task InvokeAllWhere(string methodName, object[] args, Func include) + { + var tasks = new List(_connections.Count); + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + // TODO: serialize once per format by providing a different stream? + foreach (var connection in _connections) + { + if (!include(connection)) + { + continue; + } + + var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]); + + tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream())); + } + + return Task.WhenAll(tasks); + } + + public override Task InvokeConnection(string connectionId, string methodName, params object[] args) + { + var connection = _connections[connectionId]; + + var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]); + + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()); + } + + public override Task InvokeGroup(string groupName, string methodName, params object[] args) + { + return InvokeAllWhere(methodName, args, connection => + { + var groups = connection.Metadata.Get>("groups"); + return groups?.Contains(groupName) == true; + }); + } + + public override Task InvokeUser(string userId, string methodName, params object[] args) + { + return InvokeAllWhere(methodName, args, connection => + { + return connection.User.Identity.Name == userId; + }); + } + + public override Task OnConnectedAsync(Connection connection) + { + _connections.Add(connection); + return Task.CompletedTask; + } + + public override Task OnDisconnectedAsync(Connection connection) + { + _connections.Remove(connection); + return Task.CompletedTask; + } + } + +} diff --git a/samples/SocketsSample/EndPoints/Hubs/HubCallerContext.cs b/samples/SocketsSample/EndPoints/Hubs/HubCallerContext.cs new file mode 100644 index 0000000000..794b6f445a --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/HubCallerContext.cs @@ -0,0 +1,21 @@ +using System.Security.Claims; +using Microsoft.AspNetCore.Sockets; + +namespace SocketsSample.Hubs +{ + public class HubCallerContext + { + public HubCallerContext(Connection connection) + { + ConnectionId = connection.ConnectionId; + User = connection.User; + Connection = connection; + } + + public Connection Connection { get; } + + public ClaimsPrincipal User { get; } + + public string ConnectionId { get; } + } +} diff --git a/samples/SocketsSample/EndPoints/Hubs/HubLifetimeManager.cs b/samples/SocketsSample/EndPoints/Hubs/HubLifetimeManager.cs new file mode 100644 index 0000000000..d60bc544b4 --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/HubLifetimeManager.cs @@ -0,0 +1,25 @@ +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; + +namespace SocketsSample.EndPoints.Hubs +{ + public abstract class HubLifetimeManager + { + public abstract Task OnConnectedAsync(Connection connection); + + public abstract Task OnDisconnectedAsync(Connection connection); + + public abstract Task InvokeAll(string methodName, params object[] args); + + public abstract Task InvokeConnection(string connectionId, string methodName, params object[] args); + + public abstract Task InvokeGroup(string groupName, string methodName, params object[] args); + + public abstract Task InvokeUser(string userId, string methodName, params object[] args); + + public abstract void AddGroup(Connection connection, string groupName); + + public abstract void RemoveGroup(Connection connection, string groupName); + } + +} diff --git a/samples/SocketsSample/EndPoints/Hubs/IClientProxy.cs b/samples/SocketsSample/EndPoints/Hubs/IClientProxy.cs new file mode 100644 index 0000000000..d2f180a7cd --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/IClientProxy.cs @@ -0,0 +1,15 @@ +using System.Threading.Tasks; + +namespace SocketsSample.Hubs +{ + public interface IClientProxy + { + /// + /// Invokes a method on the connection(s) represented by the instance. + /// + /// name of the method to invoke + /// argumetns to pass to the client + /// A task that represents when the data has been sent to the client. + Task Invoke(string method, params object[] args); + } +} diff --git a/samples/SocketsSample/EndPoints/Hubs/IGroupManager.cs b/samples/SocketsSample/EndPoints/Hubs/IGroupManager.cs new file mode 100644 index 0000000000..12ca3c714e --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/IGroupManager.cs @@ -0,0 +1,8 @@ +namespace SocketsSample.Hubs +{ + public interface IGroupManager + { + void Add(string groupName); + void Remove(string groupName); + } +} diff --git a/samples/SocketsSample/EndPoints/Hubs/IHubConnectionContext.cs b/samples/SocketsSample/EndPoints/Hubs/IHubConnectionContext.cs new file mode 100644 index 0000000000..ee25ac84ed --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/IHubConnectionContext.cs @@ -0,0 +1,13 @@ +namespace SocketsSample.Hubs +{ + public interface IHubConnectionContext + { + IClientProxy All { get; } + + IClientProxy Client(string connectionId); + + IClientProxy Group(string groupName); + + IClientProxy User(string userId); + } +} diff --git a/samples/SocketsSample/EndPoints/Hubs/Proxies.cs b/samples/SocketsSample/EndPoints/Hubs/Proxies.cs new file mode 100644 index 0000000000..2012abb258 --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/Proxies.cs @@ -0,0 +1,98 @@ +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Sockets; +using SocketsSample.Hubs; + +namespace SocketsSample.EndPoints.Hubs +{ + public class UserProxy : IClientProxy + { + private readonly string _userId; + private readonly HubLifetimeManager _lifetimeManager; + + public UserProxy(HubLifetimeManager lifetimeManager, string userId) + { + _lifetimeManager = lifetimeManager; + _userId = userId; + } + + public Task Invoke(string method, params object[] args) + { + return _lifetimeManager.InvokeUser(_userId, method, args); + } + } + + public class GroupProxy : IClientProxy + { + private readonly string _groupName; + private readonly HubLifetimeManager _lifetimeManager; + + public GroupProxy(HubLifetimeManager lifetimeManager, string groupName) + { + _lifetimeManager = lifetimeManager; + _groupName = groupName; + } + + public Task Invoke(string method, params object[] args) + { + return _lifetimeManager.InvokeGroup(_groupName, method, args); + } + } + + public class AllClientProxy : IClientProxy + { + private readonly HubLifetimeManager _lifetimeManager; + + public AllClientProxy(HubLifetimeManager lifetimeManager) + { + _lifetimeManager = lifetimeManager; + } + + public Task Invoke(string method, params object[] args) + { + // TODO: More than just chat + return _lifetimeManager.InvokeAll(method, args); + } + } + + public class SingleClientProxy : IClientProxy + { + private readonly string _connectionId; + private readonly HubLifetimeManager _lifetimeManager; + + + public SingleClientProxy(HubLifetimeManager lifetimeManager, string connectionId) + { + _lifetimeManager = lifetimeManager; + _connectionId = connectionId; + } + + public Task Invoke(string method, params object[] args) + { + return _lifetimeManager.InvokeConnection(_connectionId, method, args); + } + } + + public class GroupManager : IGroupManager + { + private readonly Connection _connection; + private HubLifetimeManager _lifetimeManager; + + public GroupManager(Connection connection, HubLifetimeManager lifetimeManager) + { + _connection = connection; + _lifetimeManager = lifetimeManager; + } + + public void Add(string groupName) + { + _lifetimeManager.AddGroup(_connection, groupName); + } + + public void Remove(string groupName) + { + _lifetimeManager.RemoveGroup(_connection, groupName); + } + } +} diff --git a/samples/SocketsSample/EndPoints/Hubs/PubSubHubLifetimeManager.cs b/samples/SocketsSample/EndPoints/Hubs/PubSubHubLifetimeManager.cs new file mode 100644 index 0000000000..1573104b02 --- /dev/null +++ b/samples/SocketsSample/EndPoints/Hubs/PubSubHubLifetimeManager.cs @@ -0,0 +1,123 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Channels; +using Microsoft.AspNetCore.Sockets; +using SocketsSample.Hubs; + +namespace SocketsSample.EndPoints.Hubs +{ + public class PubSubHubLifetimeManager : HubLifetimeManager + { + private readonly IPubSub _bus; + private readonly InvocationAdapterRegistry _registry; + + public PubSubHubLifetimeManager(IPubSub bus, InvocationAdapterRegistry registry) + { + _bus = bus; + _registry = registry; + } + + public override Task InvokeAll(string methodName, params object[] args) + { + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + return _bus.Publish(typeof(THub).Name, message); + } + + public override Task InvokeConnection(string connectionId, string methodName, params object[] args) + { + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + return _bus.Publish(typeof(THub) + "." + connectionId, message); + } + + public override Task InvokeGroup(string groupName, string methodName, params object[] args) + { + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + return _bus.Publish(typeof(THub) + "." + groupName, message); + } + + public override Task InvokeUser(string userId, string methodName, params object[] args) + { + var message = new InvocationDescriptor + { + Method = methodName, + Arguments = args + }; + + return _bus.Publish(typeof(THub) + "." + userId, message); + } + + public override Task OnConnectedAsync(Connection connection) + { + var subs = connection.Metadata.GetOrAdd("subscriptions", k => new List()); + + subs.Add(Subscribe(typeof(THub).Name, connection)); + subs.Add(Subscribe(typeof(THub).Name + "." + connection.ConnectionId, connection)); + subs.Add(Subscribe(typeof(THub).Name + "." + connection.User.Identity.Name, connection)); + + return Task.CompletedTask; + } + + public override Task OnDisconnectedAsync(Connection connection) + { + var subs = connection.Metadata.Get>("subscriptions"); + + if (subs != null) + { + foreach (var sub in subs) + { + sub.Dispose(); + } + } + + return Task.CompletedTask; + } + + public override void AddGroup(Connection connection, string groupName) + { + var groups = connection.Metadata.GetOrAdd("groups", k => new ConcurrentDictionary()); + var key = typeof(THub).Name + "." + groupName; + groups.TryAdd(key, Subscribe(key, connection)); + } + + public override void RemoveGroup(Connection connection, string groupName) + { + var key = typeof(THub) + "." + groupName; + var groups = connection.Metadata.Get>("groups"); + + IDisposable subscription; + if (groups != null && groups.TryRemove(key, out subscription)) + { + subscription.Dispose(); + } + } + + private IDisposable Subscribe(string signal, Connection connection) + { + return _bus.Subscribe(signal, message => + { + var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]); + + return invocationAdapter.WriteInvocationDescriptor((InvocationDescriptor)message, connection.Channel.GetStream()); + }); + } + } + +} diff --git a/samples/SocketsSample/EndPoints/RpcEndpoint.cs b/samples/SocketsSample/EndPoints/RpcEndpoint.cs index 7c8ae5fccf..270cee717e 100644 --- a/samples/SocketsSample/EndPoints/RpcEndpoint.cs +++ b/samples/SocketsSample/EndPoints/RpcEndpoint.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.IO; using System.Linq; using System.Reflection; using System.Threading.Tasks; @@ -8,34 +7,26 @@ using Channels; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; -using Newtonsoft.Json; -using Newtonsoft.Json.Linq; -using SocketsSample.Protobuf; namespace SocketsSample { - public class RpcEndpoint : EndPoint + public class RpcEndpoint : EndPoint where T : class { private readonly Dictionary> _callbacks = new Dictionary>(StringComparer.OrdinalIgnoreCase); private readonly Dictionary _paramTypes = new Dictionary(); - private readonly ILogger _logger; - private readonly IServiceProvider _serviceProvider; + private readonly ILogger _logger; + private readonly InvocationAdapterRegistry _registry; + protected readonly IServiceScopeFactory _serviceScopeFactory; - - public RpcEndpoint(ILogger logger, IServiceProvider serviceProvider) + public RpcEndpoint(InvocationAdapterRegistry registry, ILoggerFactory loggerFactory, IServiceScopeFactory serviceScopeFactory) { - // TODO: Discover end points - _logger = logger; - _serviceProvider = serviceProvider; + _logger = loggerFactory.CreateLogger>(); + _registry = registry; + _serviceScopeFactory = serviceScopeFactory; - DiscoverEndpoints(); - } - - protected virtual void DiscoverEndpoints() - { - RegisterRPCEndPoint(typeof(Echo)); + RegisterRPCEndPoint(); } public override async Task OnConnected(Connection connection) @@ -44,16 +35,14 @@ namespace SocketsSample await Task.Yield(); var stream = connection.Channel.GetStream(); - var invocationAdapter = - _serviceProvider - .GetRequiredService() - .GetInvocationAdapter((string)connection.Metadata["formatType"]); + var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]); while (true) { var invocationDescriptor = await invocationAdapter.ReadInvocationDescriptor( - stream, methodName => { + stream, methodName => + { Type[] types; // TODO: null or throw? return _paramTypes.TryGetValue(methodName, out types) ? types : null; @@ -90,12 +79,19 @@ namespace SocketsSample } } - protected virtual void Initialize(Connection connection, object endpoint) + protected virtual void BeforeInvoke(Connection connection, T endpoint) { } - protected void RegisterRPCEndPoint(Type type) + protected virtual void AfterInvoke(Connection connection, T endpoint) { + + } + + protected void RegisterRPCEndPoint() + { + var type = typeof(T); + foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic)) { var methodName = type.FullName + "." + methodInfo.Name; @@ -115,21 +111,22 @@ namespace SocketsSample _callbacks[methodName] = (connection, invocationDescriptor) => { - var invocationResult = new InvocationResultDescriptor(); - invocationResult.Id = invocationDescriptor.Id; - - var scopeFactory = _serviceProvider.GetRequiredService(); - - // Scope per call so that deps injected get disposed - using (var scope = scopeFactory.CreateScope()) + var invocationResult = new InvocationResultDescriptor() { - object value = scope.ServiceProvider.GetService(type) ?? Activator.CreateInstance(type); + Id = invocationDescriptor.Id + }; - Initialize(connection, value); + using (var scope = _serviceScopeFactory.CreateScope()) + { + var value = scope.ServiceProvider.GetService() ?? Activator.CreateInstance(); + + BeforeInvoke(connection, value); try { - var args = invocationDescriptor.Arguments + var arguments = invocationDescriptor.Arguments ?? Array.Empty(); + + var args = arguments .Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType)) .ToArray(); @@ -137,12 +134,18 @@ namespace SocketsSample } catch (TargetInvocationException ex) { + _logger.LogError(0, ex, "Failed to invoke RPC method"); invocationResult.Error = ex.InnerException.Message; } catch (Exception ex) { + _logger.LogError(0, ex, "Failed to invoke RPC method"); invocationResult.Error = ex.Message; } + finally + { + AfterInvoke(connection, value); + } } return invocationResult; diff --git a/samples/SocketsSample/Hubs/Chat.cs b/samples/SocketsSample/Hubs/Chat.cs index a480807b57..fea75e62e0 100644 --- a/samples/SocketsSample/Hubs/Chat.cs +++ b/samples/SocketsSample/Hubs/Chat.cs @@ -1,20 +1,23 @@ using System; -using System.Collections.Generic; -using System.Linq; using System.Threading.Tasks; namespace SocketsSample.Hubs { public class Chat : Hub { - public void Send(string message) + public override async Task OnConnectedAsync() { - Clients.All.Invoke("Send", message); + await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " joined"); } - public Person EchoPerson(Person p) + public override async Task OnDisconnectedAsync() { - return p; + await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " left"); + } + + public Task Send(string message) + { + return Clients.All.Invoke("Send", Context.ConnectionId + ": " + message); } } } diff --git a/samples/SocketsSample/Hubs/Hub.cs b/samples/SocketsSample/Hubs/Hub.cs index 30b3ff220e..f8b6e22ed8 100644 --- a/samples/SocketsSample/Hubs/Hub.cs +++ b/samples/SocketsSample/Hubs/Hub.cs @@ -1,44 +1,23 @@ -using System; -using System.Collections.Generic; -using System.Security.Claims; -using System.Threading.Tasks; +using System.Threading.Tasks; + namespace SocketsSample.Hubs { public class Hub { + public virtual Task OnConnectedAsync() + { + return Task.CompletedTask; + } + + public virtual Task OnDisconnectedAsync() + { + return Task.CompletedTask; + } + public IHubConnectionContext Clients { get; set; } public HubCallerContext Context { get; set; } - } - public interface IHubConnectionContext - { - IClientProxy All { get; } - - IClientProxy Client(string connectionId); - } - - public interface IClientProxy - { - /// - /// Invokes a method on the connection(s) represented by the instance. - /// - /// name of the method to invoke - /// argumetns to pass to the client - /// A task that represents when the data has been sent to the client. - Task Invoke(string method, params object[] args); - } - - public class HubCallerContext - { - public HubCallerContext(string connectionId, ClaimsPrincipal user) - { - ConnectionId = connectionId; - User = user; - } - - public ClaimsPrincipal User { get; } - - public string ConnectionId { get; } + public IGroupManager Groups { get; set; } } } diff --git a/samples/SocketsSample/Hubs/PubSub.cs b/samples/SocketsSample/Hubs/PubSub.cs new file mode 100644 index 0000000000..9bbed023b3 --- /dev/null +++ b/samples/SocketsSample/Hubs/PubSub.cs @@ -0,0 +1,57 @@ +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; + +namespace SocketsSample.Hubs +{ + public interface IPubSub + { + IDisposable Subscribe(string topic, Func callback); + Task Publish(string topic, object data); + } + + public class Bus : IPubSub + { + private readonly ConcurrentDictionary>> _subscriptions = new ConcurrentDictionary>>(); + + public IDisposable Subscribe(string key, Func observer) + { + var subscriptions = _subscriptions.GetOrAdd(key, _ => new List>()); + subscriptions.Add(observer); + + return new DisposableAction(() => + { + subscriptions.Remove(observer); + }); + } + + public async Task Publish(string key, object data) + { + List> subscriptions; + if (_subscriptions.TryGetValue(key, out subscriptions)) + { + foreach (var c in subscriptions) + { + await c(data); + } + } + } + + private class DisposableAction : IDisposable + { + private Action _action; + + public DisposableAction(Action action) + { + _action = action; + } + + public void Dispose() + { + Interlocked.Exchange(ref _action, () => { }).Invoke(); + } + } + } +} diff --git a/samples/SocketsSample/InvocationDescriptor.cs b/samples/SocketsSample/InvocationDescriptor.cs index 2f5dcac8f3..c19c30851e 100644 --- a/samples/SocketsSample/InvocationDescriptor.cs +++ b/samples/SocketsSample/InvocationDescriptor.cs @@ -1,7 +1,4 @@ using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; namespace SocketsSample { @@ -12,5 +9,10 @@ namespace SocketsSample public string Method { get; set; } public object[] Arguments { get; set; } + + public override string ToString() + { + return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})"; + } } } diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index e52c53325a..f9c7f8853b 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -1,10 +1,9 @@ -using System; -using System.Collections.Generic; -using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; -using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using SocketsSample.EndPoints.Hubs; +using SocketsSample.Hubs; namespace SocketsSample { @@ -16,9 +15,13 @@ namespace SocketsSample { services.AddRouting(); - services.AddSingleton(); - services.AddSingleton(); + services.AddSingleton(); + services.AddSingleton(typeof(HubLifetimeManager<>), typeof(PubSubHubLifetimeManager<>)); + services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); + services.AddSingleton(typeof(RpcEndpoint<>), typeof(RpcEndpoint<>)); + services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -36,12 +39,11 @@ namespace SocketsSample app.UseDeveloperExceptionPage(); } - app.UseSockets(routes => { - routes.MapSocketEndpoint("/hubs"); + routes.MapSocketEndpoint>("/hubs"); routes.MapSocketEndpoint("/chat"); - routes.MapSocketEndpoint("/jsonrpc"); + routes.MapSocketEndpoint>("/jsonrpc"); }); app.UseRpc(invocationAdapters => diff --git a/samples/SocketsSample/wwwroot/hubs.html b/samples/SocketsSample/wwwroot/hubs.html index cfd3488473..c93ef33a64 100644 --- a/samples/SocketsSample/wwwroot/hubs.html +++ b/samples/SocketsSample/wwwroot/hubs.html @@ -42,11 +42,11 @@ delete calls[response.Id]; - if (response.error) { - cb.error(response.error); + if (response.Error) { + cb.error(response.Error); } else { - cb.success(response.result); + cb.success(response.Result); } } else { diff --git a/samples/SocketsSample/wwwroot/rpc.html b/samples/SocketsSample/wwwroot/rpc.html index 2090ee3120..3aaad25f12 100644 --- a/samples/SocketsSample/wwwroot/rpc.html +++ b/samples/SocketsSample/wwwroot/rpc.html @@ -16,15 +16,15 @@ ws.onmessage = function (event) { var response = JSON.parse(event.data); - var cb = calls[response.id]; + var cb = calls[response.Id]; - delete calls[response.id]; + delete calls[response.Id]; - if (response.error) { - cb.error(response.error); + if (response.Error) { + cb.error(response.Error); } else { - cb.success(response.result); + cb.success(response.Result); } }; @@ -35,7 +35,7 @@ this.invoke = function (method, args) { return new Promise((resolve, reject) => { calls[id] = { success: resolve, error: reject }; - ws.send(JSON.stringify({ method: method, params: args, id: id })); + ws.send(JSON.stringify({ method: method, arguments: args, id: id })); id++; }); }; diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs index 18b4e8eccd..85289b2d38 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionMetadata.cs @@ -1,11 +1,12 @@  -using System.Collections.Generic; +using System; +using System.Collections.Concurrent; namespace Microsoft.AspNetCore.Sockets { public class ConnectionMetadata { - private IDictionary _metadata = new Dictionary(); + private ConcurrentDictionary _metadata = new ConcurrentDictionary(); public Format Format { get; set; } = Format.Text; @@ -13,12 +14,24 @@ namespace Microsoft.AspNetCore.Sockets { get { - return _metadata[key]; + object value; + _metadata.TryGetValue(key, out value); + return value; } set { _metadata[key] = value; } } + + public T GetOrAdd(string key, Func factory) + { + return (T)_metadata.GetOrAdd(key, k => factory(k)); + } + + public T Get(string key) + { + return (T)this[key]; + } } } diff --git a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs index 507db7862f..5deae9424a 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs @@ -7,11 +7,6 @@ namespace Microsoft.AspNetCore.Sockets /// public abstract class EndPoint { - /// - /// Live list of connections for this - /// - public ConnectionList Connections { get; } = new ConnectionList(); - /// /// Called when a new connection is accepted to the endpoint /// diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 6203fd833a..12d57c73fc 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -94,12 +94,8 @@ namespace Microsoft.AspNetCore.Sockets state.Connection.Channel.Dispose(); await endpointTask; - - endpoint.Connections.Remove(state.Connection); }; - endpoint.Connections.Add(state.Connection); - endpointTask = endpoint.OnConnected(state.Connection); state.Connection.Metadata["endpoint"] = endpointTask; } @@ -124,8 +120,6 @@ namespace Microsoft.AspNetCore.Sockets state.Connection.Channel.Dispose(); await transportTask; - - endpoint.Connections.Remove(state.Connection); } // Mark the connection as inactive @@ -143,8 +137,6 @@ namespace Microsoft.AspNetCore.Sockets // Register this transport for disconnect RegisterDisconnect(context, connection); - endpoint.Connections.Add(connection); - // Call into the end point passing the connection var endpointTask = endpoint.OnConnected(connection); @@ -159,8 +151,6 @@ namespace Microsoft.AspNetCore.Sockets // Wait for both await Task.WhenAll(endpointTask, transportTask); - - endpoint.Connections.Remove(connection); } private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)