From ed416723818e5d8ccd54d68995eb3ed24bf9fc89 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Thu, 3 Nov 2016 19:03:44 -0700 Subject: [PATCH] Implemented better Redis scaleout - Less subscriptions and connections to RedisHubLifetimeManager --- .../RedisHubLifetimeManager.cs | 179 +++++++++++++----- 1 file changed, 128 insertions(+), 51 deletions(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index a9c11bf624..78f244f83c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -1,6 +1,9 @@ using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading; using System.Threading.Tasks; using Channels; using Microsoft.AspNetCore.Sockets; @@ -12,6 +15,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable { + private readonly ConnectionList _connections = new ConnectionList(); + // TODO: Investigate "memory leak" entries never get removed + private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); private readonly InvocationAdapterRegistry _registry; private readonly ConnectionMultiplexer _redisServerConnection; private readonly ISubscriber _bus; @@ -29,6 +35,20 @@ namespace Microsoft.AspNetCore.SignalR.Redis var writer = new LoggerTextWriter(loggerFactory.CreateLogger>()); _redisServerConnection = _options.Connect(writer); _bus = _redisServerConnection.GetSubscriber(); + + _bus.Subscribe(typeof(THub).FullName, (c, data) => + { + var tasks = new List(_connections.Count); + + // TODO: serialize once per format by providing a different stream? + foreach (var connection in _connections) + { + tasks.Add(connection.Channel.Output.WriteAsync((byte[])data)); + } + + // TODO: Task Queue + Task.WhenAll(tasks).GetAwaiter().GetResult(); + }); } public override Task InvokeAllAsync(string methodName, params object[] args) @@ -91,74 +111,125 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task OnConnectedAsync(Connection connection) { - var task1 = SubscribeAsync(typeof(THub).FullName, connection); - var task2 = SubscribeAsync(typeof(THub).FullName + "." + connection.ConnectionId, connection); - var task3 = SubscribeAsync(typeof(THub).FullName + "." + connection.User.Identity.Name, connection); + _connections.Add(connection); - return Task.WhenAll(task2, task2, task3); - } + var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId; + var userChannel = typeof(THub).FullName + "." + connection.User.Identity.Name; - public override Task OnDisconnectedAsync(Connection connection) - { - var redisConnection = connection.Metadata.Get("redis"); - - if (redisConnection == null) + var task1 = _bus.SubscribeAsync(connectionChannel, (c, data) => { - return Task.CompletedTask; - } - - redisConnection.GetSubscriber().UnsubscribeAll(); - redisConnection.Close(allowCommandsToComplete: true); - - return Task.CompletedTask; - } - - public override Task AddGroupAsync(Connection connection, string groupName) - { - var key = typeof(THub).FullName + "." + groupName; - return SubscribeAsync(key, connection); - } - - public override Task RemoveGroupAsync(Connection connection, string groupName) - { - var key = typeof(THub).FullName + "." + groupName; - return UnsubscribeAsync(key, connection); - } - - private Task SubscribeAsync(string channel, Connection connection) - { - var redisConnection = connection.Metadata.GetOrAdd("redis", _ => - { - var logger = _loggerFactory.CreateLogger("REDIS_" + connection.ConnectionId); - // TODO: Async - return _options.Connect(new LoggerTextWriter(logger)); - }); - - var subscriber = redisConnection.GetSubscriber(); - - return subscriber.SubscribeAsync(channel, (c, data) => - { - // TODO: Use Task Queue + // TODO: serialize once per format by providing a different stream? + // TODO: Task Queue connection.Channel.Output.WriteAsync((byte[])data).GetAwaiter().GetResult(); }); + + var task2 = _bus.SubscribeAsync(userChannel, (c, data) => + { + // TODO: serialize once per format by providing a different stream? + // TODO: Task Queue + // TODO: Look at optimizing (looping over connections checking for Name) + connection.Channel.Output.WriteAsync((byte[])data).GetAwaiter().GetResult(); + }); + + var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet()); + redisSubscriptions.Add(connectionChannel); + redisSubscriptions.Add(userChannel); + + return Task.WhenAll(task1, task2); } - private Task UnsubscribeAsync(string channel, Connection connection) + public override async Task OnDisconnectedAsync(Connection connection) { - var redisConnection = connection.Metadata.Get("redis"); + _connections.Remove(connection); - if (redisConnection == null) + var redisSubscriptions = connection.Metadata.Get>("redis_subscriptions"); + if (redisSubscriptions != null) { - return Task.CompletedTask; + foreach (var subscription in redisSubscriptions) + { + await _bus.UnsubscribeAsync(subscription); + } } - var subscriber = redisConnection.GetSubscriber(); + var groupNames = connection.Metadata.Get>("group"); - return subscriber.UnsubscribeAsync(channel); + if (groupNames != null) + { + foreach (var group in groupNames) + { + await RemoveGroupAsync(connection, group); + } + } + } + + public override async Task AddGroupAsync(Connection connection, string groupName) + { + var groupChannel = typeof(THub).FullName + "." + groupName; + + var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet()); + groupNames.Add(groupName); + + var group = _groups.GetOrAdd(groupChannel, _ => new GroupData()); + + await group.Lock.WaitAsync(); + try + { + group.Connections.Add(connection); + + // Subscribe once + if (group.Connections.Count > 1) + { + return; + } + + await _bus.SubscribeAsync(groupChannel, (c, data) => + { + foreach (var groupConnection in group.Connections) + { + // TODO: serialize once per format by providing a different stream? + // TODO: Task Queue + groupConnection.Channel.Output.WriteAsync((byte[])data).GetAwaiter().GetResult(); + } + }); + } + finally + { + group.Lock.Release(); + } + } + + public override async Task RemoveGroupAsync(Connection connection, string groupName) + { + var groupChannel = typeof(THub).FullName + "." + groupName; + + GroupData group; + if (!_groups.TryGetValue(groupChannel, out group)) + { + return; + } + + var groupNames = connection.Metadata.Get>("group"); + groupNames?.Remove(groupName); + + await group.Lock.WaitAsync(); + try + { + group.Connections.Remove(connection); + + if (group.Connections.Count == 0) + { + await _bus.UnsubscribeAsync(groupChannel); + } + } + finally + { + group.Lock.Release(); + } } public void Dispose() { + _bus.UnsubscribeAll(); _redisServerConnection.Dispose(); } @@ -183,5 +254,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis _logger.LogDebug(value); } } + + private class GroupData + { + public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); + public ConnectionList Connections = new ConnectionList(); + } } }