// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Redis.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using StackExchange.Redis; namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable where THub : Hub { private readonly HubConnectionStore _connections = new HubConnectionStore(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(StringComparer.Ordinal); private IConnectionMultiplexer _redisServerConnection; private ISubscriber _bus; private readonly ILogger _logger; private readonly RedisOptions _options; private readonly RedisChannels _channels; private readonly string _serverName = GenerateServerName(); private readonly RedisProtocol _protocol; private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1); private readonly AckHandler _ackHandler; private int _internalId; public RedisHubLifetimeManager(ILogger> logger, IOptions options, IHubProtocolResolver hubProtocolResolver) { _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); _channels = new RedisChannels(typeof(THub).FullName); _protocol = new RedisProtocol(hubProtocolResolver.AllProtocols); RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName); _ = EnsureRedisServerConnection(); } public override async Task OnConnectedAsync(HubConnectionContext connection) { await EnsureRedisServerConnection(); var feature = new RedisFeature(); connection.Features.Set(feature); var redisSubscriptions = feature.Subscriptions; var connectionTask = Task.CompletedTask; var userTask = Task.CompletedTask; _connections.Add(connection); connectionTask = SubscribeToConnection(connection, redisSubscriptions); if (!string.IsNullOrEmpty(connection.UserIdentifier)) { userTask = SubscribeToUser(connection, redisSubscriptions); } await Task.WhenAll(connectionTask, userTask); } public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); var tasks = new List(); var feature = connection.Features.Get(); var redisSubscriptions = feature.Subscriptions; if (redisSubscriptions != null) { foreach (var subscription in redisSubscriptions) { RedisLog.Unsubscribe(_logger, subscription); tasks.Add(_bus.UnsubscribeAsync(subscription)); } } var groupNames = feature.Groups; if (groupNames != null) { // Copy the groups to an array here because they get removed from this collection // in RemoveFromGroupAsync foreach (var group in groupNames.ToArray()) { // Use RemoveGroupAsyncCore because the connection is local and we don't want to // accidentally go to other servers with our remove request. tasks.Add(RemoveGroupAsyncCore(connection, group)); } } return Task.WhenAll(tasks); } public override Task SendAllAsync(string methodName, object[] args) { var message = _protocol.WriteInvocation(methodName, args); return PublishAsync(_channels.All, message); } public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds) { var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); return PublishAsync(_channels.All, message); } public override Task SendConnectionAsync(string connectionId, string methodName, object[] args) { if (connectionId == null) { throw new ArgumentNullException(nameof(connectionId)); } // If the connection is local we can skip sending the message through the bus since we require sticky connections. // This also saves serializing and deserializing the message! var connection = _connections[connectionId]; if (connection != null) { return connection.WriteAsync(new InvocationMessage(methodName, null, args)).AsTask(); } var message = _protocol.WriteInvocation(methodName, args); return PublishAsync(_channels.Connection(connectionId), message); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) { if (groupName == null) { throw new ArgumentNullException(nameof(groupName)); } var message = _protocol.WriteInvocation(methodName, args); return PublishAsync(_channels.Group(groupName), message); } public override async Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedConnectionIds) { if (groupName == null) { throw new ArgumentNullException(nameof(groupName)); } var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); await PublishAsync(_channels.Group(groupName), message); } public override Task SendUserAsync(string userId, string methodName, object[] args) { var message = _protocol.WriteInvocation(methodName, args); return PublishAsync(_channels.User(userId), message); } public override async Task AddToGroupAsync(string connectionId, string groupName) { if (connectionId == null) { throw new ArgumentNullException(nameof(connectionId)); } if (groupName == null) { throw new ArgumentNullException(nameof(groupName)); } var connection = _connections[connectionId]; if (connection != null) { // short circuit if connection is on this server await AddGroupAsyncCore(connection, groupName); return; } await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add); } public override async Task RemoveFromGroupAsync(string connectionId, string groupName) { if (connectionId == null) { throw new ArgumentNullException(nameof(connectionId)); } if (groupName == null) { throw new ArgumentNullException(nameof(groupName)); } var connection = _connections[connectionId]; if (connection != null) { // short circuit if connection is on this server await RemoveGroupAsyncCore(connection, groupName); return; } await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove); } public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args) { if (connectionIds == null) { throw new ArgumentNullException(nameof(connectionIds)); } var publishTasks = new List(connectionIds.Count); var payload = _protocol.WriteInvocation(methodName, args); foreach (var connectionId in connectionIds) { publishTasks.Add(PublishAsync(_channels.Connection(connectionId), payload)); } return Task.WhenAll(publishTasks); } public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args) { if (groupNames == null) { throw new ArgumentNullException(nameof(groupNames)); } var publishTasks = new List(groupNames.Count); var payload = _protocol.WriteInvocation(methodName, args); foreach (var groupName in groupNames) { if (!string.IsNullOrEmpty(groupName)) { publishTasks.Add(PublishAsync(_channels.Group(groupName), payload)); } } return Task.WhenAll(publishTasks); } public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args) { if (userIds.Count > 0) { var payload = _protocol.WriteInvocation(methodName, args); var publishTasks = new List(userIds.Count); foreach (var userId in userIds) { if (!string.IsNullOrEmpty(userId)) { publishTasks.Add(PublishAsync(_channels.User(userId), payload)); } } return Task.WhenAll(publishTasks); } return Task.CompletedTask; } private async Task PublishAsync(string channel, byte[] payload) { await EnsureRedisServerConnection(); RedisLog.PublishToChannel(_logger, channel); await _bus.PublishAsync(channel, payload); } private async Task AddGroupAsyncCore(HubConnectionContext connection, string groupName) { var feature = connection.Features.Get(); var groupNames = feature.Groups; lock (groupNames) { // Connection already in group if (!groupNames.Add(groupName)) { return; } } var groupChannel = _channels.Group(groupName); var group = _groups.GetOrAdd(groupChannel, _ => new GroupData()); await group.Lock.WaitAsync(); try { group.Connections.Add(connection); // Subscribe once if (group.Connections.Count > 1) { return; } await SubscribeToGroup(groupChannel, group); } finally { group.Lock.Release(); } } /// /// This takes because we want to remove the connection from the /// _connections list in OnDisconnectedAsync and still be able to remove groups with this method. /// private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string groupName) { var groupChannel = _channels.Group(groupName); if (!_groups.TryGetValue(groupChannel, out var group)) { return; } var feature = connection.Features.Get(); var groupNames = feature.Groups; if (groupNames != null) { lock (groupNames) { groupNames.Remove(groupName); } } await group.Lock.WaitAsync(); try { if (group.Connections.Count > 0) { group.Connections.Remove(connection); if (group.Connections.Count == 0) { RedisLog.Unsubscribe(_logger, groupChannel); await _bus.UnsubscribeAsync(groupChannel); } } } finally { group.Lock.Release(); } } private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) { var id = Interlocked.Increment(ref _internalId); var ack = _ackHandler.CreateAck(id); // Send Add/Remove Group to other servers and wait for an ack or timeout var message = _protocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId)); await PublishAsync(_channels.GroupManagement, message); await ack; } public void Dispose() { _bus?.UnsubscribeAll(); _redisServerConnection?.Dispose(); _ackHandler.Dispose(); } private void SubscribeToAll() { RedisLog.Subscribing(_logger, _channels.All); _bus.Subscribe(_channels.All, async (c, data) => { try { RedisLog.ReceivedFromChannel(_logger, _channels.All); var invocation = _protocol.ReadInvocation((byte[])data); var tasks = new List(_connections.Count); foreach (var connection in _connections) { if (invocation.ExcludedConnectionIds == null || !invocation.ExcludedConnectionIds.Contains(connection.ConnectionId)) { tasks.Add(connection.WriteAsync(invocation.Message).AsTask()); } } await Task.WhenAll(tasks); } catch (Exception ex) { RedisLog.FailedWritingMessage(_logger, ex); } }); } private void SubscribeToGroupManagementChannel() { _bus.Subscribe(_channels.GroupManagement, async (c, data) => { try { var groupMessage = _protocol.ReadGroupCommand((byte[])data); var connection = _connections[groupMessage.ConnectionId]; if (connection == null) { // user not on this server return; } if (groupMessage.Action == GroupAction.Remove) { await RemoveGroupAsyncCore(connection, groupMessage.GroupName); } if (groupMessage.Action == GroupAction.Add) { await AddGroupAsyncCore(connection, groupMessage.GroupName); } // Send an ack to the server that sent the original command. await PublishAsync(_channels.Ack(groupMessage.ServerName), _protocol.WriteAck(groupMessage.Id)); } catch (Exception ex) { RedisLog.InternalMessageFailed(_logger, ex); } }); } private void SubscribeToAckChannel() { // Create server specific channel in order to send an ack to a single server _bus.Subscribe(_channels.Ack(_serverName), (c, data) => { var ackId = _protocol.ReadAck((byte[])data); _ackHandler.TriggerAck(ackId); }); } private Task SubscribeToConnection(HubConnectionContext connection, HashSet redisSubscriptions) { var connectionChannel = _channels.Connection(connection.ConnectionId); redisSubscriptions.Add(connectionChannel); RedisLog.Subscribing(_logger, connectionChannel); return _bus.SubscribeAsync(connectionChannel, async (c, data) => { var invocation = _protocol.ReadInvocation((byte[])data); await connection.WriteAsync(invocation.Message); }); } private Task SubscribeToUser(HubConnectionContext connection, HashSet redisSubscriptions) { var userChannel = _channels.User(connection.UserIdentifier); redisSubscriptions.Add(userChannel); // TODO: Look at optimizing (looping over connections checking for Name) return _bus.SubscribeAsync(userChannel, async (c, data) => { var invocation = _protocol.ReadInvocation((byte[])data); await connection.WriteAsync(invocation.Message); }); } private Task SubscribeToGroup(string groupChannel, GroupData group) { RedisLog.Subscribing(_logger, groupChannel); return _bus.SubscribeAsync(groupChannel, async (c, data) => { try { var invocation = _protocol.ReadInvocation((byte[])data); var tasks = new List(); foreach (var groupConnection in group.Connections) { if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true) { continue; } tasks.Add(groupConnection.WriteAsync(invocation.Message).AsTask()); } await Task.WhenAll(tasks); } catch (Exception ex) { RedisLog.FailedWritingMessage(_logger, ex); } }); } private async Task EnsureRedisServerConnection() { if (_redisServerConnection == null) { await _connectionLock.WaitAsync(); try { if (_redisServerConnection == null) { var writer = new LoggerTextWriter(_logger); _redisServerConnection = await _options.ConnectAsync(writer); _bus = _redisServerConnection.GetSubscriber(); _redisServerConnection.ConnectionRestored += (_, e) => { // We use the subscription connection type // Ignore messages from the interactive connection (avoids duplicates) if (e.ConnectionType == ConnectionType.Interactive) { return; } RedisLog.ConnectionRestored(_logger); }; _redisServerConnection.ConnectionFailed += (_, e) => { // We use the subscription connection type // Ignore messages from the interactive connection (avoids duplicates) if (e.ConnectionType == ConnectionType.Interactive) { return; } RedisLog.ConnectionFailed(_logger, e.Exception); }; if (_redisServerConnection.IsConnected) { RedisLog.Connected(_logger); } else { RedisLog.NotConnected(_logger); } SubscribeToAll(); SubscribeToGroupManagementChannel(); SubscribeToAckChannel(); } } finally { _connectionLock.Release(); } } } private static string GenerateServerName() { // Use the machine name for convenient diagnostics, but add a guid to make it unique. // Example: MyServerName_02db60e5fab243b890a847fa5c4dcb29 return $"{Environment.MachineName}_{Guid.NewGuid():N}"; } private class LoggerTextWriter : TextWriter { private readonly ILogger _logger; public LoggerTextWriter(ILogger logger) { _logger = logger; } public override Encoding Encoding => Encoding.UTF8; public override void Write(char value) { } public override void WriteLine(string value) { RedisLog.ConnectionMultiplexerMessage(_logger, value); } } private class GroupData { public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1); public readonly HubConnectionStore Connections = new HubConnectionStore(); } private interface IRedisFeature { HashSet Subscriptions { get; } HashSet Groups { get; } } private class RedisFeature : IRedisFeature { public HashSet Subscriptions { get; } = new HashSet(); public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); } } }