diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/AckHandler.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/AckHandler.cs new file mode 100644 index 0000000000..8c04870e71 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/AckHandler.cs @@ -0,0 +1,98 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + internal class AckHandler : IDisposable + { + private readonly ConcurrentDictionary _acks = new ConcurrentDictionary(); + private readonly Timer _timer; + private readonly TimeSpan _ackThreshold = TimeSpan.FromSeconds(30); + private readonly TimeSpan _ackInterval = TimeSpan.FromSeconds(5); + private readonly object _lock = new object(); + private bool _disposed; + + public AckHandler() + { + _timer = new Timer(_ => CheckAcks(), state: null, dueTime: _ackInterval, period: _ackInterval); + } + + public Task CreateAck(int id) + { + lock (_lock) + { + if (_disposed) + { + return Task.CompletedTask; + } + + return _acks.GetOrAdd(id, _ => new AckInfo()).Tcs.Task; + } + } + + public void TriggerAck(int id) + { + if (_acks.TryRemove(id, out var ack)) + { + ack.Tcs.TrySetResult(null); + } + } + + private void CheckAcks() + { + if (_disposed) + { + return; + } + + var utcNow = DateTime.UtcNow; + + foreach (var pair in _acks) + { + var elapsed = utcNow - pair.Value.Created; + if (elapsed > _ackThreshold) + { + if (_acks.TryRemove(pair.Key, out var ack)) + { + ack.Tcs.TrySetCanceled(); + } + } + } + } + + public void Dispose() + { + lock (_lock) + { + _disposed = true; + + _timer.Dispose(); + + foreach (var pair in _acks) + { + if (_acks.TryRemove(pair.Key, out var ack)) + { + ack.Tcs.TrySetCanceled(); + } + } + } + } + + private class AckInfo + { + public TaskCompletionSource Tcs { get; private set; } + public DateTime Created { get; private set; } + + public AckInfo() + { + Created = DateTime.UtcNow; + Tcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index fcc40935be..d6197dc5ea 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -10,6 +10,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Redis.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Newtonsoft.Json; @@ -27,6 +28,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis private readonly ILogger _logger; private readonly RedisOptions _options; private readonly string _channelNamePrefix = typeof(THub).FullName; + private readonly string _serverName = Guid.NewGuid().ToString(); + private readonly AckHandler _ackHandler; + private int _internalId; // This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection. private readonly JsonSerializer _serializer = new JsonSerializer @@ -44,6 +48,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { _logger = logger; _options = options.Value; + _ackHandler = new AckHandler(); var writer = new LoggerTextWriter(logger); _logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e)))); @@ -108,6 +113,51 @@ namespace Microsoft.AspNetCore.SignalR.Redis allExceptTask = Task.WhenAll(tasks); }); + + channelName = _channelNamePrefix + ".internal.group"; + _bus.Subscribe(channelName, async (c, data) => + { + var groupMessage = DeserializeMessage(data); + + if (groupMessage.Action == GroupAction.Remove) + { + if (!await RemoveGroupAsyncCore(groupMessage.ConnectionId, groupMessage.Group)) + { + // user not on this server + return; + } + } + + if (groupMessage.Action == GroupAction.Add) + { + if (!await AddGroupAsyncCore(groupMessage.ConnectionId, groupMessage.Group)) + { + // user not on this server + return; + } + } + + // Sending ack to server that sent the original add/remove + await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage + { + Action = GroupAction.Ack, + ConnectionId = groupMessage.ConnectionId, + Group = groupMessage.Group, + Id = groupMessage.Id + }); + }); + + // Create server specific channel in order to send an ack to a single server + var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}"; + _bus.Subscribe(serverChannel, (c, data) => + { + var groupMessage = DeserializeMessage(data); + + if (groupMessage.Action == GroupAction.Ack) + { + _ackHandler.TriggerAck(groupMessage.Id); + } + }); } public override Task InvokeAllAsync(string methodName, object[] args) @@ -144,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return PublishAsync(_channelNamePrefix + ".user." + userId, message); } - private async Task PublishAsync(string channel, HubMessage hubMessage) + private async Task PublishAsync(string channel, TMessage hubMessage) { byte[] payload; using (var stream = new MemoryStream()) @@ -241,11 +291,21 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override async Task AddGroupAsync(string connectionId, string groupName) { - var groupChannel = _channelNamePrefix + ".group." + groupName; + if (await AddGroupAsyncCore(connectionId, groupName)) + { + // short circuit if connection is on this server + return; + } + + await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add); + } + + private async Task AddGroupAsyncCore(string connectionId, string groupName) + { var connection = _connections[connectionId]; if (connection == null) { - return; + return false; } var feature = connection.Features.Get(); @@ -256,6 +316,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis groupNames.Add(groupName); } + var groupChannel = _channelNamePrefix + ".group." + groupName; var group = _groups.GetOrAdd(groupChannel, _ => new GroupData()); await group.Lock.WaitAsync(); @@ -266,7 +327,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // Subscribe once if (group.Connections.Count > 1) { - return; + return true; } var previousTask = Task.CompletedTask; @@ -294,22 +355,35 @@ namespace Microsoft.AspNetCore.SignalR.Redis { group.Lock.Release(); } + + return true; } public override async Task RemoveGroupAsync(string connectionId, string groupName) + { + if (await RemoveGroupAsyncCore(connectionId, groupName)) + { + // short circuit if connection is on this server + return; + } + + await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove); + } + + private async Task RemoveGroupAsyncCore(string connectionId, string groupName) { var groupChannel = _channelNamePrefix + ".group." + groupName; GroupData group; if (!_groups.TryGetValue(groupChannel, out group)) { - return; + return false; } var connection = _connections[connectionId]; if (connection == null) { - return; + return false; } var feature = connection.Features.Get(); @@ -325,24 +399,47 @@ namespace Microsoft.AspNetCore.SignalR.Redis await group.Lock.WaitAsync(); try { - group.Connections.Remove(connection); - - if (group.Connections.Count == 0) + if (group.Connections.Count > 0) { - _logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel); - await _bus.UnsubscribeAsync(groupChannel); + group.Connections.Remove(connection); + + if (group.Connections.Count == 0) + { + _logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel); + await _bus.UnsubscribeAsync(groupChannel); + } } } finally { group.Lock.Release(); } + + return true; + } + + private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) + { + var id = Interlocked.Increment(ref _internalId); + var ack = _ackHandler.CreateAck(id); + // Send Add/Remove Group to other servers and wait for an ack or timeout + await PublishAsync(_channelNamePrefix + ".internal.group", new GroupMessage + { + Action = action, + ConnectionId = connectionId, + Group = groupName, + Id = id, + Server = _serverName + }); + + await ack; } public void Dispose() { _bus.UnsubscribeAll(); _redisServerConnection.Dispose(); + _ackHandler.Dispose(); } private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage) @@ -420,5 +517,21 @@ namespace Microsoft.AspNetCore.SignalR.Redis public HashSet Subscriptions { get; } = new HashSet(); public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); } + + private enum GroupAction + { + Remove, + Add, + Ack + } + + private class GroupMessage + { + public string ConnectionId; + public string Group; + public int Id; + public GroupAction Action; + public string Server; + } } }