From 286e4bebf75ca5a729cf0d2a33b941802dc0459d Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Wed, 10 Oct 2018 16:40:27 -0700 Subject: [PATCH] Add StackExchange.Redis 2.X.X package (#3089) --- SignalR.sln | 16 +- ....AspNetCore.SignalR.Microbenchmarks.csproj | 3 +- .../RedisHubLifetimeManagerBenchmark.cs | 25 +- .../RedisProtocolBenchmark.cs | 2 +- build/dependencies.props | 1 + .../Internal/MessagePackUtil.cs | 3 + .../Internal/RedisGroupCommand.cs | 3 + .../Internal/RedisInvocation.cs | 4 +- .../Internal/RedisProtocol.cs | 1 - .../Internal/AckHandler.cs | 117 ++++ .../Internal/GroupAction.cs | 15 + .../Internal/MessagePackUtil.cs | 68 ++ .../Internal/RedisChannels.cs | 75 +++ .../Internal/RedisGroupCommand.cs | 42 ++ .../Internal/RedisInvocation.cs | 35 ++ .../Internal/RedisLog.cs | 119 ++++ .../Internal/RedisProtocol.cs | 208 ++++++ .../Internal/RedisSubscriptionManager.cs | 63 ++ ...pNetCore.SignalR.StackExchangeRedis.csproj | 23 + .../RedisDependencyInjectionExtensions.cs | 69 ++ .../RedisHubLifetimeManager.cs | 593 ++++++++++++++++++ .../RedisOptions.cs | 50 ++ .../Docker.cs | 8 +- .../RedisProtocolTests.cs | 3 + .../Startup.cs | 2 +- .../TestConnectionMultiplexer.cs | 0 .../Docker.cs | 192 ++++++ .../EchoHub.cs | 31 + ...re.SignalR.StackExchangeRedis.Tests.csproj | 27 + ...RedisDependencyInjectionExtensionsTests.cs | 41 ++ .../RedisEndToEnd.cs | 198 ++++++ .../RedisHubLifetimeManagerTests.cs | 84 +++ .../RedisProtocolTests.cs | 202 ++++++ .../RedisServerFixture.cs | 64 ++ .../SkipIfDockerNotPresentAttribute.cs | 39 ++ .../Startup.cs | 51 ++ .../TestConnectionMultiplexer.cs | 376 +++++++++++ ...soft.AspNetCore.SignalR.Tests.Utils.csproj | 1 - 38 files changed, 2831 insertions(+), 23 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/AckHandler.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/GroupAction.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/MessagePackUtil.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisChannels.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisGroupCommand.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisInvocation.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisLog.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisProtocol.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisSubscriptionManager.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisDependencyInjectionExtensions.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisHubLifetimeManager.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisOptions.cs rename test/{Microsoft.AspNetCore.SignalR.Tests.Utils => Microsoft.AspNetCore.SignalR.Redis.Tests}/TestConnectionMultiplexer.cs (100%) create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Docker.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/EchoHub.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests.csproj create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisDependencyInjectionExtensionsTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisEndToEnd.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisHubLifetimeManagerTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisProtocolTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisServerFixture.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/SkipIfDockerNotPresentAttribute.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Startup.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/TestConnectionMultiplexer.cs diff --git a/SignalR.sln b/SignalR.sln index f53f22d4ca..7f925af865 100644 --- a/SignalR.sln +++ b/SignalR.sln @@ -89,7 +89,11 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Crankier", "benchmarkapps\C EndProject Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "benchmarkapps", "benchmarkapps", "{43F352F3-4E2B-4ED7-901B-36E6671251F5}" EndProject -Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "Microsoft.AspNetCore.SignalR.Specification.Tests", "src\Microsoft.AspNetCore.SignalR.Specification.Tests\Microsoft.AspNetCore.SignalR.Specification.Tests.csproj", "{2B03333F-3ACD-474C-862B-FA97D3BA03B5}" +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.Specification.Tests", "src\Microsoft.AspNetCore.SignalR.Specification.Tests\Microsoft.AspNetCore.SignalR.Specification.Tests.csproj", "{2B03333F-3ACD-474C-862B-FA97D3BA03B5}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.StackExchangeRedis", "src\Microsoft.AspNetCore.SignalR.StackExchangeRedis\Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj", "{D1334F29-5C19-4C7B-B62D-0A2F23AFB31C}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests", "test\Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests\Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests.csproj", "{A5006087-81B0-4C62-B847-50ED5C37069D}" EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution @@ -213,6 +217,14 @@ Global {2B03333F-3ACD-474C-862B-FA97D3BA03B5}.Debug|Any CPU.Build.0 = Debug|Any CPU {2B03333F-3ACD-474C-862B-FA97D3BA03B5}.Release|Any CPU.ActiveCfg = Release|Any CPU {2B03333F-3ACD-474C-862B-FA97D3BA03B5}.Release|Any CPU.Build.0 = Release|Any CPU + {D1334F29-5C19-4C7B-B62D-0A2F23AFB31C}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {D1334F29-5C19-4C7B-B62D-0A2F23AFB31C}.Debug|Any CPU.Build.0 = Debug|Any CPU + {D1334F29-5C19-4C7B-B62D-0A2F23AFB31C}.Release|Any CPU.ActiveCfg = Release|Any CPU + {D1334F29-5C19-4C7B-B62D-0A2F23AFB31C}.Release|Any CPU.Build.0 = Release|Any CPU + {A5006087-81B0-4C62-B847-50ED5C37069D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {A5006087-81B0-4C62-B847-50ED5C37069D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {A5006087-81B0-4C62-B847-50ED5C37069D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {A5006087-81B0-4C62-B847-50ED5C37069D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -247,6 +259,8 @@ Global {8C75AC94-C980-4FE1-9F79-6CED3C8665CE} = {43F352F3-4E2B-4ED7-901B-36E6671251F5} {8D3E3E7D-452B-44F4-86CA-111003EA11ED} = {43F352F3-4E2B-4ED7-901B-36E6671251F5} {2B03333F-3ACD-474C-862B-FA97D3BA03B5} = {DA69F624-5398-4884-87E4-B816698CDE65} + {D1334F29-5C19-4C7B-B62D-0A2F23AFB31C} = {DA69F624-5398-4884-87E4-B816698CDE65} + {A5006087-81B0-4C62-B847-50ED5C37069D} = {6A35B453-52EC-48AF-89CA-D4A69800F131} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {7945A4E4-ACDB-4F6E-95CA-6AC6E7C2CD59} diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj index d391a18436..0143f5ffae 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj @@ -17,8 +17,9 @@ - + + diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs index 3e333ddd74..852fbdfa3c 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs @@ -10,7 +10,7 @@ using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; -using Microsoft.AspNetCore.SignalR.Redis; +using Microsoft.AspNetCore.SignalR.StackExchangeRedis; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; @@ -34,7 +34,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Params(2, 20)] public int ProtocolCount { get; set; } - [GlobalSetup] + // Re-enable micro-benchmark when https://github.com/aspnet/SignalR/issues/3088 is fixed + // [GlobalSetup] public void GlobalSetup() { var server = new TestRedisServer(); @@ -90,7 +91,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _users.Add("EvenUser"); _users.Add("OddUser"); - _args = new object[] {"Foo"}; + _args = new object[] { "Foo" }; } private IEnumerable GenerateProtocols(int protocolCount) @@ -111,55 +112,55 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } } - [Benchmark] + //[Benchmark] public async Task SendAll() { await _manager1.SendAllAsync("Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendGroup() { await _manager1.SendGroupAsync("Everyone", "Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendUser() { await _manager1.SendUserAsync("EvenUser", "Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendConnection() { await _manager1.SendConnectionAsync(_clients[0].Connection.ConnectionId, "Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendConnections() { await _manager1.SendConnectionsAsync(_sendIds, "Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendAllExcept() { await _manager1.SendAllExceptAsync("Test", _args, _excludedConnectionIds); } - [Benchmark] + //[Benchmark] public async Task SendGroupExcept() { await _manager1.SendGroupExceptAsync("Everyone", "Test", _args, _excludedConnectionIds); } - [Benchmark] + //[Benchmark] public async Task SendGroups() { await _manager1.SendGroupsAsync(_groups, "Test", _args); } - [Benchmark] + //[Benchmark] public async Task SendUsers() { await _manager1.SendUsersAsync(_users, "Test", _args); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs index 3008ed999c..f5e02e489b 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs @@ -7,7 +7,7 @@ using System.Collections.Generic; using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Protocol; -using Microsoft.AspNetCore.SignalR.Redis.Internal; +using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { diff --git a/build/dependencies.props b/build/dependencies.props index 826da88432..13a42ca540 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -62,6 +62,7 @@ 2.0.3 11.0.2 1.2.6 + 2.0.513 4.5.0 4.5.0 4.5.1 diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MessagePackUtil.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MessagePackUtil.cs index d190bb74e8..b824d90394 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MessagePackUtil.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MessagePackUtil.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using System.Diagnostics; using System.Runtime.InteropServices; diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs index a2ef82f373..3759da98ae 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisGroupCommand.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + namespace Microsoft.AspNetCore.SignalR.Redis.Internal { public readonly struct RedisGroupCommand diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs index e9cedbd5b0..a1a8a3ee07 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisInvocation.cs @@ -1,5 +1,7 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System.Collections.Generic; -using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR.Redis.Internal diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs index 6d3c51659b..6eaeb2ee79 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs @@ -8,7 +8,6 @@ using System.IO; using System.Runtime.InteropServices; using MessagePack; using Microsoft.AspNetCore.Internal; -using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR.Redis.Internal diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/AckHandler.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/AckHandler.cs new file mode 100644 index 0000000000..863fcdcb53 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/AckHandler.cs @@ -0,0 +1,117 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + internal class AckHandler : IDisposable + { + private readonly ConcurrentDictionary _acks = new ConcurrentDictionary(); + private readonly Timer _timer; + private readonly TimeSpan _ackThreshold = TimeSpan.FromSeconds(30); + private readonly TimeSpan _ackInterval = TimeSpan.FromSeconds(5); + private readonly object _lock = new object(); + private bool _disposed; + + public AckHandler() + { + // Don't capture the current ExecutionContext and its AsyncLocals onto the timer + bool restoreFlow = false; + try + { + if (!ExecutionContext.IsFlowSuppressed()) + { + ExecutionContext.SuppressFlow(); + restoreFlow = true; + } + + _timer = new Timer(state => ((AckHandler)state).CheckAcks(), state: this, dueTime: _ackInterval, period: _ackInterval); + } + finally + { + // Restore the current ExecutionContext + if (restoreFlow) + { + ExecutionContext.RestoreFlow(); + } + } + } + + public Task CreateAck(int id) + { + lock (_lock) + { + if (_disposed) + { + return Task.CompletedTask; + } + + return _acks.GetOrAdd(id, _ => new AckInfo()).Tcs.Task; + } + } + + public void TriggerAck(int id) + { + if (_acks.TryRemove(id, out var ack)) + { + ack.Tcs.TrySetResult(null); + } + } + + private void CheckAcks() + { + if (_disposed) + { + return; + } + + var utcNow = DateTime.UtcNow; + + foreach (var pair in _acks) + { + var elapsed = utcNow - pair.Value.Created; + if (elapsed > _ackThreshold) + { + if (_acks.TryRemove(pair.Key, out var ack)) + { + ack.Tcs.TrySetCanceled(); + } + } + } + } + + public void Dispose() + { + lock (_lock) + { + _disposed = true; + + _timer.Dispose(); + + foreach (var pair in _acks) + { + if (_acks.TryRemove(pair.Key, out var ack)) + { + ack.Tcs.TrySetCanceled(); + } + } + } + } + + private class AckInfo + { + public TaskCompletionSource Tcs { get; private set; } + public DateTime Created { get; private set; } + + public AckInfo() + { + Created = DateTime.UtcNow; + Tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/GroupAction.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/GroupAction.cs new file mode 100644 index 0000000000..e3aae4c006 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/GroupAction.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + // The size of the enum is defined by the protocol. Do not change it. If you need more than 255 items, + // add an additional enum. + public enum GroupAction : byte + { + // These numbers are used by the protocol, do not change them and always use explicit assignment + // when adding new items to this enum. 0 is intentionally omitted + Add = 1, + Remove = 2, + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/MessagePackUtil.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/MessagePackUtil.cs new file mode 100644 index 0000000000..7780bca988 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/MessagePackUtil.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using MessagePack; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + internal static class MessagePackUtil + { + public static int ReadArrayHeader(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static int ReadMapHeader(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static string ReadString(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static byte[] ReadBytes(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static int ReadInt32(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static byte ReadByte(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + private static ArraySegment GetArray(ReadOnlyMemory data) + { + var isArray = MemoryMarshal.TryGetArray(data, out var array); + Debug.Assert(isArray); + return array; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisChannels.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisChannels.cs new file mode 100644 index 0000000000..f377392bb1 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisChannels.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + internal class RedisChannels + { + private readonly string _prefix; + + /// + /// Gets the name of the channel for sending to all connections. + /// + /// + /// The payload on this channel is objects containing + /// invocations to be sent to all connections + /// + public string All { get; } + + /// + /// Gets the name of the internal channel for group management messages. + /// + public string GroupManagement { get; } + + public RedisChannels(string prefix) + { + _prefix = prefix; + + All = prefix + ":all"; + GroupManagement = prefix + ":internal:groups"; + } + + /// + /// Gets the name of the channel for sending a message to a specific connection. + /// + /// The ID of the connection to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Connection(string connectionId) + { + return _prefix + ":connection:" + connectionId; + } + + /// + /// Gets the name of the channel for sending a message to a named group of connections. + /// + /// The name of the group to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Group(string groupName) + { + return _prefix + ":group:" + groupName; + } + + /// + /// Gets the name of the channel for sending a message to all collections associated with a user. + /// + /// The ID of the user to get the channel for. + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string User(string userId) + { + return _prefix + ":user:" + userId; + } + + /// + /// Gets the name of the acknowledgement channel for the specified server. + /// + /// The name of the server to get the acknowledgement channel for. + /// + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public string Ack(string serverName) + { + return _prefix + ":internal:ack:" + serverName; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisGroupCommand.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisGroupCommand.cs new file mode 100644 index 0000000000..1cb155d4aa --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisGroupCommand.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + public readonly struct RedisGroupCommand + { + /// + /// Gets the ID of the group command. + /// + public int Id { get; } + + /// + /// Gets the name of the server that sent the command. + /// + public string ServerName { get; } + + /// + /// Gets the action to be performed on the group. + /// + public GroupAction Action { get; } + + /// + /// Gets the group on which the action is performed. + /// + public string GroupName { get; } + + /// + /// Gets the ID of the connection to be added or removed from the group. + /// + public string ConnectionId { get; } + + public RedisGroupCommand(int id, string serverName, GroupAction action, string groupName, string connectionId) + { + Id = id; + ServerName = serverName; + Action = action; + GroupName = groupName; + ConnectionId = connectionId; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisInvocation.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisInvocation.cs new file mode 100644 index 0000000000..aae0e88e59 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisInvocation.cs @@ -0,0 +1,35 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + public readonly struct RedisInvocation + { + /// + /// Gets a list of connections that should be excluded from this invocation. + /// May be null to indicate that no connections are to be excluded. + /// + public IReadOnlyList ExcludedConnectionIds { get; } + + /// + /// Gets the message serialization cache containing serialized payloads for the message. + /// + public SerializedHubMessage Message { get; } + + public RedisInvocation(SerializedHubMessage message, IReadOnlyList excludedConnectionIds) + { + Message = message; + ExcludedConnectionIds = excludedConnectionIds; + } + + public static RedisInvocation Create(string target, object[] arguments, IReadOnlyList excludedConnectionIds = null) + { + return new RedisInvocation( + new SerializedHubMessage(new InvocationMessage(target, null, arguments)), + excludedConnectionIds); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisLog.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisLog.cs new file mode 100644 index 0000000000..bd8d228ee3 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisLog.cs @@ -0,0 +1,119 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using Microsoft.Extensions.Logging; +using StackExchange.Redis; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + // We don't want to use our nested static class here because RedisHubLifetimeManager is generic. + // We'd end up creating separate instances of all the LoggerMessage.Define values for each Hub. + internal static class RedisLog + { + private static readonly Action _connectingToEndpoints = + LoggerMessage.Define(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}. Using Server Name: {ServerName}"); + + private static readonly Action _connected = + LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis."); + + private static readonly Action _subscribing = + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "Subscribing"), "Subscribing to channel: {Channel}."); + + private static readonly Action _receivedFromChannel = + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "ReceivedFromChannel"), "Received message from Redis channel {Channel}."); + + private static readonly Action _publishToChannel = + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "PublishToChannel"), "Publishing message to Redis channel {Channel}."); + + private static readonly Action _unsubscribe = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "Unsubscribe"), "Unsubscribing from channel: {Channel}."); + + private static readonly Action _notConnected = + LoggerMessage.Define(LogLevel.Error, new EventId(7, "Connected"), "Not connected to Redis."); + + private static readonly Action _connectionRestored = + LoggerMessage.Define(LogLevel.Information, new EventId(8, "ConnectionRestored"), "Connection to Redis restored."); + + private static readonly Action _connectionFailed = + LoggerMessage.Define(LogLevel.Error, new EventId(9, "ConnectionFailed"), "Connection to Redis failed."); + + private static readonly Action _failedWritingMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(10, "FailedWritingMessage"), "Failed writing message."); + + private static readonly Action _internalMessageFailed = + LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message."); + + public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endpoints, string serverName) + { + if (logger.IsEnabled(LogLevel.Information)) + { + if (endpoints.Count > 0) + { + _connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), serverName, null); + } + } + } + + public static void Connected(ILogger logger) + { + _connected(logger, null); + } + + public static void Subscribing(ILogger logger, string channelName) + { + _subscribing(logger, channelName, null); + } + + public static void ReceivedFromChannel(ILogger logger, string channelName) + { + _receivedFromChannel(logger, channelName, null); + } + + public static void PublishToChannel(ILogger logger, string channelName) + { + _publishToChannel(logger, channelName, null); + } + + public static void Unsubscribe(ILogger logger, string channelName) + { + _unsubscribe(logger, channelName, null); + } + + public static void NotConnected(ILogger logger) + { + _notConnected(logger, null); + } + + public static void ConnectionRestored(ILogger logger) + { + _connectionRestored(logger, null); + } + + public static void ConnectionFailed(ILogger logger, Exception exception) + { + _connectionFailed(logger, exception); + } + + public static void FailedWritingMessage(ILogger logger, Exception exception) + { + _failedWritingMessage(logger, exception); + } + + public static void InternalMessageFailed(ILogger logger, Exception exception) + { + _internalMessageFailed(logger, exception); + } + + // This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer + public static void ConnectionMultiplexerMessage(ILogger logger, string message) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + // We tag it with EventId 100 though so it can be pulled out of logs easily. + logger.LogDebug(new EventId(100, "RedisConnectionLog"), message); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisProtocol.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisProtocol.cs new file mode 100644 index 0000000000..5185b946c9 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisProtocol.cs @@ -0,0 +1,208 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using MessagePack; +using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + public class RedisProtocol + { + private readonly IReadOnlyList _protocols; + + public RedisProtocol(IReadOnlyList protocols) + { + _protocols = protocols; + } + + // The Redis Protocol: + // * The message type is known in advance because messages are sent to different channels based on type + // * Invocations are sent to the All, Group, Connection and User channels + // * Group Commands are sent to the GroupManagement channel + // * Acks are sent to the Acknowledgement channel. + // * See the Write[type] methods for a description of the protocol for each in-depth. + // * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter: + // * https://docs.microsoft.com/en-us/dotnet/api/system.io.binarywriter.write?view=netstandard-2.0 + // * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter: + // * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8. + + public byte[] WriteInvocation(string methodName, object[] args) => + WriteInvocation(methodName, args, excludedConnectionIds: null); + + public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList excludedConnectionIds) + { + // Written as a MessagePack 'arr' containing at least these items: + // * A MessagePack 'arr' of 'str's representing the excluded ids + // * [The output of WriteSerializedHubMessage, which is an 'arr'] + // Any additional items are discarded. + + var writer = MemoryBufferWriter.Get(); + + try + { + MessagePackBinary.WriteArrayHeader(writer, 2); + if (excludedConnectionIds != null && excludedConnectionIds.Count > 0) + { + MessagePackBinary.WriteArrayHeader(writer, excludedConnectionIds.Count); + foreach (var id in excludedConnectionIds) + { + MessagePackBinary.WriteString(writer, id); + } + } + else + { + MessagePackBinary.WriteArrayHeader(writer, 0); + } + + WriteSerializedHubMessage(writer, + new SerializedHubMessage(new InvocationMessage(methodName, args))); + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + + public byte[] WriteGroupCommand(RedisGroupCommand command) + { + // Written as a MessagePack 'arr' containing at least these items: + // * An 'int': the Id of the command + // * A 'str': The server name + // * An 'int': The action (likely less than 0x7F and thus a single-byte fixnum) + // * A 'str': The group name + // * A 'str': The connection Id + // Any additional items are discarded. + + var writer = MemoryBufferWriter.Get(); + try + { + MessagePackBinary.WriteArrayHeader(writer, 5); + MessagePackBinary.WriteInt32(writer, command.Id); + MessagePackBinary.WriteString(writer, command.ServerName); + MessagePackBinary.WriteByte(writer, (byte)command.Action); + MessagePackBinary.WriteString(writer, command.GroupName); + MessagePackBinary.WriteString(writer, command.ConnectionId); + + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + + public byte[] WriteAck(int messageId) + { + // Written as a MessagePack 'arr' containing at least these items: + // * An 'int': The Id of the command being acknowledged. + // Any additional items are discarded. + + var writer = MemoryBufferWriter.Get(); + try + { + MessagePackBinary.WriteArrayHeader(writer, 1); + MessagePackBinary.WriteInt32(writer, messageId); + + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + + public RedisInvocation ReadInvocation(ReadOnlyMemory data) + { + // See WriteInvocation for the format + ValidateArraySize(ref data, 2, "Invocation"); + + // Read excluded Ids + IReadOnlyList excludedConnectionIds = null; + var idCount = MessagePackUtil.ReadArrayHeader(ref data); + if (idCount > 0) + { + var ids = new string[idCount]; + for (var i = 0; i < idCount; i++) + { + ids[i] = MessagePackUtil.ReadString(ref data); + } + + excludedConnectionIds = ids; + } + + // Read payload + var message = ReadSerializedHubMessage(ref data); + return new RedisInvocation(message, excludedConnectionIds); + } + + public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data) + { + // See WriteGroupCommand for format. + ValidateArraySize(ref data, 5, "GroupCommand"); + + var id = MessagePackUtil.ReadInt32(ref data); + var serverName = MessagePackUtil.ReadString(ref data); + var action = (GroupAction)MessagePackUtil.ReadByte(ref data); + var groupName = MessagePackUtil.ReadString(ref data); + var connectionId = MessagePackUtil.ReadString(ref data); + + return new RedisGroupCommand(id, serverName, action, groupName, connectionId); + } + + public int ReadAck(ReadOnlyMemory data) + { + // See WriteAck for format + ValidateArraySize(ref data, 1, "Ack"); + return MessagePackUtil.ReadInt32(ref data); + } + + private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message) + { + // Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str') + // and the values are the serialized blob (as a MessagePack 'bin'). + + MessagePackBinary.WriteMapHeader(stream, _protocols.Count); + + foreach (var protocol in _protocols) + { + MessagePackBinary.WriteString(stream, protocol.Name); + + var serialized = message.GetSerializedMessage(protocol); + var isArray = MemoryMarshal.TryGetArray(serialized, out var array); + Debug.Assert(isArray); + MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count); + } + } + + public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory data) + { + var count = MessagePackUtil.ReadMapHeader(ref data); + var serializations = new SerializedMessage[count]; + for (var i = 0; i < count; i++) + { + var protocol = MessagePackUtil.ReadString(ref data); + var serialized = MessagePackUtil.ReadBytes(ref data); + serializations[i] = new SerializedMessage(protocol, serialized); + } + + return new SerializedHubMessage(serializations); + } + + private static void ValidateArraySize(ref ReadOnlyMemory data, int expectedLength, string messageType) + { + var length = MessagePackUtil.ReadArrayHeader(ref data); + + if (length < expectedLength) + { + throw new InvalidDataException($"Insufficient items in {messageType} array."); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisSubscriptionManager.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisSubscriptionManager.cs new file mode 100644 index 0000000000..18863dbce3 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Internal/RedisSubscriptionManager.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal +{ + internal class RedisSubscriptionManager + { + private readonly ConcurrentDictionary _subscriptions = new ConcurrentDictionary(StringComparer.Ordinal); + private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1); + + public async Task AddSubscriptionAsync(string id, HubConnectionContext connection, Func subscribeMethod) + { + await _lock.WaitAsync(); + + try + { + var subscription = _subscriptions.GetOrAdd(id, _ => new HubConnectionStore()); + + subscription.Add(connection); + + // Subscribe once + if (subscription.Count == 1) + { + await subscribeMethod(id, subscription); + } + } + finally + { + _lock.Release(); + } + } + + public async Task RemoveSubscriptionAsync(string id, HubConnectionContext connection, Func unsubscribeMethod) + { + await _lock.WaitAsync(); + + try + { + if (!_subscriptions.TryGetValue(id, out var subscription)) + { + return; + } + + subscription.Remove(connection); + + if (subscription.Count == 0) + { + _subscriptions.TryRemove(id, out _); + await unsubscribeMethod(id); + } + } + finally + { + _lock.Release(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj new file mode 100644 index 0000000000..f1fabc764b --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/Microsoft.AspNetCore.SignalR.StackExchangeRedis.csproj @@ -0,0 +1,23 @@ + + + + Provides scale-out support for ASP.NET Core SignalR using a Redis server and the StackExchange.Redis client. + netstandard2.0 + + + + + + + + + + + + + + + + + + diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisDependencyInjectionExtensions.cs new file mode 100644 index 0000000000..6be7e6ff65 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisDependencyInjectionExtensions.cs @@ -0,0 +1,69 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.StackExchangeRedis; +using StackExchange.Redis; + +namespace Microsoft.Extensions.DependencyInjection +{ + /// + /// Extension methods for configuring Redis-based scale-out for a SignalR Server in an . + /// + public static class StackExchangeRedisDependencyInjectionExtensions + { + /// + /// Adds scale-out to a , using a shared Redis server. + /// + /// The . + /// The same instance of the for chaining. + public static ISignalRServerBuilder AddStackExchangeRedis(this ISignalRServerBuilder signalrBuilder) + { + return AddStackExchangeRedis(signalrBuilder, o => { }); + } + + /// + /// Adds scale-out to a , using a shared Redis server. + /// + /// The . + /// The connection string used to connect to the Redis server. + /// The same instance of the for chaining. + public static ISignalRServerBuilder AddStackExchangeRedis(this ISignalRServerBuilder signalrBuilder, string redisConnectionString) + { + return AddStackExchangeRedis(signalrBuilder, o => + { + o.Configuration = ConfigurationOptions.Parse(redisConnectionString); + }); + } + + /// + /// Adds scale-out to a , using a shared Redis server. + /// + /// The . + /// A callback to configure the Redis options. + /// The same instance of the for chaining. + public static ISignalRServerBuilder AddStackExchangeRedis(this ISignalRServerBuilder signalrBuilder, Action configure) + { + signalrBuilder.Services.Configure(configure); + signalrBuilder.Services.AddSingleton(typeof(HubLifetimeManager<>), typeof(RedisHubLifetimeManager<>)); + return signalrBuilder; + } + + /// + /// Adds scale-out to a , using a shared Redis server. + /// + /// The . + /// The connection string used to connect to the Redis server. + /// A callback to configure the Redis options. + /// The same instance of the for chaining. + public static ISignalRServerBuilder AddStackExchangeRedis(this ISignalRServerBuilder signalrBuilder, string redisConnectionString, Action configure) + { + return AddStackExchangeRedis(signalrBuilder, o => + { + o.Configuration = ConfigurationOptions.Parse(redisConnectionString); + configure(o); + }); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisHubLifetimeManager.cs new file mode 100644 index 0000000000..17b462bfd0 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisHubLifetimeManager.cs @@ -0,0 +1,593 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using StackExchange.Redis; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis +{ + public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable where THub : Hub + { + private readonly HubConnectionStore _connections = new HubConnectionStore(); + private readonly RedisSubscriptionManager _groups = new RedisSubscriptionManager(); + private readonly RedisSubscriptionManager _users = new RedisSubscriptionManager(); + private IConnectionMultiplexer _redisServerConnection; + private ISubscriber _bus; + private readonly ILogger _logger; + private readonly RedisOptions _options; + private readonly RedisChannels _channels; + private readonly string _serverName = GenerateServerName(); + private readonly RedisProtocol _protocol; + private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1); + + private readonly AckHandler _ackHandler; + private int _internalId; + + public RedisHubLifetimeManager(ILogger> logger, + IOptions options, + IHubProtocolResolver hubProtocolResolver) + { + _logger = logger; + _options = options.Value; + _ackHandler = new AckHandler(); + _channels = new RedisChannels(typeof(THub).FullName); + _protocol = new RedisProtocol(hubProtocolResolver.AllProtocols); + + RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName); + _ = EnsureRedisServerConnection(); + } + + public override async Task OnConnectedAsync(HubConnectionContext connection) + { + await EnsureRedisServerConnection(); + var feature = new RedisFeature(); + connection.Features.Set(feature); + + var connectionTask = Task.CompletedTask; + var userTask = Task.CompletedTask; + + _connections.Add(connection); + + connectionTask = SubscribeToConnection(connection); + + if (!string.IsNullOrEmpty(connection.UserIdentifier)) + { + userTask = SubscribeToUser(connection); + } + + await Task.WhenAll(connectionTask, userTask); + } + + public override Task OnDisconnectedAsync(HubConnectionContext connection) + { + _connections.Remove(connection); + + var tasks = new List(); + + var connectionChannel = _channels.Connection(connection.ConnectionId); + RedisLog.Unsubscribe(_logger, connectionChannel); + tasks.Add(_bus.UnsubscribeAsync(connectionChannel)); + + var feature = connection.Features.Get(); + var groupNames = feature.Groups; + + if (groupNames != null) + { + // Copy the groups to an array here because they get removed from this collection + // in RemoveFromGroupAsync + foreach (var group in groupNames.ToArray()) + { + // Use RemoveGroupAsyncCore because the connection is local and we don't want to + // accidentally go to other servers with our remove request. + tasks.Add(RemoveGroupAsyncCore(connection, group)); + } + } + + if (!string.IsNullOrEmpty(connection.UserIdentifier)) + { + tasks.Add(RemoveUserAsync(connection)); + } + + return Task.WhenAll(tasks); + } + + public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) + { + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.All, message); + } + + public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + return PublishAsync(_channels.All, message); + } + + public override Task SendConnectionAsync(string connectionId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + // If the connection is local we can skip sending the message through the bus since we require sticky connections. + // This also saves serializing and deserializing the message! + var connection = _connections[connectionId]; + if (connection != null) + { + return connection.WriteAsync(new InvocationMessage(methodName, args)).AsTask(); + } + + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.Connection(connectionId), message); + } + + public override Task SendGroupAsync(string groupName, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.Group(groupName), message); + } + + public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) + { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var message = _protocol.WriteInvocation(methodName, args, excludedConnectionIds); + return PublishAsync(_channels.Group(groupName), message); + } + + public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) + { + var message = _protocol.WriteInvocation(methodName, args); + return PublishAsync(_channels.User(userId), message); + } + + public override Task AddToGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var connection = _connections[connectionId]; + if (connection != null) + { + // short circuit if connection is on this server + return AddGroupAsyncCore(connection, groupName); + } + + return SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add); + } + + public override Task RemoveFromGroupAsync(string connectionId, string groupName, CancellationToken cancellationToken = default) + { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + + var connection = _connections[connectionId]; + if (connection != null) + { + // short circuit if connection is on this server + return RemoveGroupAsyncCore(connection, groupName); + } + + return SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove); + } + + public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (connectionIds == null) + { + throw new ArgumentNullException(nameof(connectionIds)); + } + + var publishTasks = new List(connectionIds.Count); + var payload = _protocol.WriteInvocation(methodName, args); + + foreach (var connectionId in connectionIds) + { + publishTasks.Add(PublishAsync(_channels.Connection(connectionId), payload)); + } + + return Task.WhenAll(publishTasks); + } + + public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (groupNames == null) + { + throw new ArgumentNullException(nameof(groupNames)); + } + var publishTasks = new List(groupNames.Count); + var payload = _protocol.WriteInvocation(methodName, args); + + foreach (var groupName in groupNames) + { + if (!string.IsNullOrEmpty(groupName)) + { + publishTasks.Add(PublishAsync(_channels.Group(groupName), payload)); + } + } + + return Task.WhenAll(publishTasks); + } + + public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) + { + if (userIds.Count > 0) + { + var payload = _protocol.WriteInvocation(methodName, args); + var publishTasks = new List(userIds.Count); + foreach (var userId in userIds) + { + if (!string.IsNullOrEmpty(userId)) + { + publishTasks.Add(PublishAsync(_channels.User(userId), payload)); + } + } + + return Task.WhenAll(publishTasks); + } + + return Task.CompletedTask; + } + + private async Task PublishAsync(string channel, byte[] payload) + { + await EnsureRedisServerConnection(); + RedisLog.PublishToChannel(_logger, channel); + await _bus.PublishAsync(channel, payload); + } + + private Task AddGroupAsyncCore(HubConnectionContext connection, string groupName) + { + var feature = connection.Features.Get(); + var groupNames = feature.Groups; + + lock (groupNames) + { + // Connection already in group + if (!groupNames.Add(groupName)) + { + return Task.CompletedTask; + } + } + + var groupChannel = _channels.Group(groupName); + return _groups.AddSubscriptionAsync(groupChannel, connection, SubscribeToGroupAsync); + } + + /// + /// This takes because we want to remove the connection from the + /// _connections list in OnDisconnectedAsync and still be able to remove groups with this method. + /// + private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string groupName) + { + var groupChannel = _channels.Group(groupName); + + await _groups.RemoveSubscriptionAsync(groupChannel, connection, channelName => + { + RedisLog.Unsubscribe(_logger, channelName); + return _bus.UnsubscribeAsync(channelName); + }); + + var feature = connection.Features.Get(); + var groupNames = feature.Groups; + if (groupNames != null) + { + lock (groupNames) + { + groupNames.Remove(groupName); + } + } + } + + private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) + { + var id = Interlocked.Increment(ref _internalId); + var ack = _ackHandler.CreateAck(id); + // Send Add/Remove Group to other servers and wait for an ack or timeout + var message = _protocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId)); + await PublishAsync(_channels.GroupManagement, message); + + await ack; + } + + private Task RemoveUserAsync(HubConnectionContext connection) + { + var userChannel = _channels.User(connection.UserIdentifier); + + return _users.RemoveSubscriptionAsync(userChannel, connection, channelName => + { + RedisLog.Unsubscribe(_logger, channelName); + return _bus.UnsubscribeAsync(channelName); + }); + } + + public void Dispose() + { + _bus?.UnsubscribeAll(); + _redisServerConnection?.Dispose(); + _ackHandler.Dispose(); + } + + private async Task SubscribeToAll() + { + RedisLog.Subscribing(_logger, _channels.All); + var channel = await _bus.SubscribeAsync(_channels.All); + channel.OnMessage(async channelMessage => + { + try + { + RedisLog.ReceivedFromChannel(_logger, _channels.All); + + var invocation = _protocol.ReadInvocation((byte[])channelMessage.Message); + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + if (invocation.ExcludedConnectionIds == null || !invocation.ExcludedConnectionIds.Contains(connection.ConnectionId)) + { + tasks.Add(connection.WriteAsync(invocation.Message).AsTask()); + } + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + RedisLog.FailedWritingMessage(_logger, ex); + } + }); + } + + private async Task SubscribeToGroupManagementChannel() + { + var channel = await _bus.SubscribeAsync(_channels.GroupManagement); + channel.OnMessage(async channelMessage => + { + try + { + var groupMessage = _protocol.ReadGroupCommand((byte[])channelMessage.Message); + + var connection = _connections[groupMessage.ConnectionId]; + if (connection == null) + { + // user not on this server + return; + } + + if (groupMessage.Action == GroupAction.Remove) + { + await RemoveGroupAsyncCore(connection, groupMessage.GroupName); + } + + if (groupMessage.Action == GroupAction.Add) + { + await AddGroupAsyncCore(connection, groupMessage.GroupName); + } + + // Send an ack to the server that sent the original command. + await PublishAsync(_channels.Ack(groupMessage.ServerName), _protocol.WriteAck(groupMessage.Id)); + } + catch (Exception ex) + { + RedisLog.InternalMessageFailed(_logger, ex); + } + }); + } + + private async Task SubscribeToAckChannel() + { + // Create server specific channel in order to send an ack to a single server + var channel = await _bus.SubscribeAsync(_channels.Ack(_serverName)); + channel.OnMessage(channelMessage => + { + var ackId = _protocol.ReadAck((byte[])channelMessage.Message); + + _ackHandler.TriggerAck(ackId); + }); + } + + private async Task SubscribeToConnection(HubConnectionContext connection) + { + var connectionChannel = _channels.Connection(connection.ConnectionId); + + RedisLog.Subscribing(_logger, connectionChannel); + var channel = await _bus.SubscribeAsync(connectionChannel); + channel.OnMessage(channelMessage => + { + var invocation = _protocol.ReadInvocation((byte[])channelMessage.Message); + return connection.WriteAsync(invocation.Message).AsTask(); + }); + } + + private Task SubscribeToUser(HubConnectionContext connection) + { + var userChannel = _channels.User(connection.UserIdentifier); + + return _users.AddSubscriptionAsync(userChannel, connection, async (channelName, subscriptions) => + { + RedisLog.Subscribing(_logger, channelName); + var channel = await _bus.SubscribeAsync(channelName); + channel.OnMessage(async channelMessage => + { + try + { + var invocation = _protocol.ReadInvocation((byte[])channelMessage.Message); + + var tasks = new List(); + foreach (var userConnection in subscriptions) + { + tasks.Add(userConnection.WriteAsync(invocation.Message).AsTask()); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + RedisLog.FailedWritingMessage(_logger, ex); + } + }); + }); + } + + private async Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore groupConnections) + { + RedisLog.Subscribing(_logger, groupChannel); + var channel = await _bus.SubscribeAsync(groupChannel); + channel.OnMessage(async (channelMessage) => + { + try + { + var invocation = _protocol.ReadInvocation((byte[])channelMessage.Message); + + var tasks = new List(); + foreach (var groupConnection in groupConnections) + { + if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true) + { + continue; + } + + tasks.Add(groupConnection.WriteAsync(invocation.Message).AsTask()); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + RedisLog.FailedWritingMessage(_logger, ex); + } + }); + } + + private async Task EnsureRedisServerConnection() + { + if (_redisServerConnection == null) + { + await _connectionLock.WaitAsync(); + try + { + if (_redisServerConnection == null) + { + var writer = new LoggerTextWriter(_logger); + _redisServerConnection = await _options.ConnectAsync(writer); + _bus = _redisServerConnection.GetSubscriber(); + + _redisServerConnection.ConnectionRestored += (_, e) => + { + // We use the subscription connection type + // Ignore messages from the interactive connection (avoids duplicates) + if (e.ConnectionType == ConnectionType.Interactive) + { + return; + } + + RedisLog.ConnectionRestored(_logger); + }; + + _redisServerConnection.ConnectionFailed += (_, e) => + { + // We use the subscription connection type + // Ignore messages from the interactive connection (avoids duplicates) + if (e.ConnectionType == ConnectionType.Interactive) + { + return; + } + + RedisLog.ConnectionFailed(_logger, e.Exception); + }; + + if (_redisServerConnection.IsConnected) + { + RedisLog.Connected(_logger); + } + else + { + RedisLog.NotConnected(_logger); + } + + await SubscribeToAll(); + await SubscribeToGroupManagementChannel(); + await SubscribeToAckChannel(); + } + } + finally + { + _connectionLock.Release(); + } + } + } + + private static string GenerateServerName() + { + // Use the machine name for convenient diagnostics, but add a guid to make it unique. + // Example: MyServerName_02db60e5fab243b890a847fa5c4dcb29 + return $"{Environment.MachineName}_{Guid.NewGuid():N}"; + } + + private class LoggerTextWriter : TextWriter + { + private readonly ILogger _logger; + + public LoggerTextWriter(ILogger logger) + { + _logger = logger; + } + + public override Encoding Encoding => Encoding.UTF8; + + public override void Write(char value) + { + + } + + public override void WriteLine(string value) + { + RedisLog.ConnectionMultiplexerMessage(_logger, value); + } + } + + private interface IRedisFeature + { + HashSet Groups { get; } + } + + private class RedisFeature : IRedisFeature + { + public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisOptions.cs b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisOptions.cs new file mode 100644 index 0000000000..b34c7fb117 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.StackExchangeRedis/RedisOptions.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using StackExchange.Redis; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis +{ + /// + /// Options used to configure . + /// + public class RedisOptions + { + /// + /// Gets or sets configuration options exposed by StackExchange.Redis. + /// + public ConfigurationOptions Configuration { get; set; } = new ConfigurationOptions + { + // Enable reconnecting by default + AbortOnConnectFail = false + }; + + /// + /// Gets or sets the Redis connection factory. + /// + public Func> ConnectionFactory { get; set; } + + internal async Task ConnectAsync(TextWriter log) + { + // Factory is publically settable. Assigning to a local variable before null check for thread safety. + var factory = ConnectionFactory; + if (factory == null) + { + // REVIEW: Should we do this? + if (Configuration.EndPoints.Count == 0) + { + Configuration.EndPoints.Add(IPAddress.Loopback, 0); + Configuration.SetDefaultPorts(); + } + + return await ConnectionMultiplexer.ConnectAsync(Configuration, log); + } + + return await factory(log); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs index f60e557ae5..10f087ac6a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs @@ -15,8 +15,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests { private static readonly string _exeSuffix = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty; - private static readonly string _dockerContainerName = "redisTestContainer"; - private static readonly string _dockerMonitorContainerName = _dockerContainerName + "Monitor"; + private static readonly string _dockerContainerName = "redisTestContainer-1x"; + private static readonly string _dockerMonitorContainerName = _dockerContainerName + "Monitor-1x"; private static readonly Lazy _instance = new Lazy(Create); public static Docker Default => _instance.Value; @@ -82,7 +82,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests // use static name 'redisTestContainer' so if the container doesn't get removed we don't keep adding more // use redis base docker image // 20 second timeout to allow redis image to be downloaded, should be a rare occurrence, only happening when a new version is released - RunProcessAndThrowIfFailed(_path, $"run --rm -p 6379:6379 --name {_dockerContainerName} -d redis", "redis", logger, TimeSpan.FromSeconds(20)); + RunProcessAndThrowIfFailed(_path, $"run --rm -p 6380:6379 --name {_dockerContainerName} -d redis", "redis", logger, TimeSpan.FromSeconds(20)); // inspect the redis docker image and extract the IPAddress. Necessary when running tests from inside a docker container, spinning up a new docker container for redis // outside the current container requires linking the networks (difficult to automate) or using the IP:Port combo @@ -90,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests output = output.Trim().Replace(Environment.NewLine, ""); // variable used by Startup.cs - Environment.SetEnvironmentVariable("REDIS_CONNECTION", $"{output}:6379"); + Environment.SetEnvironmentVariable("REDIS_CONNECTION-PREV", $"{output}:6379"); var (monitorProcess, monitorOutput) = RunProcess(_path, $"run -i --name {_dockerMonitorContainerName} --link {_dockerContainerName}:redis --rm redis redis-cli -h redis -p 6379", "redis monitor", logger); monitorProcess.StandardInput.WriteLine("MONITOR"); diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs index b3ab182d5d..89b960df71 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using System.Buffers; using System.Collections.Generic; diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Startup.cs index f760a4e869..e99631513e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Startup.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Startup.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests { // We start the servers before starting redis so we want to time them out ASAP options.Configuration.ConnectTimeout = 1; - options.Configuration.EndPoints.Add(Environment.GetEnvironmentVariable("REDIS_CONNECTION")); + options.Configuration.EndPoints.Add(Environment.GetEnvironmentVariable("REDIS_CONNECTION-PREV")); }); services.AddSingleton(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestConnectionMultiplexer.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/TestConnectionMultiplexer.cs similarity index 100% rename from test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestConnectionMultiplexer.cs rename to test/Microsoft.AspNetCore.SignalR.Redis.Tests/TestConnectionMultiplexer.cs diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Docker.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Docker.cs new file mode 100644 index 0000000000..d8722505c2 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Docker.cs @@ -0,0 +1,192 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Diagnostics; +using System.IO; +using System.Runtime.InteropServices; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class Docker + { + private static readonly string _exeSuffix = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? ".exe" : string.Empty; + + private static readonly string _dockerContainerName = "redisTestContainer"; + private static readonly string _dockerMonitorContainerName = _dockerContainerName + "Monitor"; + private static readonly Lazy _instance = new Lazy(Create); + + public static Docker Default => _instance.Value; + + private readonly string _path; + + public Docker(string path) + { + _path = path; + } + + private static Docker Create() + { + var location = GetDockerLocation(); + if (location == null) + { + return null; + } + + var docker = new Docker(location); + + docker.RunCommand("info --format '{{.OSType}}'", "docker info", out var output); + + if (!string.Equals(output.Trim('\'', '"', '\r', '\n', ' '), "linux")) + { + Console.WriteLine($"'docker info' output: {output}"); + return null; + } + + return docker; + } + + private static string GetDockerLocation() + { + // OSX + Docker + Redis don't play well together for some reason. We already have these tests covered on Linux and Windows + // So we are happy ignoring them on OSX + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + return null; + } + + foreach (var dir in Environment.GetEnvironmentVariable("PATH").Split(Path.PathSeparator)) + { + var candidate = Path.Combine(dir, "docker" + _exeSuffix); + if (File.Exists(candidate)) + { + return candidate; + } + } + + return null; + } + + public void Start(ILogger logger) + { + logger.LogInformation("Starting docker container"); + + // stop container if there is one, could be from a previous test run, ignore failures + RunProcessAndWait(_path, $"stop {_dockerMonitorContainerName}", "docker stop", logger, TimeSpan.FromSeconds(15), out var _); + RunProcessAndWait(_path, $"stop {_dockerContainerName}", "docker stop", logger, TimeSpan.FromSeconds(15), out var output); + + // create and run docker container, remove automatically when stopped, map 6379 from the container to 6379 localhost + // use static name 'redisTestContainer' so if the container doesn't get removed we don't keep adding more + // use redis base docker image + // 20 second timeout to allow redis image to be downloaded, should be a rare occurrence, only happening when a new version is released + RunProcessAndThrowIfFailed(_path, $"run --rm -p 6379:6379 --name {_dockerContainerName} -d redis", "redis", logger, TimeSpan.FromSeconds(20)); + + // inspect the redis docker image and extract the IPAddress. Necessary when running tests from inside a docker container, spinning up a new docker container for redis + // outside the current container requires linking the networks (difficult to automate) or using the IP:Port combo + RunProcessAndWait(_path, "inspect --format=\"{{range .NetworkSettings.Networks}}{{.IPAddress}}{{end}}\" " + _dockerContainerName, "docker ipaddress", logger, TimeSpan.FromSeconds(5), out output); + output = output.Trim().Replace(Environment.NewLine, ""); + + // variable used by Startup.cs + Environment.SetEnvironmentVariable("REDIS_CONNECTION", $"{output}:6379"); + + var (monitorProcess, monitorOutput) = RunProcess(_path, $"run -i --name {_dockerMonitorContainerName} --link {_dockerContainerName}:redis --rm redis redis-cli -h redis -p 6379", "redis monitor", logger); + monitorProcess.StandardInput.WriteLine("MONITOR"); + monitorProcess.StandardInput.Flush(); + } + + public void Stop(ILogger logger) + { + // Get logs from Redis container before stopping the container + RunProcessAndThrowIfFailed(_path, $"logs {_dockerContainerName}", "docker logs", logger, TimeSpan.FromSeconds(5)); + + logger.LogInformation("Stopping docker container"); + RunProcessAndWait(_path, $"stop {_dockerMonitorContainerName}", "docker stop", logger, TimeSpan.FromSeconds(15), out var _); + RunProcessAndWait(_path, $"stop {_dockerContainerName}", "docker stop", logger, TimeSpan.FromSeconds(15), out var _); + } + + public int RunCommand(string commandAndArguments, string prefix, out string output) => + RunCommand(commandAndArguments, prefix, NullLogger.Instance, out output); + + public int RunCommand(string commandAndArguments, string prefix, ILogger logger, out string output) + { + return RunProcessAndWait(_path, commandAndArguments, prefix, logger, TimeSpan.FromSeconds(5), out output); + } + + private static void RunProcessAndThrowIfFailed(string fileName, string arguments, string prefix, ILogger logger, TimeSpan timeout) + { + var exitCode = RunProcessAndWait(fileName, arguments, prefix, logger, timeout, out var output); + + if (exitCode != 0) + { + throw new Exception($"Command '{fileName} {arguments}' failed with exit code '{exitCode}'. Output:{Environment.NewLine}{output}"); + } + } + + private static int RunProcessAndWait(string fileName, string arguments, string prefix, ILogger logger, TimeSpan timeout, out string output) + { + var (process, lines) = RunProcess(fileName, arguments, prefix, logger); + + if (!process.WaitForExit((int)timeout.TotalMilliseconds)) + { + process.Close(); + logger.LogError("Closing process '{processName}' because it is running longer than the configured timeout.", fileName); + } + + // Need to WaitForExit without a timeout to guarantee the output stream has written everything + process.WaitForExit(); + + output = string.Join(Environment.NewLine, lines); + + return process.ExitCode; + } + + private static (Process, ConcurrentQueue) RunProcess(string fileName, string arguments, string prefix, ILogger logger) + { + var process = new Process + { + StartInfo = new ProcessStartInfo + { + FileName = fileName, + Arguments = arguments, + UseShellExecute = false, + RedirectStandardError = true, + RedirectStandardOutput = true, + RedirectStandardInput = true + }, + EnableRaisingEvents = true + }; + + var exitCode = 0; + var lines = new ConcurrentQueue(); + process.Exited += (_, __) => exitCode = process.ExitCode; + process.OutputDataReceived += (_, a) => + { + LogIfNotNull(logger.LogInformation, $"'{prefix}' stdout: {{0}}", a.Data); + lines.Enqueue(a.Data); + }; + process.ErrorDataReceived += (_, a) => + { + LogIfNotNull(logger.LogError, $"'{prefix}' stderr: {{0}}", a.Data); + lines.Enqueue(a.Data); + }; + + process.Start(); + + process.BeginErrorReadLine(); + process.BeginOutputReadLine(); + + return (process, lines); + } + + private static void LogIfNotNull(Action logger, string message, string data) + { + if (!string.IsNullOrEmpty(data)) + { + logger(message, new[] { data }); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/EchoHub.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/EchoHub.cs new file mode 100644 index 0000000000..bfde399d32 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/EchoHub.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class EchoHub : Hub + { + public string Echo(string message) + { + return message; + } + + public Task EchoGroup(string groupName, string message) + { + return Clients.Group(groupName).SendAsync("Echo", message); + } + + public Task EchoUser(string userName, string message) + { + return Clients.User(userName).SendAsync("Echo", message); + } + + public Task AddSelfToGroup(string groupName) + { + return Groups.AddToGroupAsync(Context.ConnectionId, groupName); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests.csproj b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests.csproj new file mode 100644 index 0000000000..ef017f8ba8 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests.csproj @@ -0,0 +1,27 @@ + + + + $(StandardTestTfms) + + + + + PreserveNewest + + + + + + + + + + + + + + + + + + diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisDependencyInjectionExtensionsTests.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisDependencyInjectionExtensionsTests.cs new file mode 100644 index 0000000000..14b4ca1026 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisDependencyInjectionExtensionsTests.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Net; +using System.Text; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class RedisDependencyInjectionExtensionsTests + { + // No need to go too deep with these tests, or we're just testing StackExchange.Redis again :). It's the one doing the parsing. + [Theory] + [InlineData("testredis.example.com", "testredis.example.com", 0, null, false)] + [InlineData("testredis.example.com:6380,ssl=True", "testredis.example.com", 6380, null, true)] + [InlineData("testredis.example.com:6380,password=hunter2,ssl=True", "testredis.example.com", 6380, "hunter2", true)] + public void AddRedisWithConnectionStringProperlyParsesOptions(string connectionString, string host, int port, string password, bool useSsl) + { + var services = new ServiceCollection(); + services.AddSignalR().AddStackExchangeRedis(connectionString); + var provider = services.BuildServiceProvider(); + + var options = provider.GetService>(); + Assert.NotNull(options.Value); + Assert.NotNull(options.Value.Configuration); + Assert.Equal(password, options.Value.Configuration.Password); + Assert.Collection(options.Value.Configuration.EndPoints, + endpoint => + { + var dnsEndpoint = Assert.IsType(endpoint); + Assert.Equal(host, dnsEndpoint.Host); + Assert.Equal(port, dnsEndpoint.Port); + }); + Assert.Equal(useSsl, options.Value.Configuration.Ssl); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisEndToEnd.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisEndToEnd.cs new file mode 100644 index 0000000000..50eb0dfb98 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisEndToEnd.cs @@ -0,0 +1,198 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Connections; +using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.AspNetCore.Testing.xunit; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + // Disable running server tests in parallel so server logs can accurately be captured per test + [CollectionDefinition(Name, DisableParallelization = true)] + public class RedisEndToEndTestsCollection : ICollectionFixture> + { + public const string Name = nameof(RedisEndToEndTestsCollection); + } + + [Collection(RedisEndToEndTestsCollection.Name)] + public class RedisEndToEndTests : VerifiableLoggedTest + { + private readonly RedisServerFixture _serverFixture; + + public RedisEndToEndTests(RedisServerFixture serverFixture, ITestOutputHelper output) : base(output) + { + if (serverFixture == null) + { + throw new ArgumentNullException(nameof(serverFixture)); + } + + _serverFixture = serverFixture; + } + + [ConditionalTheory] + [SkipIfDockerNotPresent] + [MemberData(nameof(TransportTypesAndProtocolTypes))] + public async Task HubConnectionCanSendAndReceiveMessages(HttpTransportType transportType, string protocolName) + { + using (StartVerifiableLog(out var loggerFactory, testName: + $"{nameof(HubConnectionCanSendAndReceiveMessages)}_{transportType.ToString()}_{protocolName}")) + { + var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); + + var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory); + + await connection.StartAsync().OrTimeout(); + var str = await connection.InvokeAsync("Echo", "Hello, World!").OrTimeout(); + + Assert.Equal("Hello, World!", str); + + await connection.DisposeAsync().OrTimeout(); + } + } + + [ConditionalTheory] + [SkipIfDockerNotPresent] + [MemberData(nameof(TransportTypesAndProtocolTypes))] + public async Task HubConnectionCanSendAndReceiveGroupMessages(HttpTransportType transportType, string protocolName) + { + using (StartVerifiableLog(out var loggerFactory, testName: + $"{nameof(HubConnectionCanSendAndReceiveGroupMessages)}_{transportType.ToString()}_{protocolName}")) + { + var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); + + var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory); + var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory); + + var tcs = new TaskCompletionSource(); + connection.On("Echo", message => tcs.TrySetResult(message)); + var tcs2 = new TaskCompletionSource(); + secondConnection.On("Echo", message => tcs2.TrySetResult(message)); + + var groupName = $"TestGroup_{transportType}_{protocolName}_{Guid.NewGuid()}"; + + await secondConnection.StartAsync().OrTimeout(); + await connection.StartAsync().OrTimeout(); + await connection.InvokeAsync("AddSelfToGroup", groupName).OrTimeout(); + await secondConnection.InvokeAsync("AddSelfToGroup", groupName).OrTimeout(); + await connection.InvokeAsync("EchoGroup", groupName, "Hello, World!").OrTimeout(); + + Assert.Equal("Hello, World!", await tcs.Task.OrTimeout()); + Assert.Equal("Hello, World!", await tcs2.Task.OrTimeout()); + + await connection.DisposeAsync().OrTimeout(); + } + } + + [ConditionalTheory(Skip= "https://github.com/aspnet/SignalR/issues/3058")] + [SkipIfDockerNotPresent] + [MemberData(nameof(TransportTypesAndProtocolTypes))] + public async Task CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser(HttpTransportType transportType, string protocolName) + { + using (StartVerifiableLog(out var loggerFactory, testName: + $"{nameof(CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser)}_{transportType.ToString()}_{protocolName}")) + { + var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); + + var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA"); + var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA"); + + var tcs = new TaskCompletionSource(); + connection.On("Echo", message => tcs.TrySetResult(message)); + var tcs2 = new TaskCompletionSource(); + secondConnection.On("Echo", message => tcs2.TrySetResult(message)); + + await secondConnection.StartAsync().OrTimeout(); + await connection.StartAsync().OrTimeout(); + await connection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout(); + + Assert.Equal("Hello, World!", await tcs.Task.OrTimeout()); + Assert.Equal("Hello, World!", await tcs2.Task.OrTimeout()); + + await connection.DisposeAsync().OrTimeout(); + await secondConnection.DisposeAsync().OrTimeout(); + } + } + + [ConditionalTheory] + [SkipIfDockerNotPresent] + [MemberData(nameof(TransportTypesAndProtocolTypes))] + public async Task CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects(HttpTransportType transportType, string protocolName) + { + // Regression test: + // When multiple connections from the same user were connected and one left, it used to unsubscribe from the user channel + // Now we keep track of users connections and only unsubscribe when no users are listening + using (StartVerifiableLog(out var loggerFactory, testName: + $"{nameof(CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects)}_{transportType.ToString()}_{protocolName}")) + { + var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); + + var firstConnection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA"); + var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA"); + + var tcs = new TaskCompletionSource(); + firstConnection.On("Echo", message => tcs.TrySetResult(message)); + + await secondConnection.StartAsync().OrTimeout(); + await firstConnection.StartAsync().OrTimeout(); + await secondConnection.DisposeAsync().OrTimeout(); + await firstConnection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout(); + + Assert.Equal("Hello, World!", await tcs.Task.OrTimeout()); + + await firstConnection.DisposeAsync().OrTimeout(); + } + } + + private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null) + { + var hubConnectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(url, transportType, httpConnectionOptions => + { + if (!string.IsNullOrEmpty(userName)) + { + httpConnectionOptions.Headers["UserName"] = userName; + } + }); + + hubConnectionBuilder.Services.AddSingleton(protocol); + + return hubConnectionBuilder.Build(); + } + + private static IEnumerable TransportTypes() + { + if (TestHelpers.IsWebSocketsSupported()) + { + yield return HttpTransportType.WebSockets; + } + yield return HttpTransportType.ServerSentEvents; + yield return HttpTransportType.LongPolling; + } + + public static IEnumerable TransportTypesAndProtocolTypes + { + get + { + foreach (var transport in TransportTypes()) + { + yield return new object[] { transport, "json" }; + + if (transport != HttpTransportType.ServerSentEvents) + { + yield return new object[] { transport, "messagepack" }; + } + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisHubLifetimeManagerTests.cs new file mode 100644 index 0000000000..c9137975b5 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisHubLifetimeManagerTests.cs @@ -0,0 +1,84 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Microsoft.AspNetCore.SignalR.Specification.Tests; +using Newtonsoft.Json.Linq; +using Newtonsoft.Json.Serialization; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + // Add ScaleoutHubLifetimeManagerTests back after https://github.com/aspnet/SignalR/issues/3088 + public class RedisHubLifetimeManagerTests + { + public class TestObject + { + public string TestProperty { get; set; } + } + + private RedisHubLifetimeManager CreateLifetimeManager(TestRedisServer server, MessagePackHubProtocolOptions messagePackOptions = null, JsonHubProtocolOptions jsonOptions = null) + { + var options = new RedisOptions() { ConnectionFactory = async (t) => await Task.FromResult(new TestConnectionMultiplexer(server)) }; + messagePackOptions = messagePackOptions ?? new MessagePackHubProtocolOptions(); + jsonOptions = jsonOptions ?? new JsonHubProtocolOptions(); + return new RedisHubLifetimeManager( + NullLogger>.Instance, + Options.Create(options), + new DefaultHubProtocolResolver(new IHubProtocol[] + { + new JsonHubProtocol(Options.Create(jsonOptions)), + new MessagePackHubProtocol(Options.Create(messagePackOptions)), + }, NullLogger.Instance)); + } + + [Fact(Skip = "https://github.com/aspnet/SignalR/issues/3088")] + public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary() + { + var server = new TestRedisServer(); + + var messagePackOptions = new MessagePackHubProtocolOptions(); + + var jsonOptions = new JsonHubProtocolOptions(); + jsonOptions.PayloadSerializerSettings.ContractResolver = new CamelCasePropertyNamesContractResolver(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + // The sending manager has serializer settings + var manager1 = CreateLifetimeManager(server, messagePackOptions, jsonOptions); + + // The receiving one doesn't matter because of how we serialize! + var manager2 = CreateLifetimeManager(server); + + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager1.OnConnectedAsync(connection1).OrTimeout(); + await manager2.OnConnectedAsync(connection2).OrTimeout(); + + await manager1.SendAllAsync("Hello", new object[] { new TestObject { TestProperty = "Foo" } }); + + var message = Assert.IsType(await client2.ReadAsync().OrTimeout()); + Assert.Equal("Hello", message.Target); + Assert.Collection( + message.Arguments, + arg0 => + { + var dict = Assert.IsType(arg0); + Assert.Collection(dict.Properties(), + prop => + { + Assert.Equal("testProperty", prop.Name); + Assert.Equal("Foo", prop.Value.Value()); + }); + }); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisProtocolTests.cs new file mode 100644 index 0000000000..b8b3a0bca9 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisProtocolTests.cs @@ -0,0 +1,202 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; +using Microsoft.AspNetCore.SignalR.Tests; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class RedisProtocolTests + { + private static Dictionary> _ackTestData = new[] + { + CreateTestData("Zero", 0, 0x91, 0x00), + CreateTestData("Fixnum", 42, 0x91, 0x2A), + CreateTestData("Uint8", 180, 0x91, 0xCC, 0xB4), + CreateTestData("Uint16", 384, 0x91, 0xCD, 0x01, 0x80), + CreateTestData("Uint32", 70_000, 0x91, 0xCE, 0x00, 0x01, 0x11, 0x70), + }.ToDictionary(t => t.Name); + + public static IEnumerable AckTestData = _ackTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(AckTestData))] + public void ParseAck(string testName) + { + var testData = _ackTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var decoded = protocol.ReadAck(testData.Encoded); + + Assert.Equal(testData.Decoded, decoded); + } + + [Theory] + [MemberData(nameof(AckTestData))] + public void WriteAck(string testName) + { + var testData = _ackTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var encoded = protocol.WriteAck(testData.Decoded); + + Assert.Equal(testData.Encoded, encoded); + } + + private static Dictionary> _groupCommandTestData = new[] + { + CreateTestData("GroupAdd", new RedisGroupCommand(42, "S", GroupAction.Add, "G", "C" ), 0x95, 0x2A, 0xA1, (byte)'S', 0x01, 0xA1, (byte)'G', 0xA1, (byte)'C'), + CreateTestData("GroupRemove", new RedisGroupCommand(42, "S", GroupAction.Remove, "G", "C" ), 0x95, 0x2A, 0xA1, (byte)'S', 0x02, 0xA1, (byte)'G', 0xA1, (byte)'C'), + }.ToDictionary(t => t.Name); + + public static IEnumerable GroupCommandTestData = _groupCommandTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(GroupCommandTestData))] + public void ParseGroupCommand(string testName) + { + var testData = _groupCommandTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var decoded = protocol.ReadGroupCommand(testData.Encoded); + + Assert.Equal(testData.Decoded.Id, decoded.Id); + Assert.Equal(testData.Decoded.ServerName, decoded.ServerName); + Assert.Equal(testData.Decoded.Action, decoded.Action); + Assert.Equal(testData.Decoded.GroupName, decoded.GroupName); + Assert.Equal(testData.Decoded.ConnectionId, decoded.ConnectionId); + } + + [Theory] + [MemberData(nameof(GroupCommandTestData))] + public void WriteGroupCommand(string testName) + { + var testData = _groupCommandTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var encoded = protocol.WriteGroupCommand(testData.Decoded); + + Assert.Equal(testData.Encoded, encoded); + } + + // The actual invocation message doesn't matter + private static InvocationMessage _testMessage = new InvocationMessage("target", Array.Empty()); + + // We use a func so we are guaranteed to get a new SerializedHubMessage for each test + private static Dictionary>> _invocationTestData = new[] + { + CreateTestData>( + "NoExcludedIds", + () => new RedisInvocation(new SerializedHubMessage(_testMessage), null), + 0x92, + 0x90, + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + CreateTestData>( + "OneExcludedId", + () => new RedisInvocation(new SerializedHubMessage(_testMessage), new [] { "a" }), + 0x92, + 0x91, + 0xA1, (byte)'a', + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + CreateTestData>( + "ManyExcludedIds", + () => new RedisInvocation(new SerializedHubMessage(_testMessage), new [] { "a", "b", "c", "d", "e", "f" }), + 0x92, + 0x96, + 0xA1, (byte)'a', + 0xA1, (byte)'b', + 0xA1, (byte)'c', + 0xA1, (byte)'d', + 0xA1, (byte)'e', + 0xA1, (byte)'f', + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + }.ToDictionary(t => t.Name); + + public static IEnumerable InvocationTestData = _invocationTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(InvocationTestData))] + public void ParseInvocation(string testName) + { + var testData = _invocationTestData[testName]; + var hubProtocols = new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }; + var protocol = new RedisProtocol(hubProtocols); + + var expected = testData.Decoded(); + + var decoded = protocol.ReadInvocation(testData.Encoded); + + Assert.Equal(expected.ExcludedConnectionIds, decoded.ExcludedConnectionIds); + + // Verify the deserialized object has the necessary serialized forms + foreach (var hubProtocol in hubProtocols) + { + Assert.Equal( + expected.Message.GetSerializedMessage(hubProtocol).ToArray(), + decoded.Message.GetSerializedMessage(hubProtocol).ToArray()); + + var writtenMessages = hubProtocol.GetWrittenMessages(); + Assert.Collection(writtenMessages, + actualMessage => + { + var invocation = Assert.IsType(actualMessage); + Assert.Same(_testMessage.Target, invocation.Target); + Assert.Same(_testMessage.Arguments, invocation.Arguments); + }); + } + } + + [Theory] + [MemberData(nameof(InvocationTestData))] + public void WriteInvocation(string testName) + { + var testData = _invocationTestData[testName]; + var protocol = new RedisProtocol(new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }); + + // Actual invocation doesn't matter because we're using a dummy hub protocol. + // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. + var expected = testData.Decoded(); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + + Assert.Equal(testData.Encoded, encoded); + } + + // Create ProtocolTestData using the Power of Type Inference(TM). + private static ProtocolTestData CreateTestData(string name, T decoded, params byte[] encoded) + => new ProtocolTestData(name, decoded, encoded); + + public class ProtocolTestData + { + public string Name { get; } + public T Decoded { get; } + public byte[] Encoded { get; } + + public ProtocolTestData(string name, T decoded, byte[] encoded) + { + Name = name; + Decoded = decoded; + Encoded = encoded; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisServerFixture.cs new file mode 100644 index 0000000000..0d0da42680 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/RedisServerFixture.cs @@ -0,0 +1,64 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class RedisServerFixture : IDisposable + where TStartup : class + { + public ServerFixture FirstServer { get; private set; } + public ServerFixture SecondServer { get; private set; } + + private readonly ILogger _logger; + private readonly ILoggerFactory _loggerFactory; + private readonly IDisposable _logToken; + + public RedisServerFixture() + { + // Docker is not available on the machine, tests using this fixture + // should be using SkipIfDockerNotPresentAttribute and will be skipped. + if (Docker.Default == null) + { + return; + } + + var testLog = AssemblyTestLog.ForAssembly(typeof(RedisServerFixture).Assembly); + _logToken = testLog.StartTestLog(null, $"{nameof(RedisServerFixture)}_{typeof(TStartup).Name}", out _loggerFactory, LogLevel.Trace, "RedisServerFixture"); + _logger = _loggerFactory.CreateLogger>(); + + Docker.Default.Start(_logger); + + FirstServer = StartServer(); + SecondServer = StartServer(); + } + + private ServerFixture StartServer() + { + try + { + return new ServerFixture(_loggerFactory); + } + catch (Exception ex) + { + _logger.LogError(ex, "Server failed to start."); + throw; + } + } + + public void Dispose() + { + if (Docker.Default != null) + { + FirstServer.Dispose(); + SecondServer.Dispose(); + Docker.Default.Stop(_logger); + _logToken.Dispose(); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/SkipIfDockerNotPresentAttribute.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/SkipIfDockerNotPresentAttribute.cs new file mode 100644 index 0000000000..bf6c9ad91a --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/SkipIfDockerNotPresentAttribute.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Testing.xunit; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + [AttributeUsage(AttributeTargets.Method, AllowMultiple = false)] + public class SkipIfDockerNotPresentAttribute : Attribute, ITestCondition + { + public bool IsMet => CheckDocker(); + public string SkipReason { get; private set; } = "Docker is not available"; + + private bool CheckDocker() + { + if (Docker.Default != null) + { + // Docker is present, but is it working? + if (Docker.Default.RunCommand("ps", "docker ps", out var output) != 0) + { + SkipReason = $"Failed to invoke test command 'docker ps'. Output: {output}"; + } + else + { + // We have a docker + return true; + } + } + else + { + SkipReason = "Docker is not installed on the host machine."; + } + + // If we get here, we don't have a docker + return false; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Startup.cs new file mode 100644 index 0000000000..56bf354306 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/Startup.cs @@ -0,0 +1,51 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests +{ + public class Startup + { + public void ConfigureServices(IServiceCollection services) + { + services.AddSignalR(options => + { + options.EnableDetailedErrors = true; + }) + .AddMessagePackProtocol() + .AddStackExchangeRedis(options => + { + // We start the servers before starting redis so we want to time them out ASAP + options.Configuration.ConnectTimeout = 1; + options.Configuration.EndPoints.Add(Environment.GetEnvironmentVariable("REDIS_CONNECTION")); + }); + + services.AddSingleton(); + } + + public void Configure(IApplicationBuilder app, IHostingEnvironment env) + { + app.UseSignalR(options => options.MapHub("/echo")); + } + + private class UserNameIdProvider : IUserIdProvider + { + public string GetUserId(HubConnectionContext connection) + { + // This is an AWFUL way to authenticate users! We're just using it for test purposes. + var userNameHeader = connection.GetHttpContext().Request.Headers["UserName"]; + if (!StringValues.IsNullOrEmpty(userNameHeader)) + { + return userNameHeader; + } + + return null; + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/TestConnectionMultiplexer.cs b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/TestConnectionMultiplexer.cs new file mode 100644 index 0000000000..664316c469 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests/TestConnectionMultiplexer.cs @@ -0,0 +1,376 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.IO; +using System.Net; +using System.Threading.Tasks; +using StackExchange.Redis; +using StackExchange.Redis.Profiling; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class TestConnectionMultiplexer : IConnectionMultiplexer + { + public string ClientName => throw new NotImplementedException(); + + public string Configuration => throw new NotImplementedException(); + + public int TimeoutMilliseconds => throw new NotImplementedException(); + + public long OperationCount => throw new NotImplementedException(); + + public bool PreserveAsyncOrder { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public bool IsConnected => true; + + public bool IncludeDetailInExceptions { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public int StormLogThreshold { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + + public bool IsConnecting => throw new NotImplementedException(); + + public event EventHandler ErrorMessage + { + add { } + remove { } + } + + public event EventHandler ConnectionFailed + { + add { } + remove { } + } + + public event EventHandler InternalError + { + add { } + remove { } + } + + public event EventHandler ConnectionRestored + { + add { } + remove { } + } + + public event EventHandler ConfigurationChanged + { + add { } + remove { } + } + + public event EventHandler ConfigurationChangedBroadcast + { + add { } + remove { } + } + + public event EventHandler HashSlotMoved + { + add { } + remove { } + } + + private readonly ISubscriber _subscriber; + + public TestConnectionMultiplexer(TestRedisServer server) + { + _subscriber = new TestSubscriber(server); + } + + public void BeginProfiling(object forContext) + { + throw new NotImplementedException(); + } + + public void Close(bool allowCommandsToComplete = true) + { + throw new NotImplementedException(); + } + + public Task CloseAsync(bool allowCommandsToComplete = true) + { + throw new NotImplementedException(); + } + + public bool Configure(TextWriter log = null) + { + throw new NotImplementedException(); + } + + public Task ConfigureAsync(TextWriter log = null) + { + throw new NotImplementedException(); + } + + public void Dispose() + { + throw new NotImplementedException(); + } + + public ProfiledCommandEnumerable FinishProfiling(object forContext, bool allowCleanupSweep = true) + { + throw new NotImplementedException(); + } + + public ServerCounters GetCounters() + { + throw new NotImplementedException(); + } + + public IDatabase GetDatabase(int db = -1, object asyncState = null) + { + throw new NotImplementedException(); + } + + public EndPoint[] GetEndPoints(bool configuredOnly = false) + { + throw new NotImplementedException(); + } + + public IServer GetServer(string host, int port, object asyncState = null) + { + throw new NotImplementedException(); + } + + public IServer GetServer(string hostAndPort, object asyncState = null) + { + throw new NotImplementedException(); + } + + public IServer GetServer(IPAddress host, int port) + { + throw new NotImplementedException(); + } + + public IServer GetServer(EndPoint endpoint, object asyncState = null) + { + throw new NotImplementedException(); + } + + public string GetStatus() + { + throw new NotImplementedException(); + } + + public void GetStatus(TextWriter log) + { + throw new NotImplementedException(); + } + + public string GetStormLog() + { + throw new NotImplementedException(); + } + + public ISubscriber GetSubscriber(object asyncState = null) + { + return _subscriber; + } + + public int HashSlot(RedisKey key) + { + throw new NotImplementedException(); + } + + public long PublishReconfigure(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task PublishReconfigureAsync(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public void ResetStormLog() + { + throw new NotImplementedException(); + } + + public void Wait(Task task) + { + throw new NotImplementedException(); + } + + public T Wait(Task task) + { + throw new NotImplementedException(); + } + + public void WaitAll(params Task[] tasks) + { + throw new NotImplementedException(); + } + + public void RegisterProfiler(Func profilingSessionProvider) + { + throw new NotImplementedException(); + } + + public int GetHashSlot(RedisKey key) + { + throw new NotImplementedException(); + } + + public void ExportConfiguration(Stream destination, ExportOptions options = (ExportOptions)(-1)) + { + throw new NotImplementedException(); + } + } + + public class TestRedisServer + { + private readonly ConcurrentDictionary>> _subscriptions = + new ConcurrentDictionary>>(); + + public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) + { + if (_subscriptions.TryGetValue(channel, out var handlers)) + { + foreach (var handler in handlers) + { + handler(channel, message); + } + } + + return handlers != null ? handlers.Count : 0; + } + + public void Subscribe(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) + { + _subscriptions.AddOrUpdate(channel, _ => new List> { handler }, (_, list) => + { + list.Add(handler); + return list; + }); + } + + public void Unsubscribe(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) + { + if (_subscriptions.TryGetValue(channel, out var list)) + { + list.Remove(handler); + } + } + } + + public class TestSubscriber : ISubscriber + { + private readonly TestRedisServer _server; + public ConnectionMultiplexer Multiplexer => throw new NotImplementedException(); + + IConnectionMultiplexer IRedisAsync.Multiplexer => throw new NotImplementedException(); + + public TestSubscriber(TestRedisServer server) + { + _server = server; + } + + public EndPoint IdentifyEndpoint(RedisChannel channel, CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task IdentifyEndpointAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public bool IsConnected(RedisChannel channel = default) + { + throw new NotImplementedException(); + } + + public TimeSpan Ping(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task PingAsync(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) + { + return _server.Publish(channel, message, flags); + } + + public async Task PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None) + { + await Task.Yield(); + return Publish(channel, message, flags); + } + + public void Subscribe(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) + { + _server.Subscribe(channel, handler, flags); + } + + public Task SubscribeAsync(RedisChannel channel, Action handler, CommandFlags flags = CommandFlags.None) + { + Subscribe(channel, handler, flags); + return Task.CompletedTask; + } + + public EndPoint SubscribedEndpoint(RedisChannel channel) + { + throw new NotImplementedException(); + } + + public bool TryWait(Task task) + { + throw new NotImplementedException(); + } + + public void Unsubscribe(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) + { + _server.Unsubscribe(channel, handler, flags); + } + + public void UnsubscribeAll(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task UnsubscribeAllAsync(CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task UnsubscribeAsync(RedisChannel channel, Action handler = null, CommandFlags flags = CommandFlags.None) + { + Unsubscribe(channel, handler, flags); + return Task.CompletedTask; + } + + public void Wait(Task task) + { + throw new NotImplementedException(); + } + + public T Wait(Task task) + { + throw new NotImplementedException(); + } + + public void WaitAll(params Task[] tasks) + { + throw new NotImplementedException(); + } + + public ChannelMessageQueue Subscribe(RedisChannel channel, CommandFlags flags = CommandFlags.None) + { + throw new NotImplementedException(); + } + + public Task SubscribeAsync(RedisChannel channel, CommandFlags flags = CommandFlags.None) + { + var t = Subscribe(channel, flags); + return Task.FromResult(t); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj index b6884a9fdc..976041ab7c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj @@ -22,7 +22,6 @@ -