using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Channels; using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR { public class DefaultHubLifetimeManager : HubLifetimeManager { private readonly ConnectionList _connections = new ConnectionList(); private readonly InvocationAdapterRegistry _registry; public DefaultHubLifetimeManager(InvocationAdapterRegistry registry) { _registry = registry; } public override Task AddGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet()); lock (groups) { groups.Add(groupName); } return Task.CompletedTask; } public override Task RemoveGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.Get>("groups"); lock (groups) { groups.Remove(groupName); } return Task.CompletedTask; } public override Task InvokeAllAsync(string methodName, params object[] args) { return InvokeAllWhere(methodName, args, c => true); } private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); var message = new InvocationDescriptor { Method = methodName, Arguments = args }; // TODO: serialize once per format by providing a different stream? foreach (var connection in _connections) { if (!include(connection)) { continue; } var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); tasks.Add(invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream())); } return Task.WhenAll(tasks); } public override Task InvokeConnectionAsync(string connectionId, string methodName, params object[] args) { var connection = _connections[connectionId]; var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); var message = new InvocationDescriptor { Method = methodName, Arguments = args }; return invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream()); } public override Task InvokeGroupAsync(string groupName, string methodName, params object[] args) { return InvokeAllWhere(methodName, args, connection => { var groups = connection.Metadata.Get>("groups"); return groups?.Contains(groupName) == true; }); } public override Task InvokeUserAsync(string userId, string methodName, params object[] args) { return InvokeAllWhere(methodName, args, connection => { return connection.User.Identity.Name == userId; }); } public override Task OnConnectedAsync(Connection connection) { _connections.Add(connection); return Task.CompletedTask; } public override Task OnDisconnectedAsync(Connection connection) { _connections.Remove(connection); return Task.CompletedTask; } } }