From 19b2fea0d8f5b829588f136fd4335a67f5caf677 Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Thu, 5 Apr 2018 13:48:14 -0700 Subject: [PATCH] Preserialize for all formats when sending through Redis (#1843) --- .../HubConnectionContextBenchmark.cs | 5 +- ....AspNetCore.SignalR.Microbenchmarks.csproj | 4 +- .../RedisHubLifetimeManagerBenchmark.cs | 203 ++++++++ samples/SignalRSamples/Startup.cs | 2 +- .../Internal/Protocol/HubMessage.cs | 78 --- .../DefaultHubLifetimeManager.cs | 5 +- .../HubConnectionContext.cs | 83 +++- .../Internal/DefaultHubProtocolResolver.cs | 8 +- .../Internal/IHubProtocolResolver.cs | 3 +- .../Internal/SerializedHubMessage.cs | 161 +++++++ .../Internal/GroupAction.cs | 15 + .../Internal/RedisChannels.cs | 75 +++ .../Internal/RedisGroupCommand.cs | 39 ++ .../Internal/RedisInvocation.cs | 33 ++ .../Internal/RedisProtocol.cs | 170 +++++++ .../RedisHubLifetimeManager.cs | 453 +++++++----------- .../RedisLoggerExtensions.cs => RedisLog.cs} | 47 +- .../RedisHubLifetimeManagerTests.cs | 331 ++++++------- .../HubConnectionContextUtils.cs | 5 +- ...soft.AspNetCore.SignalR.Tests.Utils.csproj | 1 + .../TestConnectionMultiplexer.cs | 81 ++-- 21 files changed, 1200 insertions(+), 602 deletions(-) create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Redis/Internal/GroupAction.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisChannels.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs rename src/Microsoft.AspNetCore.SignalR.Redis/{Internal/RedisLoggerExtensions.cs => RedisLog.cs} (63%) rename test/{Microsoft.AspNetCore.SignalR.Redis.Tests => Microsoft.AspNetCore.SignalR.Tests.Utils}/TestConnectionMultiplexer.cs (85%) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs index d8d70a608f..2bd07d955a 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs @@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _successHubProtocolResolver = new TestHubProtocolResolver(new JsonHubProtocol()); _failureHubProtocolResolver = new TestHubProtocolResolver(null); _userIdProvider = new TestUserIdProvider(); - _supportedProtocols = new List {"json"}; + _supportedProtocols = new List { "json" }; } [Benchmark] @@ -83,8 +83,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { private readonly IHubProtocol _instance; + public IReadOnlyList AllProtocols { get; } + public TestHubProtocolResolver(IHubProtocol instance) { + AllProtocols = new[] { instance }; _instance = instance; } diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj index 4ff4eb83a8..333e897d4e 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj @@ -1,4 +1,4 @@ - + Exe @@ -12,6 +12,8 @@ + + diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs new file mode 100644 index 0000000000..40426e995d --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs @@ -0,0 +1,203 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Redis; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks +{ + public class RedisHubLifetimeManagerBenchmark + { + private RedisHubLifetimeManager _manager1; + private RedisHubLifetimeManager _manager2; + private TestClient[] _clients; + private object[] _args; + private List _excludedIds = new List(); + private List _sendIds = new List(); + private List _groups = new List(); + private List _users = new List(); + + private const int ClientCount = 20; + + [Params(2, 20)] + public int ProtocolCount { get; set; } + + [GlobalSetup] + public void GlobalSetup() + { + var server = new TestRedisServer(); + var logger = NullLogger>.Instance; + var protocols = GenerateProtocols(ProtocolCount).ToArray(); + var options = Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer(server) + }); + var resolver = new DefaultHubProtocolResolver(protocols, NullLogger.Instance); + + _manager1 = new RedisHubLifetimeManager(logger, options, resolver); + _manager2 = new RedisHubLifetimeManager(logger, options, resolver); + + async Task ConnectClient(TestClient client, IHubProtocol protocol, string userId, string group) + { + await _manager2.OnConnectedAsync(HubConnectionContextUtils.Create(client.Connection, protocol, userId)); + await _manager2.AddGroupAsync(client.Connection.ConnectionId, "Everyone"); + await _manager2.AddGroupAsync(client.Connection.ConnectionId, group); + } + + // Connect clients + _clients = new TestClient[ClientCount]; + var tasks = new Task[ClientCount]; + for (var i = 0; i < _clients.Length; i++) + { + var protocol = protocols[i % ProtocolCount]; + _clients[i] = new TestClient(protocol: protocol); + + string group; + string user; + if ((i % 2) == 0) + { + group = "Evens"; + user = "EvenUser"; + _excludedIds.Add(_clients[i].Connection.ConnectionId); + } + else + { + group = "Odds"; + user = "OddUser"; + _sendIds.Add(_clients[i].Connection.ConnectionId); + } + + tasks[i] = ConnectClient(_clients[i], protocol, user, group); + _ = ConsumeAsync(_clients[i]); + } + + Task.WaitAll(tasks); + + _groups.Add("Evens"); + _groups.Add("Odds"); + _users.Add("EvenUser"); + _users.Add("OddUser"); + + _args = new object[] {"Foo"}; + } + + private IEnumerable GenerateProtocols(int protocolCount) + { + for (var i = 0; i < protocolCount; i++) + { + yield return ((i % 2) == 0) + ? new WrappedHubProtocol($"json_{i}", new JsonHubProtocol()) + : new WrappedHubProtocol($"msgpack_{i}", new MessagePackHubProtocol()); + } + } + + private async Task ConsumeAsync(TestClient testClient) + { + while (await testClient.ReadAsync() != null) + { + // Just dump the message + } + } + + [Benchmark] + public async Task SendAll() + { + await _manager1.SendAllAsync("Test", _args); + } + + [Benchmark] + public async Task SendGroup() + { + await _manager1.SendGroupAsync("Everyone", "Test", _args); + } + + [Benchmark] + public async Task SendUser() + { + await _manager1.SendUserAsync("EvenUser", "Test", _args); + } + + [Benchmark] + public async Task SendConnection() + { + await _manager1.SendConnectionAsync(_clients[0].Connection.ConnectionId, "Test", _args); + } + + [Benchmark] + public async Task SendConnections() + { + await _manager1.SendConnectionsAsync(_sendIds, "Test", _args); + } + + [Benchmark] + public async Task SendAllExcept() + { + await _manager1.SendAllExceptAsync("Test", _args, _excludedIds); + } + + [Benchmark] + public async Task SendGroupExcept() + { + await _manager1.SendGroupExceptAsync("Everyone", "Test", _args, _excludedIds); + } + + [Benchmark] + public async Task SendGroups() + { + await _manager1.SendGroupsAsync(_groups, "Test", _args); + } + + [Benchmark] + public async Task SendUsers() + { + await _manager1.SendUsersAsync(_users, "Test", _args); + } + + public class TestHub : Hub + { + } + + private class WrappedHubProtocol : IHubProtocol + { + private readonly string _name; + private readonly IHubProtocol _innerProtocol; + + public string Name => _name; + + public int Version => _innerProtocol.Version; + + public TransferFormat TransferFormat => _innerProtocol.TransferFormat; + + public WrappedHubProtocol(string name, IHubProtocol innerProtocol) + { + _name = name; + _innerProtocol = innerProtocol; + } + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + return _innerProtocol.TryParseMessage(ref input, binder, out message); + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + _innerProtocol.WriteMessage(message, output); + } + + public bool IsVersionSupported(int version) + { + return _innerProtocol.IsVersionSupported(version); + } + } + } +} diff --git a/samples/SignalRSamples/Startup.cs b/samples/SignalRSamples/Startup.cs index e1233d8732..30de147776 100644 --- a/samples/SignalRSamples/Startup.cs +++ b/samples/SignalRSamples/Startup.cs @@ -28,7 +28,7 @@ namespace SignalRSamples { options.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel; }); - // .AddRedis(); + //.AddRedis(); services.AddCors(o => { diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs index 9cb4aabf5a..3c9378f51c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs @@ -1,87 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Collections.Generic; - namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public abstract class HubMessage { - protected HubMessage() - { - } - - private object _lock = new object(); - private List _serializedMessages; - private SerializedMessage _message1; - private SerializedMessage _message2; - - public byte[] WriteMessage(IHubProtocol protocol) - { - // REVIEW: Revisit lock - // Could use a reader/writer lock to allow the loop to take place in "unlocked" code - // Or, could use a fixed size array and Interlocked to manage it. - // Or, Immutable *ducks* - - lock (_lock) - { - if (ReferenceEquals(_message1.Protocol, protocol)) - { - return _message1.Message; - } - - if (ReferenceEquals(_message2.Protocol, protocol)) - { - return _message2.Message; - } - - for (var i = 0; i < _serializedMessages?.Count; i++) - { - if (ReferenceEquals(_serializedMessages[i].Protocol, protocol)) - { - return _serializedMessages[i].Message; - } - } - - var bytes = protocol.WriteToArray(this); - - if (_message1.Protocol == null) - { - _message1 = new SerializedMessage(protocol, bytes); - } - else if (_message2.Protocol == null) - { - _message2 = new SerializedMessage(protocol, bytes); - } - else - { - if (_serializedMessages == null) - { - _serializedMessages = new List(); - } - - // We don't want to balloon memory if someone writes a poor IHubProtocolResolver - // So we cap how many caches we store and worst case just serialize the message for every connection - if (_serializedMessages.Count < 10) - { - _serializedMessages.Add(new SerializedMessage(protocol, bytes)); - } - } - - return bytes; - } - } - - private readonly struct SerializedMessage - { - public readonly IHubProtocol Protocol; - public readonly byte[] Message; - - public SerializedMessage(IHubProtocol protocol, byte[] message) - { - Protocol = protocol; - Message = message; - } - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 2385bc76db..ebfa3c312d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; @@ -210,9 +211,9 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - private InvocationMessage CreateInvocationMessage(string methodName, object[] args) + private SerializedHubMessage CreateInvocationMessage(string methodName, object[] args) { - return new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args); + return new SerializedHubMessage(new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args)); } public override Task SendUserAsync(string userId, string methodName, object[] args) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 7fb8d6a4d4..180fbf4409 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -6,19 +6,16 @@ using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; -using System.IO; using System.IO.Pipelines; -using System.Net; using System.Runtime.ExceptionServices; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; -using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; @@ -73,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR public virtual PipeReader Input => _connectionContext.Transport.Input; - public string UserIdentifier { get; private set; } + public string UserIdentifier { get; set; } internal virtual IHubProtocol Protocol { get; set; } @@ -84,7 +81,36 @@ namespace Microsoft.AspNetCore.SignalR public virtual ValueTask WriteAsync(HubMessage message) { - // We were unable to get the lock so take the slow async path of waiting for the semaphore + // Try to grab the lock synchronously, if we fail, go to the slower path + if (!_writeLock.Wait(0)) + { + return new ValueTask(WriteSlowAsync(message)); + } + + // This method should never throw synchronously + var task = WriteCore(message); + + // The write didn't complete synchronously so await completion + if (!task.IsCompletedSuccessfully) + { + return new ValueTask(CompleteWriteAsync(task)); + } + + // Otherwise, release the lock acquired when entering WriteAsync + _writeLock.Release(); + + return default; + } + + /// + /// This method is designed to support the framework and is not intended to be used by application code. Writes a pre-serialized message to the + /// connection. + /// + /// The serialization cache to use. + /// + public virtual ValueTask WriteAsync(SerializedHubMessage message) + { + // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { return new ValueTask(WriteSlowAsync(message)); @@ -109,14 +135,28 @@ namespace Microsoft.AspNetCore.SignalR { try { - // This will internally cache the buffer for each unique HubProtocol - // So that we don't serialize the HubMessage for every single connection - var buffer = message.WriteMessage(Protocol); + // We know that we are only writing this message to one receiver, so we can + // write it without caching. + Protocol.WriteMessage(message, _connectionContext.Transport.Output); - var output = _connectionContext.Transport.Output; - output.Write(buffer); + return _connectionContext.Transport.Output.FlushAsync(); + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); - return output.FlushAsync(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); + } + } + + private ValueTask WriteCore(SerializedHubMessage message) + { + try + { + // Grab a preserialized buffer for this protocol. + var buffer = message.GetSerializedMessage(Protocol); + + return _connectionContext.Transport.Output.WriteAsync(buffer); } catch (Exception ex) { @@ -162,6 +202,25 @@ namespace Microsoft.AspNetCore.SignalR } } + private async Task WriteSlowAsync(SerializedHubMessage message) + { + try + { + // Failed to get the lock immediately when entering WriteAsync so await until it is available + await _writeLock.WaitAsync(); + + await WriteCore(message); + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); + } + finally + { + _writeLock.Release(); + } + } + private ValueTask TryWritePingAsync() { // Don't wait for the lock, if it returns false that means someone wrote to the connection diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs index a11c2d380e..b7ddf87b3b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubProtocolResolver.cs @@ -7,21 +7,25 @@ using System.Linq; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR.Internal { public class DefaultHubProtocolResolver : IHubProtocolResolver { private readonly ILogger _logger; + private readonly List _hubProtocols; private readonly Dictionary _availableProtocols; + public IReadOnlyList AllProtocols => _hubProtocols; + public DefaultHubProtocolResolver(IEnumerable availableProtocols, ILogger logger) { _logger = logger ?? NullLogger.Instance; _availableProtocols = new Dictionary(StringComparer.OrdinalIgnoreCase); - foreach (var protocol in availableProtocols) + // We might get duplicates in _hubProtocols, but we're going to check it and throw in just a sec. + _hubProtocols = availableProtocols.ToList(); + foreach (var protocol in _hubProtocols) { if (_availableProtocols.ContainsKey(protocol.Name)) { diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs index 0141bf6a92..102c1f2567 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/IHubProtocolResolver.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; @@ -8,6 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { public interface IHubProtocolResolver { + IReadOnlyList AllProtocols { get; } IHubProtocol GetProtocol(string protocolName, IList supportedProtocols); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs new file mode 100644 index 0000000000..b54de1dade --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs @@ -0,0 +1,161 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + /// + /// This class is designed to support the framework. The API is subject to breaking changes. + /// Represents a serialization cache for a single message. + /// + public class SerializedHubMessage + { + private SerializedMessage _cachedItem1; + private SerializedMessage _cachedItem2; + private IList _cachedItems; + + public HubMessage Message { get; } + + private SerializedHubMessage() + { + } + + public SerializedHubMessage(HubMessage message) + { + Message = message; + } + + public ReadOnlyMemory GetSerializedMessage(IHubProtocol protocol) + { + if (!TryGetCached(protocol.Name, out var serialized)) + { + if (Message == null) + { + throw new InvalidOperationException( + "This message was received from another server that did not have the requested protocol available."); + } + + serialized = protocol.WriteToArray(Message); + SetCache(protocol.Name, serialized); + } + + return serialized; + } + + public static void WriteAllSerializedVersions(BinaryWriter writer, HubMessage message, IReadOnlyList protocols) + { + // The serialization format is based on BinaryWriter + // * 1 byte number of protocols + // * For each protocol: + // * Length-prefixed string using 7-bit variable length encoding (length depends on BinaryWriter's encoding) + // * 4 byte length of the buffer + // * N byte buffer + + if (protocols.Count > byte.MaxValue) + { + throw new InvalidOperationException($"Can't serialize cache containing more than {byte.MaxValue} entries"); + } + + writer.Write((byte)protocols.Count); + foreach (var protocol in protocols) + { + writer.Write(protocol.Name); + + var buffer = protocol.WriteToArray(message); + writer.Write(buffer.Length); + writer.Write(buffer); + } + } + + public static SerializedHubMessage ReadAllSerializedVersions(BinaryReader reader) + { + var cache = new SerializedHubMessage(); + var count = reader.ReadByte(); + for (var i = 0; i < count; i++) + { + var protocol = reader.ReadString(); + var length = reader.ReadInt32(); + var serialized = reader.ReadBytes(length); + cache.SetCache(protocol, serialized); + } + + return cache; + } + + private void SetCache(string protocolName, byte[] serialized) + { + if (_cachedItem1.ProtocolName == null) + { + _cachedItem1 = new SerializedMessage(protocolName, serialized); + } + else if (_cachedItem2.ProtocolName == null) + { + _cachedItem2 = new SerializedMessage(protocolName, serialized); + } + else + { + if (_cachedItems == null) + { + _cachedItems = new List(); + } + + foreach (var item in _cachedItems) + { + if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal)) + { + // No need to add + return; + } + } + + _cachedItems.Add(new SerializedMessage(protocolName, serialized)); + } + } + + private bool TryGetCached(string protocolName, out byte[] result) + { + if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem1.Serialized; + return true; + } + + if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = _cachedItem2.Serialized; + return true; + } + + if (_cachedItems != null) + { + foreach (var serializedMessage in _cachedItems) + { + if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal)) + { + result = serializedMessage.Serialized; + return true; + } + } + } + + result = default; + return false; + } + + private readonly struct SerializedMessage + { + public string ProtocolName { get; } + public byte[] Serialized { get; } + + public SerializedMessage(string protocolName, byte[] serialized) + { + ProtocolName = protocolName; + Serialized = serialized; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/GroupAction.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/GroupAction.cs new file mode 100644 index 0000000000..874d190f84 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/GroupAction.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + // The size of the enum is defined by the protocol. Do not change it. If you need more than 255 items, + // add an additional enum. + public enum GroupAction : byte + { + // These numbers are used by the protocol, do not change them and always use explicit assignment + // when adding new items to this enum. 0 is intentionally omitted + Add = 1, + Remove = 2, + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisChannels.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisChannels.cs new file mode 100644 index 0000000000..28393b2886 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisChannels.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + internal class RedisChannels + { + private readonly string _prefix; + + /// + /// Gets the name of the channel for sending to all connections. + /// + /// + /// The payload on this channel is objects containing + /// invocations to be sent to all connections + /// + public string All { get; } + + /// + /// Gets the name of the internal channel for group management messages. + /// + public string GroupManagement { get; } + + public RedisChannels(string prefix) + { + _prefix = prefix; + + All = prefix + ":all"; + GroupManagement = prefix + ":internal:groups"; + } + + /// + /// Gets the name of the channel for sending a message to a specific connection. + /// + /// The ID of the connection to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Connection(string connectionId) + { + return _prefix + ":connection:" + connectionId; + } + + /// + /// Gets the name of the channel for sending a message to a named group of connections. + /// + /// The name of the group to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Group(string groupName) + { + return _prefix + ":group:" + groupName; + } + + /// + /// Gets the name of the channel for sending a message to all collections associated with a user. + /// + /// The ID of the user to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string User(string userId) + { + return _prefix + ":user:" + userId; + } + + /// + /// Gets the name of the acknowledgement channel for the specified server. + /// + /// The name of the server to get the acknowledgement channel for. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Ack(string serverName) + { + return _prefix + ":internal:ack:" + serverName; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs new file mode 100644 index 0000000000..a2ef82f373 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs @@ -0,0 +1,39 @@ +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + public readonly struct RedisGroupCommand + { + /// + /// Gets the ID of the group command. + /// + public int Id { get; } + + /// + /// Gets the name of the server that sent the command. + /// + public string ServerName { get; } + + /// + /// Gets the action to be performed on the group. + /// + public GroupAction Action { get; } + + /// + /// Gets the group on which the action is performed. + /// + public string GroupName { get; } + + /// + /// Gets the ID of the connection to be added or removed from the group. + /// + public string ConnectionId { get; } + + public RedisGroupCommand(int id, string serverName, GroupAction action, string groupName, string connectionId) + { + Id = id; + ServerName = serverName; + Action = action; + GroupName = groupName; + ConnectionId = connectionId; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs new file mode 100644 index 0000000000..66618f5c88 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs @@ -0,0 +1,33 @@ +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + public readonly struct RedisInvocation + { + /// + /// Gets a list of connections that should be excluded from this invocation. + /// May be null to indicate that no connections are to be excluded. + /// + public IReadOnlyList ExcludedIds { get; } + + /// + /// Gets the message serialization cache containing serialized payloads for the message. + /// + public SerializedHubMessage Message { get; } + + public RedisInvocation(SerializedHubMessage message, IReadOnlyList excludedIds) + { + Message = message; + ExcludedIds = excludedIds; + } + + public static RedisInvocation Create(string target, object[] arguments, IReadOnlyList excludedIds = null) + { + return new RedisInvocation( + new SerializedHubMessage(new InvocationMessage(target, argumentBindingException: null, arguments)), + excludedIds); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs new file mode 100644 index 0000000000..74511ed641 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs @@ -0,0 +1,170 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.IO; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + public class RedisProtocol + { + private readonly IReadOnlyList _protocols; + private static readonly Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + + public RedisProtocol(IReadOnlyList protocols) + { + _protocols = protocols; + } + + // The Redis Protocol: + // * The message type is known in advance because messages are sent to different channels based on type + // * Invocations are sent to the All, Group, Connection and User channels + // * Group Commands are sent to the GroupManagement channel + // * Acks are sent to the Acknowledgement channel. + // * See the Write[type] methods for a description of the protocol for each in-depth. + // * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter: + // * https://docs.microsoft.com/en-us/dotnet/api/system.io.binarywriter.write?view=netstandard-2.0 + // * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter: + // * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8. + + public byte[] WriteInvocation(string methodName, object[] args) => + WriteInvocation(methodName, args, excludedIds: null); + + public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList excludedIds) + { + // Redis Invocation Format: + // * Variable length integer: Number of excluded Ids + // * For each excluded Id: + // * Length prefixed string: ID + // * SerializedHubMessage encoded by the format described by that type. + + using (var stream = new MemoryStream()) + using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + { + if (excludedIds != null) + { + writer.WriteVarInt(excludedIds.Count); + foreach (var id in excludedIds) + { + writer.Write(id); + } + } + else + { + writer.WriteVarInt(0); + } + + SerializedHubMessage.WriteAllSerializedVersions(writer, new InvocationMessage(methodName, argumentBindingException: null, args), _protocols); + return stream.ToArray(); + } + } + + public byte[] WriteGroupCommand(RedisGroupCommand command) + { + // Group Command Format: + // * Variable length integer: Id + // * Length prefixed string: ServerName + // * 1 byte: Action + // * Length prefixed string: GroupName + // * Length prefixed string: ConnectionId + + using (var stream = new MemoryStream()) + using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + { + writer.WriteVarInt(command.Id); + writer.Write(command.ServerName); + writer.Write((byte)command.Action); + writer.Write(command.GroupName); + writer.Write(command.ConnectionId); + return stream.ToArray(); + } + } + + public byte[] WriteAck(int messageId) + { + // Acknowledgement Format: + // * Variable length integer: Id + + using (var stream = new MemoryStream()) + using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + { + writer.WriteVarInt(messageId); + return stream.ToArray(); + } + } + + public RedisInvocation ReadInvocation(byte[] data) + { + // See WriteInvocation for format. + + using (var stream = new MemoryStream(data)) + using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) + { + IReadOnlyList excludedIds = null; + + var idCount = reader.ReadVarInt(); + if (idCount > 0) + { + var ids = new string[idCount]; + for (var i = 0; i < idCount; i++) + { + ids[i] = reader.ReadString(); + } + + excludedIds = ids; + } + + var message = SerializedHubMessage.ReadAllSerializedVersions(reader); + return new RedisInvocation(message, excludedIds); + } + } + + public RedisGroupCommand ReadGroupCommand(byte[] data) + { + // See WriteGroupCommand for format. + using (var stream = new MemoryStream(data)) + using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) + { + var id = reader.ReadVarInt(); + var serverName = reader.ReadString(); + var action = (GroupAction)reader.ReadByte(); + var groupName = reader.ReadString(); + var connectionId = reader.ReadString(); + + return new RedisGroupCommand(id, serverName, action, groupName, connectionId); + } + } + + public int ReadAck(byte[] data) + { + // See WriteAck for format + using (var stream = new MemoryStream(data)) + using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) + { + return reader.ReadVarInt(); + } + } + + // Kinda cheaty way to get access to write the 7-bit varint format directly + private class BinaryWriterWithVarInt : BinaryWriter + { + public BinaryWriterWithVarInt(Stream output, Encoding encoding) : base(output, encoding) + { + } + + public void WriteVarInt(int value) => Write7BitEncodedInt(value); + } + + private class BinaryReaderWithVarInt : BinaryReader + { + public BinaryReaderWithVarInt(Stream input, Encoding encoding) : base(input, encoding) + { + } + + public int ReadVarInt() => Read7BitEncodedInt(); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 08e63ade67..aeae6d1df2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -9,7 +9,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Redis.Internal; using Microsoft.Extensions.Logging; @@ -28,8 +28,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis private readonly ISubscriber _bus; private readonly ILogger _logger; private readonly RedisOptions _options; - private readonly string _channelNamePrefix = typeof(THub).FullName; - private readonly string _serverName = Guid.NewGuid().ToString(); + private readonly RedisChannels _channels; + private readonly string _serverName = GenerateServerName(); + private readonly RedisProtocol _protocol; + private readonly AckHandler _ackHandler; private int _internalId; @@ -41,14 +43,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis }; public RedisHubLifetimeManager(ILogger> logger, - IOptions options) + IOptions options, + IHubProtocolResolver hubProtocolResolver) { _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); + _channels = new RedisChannels(typeof(THub).FullName); + _protocol = new RedisProtocol(hubProtocolResolver.AllProtocols); var writer = new LoggerTextWriter(logger); - _logger.ConnectingToEndpoints(options.Value.Options.EndPoints); + RedisLog.ConnectingToEndpoints(_logger, options.Value.Options.EndPoints, _serverName); _redisServerConnection = _options.Connect(writer); _redisServerConnection.ConnectionRestored += (_, e) => @@ -60,7 +65,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - _logger.ConnectionRestored(); + RedisLog.ConnectionRestored(_logger); }; _redisServerConnection.ConnectionFailed += (_, e) => @@ -72,23 +77,22 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - _logger.ConnectionFailed(e.Exception); + RedisLog.ConnectionFailed(_logger, e.Exception); }; if (_redisServerConnection.IsConnected) { - _logger.Connected(); + RedisLog.Connected(_logger); } else { - _logger.NotConnected(); + RedisLog.NotConnected(_logger); } _bus = _redisServerConnection.GetSubscriber(); - SubscribeToHub(); - SubscribeToAllExcept(); - SubscribeToInternalGroup(); - SubscribeToInternalServerName(); + SubscribeToAll(); + SubscribeToGroupManagementChannel(); + SubscribeToAckChannel(); } public override Task OnConnectedAsync(HubConnectionContext connection) @@ -125,7 +129,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { foreach (var subscription in redisSubscriptions) { - _logger.Unsubscribe(subscription); + RedisLog.Unsubscribe(_logger, subscription); tasks.Add(_bus.UnsubscribeAsync(subscription)); } } @@ -149,15 +153,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task SendAllAsync(string methodName, object[] args) { - var message = new RedisInvocationMessage(target: methodName, arguments: args); - - return PublishAsync(_channelNamePrefix, message); + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.All, message); } public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) { - var message = new RedisInvocationMessage(target: methodName, excludedIds: excludedIds, arguments: args); - return PublishAsync(_channelNamePrefix + ".AllExcept", message); + var message = _protocol.WriteInvocation(methodName, args, excludedIds); + return PublishAsync(_channels.All, message); } public override Task SendConnectionAsync(string connectionId, string methodName, object[] args) @@ -167,17 +170,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis throw new ArgumentNullException(nameof(connectionId)); } - var message = new RedisInvocationMessage(target: methodName, arguments: args); - // If the connection is local we can skip sending the message through the bus since we require sticky connections. // This also saves serializing and deserializing the message! var connection = _connections[connectionId]; if (connection != null) { - return SafeWriteAsync(connection, message.CreateInvocation()); + return connection.WriteAsync(new InvocationMessage(methodName, argumentBindingException: null, args)).AsTask(); } - return PublishAsync(_channelNamePrefix + "." + connectionId, message); + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.Connection(connectionId), message); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) @@ -187,9 +189,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis throw new ArgumentNullException(nameof(groupName)); } - var message = new RedisInvocationMessage(target: methodName, excludedIds: null, arguments: args); - - return PublishAsync(_channelNamePrefix + ".group." + groupName, message); + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.Group(groupName), message); } public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) @@ -199,31 +200,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis throw new ArgumentNullException(nameof(groupName)); } - var message = new RedisInvocationMessage(methodName, excludedIds, args); - - return PublishAsync(_channelNamePrefix + ".group." + groupName, message); + var message = _protocol.WriteInvocation(methodName, args, excludedIds); + return PublishAsync(_channels.Group(groupName), message); } public override Task SendUserAsync(string userId, string methodName, object[] args) { - var message = new RedisInvocationMessage(methodName, args); - - return PublishAsync(_channelNamePrefix + ".user." + userId, message); - } - - private async Task PublishAsync(string channel, IRedisMessage message) - { - byte[] payload; - using (var stream = new LimitArrayPoolWriteStream()) - using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream))) - { - _serializer.Serialize(writer, message); - writer.Flush(); - payload = stream.ToArray(); - } - - _logger.PublishToChannel(channel); - await _bus.PublishAsync(channel, payload); + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.User(userId), message); } public override async Task AddGroupAsync(string connectionId, string groupName) @@ -249,6 +233,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add); } + public override async Task RemoveGroupAsync(string connectionId, string groupName) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var connection = _connections[connectionId]; + if (connection != null) + { + // short circuit if connection is on this server + await RemoveGroupAsyncCore(connection, groupName); + return; + } + + await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove); + } + + public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args) + { + if (connectionIds == null) + { + throw new ArgumentNullException(nameof(connectionIds)); + } + + var publishTasks = new List(connectionIds.Count); + var payload = _protocol.WriteInvocation(methodName, args); + + foreach (var connectionId in connectionIds) + { + publishTasks.Add(PublishAsync(_channels.Connection(connectionId), payload)); + } + + return Task.WhenAll(publishTasks); + } + + public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args) + { + if (groupNames == null) + { + throw new ArgumentNullException(nameof(groupNames)); + } + var publishTasks = new List(groupNames.Count); + var payload = _protocol.WriteInvocation(methodName, args); + + foreach (var groupName in groupNames) + { + if (!string.IsNullOrEmpty(groupName)) + { + publishTasks.Add(PublishAsync(_channels.Group(groupName), payload)); + } + } + + return Task.WhenAll(publishTasks); + } + + public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args) + { + if (userIds.Count > 0) + { + var payload = _protocol.WriteInvocation(methodName, args); + var publishTasks = new List(userIds.Count); + foreach (var userId in userIds) + { + if (!string.IsNullOrEmpty(userId)) + { + publishTasks.Add(PublishAsync(_channels.User(userId), payload)); + } + } + + return Task.WhenAll(publishTasks); + } + + return Task.CompletedTask; + } + + private Task PublishAsync(string channel, byte[] payload) + { + RedisLog.PublishToChannel(_logger, channel); + return _bus.PublishAsync(channel, payload); + } + private async Task AddGroupAsyncCore(HubConnectionContext connection, string groupName) { var feature = connection.Features.Get(); @@ -263,7 +334,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - var groupChannel = _channelNamePrefix + ".group." + groupName; + var groupChannel = _channels.Group(groupName); var group = _groups.GetOrAdd(groupChannel, _ => new GroupData()); await group.Lock.WaitAsync(); @@ -285,37 +356,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override async Task RemoveGroupAsync(string connectionId, string groupName) - { - if (connectionId == null) - { - throw new ArgumentNullException(nameof(connectionId)); - } - - if (groupName == null) - { - throw new ArgumentNullException(nameof(groupName)); - } - - - var connection = _connections[connectionId]; - if (connection != null) - { - // short circuit if connection is on this server - await RemoveGroupAsyncCore(connection, groupName); - return; - } - - await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove); - } - /// /// This takes because we want to remove the connection from the /// _connections list in OnDisconnectedAsync and still be able to remove groups with this method. /// private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string groupName) { - var groupChannel = _channelNamePrefix + ".group." + groupName; + var groupChannel = _channels.Group(groupName); if (!_groups.TryGetValue(groupChannel, out var group)) { @@ -341,7 +388,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis if (group.Connections.Count == 0) { - _logger.Unsubscribe(groupChannel); + RedisLog.Unsubscribe(_logger, groupChannel); await _bus.UnsubscribeAsync(groupChannel); } } @@ -350,8 +397,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis { group.Lock.Release(); } - - return; } private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) @@ -359,14 +404,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis var id = Interlocked.Increment(ref _internalId); var ack = _ackHandler.CreateAck(id); // Send Add/Remove Group to other servers and wait for an ack or timeout - await PublishAsync(_channelNamePrefix + ".internal.group", new RedisGroupMessage - { - Action = action, - ConnectionId = connectionId, - Group = groupName, - Id = id, - Server = _serverName - }); + var message = _protocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId)); + await PublishAsync(_channels.GroupManagement, message); await ack; } @@ -378,63 +417,24 @@ namespace Microsoft.AspNetCore.SignalR.Redis _ackHandler.Dispose(); } - private T DeserializeMessage(RedisValue data) + private void SubscribeToAll() { - using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data)))) - { - return _serializer.Deserialize(reader); - } - } - - private void SubscribeToHub() - { - _logger.Subscribing(_channelNamePrefix); - _bus.Subscribe(_channelNamePrefix, async (c, data) => + RedisLog.Subscribing(_logger, _channels.All); + _bus.Subscribe(_channels.All, async (c, data) => { try { - _logger.ReceivedFromChannel(_channelNamePrefix); + RedisLog.ReceivedFromChannel(_logger, _channels.All); - var message = DeserializeMessage(data); + var invocation = _protocol.ReadInvocation(data); var tasks = new List(_connections.Count); - var invocation = message.CreateInvocation(); foreach (var connection in _connections) { - tasks.Add(SafeWriteAsync(connection, invocation)); - } - - await Task.WhenAll(tasks); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } - }); - } - - private void SubscribeToAllExcept() - { - var channelName = _channelNamePrefix + ".AllExcept"; - _logger.Subscribing(channelName); - _bus.Subscribe(channelName, async (c, data) => - { - try - { - _logger.ReceivedFromChannel(channelName); - - var message = DeserializeMessage(data); - var excludedIds = message.ExcludedIds ?? Array.Empty(); - - var tasks = new List(_connections.Count); - - var invocation = message.CreateInvocation(); - foreach (var connection in _connections) - { - if (!excludedIds.Contains(connection.ConnectionId)) + if (invocation.ExcludedIds == null || !invocation.ExcludedIds.Contains(connection.ConnectionId)) { - tasks.Add(SafeWriteAsync(connection, invocation)); + tasks.Add(connection.WriteAsync(invocation.Message).AsTask()); } } @@ -442,19 +442,18 @@ namespace Microsoft.AspNetCore.SignalR.Redis } catch (Exception ex) { - _logger.FailedWritingMessage(ex); + RedisLog.FailedWritingMessage(_logger, ex); } }); } - private void SubscribeToInternalGroup() + private void SubscribeToGroupManagementChannel() { - var channelName = _channelNamePrefix + ".internal.group"; - _bus.Subscribe(channelName, async (c, data) => + _bus.Subscribe(_channels.GroupManagement, async (c, data) => { try { - var groupMessage = DeserializeMessage(data); + var groupMessage = _protocol.ReadGroupCommand(data); var connection = _connections[groupMessage.ConnectionId]; if (connection == null) @@ -465,179 +464,95 @@ namespace Microsoft.AspNetCore.SignalR.Redis if (groupMessage.Action == GroupAction.Remove) { - await RemoveGroupAsyncCore(connection, groupMessage.Group); + await RemoveGroupAsyncCore(connection, groupMessage.GroupName); } if (groupMessage.Action == GroupAction.Add) { - await AddGroupAsyncCore(connection, groupMessage.Group); + await AddGroupAsyncCore(connection, groupMessage.GroupName); } - // Sending ack to server that sent the original add/remove - await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new RedisGroupMessage - { - Action = GroupAction.Ack, - Id = groupMessage.Id - }); + // Send an ack to the server that sent the original command. + await PublishAsync(_channels.Ack(groupMessage.ServerName), _protocol.WriteAck(groupMessage.Id)); } catch (Exception ex) { - _logger.InternalMessageFailed(ex); + RedisLog.InternalMessageFailed(_logger, ex); } }); } - private void SubscribeToInternalServerName() + private void SubscribeToAckChannel() { // Create server specific channel in order to send an ack to a single server - var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}"; - _bus.Subscribe(serverChannel, (c, data) => + _bus.Subscribe(_channels.Ack(_serverName), (c, data) => { - var groupMessage = DeserializeMessage(data); + var ackId = _protocol.ReadAck(data); - if (groupMessage.Action == GroupAction.Ack) - { - _ackHandler.TriggerAck(groupMessage.Id); - } + _ackHandler.TriggerAck(ackId); }); } private Task SubscribeToConnection(HubConnectionContext connection, HashSet redisSubscriptions) { - var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; + var connectionChannel = _channels.Connection(connection.ConnectionId); redisSubscriptions.Add(connectionChannel); - _logger.Subscribing(connectionChannel); + RedisLog.Subscribing(_logger, connectionChannel); return _bus.SubscribeAsync(connectionChannel, async (c, data) => { - var message = DeserializeMessage(data); - - await SafeWriteAsync(connection, message.CreateInvocation()); + var invocation = _protocol.ReadInvocation(data); + await connection.WriteAsync(invocation.Message); }); } private Task SubscribeToUser(HubConnectionContext connection, HashSet redisSubscriptions) { - var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; + var userChannel = _channels.User(connection.UserIdentifier); redisSubscriptions.Add(userChannel); // TODO: Look at optimizing (looping over connections checking for Name) return _bus.SubscribeAsync(userChannel, async (c, data) => { - var message = DeserializeMessage(data); - - await SafeWriteAsync(connection, message.CreateInvocation()); + var invocation = _protocol.ReadInvocation(data); + await connection.WriteAsync(invocation.Message); }); } private Task SubscribeToGroup(string groupChannel, GroupData group) { - _logger.Subscribing(groupChannel); + RedisLog.Subscribing(_logger, groupChannel); return _bus.SubscribeAsync(groupChannel, async (c, data) => { try { - var message = DeserializeMessage(data); + var invocation = _protocol.ReadInvocation(data); var tasks = new List(); - var invocation = message.CreateInvocation(); foreach (var groupConnection in group.Connections) { - if (message.ExcludedIds?.Contains(groupConnection.ConnectionId) == true) + if (invocation.ExcludedIds?.Contains(groupConnection.ConnectionId) == true) { continue; } - tasks.Add(SafeWriteAsync(groupConnection, invocation)); + tasks.Add(groupConnection.WriteAsync(invocation.Message).AsTask()); } await Task.WhenAll(tasks); } catch (Exception ex) { - _logger.FailedWritingMessage(ex); + RedisLog.FailedWritingMessage(_logger, ex); } }); } - public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args) + private static string GenerateServerName() { - if (connectionIds == null) - { - throw new ArgumentNullException(nameof(connectionIds)); - } - var publishTasks = new List(connectionIds.Count); - var message = new RedisInvocationMessage(target: methodName, arguments: args); - - foreach (var connectionId in connectionIds) - { - var connection = _connections[connectionId]; - // If the connection is local we can skip sending the message through the bus since we require sticky connections. - // This also saves serializing and deserializing the message! - if (connection != null) - { - publishTasks.Add(SafeWriteAsync(connection, message.CreateInvocation())); - } - else - { - publishTasks.Add(PublishAsync(_channelNamePrefix + "." + connectionId, message)); - } - } - - return Task.WhenAll(publishTasks); - } - - public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args) - { - if (groupNames == null) - { - throw new ArgumentNullException(nameof(groupNames)); - } - var publishTasks = new List(groupNames.Count); - var message = new RedisInvocationMessage(target: methodName, arguments: args); - - foreach (var groupName in groupNames) - { - if (!string.IsNullOrEmpty(groupName)) - { - publishTasks.Add(PublishAsync(_channelNamePrefix + "." + groupName, message)); - } - } - - return Task.WhenAll(publishTasks); - } - - public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args) - { - if (userIds.Count > 0) - { - var message = new RedisInvocationMessage(methodName, args); - var publishTasks = new List(userIds.Count); - foreach (var userId in userIds) - { - if (!string.IsNullOrEmpty(userId)) - { - publishTasks.Add(PublishAsync(_channelNamePrefix + ".user." + userId, message)); - } - } - - return Task.WhenAll(publishTasks); - } - - return Task.CompletedTask; - } - - // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to - private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) - { - try - { - await connection.WriteAsync(message); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + // Use the machine name for convenient diagnostics, but add a guid to make it unique. + // Example: MyServerName_02db60e5fab243b890a847fa5c4dcb29 + return $"{Environment.MachineName}_{Guid.NewGuid():N}"; } private class LoggerTextWriter : TextWriter @@ -658,7 +573,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override void WriteLine(string value) { - _logger.LogDebug(value); + RedisLog.ConnectionMultiplexerMessage(_logger, value); } } @@ -679,53 +594,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis public HashSet Subscriptions { get; } = new HashSet(); public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); } - - private enum GroupAction - { - Remove, - Add, - Ack - } - - // Marker interface to represent the messages that can be sent over Redis. - private interface IRedisMessage { } - - private class RedisGroupMessage : IRedisMessage - { - public string ConnectionId { get; set; } - public string Group { get; set; } - public int Id { get; set; } - public GroupAction Action { get; set; } - public string Server { get; set; } - } - - // Represents a message published to the Redis bus - private class RedisInvocationMessage : IRedisMessage - { - public string Target { get; set; } - public IReadOnlyList ExcludedIds { get; set; } - public object[] Arguments { get; set; } - - public RedisInvocationMessage() - { - } - - public RedisInvocationMessage(string target, object[] arguments) - : this(target, excludedIds: null, arguments: arguments) - { - } - - public RedisInvocationMessage(string target, IReadOnlyList excludedIds, object[] arguments) - { - Target = target; - ExcludedIds = excludedIds; - Arguments = arguments; - } - - public InvocationMessage CreateInvocation() - { - return new InvocationMessage(Target, argumentBindingException: null, arguments: Arguments); - } - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisLog.cs similarity index 63% rename from src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs rename to src/Microsoft.AspNetCore.SignalR.Redis/RedisLog.cs index e8862cf0e1..085133facc 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisLog.cs @@ -6,13 +6,14 @@ using System.Linq; using Microsoft.Extensions.Logging; using StackExchange.Redis; -namespace Microsoft.AspNetCore.SignalR.Redis.Internal +namespace Microsoft.AspNetCore.SignalR.Redis { - internal static class RedisLoggerExtensions + // We don't want to use our nested static class here because RedisHubLifetimeManager is generic. + // We'd end up creating separate instances of all the LoggerMessage.Define values for each Hub. + internal static class RedisLog { - // Category: RedisHubLifetimeManager - private static readonly Action _connectingToEndpoints = - LoggerMessage.Define(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}."); + private static readonly Action _connectingToEndpoints = + LoggerMessage.Define(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}. Using Server Name: {ServerName}"); private static readonly Action _connected = LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis."); @@ -44,65 +45,75 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal private static readonly Action _internalMessageFailed = LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message."); - public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints) + public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endpoints, string serverName) { if (logger.IsEnabled(LogLevel.Information)) { if (endpoints.Count > 0) { - _connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), null); + _connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), serverName, null); } } } - public static void Connected(this ILogger logger) + public static void Connected(ILogger logger) { _connected(logger, null); } - public static void Subscribing(this ILogger logger, string channelName) + public static void Subscribing(ILogger logger, string channelName) { _subscribing(logger, channelName, null); } - public static void ReceivedFromChannel(this ILogger logger, string channelName) + public static void ReceivedFromChannel(ILogger logger, string channelName) { _receivedFromChannel(logger, channelName, null); } - public static void PublishToChannel(this ILogger logger, string channelName) + public static void PublishToChannel(ILogger logger, string channelName) { _publishToChannel(logger, channelName, null); } - public static void Unsubscribe(this ILogger logger, string channelName) + public static void Unsubscribe(ILogger logger, string channelName) { _unsubscribe(logger, channelName, null); } - public static void NotConnected(this ILogger logger) + public static void NotConnected(ILogger logger) { _notConnected(logger, null); } - public static void ConnectionRestored(this ILogger logger) + public static void ConnectionRestored(ILogger logger) { _connectionRestored(logger, null); } - public static void ConnectionFailed(this ILogger logger, Exception exception) + public static void ConnectionFailed(ILogger logger, Exception exception) { _connectionFailed(logger, exception); } - public static void FailedWritingMessage(this ILogger logger, Exception exception) + public static void FailedWritingMessage(ILogger logger, Exception exception) { _failedWritingMessage(logger, exception); } - public static void InternalMessageFailed(this ILogger logger, Exception exception) + public static void InternalMessageFailed(ILogger logger, Exception exception) { _internalMessageFailed(logger, exception); } + + // This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer + public static void ConnectionMultiplexerMessage(ILogger logger, string message) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + // We tag it with EventId 100 though so it can be pulled out of logs easily. + logger.LogDebug(new EventId(100, "RedisConnectionLog"), message); + } + } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index 7a6aeca22b..cb46ecb3eb 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -5,11 +5,15 @@ using System; using System.Collections.Generic; using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests; -using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using Moq; +using MsgPack.Serialization; +using Newtonsoft.Json.Linq; +using Newtonsoft.Json.Serialization; using Xunit; namespace Microsoft.AspNetCore.SignalR.Redis.Tests @@ -19,14 +23,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeAllAsyncWritesToAllConnectionsOutput() { + var server = new TestRedisServer(); + using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var manager = CreateLifetimeManager(server); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -40,17 +42,44 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } } + [Fact] + public async Task InvokeAllExceptAsyncExcludesSpecifiedConnections() + { + var server = new TestRedisServer(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + using (var client3 = new TestClient()) + { + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); + var manager3 = CreateLifetimeManager(server); + + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + var connection3 = HubConnectionContextUtils.Create(client3.Connection); + + await manager1.OnConnectedAsync(connection1).OrTimeout(); + await manager2.OnConnectedAsync(connection2).OrTimeout(); + await manager3.OnConnectedAsync(connection3).OrTimeout(); + + await manager1.SendAllExceptAsync("Hello", new object[] { "World" }, new [] { client3.Connection.ConnectionId }).OrTimeout(); + + await AssertMessageAsync(client1); + await AssertMessageAsync(client2); + Assert.Null(client3.TryRead()); + } + } + [Fact] public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() { + var server = new TestRedisServer(); + using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var manager = CreateLifetimeManager(server); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -70,14 +99,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() { + var server = new TestRedisServer(); + using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var manager = CreateLifetimeManager(server); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -96,14 +123,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeGroupExceptAsyncWritesToAllValidConnectionsInGroupOutput() { + var server = new TestRedisServer(); + using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var manager = CreateLifetimeManager(server); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -124,13 +149,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeConnectionAsyncWritesToConnectionOutput() { + var server = new TestRedisServer(); + using (var client = new TestClient()) { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var manager = CreateLifetimeManager(server); var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -144,27 +167,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeConnectionAsyncOnNonExistentConnectionDoesNotThrow() { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager = CreateLifetimeManager(server); await manager.SendConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout(); } [Fact] public async Task InvokeAllAsyncWithMultipleServersWritesToAllConnectionsOutput() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client1 = new TestClient()) using (var client2 = new TestClient()) @@ -185,16 +200,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeAllAsyncWithMultipleServersDoesNotWriteToDisconnectedConnectionsOutput() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client1 = new TestClient()) using (var client2 = new TestClient()) @@ -218,16 +227,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeConnectionAsyncOnServerWithoutConnectionWritesOutputToConnection() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -244,16 +247,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeGroupAsyncOnServerWithoutConnectionWritesOutputToGroupConnection() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -272,11 +269,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task DisconnectConnectionRemovesConnectionFromGroup() { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -297,11 +292,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task RemoveGroupFromLocalConnectionNotInGroupDoesNothing() { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -316,16 +309,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task RemoveGroupFromConnectionOnDifferentServerNotInGroupDoesNothing() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), - Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -340,14 +327,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task AddGroupAsyncForConnectionOnDifferentServerWorks() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -366,10 +349,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task AddGroupAsyncForLocalConnectionAlreadyInGroupDoesNothing() { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -382,7 +364,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout(); - await AssertMessageAsync(client); Assert.Null(client.TryRead()); } @@ -391,14 +372,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task AddGroupAsyncForConnectionOnDifferentServerAlreadyInGroupDoesNothing() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -419,14 +396,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task RemoveGroupAsyncForConnectionOnDifferentServerWorks() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -451,14 +424,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task InvokeConnectionAsyncForLocalConnectionDoesNotPublishToRedis() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -478,14 +447,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests [Fact] public async Task WritingToRemoteConnectionThatFailsDoesNotThrow() { - var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager1 = CreateLifetimeManager(server); + var manager2 = CreateLifetimeManager(server); using (var client = new TestClient()) { @@ -502,34 +467,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } } - [Fact] - public async Task WritingToLocalConnectionThatFailsDoesNotThrowException() - { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); - - using (var client = new TestClient()) - { - // Force an exception when writing to connection - var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); - connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); - var connection = connectionMock.Object; - - await manager.OnConnectedAsync(connection).OrTimeout(); - - await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - } - } - [Fact] public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage() { - var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() - { - Factory = t => new TestConnectionMultiplexer() - })); + var server = new TestRedisServer(); + + var manager = CreateLifetimeManager(server); using (var client1 = new TestClient()) using (var client2 = new TestClient()) @@ -557,6 +500,72 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } } + [Fact] + public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary() + { + var server = new TestRedisServer(); + + var messagePackOptions = new MessagePackHubProtocolOptions(); + messagePackOptions.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel; + + var jsonOptions = new JsonHubProtocolOptions(); + jsonOptions.PayloadSerializerSettings.ContractResolver = new CamelCasePropertyNamesContractResolver(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + // The sending manager has serializer settings + var manager1 = CreateLifetimeManager(server, messagePackOptions, jsonOptions); + + // The receiving one doesn't matter because of how we serialize! + var manager2 = CreateLifetimeManager(server); + + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).OrTimeout(); + await manager2.OnConnectedAsync(connection2).OrTimeout(); + + await manager1.SendAllAsync("Hello", new object[] { new TestObject { TestProperty = "Foo" } }); + + var message = Assert.IsType(await client2.ReadAsync().OrTimeout()); + Assert.Equal("Hello", message.Target); + Assert.Collection( + message.Arguments, + arg0 => + { + var dict = Assert.IsType(arg0); + Assert.Collection(dict.Properties(), + prop => + { + Assert.Equal("testProperty", prop.Name); + Assert.Equal("Foo", prop.Value.Value()); + }); + }); + } + } + + public class TestObject + { + public string TestProperty { get; set; } + } + + private RedisHubLifetimeManager CreateLifetimeManager(TestRedisServer server, MessagePackHubProtocolOptions messagePackOptions = null, JsonHubProtocolOptions jsonOptions = null) + { + var options = new RedisOptions() { Factory = t => new TestConnectionMultiplexer(server) }; + messagePackOptions = messagePackOptions ?? new MessagePackHubProtocolOptions(); + jsonOptions = jsonOptions ?? new JsonHubProtocolOptions(); + + return new RedisHubLifetimeManager( + NullLogger>.Instance, + Options.Create(options), + new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(Options.Create(jsonOptions)), + new MessagePackHubProtocol(Options.Create(messagePackOptions)), + }, NullLogger.Instance)); + } + private async Task AssertMessageAsync(TestClient client) { var message = Assert.IsType(await client.ReadAsync().OrTimeout()); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs index 8529ccaf1d..3a1914732e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/HubConnectionContextUtils.cs @@ -11,11 +11,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public static class HubConnectionContextUtils { - public static HubConnectionContext Create(ConnectionContext connection) + public static HubConnectionContext Create(ConnectionContext connection, IHubProtocol protocol = null, string userIdentifier = null) { return new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance) { - Protocol = new JsonHubProtocol() + Protocol = protocol ?? new JsonHubProtocol(), + UserIdentifier = userIdentifier, }; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj index 5d99a4a543..c9074935ac 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj @@ -21,6 +21,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/TestConnectionMultiplexer.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestConnectionMultiplexer.cs similarity index 85% rename from test/Microsoft.AspNetCore.SignalR.Redis.Tests/TestConnectionMultiplexer.cs rename to test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestConnectionMultiplexer.cs index 98164b622e..412ccf0cd4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/TestConnectionMultiplexer.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestConnectionMultiplexer.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -9,7 +9,7 @@ using System.Net; using System.Threading.Tasks; using StackExchange.Redis; -namespace Microsoft.AspNetCore.SignalR.Redis.Tests +namespace Microsoft.AspNetCore.SignalR.Tests { public class TestConnectionMultiplexer : IConnectionMultiplexer { @@ -70,7 +70,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests remove { } } - private ISubscriber _subscriber = new TestSubscriber(); + private ISubscriber _subscriber; + + public TestConnectionMultiplexer(TestRedisServer server) + { + _subscriber = new TestSubscriber(server); + } public void BeginProfiling(object forContext) { @@ -203,19 +208,52 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } } - public class TestSubscriber : ISubscriber + public class TestRedisServer { - // _globalSubscriptions represents the Redis Server you are connected to. - // So when publishing from a TestSubscriber you fake sending through the server by grabbing the callbacks - // from the _globalSubscriptions and inoking them inplace. - private static ConcurrentDictionary>> _globalSubscriptions = + private ConcurrentDictionary>> _subscriptions = new ConcurrentDictionary>>(); - private ConcurrentDictionary> _subscriptions = - new ConcurrentDictionary>(); + public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) + { + if (_subscriptions.TryGetValue(channel, out var handlers)) + { + foreach (var handler in handlers) + { + handler(channel, message); + } + } + return handlers != null ? handlers.Count : 0; + } + + public void Subscribe(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) + { + _subscriptions.AddOrUpdate(channel, _ => new List> { handler }, (_, list) => + { + list.Add(handler); + return list; + }); + } + + public void Unsubscribe(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) + { + if (_subscriptions.TryGetValue(channel, out var list)) + { + list.Remove(handler); + } + } + } + + public class TestSubscriber : ISubscriber + { + private readonly TestRedisServer _server; public ConnectionMultiplexer Multiplexer => throw new NotImplementedException(); + public TestSubscriber(TestRedisServer server) + { + _server = server; + } + public EndPoint IdentifyEndpoint(RedisChannel channel, CommandFlags flags = CommandFlags.None) { throw new NotImplementedException(); @@ -243,15 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) { - if (_globalSubscriptions.TryGetValue(channel, out var handlers)) - { - foreach (var handler in handlers) - { - handler(channel, message); - } - } - - return handlers != null ? handlers.Count : 0; + return _server.Publish(channel, message, flags); } public async Task PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) @@ -262,12 +292,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests public void Subscribe(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) { - _globalSubscriptions.AddOrUpdate(channel, _ => new List> { handler }, (_, list) => - { - list.Add(handler); - return list; - }); - _subscriptions.AddOrUpdate(channel, handler, (_, __) => handler); + _server.Subscribe(channel, handler, flags); } public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) @@ -288,11 +313,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests public void Unsubscribe(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) { - _subscriptions.TryRemove(channel, out var handle); - if (_globalSubscriptions.TryGetValue(channel, out var list)) - { - list.Remove(handle); - } + _server.Unsubscribe(channel, handler, flags); } public void UnsubscribeAll(CommandFlags flags = CommandFlags.None)