diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 42f7fdc713..e72e20b567 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -88,108 +88,66 @@ namespace Microsoft.AspNetCore.SignalR.Redis } _bus = _redisServerConnection.GetSubscriber(); - var channelName = _channelNamePrefix; - _logger.Subscribing(channelName); - _bus.Subscribe(channelName, async (c, data) => + SubscribeToHub(); + SubscribeToAllExcept(); + SubscribeToInternalGroup(); + SubscribeToInternalServerName(); + } + + public override Task OnConnectedAsync(HubConnectionContext connection) + { + var feature = new RedisFeature(); + connection.Features.Set(feature); + + var redisSubscriptions = feature.Subscriptions; + var connectionTask = Task.CompletedTask; + var userTask = Task.CompletedTask; + + _connections.Add(connection); + + connectionTask = SubscribeToConnection(connection, redisSubscriptions); + + if (!string.IsNullOrEmpty(connection.UserIdentifier)) { - try - { - _logger.ReceivedFromChannel(channelName); + userTask = SubscribeToUser(connection, redisSubscriptions); + } - var message = DeserializeMessage(data); + return Task.WhenAll(connectionTask, userTask); + } - var tasks = new List(_connections.Count); + public override Task OnDisconnectedAsync(HubConnectionContext connection) + { + _connections.Remove(connection); - foreach (var connection in _connections) - { - tasks.Add(WriteAsync(connection, message)); - } + var tasks = new List(); - await Task.WhenAll(tasks); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); + var feature = connection.Features.Get(); - channelName = _channelNamePrefix + ".AllExcept"; - _logger.Subscribing(channelName); - _bus.Subscribe(channelName, async (c, data) => + var redisSubscriptions = feature.Subscriptions; + if (redisSubscriptions != null) { - try + foreach (var subscription in redisSubscriptions) { - _logger.ReceivedFromChannel(channelName); - - var message = DeserializeMessage(data); - var excludedIds = message.ExcludedIds; - - var tasks = new List(_connections.Count); - - foreach (var connection in _connections) - { - if (!excludedIds.Contains(connection.ConnectionId)) - { - tasks.Add(WriteAsync(connection, message)); - } - } - - await Task.WhenAll(tasks); + _logger.Unsubscribe(subscription); + tasks.Add(_bus.UnsubscribeAsync(subscription)); } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); + } - channelName = _channelNamePrefix + ".internal.group"; - _bus.Subscribe(channelName, async (c, data) => + var groupNames = feature.Groups; + + if (groupNames != null) { - try + // Copy the groups to an array here because they get removed from this collection + // in RemoveGroupAsync + foreach (var group in groupNames.ToArray()) { - var groupMessage = DeserializeMessage(data); - - var connection = _connections[groupMessage.ConnectionId]; - if (connection == null) - { - // user not on this server - return; - } - - if (groupMessage.Action == GroupAction.Remove) - { - await RemoveGroupAsyncCore(connection, groupMessage.Group); - } - - if (groupMessage.Action == GroupAction.Add) - { - await AddGroupAsyncCore(connection, groupMessage.Group); - } - - // Sending ack to server that sent the original add/remove - await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage - { - Action = GroupAction.Ack, - Id = groupMessage.Id - }); + // Use RemoveGroupAsyncCore because the connection is local and we don't want to + // accidentally go to other servers with our remove request. + tasks.Add(RemoveGroupAsyncCore(connection, group)); } - catch (Exception ex) - { - _logger.InternalMessageFailed(ex); - } - }); + } - // Create server specific channel in order to send an ack to a single server - var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}"; - _bus.Subscribe(serverChannel, (c, data) => - { - var groupMessage = DeserializeMessage(data); - - if (groupMessage.Action == GroupAction.Ack) - { - _ackHandler.TriggerAck(groupMessage.Id); - } - }); + return Task.WhenAll(tasks); } public override Task InvokeAllAsync(string methodName, object[] args) @@ -259,94 +217,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis await _bus.PublishAsync(channel, payload); } - public override Task OnConnectedAsync(HubConnectionContext connection) - { - var feature = new RedisFeature(); - connection.Features.Set(feature); - - var redisSubscriptions = feature.Subscriptions; - var connectionTask = Task.CompletedTask; - var userTask = Task.CompletedTask; - - _connections.Add(connection); - - var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; - redisSubscriptions.Add(connectionChannel); - - _logger.Subscribing(connectionChannel); - connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) => - { - try - { - var message = DeserializeMessage(data); - - await WriteAsync(connection, message); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); - - if (!string.IsNullOrEmpty(connection.UserIdentifier)) - { - var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; - redisSubscriptions.Add(userChannel); - - // TODO: Look at optimizing (looping over connections checking for Name) - userTask = _bus.SubscribeAsync(userChannel, async (c, data) => - { - try - { - var message = DeserializeMessage(data); - - await WriteAsync(connection, message); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); - } - - return Task.WhenAll(connectionTask, userTask); - } - - public override Task OnDisconnectedAsync(HubConnectionContext connection) - { - _connections.Remove(connection); - - var tasks = new List(); - - var feature = connection.Features.Get(); - - var redisSubscriptions = feature.Subscriptions; - if (redisSubscriptions != null) - { - foreach (var subscription in redisSubscriptions) - { - _logger.Unsubscribe(subscription); - tasks.Add(_bus.UnsubscribeAsync(subscription)); - } - } - - var groupNames = feature.Groups; - - if (groupNames != null) - { - // Copy the groups to an array here because they get removed from this collection - // in RemoveGroupAsync - foreach (var group in groupNames.ToArray()) - { - // Use RemoveGroupAsyncCore because the connection is local and we don't want to - // accidentally go to other servers with our remove request. - tasks.Add(RemoveGroupAsyncCore(connection, group)); - } - } - - return Task.WhenAll(tasks); - } - public override async Task AddGroupAsync(string connectionId, string groupName) { if (connectionId == null) @@ -398,26 +268,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - _logger.Subscribing(groupChannel); - await _bus.SubscribeAsync(groupChannel, async (c, data) => - { - try - { - var message = DeserializeMessage(data); - - var tasks = new List(group.Connections.Count); - foreach (var groupConnection in group.Connections) - { - tasks.Add(WriteAsync(groupConnection, message)); - } - - await Task.WhenAll(tasks); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); + await SubscribeToGroup(groupChannel, group); } finally { @@ -519,7 +370,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis _ackHandler.Dispose(); } - private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage) + private static async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage) { while (await connection.Output.WaitToWriteAsync()) { @@ -544,6 +395,186 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } + private void SubscribeToHub() + { + _logger.Subscribing(_channelNamePrefix); + _bus.Subscribe(_channelNamePrefix, async (c, data) => + { + try + { + _logger.ReceivedFromChannel(_channelNamePrefix); + + var message = DeserializeMessage(data); + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + tasks.Add(WriteAsync(connection, message)); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + }); + } + + private void SubscribeToAllExcept() + { + var channelName = _channelNamePrefix + ".AllExcept"; + _logger.Subscribing(channelName); + _bus.Subscribe(channelName, async (c, data) => + { + try + { + _logger.ReceivedFromChannel(channelName); + + var message = DeserializeMessage(data); + var excludedIds = message.ExcludedIds; + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + if (!excludedIds.Contains(connection.ConnectionId)) + { + tasks.Add(WriteAsync(connection, message)); + } + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + }); + } + + private void SubscribeToInternalGroup() + { + var channelName = _channelNamePrefix + ".internal.group"; + _bus.Subscribe(channelName, async (c, data) => + { + try + { + var groupMessage = DeserializeMessage(data); + + var connection = _connections[groupMessage.ConnectionId]; + if (connection == null) + { + // user not on this server + return; + } + + if (groupMessage.Action == GroupAction.Remove) + { + await RemoveGroupAsyncCore(connection, groupMessage.Group); + } + + if (groupMessage.Action == GroupAction.Add) + { + await AddGroupAsyncCore(connection, groupMessage.Group); + } + + // Sending ack to server that sent the original add/remove + await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage + { + Action = GroupAction.Ack, + Id = groupMessage.Id + }); + } + catch (Exception ex) + { + _logger.InternalMessageFailed(ex); + } + }); + } + + private void SubscribeToInternalServerName() + { + // Create server specific channel in order to send an ack to a single server + var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}"; + _bus.Subscribe(serverChannel, (c, data) => + { + var groupMessage = DeserializeMessage(data); + + if (groupMessage.Action == GroupAction.Ack) + { + _ackHandler.TriggerAck(groupMessage.Id); + } + }); + } + + private Task SubscribeToConnection(HubConnectionContext connection, HashSet redisSubscriptions) + { + var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; + redisSubscriptions.Add(connectionChannel); + + _logger.Subscribing(connectionChannel); + return _bus.SubscribeAsync(connectionChannel, async (c, data) => + { + try + { + var message = DeserializeMessage(data); + + await WriteAsync(connection, message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + }); + } + + private Task SubscribeToUser(HubConnectionContext connection, HashSet redisSubscriptions) + { + var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; + redisSubscriptions.Add(userChannel); + + // TODO: Look at optimizing (looping over connections checking for Name) + return _bus.SubscribeAsync(userChannel, async (c, data) => + { + try + { + var message = DeserializeMessage(data); + + await WriteAsync(connection, message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + }); + } + + private Task SubscribeToGroup(string groupChannel, GroupData group) + { + _logger.Subscribing(groupChannel); + return _bus.SubscribeAsync(groupChannel, async (c, data) => + { + try + { + var message = DeserializeMessage(data); + + var tasks = new List(group.Connections.Count); + foreach (var groupConnection in group.Connections) + { + tasks.Add(WriteAsync(groupConnection, message)); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + }); + } + private class LoggerTextWriter : TextWriter { private readonly ILogger _logger;