Preserialize for all formats when sending through Redis (#1843)
This commit is contained in:
parent
39f693b9ed
commit
19b2fea0d8
|
|
@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
_successHubProtocolResolver = new TestHubProtocolResolver(new JsonHubProtocol());
|
||||
_failureHubProtocolResolver = new TestHubProtocolResolver(null);
|
||||
_userIdProvider = new TestUserIdProvider();
|
||||
_supportedProtocols = new List<string> {"json"};
|
||||
_supportedProtocols = new List<string> { "json" };
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
|
|
@ -83,8 +83,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
{
|
||||
private readonly IHubProtocol _instance;
|
||||
|
||||
public IReadOnlyList<IHubProtocol> AllProtocols { get; }
|
||||
|
||||
public TestHubProtocolResolver(IHubProtocol instance)
|
||||
{
|
||||
AllProtocols = new[] { instance };
|
||||
_instance = instance;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<OutputType>Exe</OutputType>
|
||||
|
|
@ -12,6 +12,8 @@
|
|||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Client.Core\Microsoft.AspNetCore.SignalR.Client.Core.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Protocols.MsgPack\Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj" />
|
||||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Redis\Microsoft.AspNetCore.SignalR.Redis.csproj" />
|
||||
<ProjectReference Include="..\..\test\Microsoft.AspNetCore.SignalR.Tests.Utils\Microsoft.AspNetCore.SignalR.Tests.Utils.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,203 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using BenchmarkDotNet.Attributes;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Redis;
|
||||
using Microsoft.AspNetCore.SignalR.Tests;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
||||
{
|
||||
public class RedisHubLifetimeManagerBenchmark
|
||||
{
|
||||
private RedisHubLifetimeManager<TestHub> _manager1;
|
||||
private RedisHubLifetimeManager<TestHub> _manager2;
|
||||
private TestClient[] _clients;
|
||||
private object[] _args;
|
||||
private List<string> _excludedIds = new List<string>();
|
||||
private List<string> _sendIds = new List<string>();
|
||||
private List<string> _groups = new List<string>();
|
||||
private List<string> _users = new List<string>();
|
||||
|
||||
private const int ClientCount = 20;
|
||||
|
||||
[Params(2, 20)]
|
||||
public int ProtocolCount { get; set; }
|
||||
|
||||
[GlobalSetup]
|
||||
public void GlobalSetup()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
var logger = NullLogger<RedisHubLifetimeManager<TestHub>>.Instance;
|
||||
var protocols = GenerateProtocols(ProtocolCount).ToArray();
|
||||
var options = Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer(server)
|
||||
});
|
||||
var resolver = new DefaultHubProtocolResolver(protocols, NullLogger<DefaultHubProtocolResolver>.Instance);
|
||||
|
||||
_manager1 = new RedisHubLifetimeManager<TestHub>(logger, options, resolver);
|
||||
_manager2 = new RedisHubLifetimeManager<TestHub>(logger, options, resolver);
|
||||
|
||||
async Task ConnectClient(TestClient client, IHubProtocol protocol, string userId, string group)
|
||||
{
|
||||
await _manager2.OnConnectedAsync(HubConnectionContextUtils.Create(client.Connection, protocol, userId));
|
||||
await _manager2.AddGroupAsync(client.Connection.ConnectionId, "Everyone");
|
||||
await _manager2.AddGroupAsync(client.Connection.ConnectionId, group);
|
||||
}
|
||||
|
||||
// Connect clients
|
||||
_clients = new TestClient[ClientCount];
|
||||
var tasks = new Task[ClientCount];
|
||||
for (var i = 0; i < _clients.Length; i++)
|
||||
{
|
||||
var protocol = protocols[i % ProtocolCount];
|
||||
_clients[i] = new TestClient(protocol: protocol);
|
||||
|
||||
string group;
|
||||
string user;
|
||||
if ((i % 2) == 0)
|
||||
{
|
||||
group = "Evens";
|
||||
user = "EvenUser";
|
||||
_excludedIds.Add(_clients[i].Connection.ConnectionId);
|
||||
}
|
||||
else
|
||||
{
|
||||
group = "Odds";
|
||||
user = "OddUser";
|
||||
_sendIds.Add(_clients[i].Connection.ConnectionId);
|
||||
}
|
||||
|
||||
tasks[i] = ConnectClient(_clients[i], protocol, user, group);
|
||||
_ = ConsumeAsync(_clients[i]);
|
||||
}
|
||||
|
||||
Task.WaitAll(tasks);
|
||||
|
||||
_groups.Add("Evens");
|
||||
_groups.Add("Odds");
|
||||
_users.Add("EvenUser");
|
||||
_users.Add("OddUser");
|
||||
|
||||
_args = new object[] {"Foo"};
|
||||
}
|
||||
|
||||
private IEnumerable<IHubProtocol> GenerateProtocols(int protocolCount)
|
||||
{
|
||||
for (var i = 0; i < protocolCount; i++)
|
||||
{
|
||||
yield return ((i % 2) == 0)
|
||||
? new WrappedHubProtocol($"json_{i}", new JsonHubProtocol())
|
||||
: new WrappedHubProtocol($"msgpack_{i}", new MessagePackHubProtocol());
|
||||
}
|
||||
}
|
||||
|
||||
private async Task ConsumeAsync(TestClient testClient)
|
||||
{
|
||||
while (await testClient.ReadAsync() != null)
|
||||
{
|
||||
// Just dump the message
|
||||
}
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendAll()
|
||||
{
|
||||
await _manager1.SendAllAsync("Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendGroup()
|
||||
{
|
||||
await _manager1.SendGroupAsync("Everyone", "Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendUser()
|
||||
{
|
||||
await _manager1.SendUserAsync("EvenUser", "Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendConnection()
|
||||
{
|
||||
await _manager1.SendConnectionAsync(_clients[0].Connection.ConnectionId, "Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendConnections()
|
||||
{
|
||||
await _manager1.SendConnectionsAsync(_sendIds, "Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendAllExcept()
|
||||
{
|
||||
await _manager1.SendAllExceptAsync("Test", _args, _excludedIds);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendGroupExcept()
|
||||
{
|
||||
await _manager1.SendGroupExceptAsync("Everyone", "Test", _args, _excludedIds);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendGroups()
|
||||
{
|
||||
await _manager1.SendGroupsAsync(_groups, "Test", _args);
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public async Task SendUsers()
|
||||
{
|
||||
await _manager1.SendUsersAsync(_users, "Test", _args);
|
||||
}
|
||||
|
||||
public class TestHub : Hub
|
||||
{
|
||||
}
|
||||
|
||||
private class WrappedHubProtocol : IHubProtocol
|
||||
{
|
||||
private readonly string _name;
|
||||
private readonly IHubProtocol _innerProtocol;
|
||||
|
||||
public string Name => _name;
|
||||
|
||||
public int Version => _innerProtocol.Version;
|
||||
|
||||
public TransferFormat TransferFormat => _innerProtocol.TransferFormat;
|
||||
|
||||
public WrappedHubProtocol(string name, IHubProtocol innerProtocol)
|
||||
{
|
||||
_name = name;
|
||||
_innerProtocol = innerProtocol;
|
||||
}
|
||||
|
||||
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
|
||||
{
|
||||
return _innerProtocol.TryParseMessage(ref input, binder, out message);
|
||||
}
|
||||
|
||||
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
|
||||
{
|
||||
_innerProtocol.WriteMessage(message, output);
|
||||
}
|
||||
|
||||
public bool IsVersionSupported(int version)
|
||||
{
|
||||
return _innerProtocol.IsVersionSupported(version);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -28,7 +28,7 @@ namespace SignalRSamples
|
|||
{
|
||||
options.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel;
|
||||
});
|
||||
// .AddRedis();
|
||||
//.AddRedis();
|
||||
|
||||
services.AddCors(o =>
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,87 +1,9 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
public abstract class HubMessage
|
||||
{
|
||||
protected HubMessage()
|
||||
{
|
||||
}
|
||||
|
||||
private object _lock = new object();
|
||||
private List<SerializedMessage> _serializedMessages;
|
||||
private SerializedMessage _message1;
|
||||
private SerializedMessage _message2;
|
||||
|
||||
public byte[] WriteMessage(IHubProtocol protocol)
|
||||
{
|
||||
// REVIEW: Revisit lock
|
||||
// Could use a reader/writer lock to allow the loop to take place in "unlocked" code
|
||||
// Or, could use a fixed size array and Interlocked to manage it.
|
||||
// Or, Immutable *ducks*
|
||||
|
||||
lock (_lock)
|
||||
{
|
||||
if (ReferenceEquals(_message1.Protocol, protocol))
|
||||
{
|
||||
return _message1.Message;
|
||||
}
|
||||
|
||||
if (ReferenceEquals(_message2.Protocol, protocol))
|
||||
{
|
||||
return _message2.Message;
|
||||
}
|
||||
|
||||
for (var i = 0; i < _serializedMessages?.Count; i++)
|
||||
{
|
||||
if (ReferenceEquals(_serializedMessages[i].Protocol, protocol))
|
||||
{
|
||||
return _serializedMessages[i].Message;
|
||||
}
|
||||
}
|
||||
|
||||
var bytes = protocol.WriteToArray(this);
|
||||
|
||||
if (_message1.Protocol == null)
|
||||
{
|
||||
_message1 = new SerializedMessage(protocol, bytes);
|
||||
}
|
||||
else if (_message2.Protocol == null)
|
||||
{
|
||||
_message2 = new SerializedMessage(protocol, bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (_serializedMessages == null)
|
||||
{
|
||||
_serializedMessages = new List<SerializedMessage>();
|
||||
}
|
||||
|
||||
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
|
||||
// So we cap how many caches we store and worst case just serialize the message for every connection
|
||||
if (_serializedMessages.Count < 10)
|
||||
{
|
||||
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
|
||||
}
|
||||
}
|
||||
|
||||
return bytes;
|
||||
}
|
||||
}
|
||||
|
||||
private readonly struct SerializedMessage
|
||||
{
|
||||
public readonly IHubProtocol Protocol;
|
||||
public readonly byte[] Message;
|
||||
|
||||
public SerializedMessage(IHubProtocol protocol, byte[] message)
|
||||
{
|
||||
Protocol = protocol;
|
||||
Message = message;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
|
|
@ -210,9 +211,9 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private InvocationMessage CreateInvocationMessage(string methodName, object[] args)
|
||||
private SerializedHubMessage CreateInvocationMessage(string methodName, object[] args)
|
||||
{
|
||||
return new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args);
|
||||
return new SerializedHubMessage(new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args));
|
||||
}
|
||||
|
||||
public override Task SendUserAsync(string userId, string methodName, object[] args)
|
||||
|
|
|
|||
|
|
@ -6,19 +6,16 @@ using System.Buffers;
|
|||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Net;
|
||||
using System.Runtime.ExceptionServices;
|
||||
using System.Security.Claims;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.Connections.Features;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
using Microsoft.AspNetCore.SignalR.Core;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
||||
|
|
@ -73,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public virtual PipeReader Input => _connectionContext.Transport.Input;
|
||||
|
||||
public string UserIdentifier { get; private set; }
|
||||
public string UserIdentifier { get; set; }
|
||||
|
||||
internal virtual IHubProtocol Protocol { get; set; }
|
||||
|
||||
|
|
@ -84,7 +81,36 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public virtual ValueTask WriteAsync(HubMessage message)
|
||||
{
|
||||
// We were unable to get the lock so take the slow async path of waiting for the semaphore
|
||||
// Try to grab the lock synchronously, if we fail, go to the slower path
|
||||
if (!_writeLock.Wait(0))
|
||||
{
|
||||
return new ValueTask(WriteSlowAsync(message));
|
||||
}
|
||||
|
||||
// This method should never throw synchronously
|
||||
var task = WriteCore(message);
|
||||
|
||||
// The write didn't complete synchronously so await completion
|
||||
if (!task.IsCompletedSuccessfully)
|
||||
{
|
||||
return new ValueTask(CompleteWriteAsync(task));
|
||||
}
|
||||
|
||||
// Otherwise, release the lock acquired when entering WriteAsync
|
||||
_writeLock.Release();
|
||||
|
||||
return default;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This method is designed to support the framework and is not intended to be used by application code. Writes a pre-serialized message to the
|
||||
/// connection.
|
||||
/// </summary>
|
||||
/// <param name="message">The serialization cache to use.</param>
|
||||
/// <returns></returns>
|
||||
public virtual ValueTask WriteAsync(SerializedHubMessage message)
|
||||
{
|
||||
// Try to grab the lock synchronously, if we fail, go to the slower path
|
||||
if (!_writeLock.Wait(0))
|
||||
{
|
||||
return new ValueTask(WriteSlowAsync(message));
|
||||
|
|
@ -109,14 +135,28 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
{
|
||||
try
|
||||
{
|
||||
// This will internally cache the buffer for each unique HubProtocol
|
||||
// So that we don't serialize the HubMessage for every single connection
|
||||
var buffer = message.WriteMessage(Protocol);
|
||||
// We know that we are only writing this message to one receiver, so we can
|
||||
// write it without caching.
|
||||
Protocol.WriteMessage(message, _connectionContext.Transport.Output);
|
||||
|
||||
var output = _connectionContext.Transport.Output;
|
||||
output.Write(buffer);
|
||||
return _connectionContext.Transport.Output.FlushAsync();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
|
||||
return output.FlushAsync();
|
||||
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
|
||||
}
|
||||
}
|
||||
|
||||
private ValueTask<FlushResult> WriteCore(SerializedHubMessage message)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Grab a preserialized buffer for this protocol.
|
||||
var buffer = message.GetSerializedMessage(Protocol);
|
||||
|
||||
return _connectionContext.Transport.Output.WriteAsync(buffer);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
|
@ -162,6 +202,25 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task WriteSlowAsync(SerializedHubMessage message)
|
||||
{
|
||||
try
|
||||
{
|
||||
// Failed to get the lock immediately when entering WriteAsync so await until it is available
|
||||
await _writeLock.WaitAsync();
|
||||
|
||||
await WriteCore(message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
Log.FailedWritingMessage(_logger, ex);
|
||||
}
|
||||
finally
|
||||
{
|
||||
_writeLock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
private ValueTask TryWritePingAsync()
|
||||
{
|
||||
// Don't wait for the lock, if it returns false that means someone wrote to the connection
|
||||
|
|
|
|||
|
|
@ -7,21 +7,25 @@ using System.Linq;
|
|||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public class DefaultHubProtocolResolver : IHubProtocolResolver
|
||||
{
|
||||
private readonly ILogger<DefaultHubProtocolResolver> _logger;
|
||||
private readonly List<IHubProtocol> _hubProtocols;
|
||||
private readonly Dictionary<string, IHubProtocol> _availableProtocols;
|
||||
|
||||
public IReadOnlyList<IHubProtocol> AllProtocols => _hubProtocols;
|
||||
|
||||
public DefaultHubProtocolResolver(IEnumerable<IHubProtocol> availableProtocols, ILogger<DefaultHubProtocolResolver> logger)
|
||||
{
|
||||
_logger = logger ?? NullLogger<DefaultHubProtocolResolver>.Instance;
|
||||
_availableProtocols = new Dictionary<string, IHubProtocol>(StringComparer.OrdinalIgnoreCase);
|
||||
|
||||
foreach (var protocol in availableProtocols)
|
||||
// We might get duplicates in _hubProtocols, but we're going to check it and throw in just a sec.
|
||||
_hubProtocols = availableProtocols.ToList();
|
||||
foreach (var protocol in _hubProtocols)
|
||||
{
|
||||
if (_availableProtocols.ContainsKey(protocol.Name))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
|
@ -8,6 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
public interface IHubProtocolResolver
|
||||
{
|
||||
IReadOnlyList<IHubProtocol> AllProtocols { get; }
|
||||
IHubProtocol GetProtocol(string protocolName, IList<string> supportedProtocols);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
/// <summary>
|
||||
/// This class is designed to support the framework. The API is subject to breaking changes.
|
||||
/// Represents a serialization cache for a single message.
|
||||
/// </summary>
|
||||
public class SerializedHubMessage
|
||||
{
|
||||
private SerializedMessage _cachedItem1;
|
||||
private SerializedMessage _cachedItem2;
|
||||
private IList<SerializedMessage> _cachedItems;
|
||||
|
||||
public HubMessage Message { get; }
|
||||
|
||||
private SerializedHubMessage()
|
||||
{
|
||||
}
|
||||
|
||||
public SerializedHubMessage(HubMessage message)
|
||||
{
|
||||
Message = message;
|
||||
}
|
||||
|
||||
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
|
||||
{
|
||||
if (!TryGetCached(protocol.Name, out var serialized))
|
||||
{
|
||||
if (Message == null)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"This message was received from another server that did not have the requested protocol available.");
|
||||
}
|
||||
|
||||
serialized = protocol.WriteToArray(Message);
|
||||
SetCache(protocol.Name, serialized);
|
||||
}
|
||||
|
||||
return serialized;
|
||||
}
|
||||
|
||||
public static void WriteAllSerializedVersions(BinaryWriter writer, HubMessage message, IReadOnlyList<IHubProtocol> protocols)
|
||||
{
|
||||
// The serialization format is based on BinaryWriter
|
||||
// * 1 byte number of protocols
|
||||
// * For each protocol:
|
||||
// * Length-prefixed string using 7-bit variable length encoding (length depends on BinaryWriter's encoding)
|
||||
// * 4 byte length of the buffer
|
||||
// * N byte buffer
|
||||
|
||||
if (protocols.Count > byte.MaxValue)
|
||||
{
|
||||
throw new InvalidOperationException($"Can't serialize cache containing more than {byte.MaxValue} entries");
|
||||
}
|
||||
|
||||
writer.Write((byte)protocols.Count);
|
||||
foreach (var protocol in protocols)
|
||||
{
|
||||
writer.Write(protocol.Name);
|
||||
|
||||
var buffer = protocol.WriteToArray(message);
|
||||
writer.Write(buffer.Length);
|
||||
writer.Write(buffer);
|
||||
}
|
||||
}
|
||||
|
||||
public static SerializedHubMessage ReadAllSerializedVersions(BinaryReader reader)
|
||||
{
|
||||
var cache = new SerializedHubMessage();
|
||||
var count = reader.ReadByte();
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
var protocol = reader.ReadString();
|
||||
var length = reader.ReadInt32();
|
||||
var serialized = reader.ReadBytes(length);
|
||||
cache.SetCache(protocol, serialized);
|
||||
}
|
||||
|
||||
return cache;
|
||||
}
|
||||
|
||||
private void SetCache(string protocolName, byte[] serialized)
|
||||
{
|
||||
if (_cachedItem1.ProtocolName == null)
|
||||
{
|
||||
_cachedItem1 = new SerializedMessage(protocolName, serialized);
|
||||
}
|
||||
else if (_cachedItem2.ProtocolName == null)
|
||||
{
|
||||
_cachedItem2 = new SerializedMessage(protocolName, serialized);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (_cachedItems == null)
|
||||
{
|
||||
_cachedItems = new List<SerializedMessage>();
|
||||
}
|
||||
|
||||
foreach (var item in _cachedItems)
|
||||
{
|
||||
if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal))
|
||||
{
|
||||
// No need to add
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
_cachedItems.Add(new SerializedMessage(protocolName, serialized));
|
||||
}
|
||||
}
|
||||
|
||||
private bool TryGetCached(string protocolName, out byte[] result)
|
||||
{
|
||||
if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal))
|
||||
{
|
||||
result = _cachedItem1.Serialized;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal))
|
||||
{
|
||||
result = _cachedItem2.Serialized;
|
||||
return true;
|
||||
}
|
||||
|
||||
if (_cachedItems != null)
|
||||
{
|
||||
foreach (var serializedMessage in _cachedItems)
|
||||
{
|
||||
if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal))
|
||||
{
|
||||
result = serializedMessage.Serialized;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result = default;
|
||||
return false;
|
||||
}
|
||||
|
||||
private readonly struct SerializedMessage
|
||||
{
|
||||
public string ProtocolName { get; }
|
||||
public byte[] Serialized { get; }
|
||||
|
||||
public SerializedMessage(string protocolName, byte[] serialized)
|
||||
{
|
||||
ProtocolName = protocolName;
|
||||
Serialized = serialized;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
// The size of the enum is defined by the protocol. Do not change it. If you need more than 255 items,
|
||||
// add an additional enum.
|
||||
public enum GroupAction : byte
|
||||
{
|
||||
// These numbers are used by the protocol, do not change them and always use explicit assignment
|
||||
// when adding new items to this enum. 0 is intentionally omitted
|
||||
Add = 1,
|
||||
Remove = 2,
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,75 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Runtime.CompilerServices;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
internal class RedisChannels
|
||||
{
|
||||
private readonly string _prefix;
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the channel for sending to all connections.
|
||||
/// </summary>
|
||||
/// <remarks>
|
||||
/// The payload on this channel is <see cref="RedisInvocation"/> objects containing
|
||||
/// invocations to be sent to all connections
|
||||
/// </remarks>
|
||||
public string All { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the internal channel for group management messages.
|
||||
/// </summary>
|
||||
public string GroupManagement { get; }
|
||||
|
||||
public RedisChannels(string prefix)
|
||||
{
|
||||
_prefix = prefix;
|
||||
|
||||
All = prefix + ":all";
|
||||
GroupManagement = prefix + ":internal:groups";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the channel for sending a message to a specific connection.
|
||||
/// </summary>
|
||||
/// <param name="connectionId">The ID of the connection to get the channel for.</param>
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public string Connection(string connectionId)
|
||||
{
|
||||
return _prefix + ":connection:" + connectionId;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the channel for sending a message to a named group of connections.
|
||||
/// </summary>
|
||||
/// <param name="groupName">The name of the group to get the channel for.</param>
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public string Group(string groupName)
|
||||
{
|
||||
return _prefix + ":group:" + groupName;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the channel for sending a message to all collections associated with a user.
|
||||
/// </summary>
|
||||
/// <param name="userId">The ID of the user to get the channel for.</param>
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public string User(string userId)
|
||||
{
|
||||
return _prefix + ":user:" + userId;
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the acknowledgement channel for the specified server.
|
||||
/// </summary>
|
||||
/// <param name="serverName">The name of the server to get the acknowledgement channel for.</param>
|
||||
/// <returns></returns>
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public string Ack(string serverName)
|
||||
{
|
||||
return _prefix + ":internal:ack:" + serverName;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
public readonly struct RedisGroupCommand
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets the ID of the group command.
|
||||
/// </summary>
|
||||
public int Id { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the name of the server that sent the command.
|
||||
/// </summary>
|
||||
public string ServerName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the action to be performed on the group.
|
||||
/// </summary>
|
||||
public GroupAction Action { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the group on which the action is performed.
|
||||
/// </summary>
|
||||
public string GroupName { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the ID of the connection to be added or removed from the group.
|
||||
/// </summary>
|
||||
public string ConnectionId { get; }
|
||||
|
||||
public RedisGroupCommand(int id, string serverName, GroupAction action, string groupName, string connectionId)
|
||||
{
|
||||
Id = id;
|
||||
ServerName = serverName;
|
||||
Action = action;
|
||||
GroupName = groupName;
|
||||
ConnectionId = connectionId;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
using System.Collections.Generic;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
public readonly struct RedisInvocation
|
||||
{
|
||||
/// <summary>
|
||||
/// Gets a list of connections that should be excluded from this invocation.
|
||||
/// May be null to indicate that no connections are to be excluded.
|
||||
/// </summary>
|
||||
public IReadOnlyList<string> ExcludedIds { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets the message serialization cache containing serialized payloads for the message.
|
||||
/// </summary>
|
||||
public SerializedHubMessage Message { get; }
|
||||
|
||||
public RedisInvocation(SerializedHubMessage message, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
Message = message;
|
||||
ExcludedIds = excludedIds;
|
||||
}
|
||||
|
||||
public static RedisInvocation Create(string target, object[] arguments, IReadOnlyList<string> excludedIds = null)
|
||||
{
|
||||
return new RedisInvocation(
|
||||
new SerializedHubMessage(new InvocationMessage(target, argumentBindingException: null, arguments)),
|
||||
excludedIds);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Text;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
public class RedisProtocol
|
||||
{
|
||||
private readonly IReadOnlyList<IHubProtocol> _protocols;
|
||||
private static readonly Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
|
||||
|
||||
public RedisProtocol(IReadOnlyList<IHubProtocol> protocols)
|
||||
{
|
||||
_protocols = protocols;
|
||||
}
|
||||
|
||||
// The Redis Protocol:
|
||||
// * The message type is known in advance because messages are sent to different channels based on type
|
||||
// * Invocations are sent to the All, Group, Connection and User channels
|
||||
// * Group Commands are sent to the GroupManagement channel
|
||||
// * Acks are sent to the Acknowledgement channel.
|
||||
// * See the Write[type] methods for a description of the protocol for each in-depth.
|
||||
// * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter:
|
||||
// * https://docs.microsoft.com/en-us/dotnet/api/system.io.binarywriter.write?view=netstandard-2.0
|
||||
// * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter:
|
||||
// * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8.
|
||||
|
||||
public byte[] WriteInvocation(string methodName, object[] args) =>
|
||||
WriteInvocation(methodName, args, excludedIds: null);
|
||||
|
||||
public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
// Redis Invocation Format:
|
||||
// * Variable length integer: Number of excluded Ids
|
||||
// * For each excluded Id:
|
||||
// * Length prefixed string: ID
|
||||
// * SerializedHubMessage encoded by the format described by that type.
|
||||
|
||||
using (var stream = new MemoryStream())
|
||||
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
if (excludedIds != null)
|
||||
{
|
||||
writer.WriteVarInt(excludedIds.Count);
|
||||
foreach (var id in excludedIds)
|
||||
{
|
||||
writer.Write(id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
writer.WriteVarInt(0);
|
||||
}
|
||||
|
||||
SerializedHubMessage.WriteAllSerializedVersions(writer, new InvocationMessage(methodName, argumentBindingException: null, args), _protocols);
|
||||
return stream.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] WriteGroupCommand(RedisGroupCommand command)
|
||||
{
|
||||
// Group Command Format:
|
||||
// * Variable length integer: Id
|
||||
// * Length prefixed string: ServerName
|
||||
// * 1 byte: Action
|
||||
// * Length prefixed string: GroupName
|
||||
// * Length prefixed string: ConnectionId
|
||||
|
||||
using (var stream = new MemoryStream())
|
||||
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
writer.WriteVarInt(command.Id);
|
||||
writer.Write(command.ServerName);
|
||||
writer.Write((byte)command.Action);
|
||||
writer.Write(command.GroupName);
|
||||
writer.Write(command.ConnectionId);
|
||||
return stream.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
public byte[] WriteAck(int messageId)
|
||||
{
|
||||
// Acknowledgement Format:
|
||||
// * Variable length integer: Id
|
||||
|
||||
using (var stream = new MemoryStream())
|
||||
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
writer.WriteVarInt(messageId);
|
||||
return stream.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
public RedisInvocation ReadInvocation(byte[] data)
|
||||
{
|
||||
// See WriteInvocation for format.
|
||||
|
||||
using (var stream = new MemoryStream(data))
|
||||
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
IReadOnlyList<string> excludedIds = null;
|
||||
|
||||
var idCount = reader.ReadVarInt();
|
||||
if (idCount > 0)
|
||||
{
|
||||
var ids = new string[idCount];
|
||||
for (var i = 0; i < idCount; i++)
|
||||
{
|
||||
ids[i] = reader.ReadString();
|
||||
}
|
||||
|
||||
excludedIds = ids;
|
||||
}
|
||||
|
||||
var message = SerializedHubMessage.ReadAllSerializedVersions(reader);
|
||||
return new RedisInvocation(message, excludedIds);
|
||||
}
|
||||
}
|
||||
|
||||
public RedisGroupCommand ReadGroupCommand(byte[] data)
|
||||
{
|
||||
// See WriteGroupCommand for format.
|
||||
using (var stream = new MemoryStream(data))
|
||||
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
var id = reader.ReadVarInt();
|
||||
var serverName = reader.ReadString();
|
||||
var action = (GroupAction)reader.ReadByte();
|
||||
var groupName = reader.ReadString();
|
||||
var connectionId = reader.ReadString();
|
||||
|
||||
return new RedisGroupCommand(id, serverName, action, groupName, connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
public int ReadAck(byte[] data)
|
||||
{
|
||||
// See WriteAck for format
|
||||
using (var stream = new MemoryStream(data))
|
||||
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
|
||||
{
|
||||
return reader.ReadVarInt();
|
||||
}
|
||||
}
|
||||
|
||||
// Kinda cheaty way to get access to write the 7-bit varint format directly
|
||||
private class BinaryWriterWithVarInt : BinaryWriter
|
||||
{
|
||||
public BinaryWriterWithVarInt(Stream output, Encoding encoding) : base(output, encoding)
|
||||
{
|
||||
}
|
||||
|
||||
public void WriteVarInt(int value) => Write7BitEncodedInt(value);
|
||||
}
|
||||
|
||||
private class BinaryReaderWithVarInt : BinaryReader
|
||||
{
|
||||
public BinaryReaderWithVarInt(Stream input, Encoding encoding) : base(input, encoding)
|
||||
{
|
||||
}
|
||||
|
||||
public int ReadVarInt() => Read7BitEncodedInt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@ using System.Linq;
|
|||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Redis.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -28,8 +28,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
private readonly ISubscriber _bus;
|
||||
private readonly ILogger _logger;
|
||||
private readonly RedisOptions _options;
|
||||
private readonly string _channelNamePrefix = typeof(THub).FullName;
|
||||
private readonly string _serverName = Guid.NewGuid().ToString();
|
||||
private readonly RedisChannels _channels;
|
||||
private readonly string _serverName = GenerateServerName();
|
||||
private readonly RedisProtocol _protocol;
|
||||
|
||||
private readonly AckHandler _ackHandler;
|
||||
private int _internalId;
|
||||
|
||||
|
|
@ -41,14 +43,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
};
|
||||
|
||||
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
|
||||
IOptions<RedisOptions> options)
|
||||
IOptions<RedisOptions> options,
|
||||
IHubProtocolResolver hubProtocolResolver)
|
||||
{
|
||||
_logger = logger;
|
||||
_options = options.Value;
|
||||
_ackHandler = new AckHandler();
|
||||
_channels = new RedisChannels(typeof(THub).FullName);
|
||||
_protocol = new RedisProtocol(hubProtocolResolver.AllProtocols);
|
||||
|
||||
var writer = new LoggerTextWriter(logger);
|
||||
_logger.ConnectingToEndpoints(options.Value.Options.EndPoints);
|
||||
RedisLog.ConnectingToEndpoints(_logger, options.Value.Options.EndPoints, _serverName);
|
||||
_redisServerConnection = _options.Connect(writer);
|
||||
|
||||
_redisServerConnection.ConnectionRestored += (_, e) =>
|
||||
|
|
@ -60,7 +65,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return;
|
||||
}
|
||||
|
||||
_logger.ConnectionRestored();
|
||||
RedisLog.ConnectionRestored(_logger);
|
||||
};
|
||||
|
||||
_redisServerConnection.ConnectionFailed += (_, e) =>
|
||||
|
|
@ -72,23 +77,22 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return;
|
||||
}
|
||||
|
||||
_logger.ConnectionFailed(e.Exception);
|
||||
RedisLog.ConnectionFailed(_logger, e.Exception);
|
||||
};
|
||||
|
||||
if (_redisServerConnection.IsConnected)
|
||||
{
|
||||
_logger.Connected();
|
||||
RedisLog.Connected(_logger);
|
||||
}
|
||||
else
|
||||
{
|
||||
_logger.NotConnected();
|
||||
RedisLog.NotConnected(_logger);
|
||||
}
|
||||
_bus = _redisServerConnection.GetSubscriber();
|
||||
|
||||
SubscribeToHub();
|
||||
SubscribeToAllExcept();
|
||||
SubscribeToInternalGroup();
|
||||
SubscribeToInternalServerName();
|
||||
SubscribeToAll();
|
||||
SubscribeToGroupManagementChannel();
|
||||
SubscribeToAckChannel();
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(HubConnectionContext connection)
|
||||
|
|
@ -125,7 +129,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
foreach (var subscription in redisSubscriptions)
|
||||
{
|
||||
_logger.Unsubscribe(subscription);
|
||||
RedisLog.Unsubscribe(_logger, subscription);
|
||||
tasks.Add(_bus.UnsubscribeAsync(subscription));
|
||||
}
|
||||
}
|
||||
|
|
@ -149,15 +153,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
public override Task SendAllAsync(string methodName, object[] args)
|
||||
{
|
||||
var message = new RedisInvocationMessage(target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(_channelNamePrefix, message);
|
||||
var message = _protocol.WriteInvocation(methodName, args);
|
||||
return PublishAsync(_channels.All, message);
|
||||
}
|
||||
|
||||
public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
var message = new RedisInvocationMessage(target: methodName, excludedIds: excludedIds, arguments: args);
|
||||
return PublishAsync(_channelNamePrefix + ".AllExcept", message);
|
||||
var message = _protocol.WriteInvocation(methodName, args, excludedIds);
|
||||
return PublishAsync(_channels.All, message);
|
||||
}
|
||||
|
||||
public override Task SendConnectionAsync(string connectionId, string methodName, object[] args)
|
||||
|
|
@ -167,17 +170,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
throw new ArgumentNullException(nameof(connectionId));
|
||||
}
|
||||
|
||||
var message = new RedisInvocationMessage(target: methodName, arguments: args);
|
||||
|
||||
// If the connection is local we can skip sending the message through the bus since we require sticky connections.
|
||||
// This also saves serializing and deserializing the message!
|
||||
var connection = _connections[connectionId];
|
||||
if (connection != null)
|
||||
{
|
||||
return SafeWriteAsync(connection, message.CreateInvocation());
|
||||
return connection.WriteAsync(new InvocationMessage(methodName, argumentBindingException: null, args)).AsTask();
|
||||
}
|
||||
|
||||
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
|
||||
var message = _protocol.WriteInvocation(methodName, args);
|
||||
return PublishAsync(_channels.Connection(connectionId), message);
|
||||
}
|
||||
|
||||
public override Task SendGroupAsync(string groupName, string methodName, object[] args)
|
||||
|
|
@ -187,9 +189,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
throw new ArgumentNullException(nameof(groupName));
|
||||
}
|
||||
|
||||
var message = new RedisInvocationMessage(target: methodName, excludedIds: null, arguments: args);
|
||||
|
||||
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
|
||||
var message = _protocol.WriteInvocation(methodName, args);
|
||||
return PublishAsync(_channels.Group(groupName), message);
|
||||
}
|
||||
|
||||
public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
|
|
@ -199,31 +200,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
throw new ArgumentNullException(nameof(groupName));
|
||||
}
|
||||
|
||||
var message = new RedisInvocationMessage(methodName, excludedIds, args);
|
||||
|
||||
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
|
||||
var message = _protocol.WriteInvocation(methodName, args, excludedIds);
|
||||
return PublishAsync(_channels.Group(groupName), message);
|
||||
}
|
||||
|
||||
public override Task SendUserAsync(string userId, string methodName, object[] args)
|
||||
{
|
||||
var message = new RedisInvocationMessage(methodName, args);
|
||||
|
||||
return PublishAsync(_channelNamePrefix + ".user." + userId, message);
|
||||
}
|
||||
|
||||
private async Task PublishAsync(string channel, IRedisMessage message)
|
||||
{
|
||||
byte[] payload;
|
||||
using (var stream = new LimitArrayPoolWriteStream())
|
||||
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream)))
|
||||
{
|
||||
_serializer.Serialize(writer, message);
|
||||
writer.Flush();
|
||||
payload = stream.ToArray();
|
||||
}
|
||||
|
||||
_logger.PublishToChannel(channel);
|
||||
await _bus.PublishAsync(channel, payload);
|
||||
var message = _protocol.WriteInvocation(methodName, args);
|
||||
return PublishAsync(_channels.User(userId), message);
|
||||
}
|
||||
|
||||
public override async Task AddGroupAsync(string connectionId, string groupName)
|
||||
|
|
@ -249,6 +233,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add);
|
||||
}
|
||||
|
||||
public override async Task RemoveGroupAsync(string connectionId, string groupName)
|
||||
{
|
||||
if (connectionId == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(connectionId));
|
||||
}
|
||||
|
||||
if (groupName == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(groupName));
|
||||
}
|
||||
|
||||
var connection = _connections[connectionId];
|
||||
if (connection != null)
|
||||
{
|
||||
// short circuit if connection is on this server
|
||||
await RemoveGroupAsyncCore(connection, groupName);
|
||||
return;
|
||||
}
|
||||
|
||||
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove);
|
||||
}
|
||||
|
||||
public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args)
|
||||
{
|
||||
if (connectionIds == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(connectionIds));
|
||||
}
|
||||
|
||||
var publishTasks = new List<Task>(connectionIds.Count);
|
||||
var payload = _protocol.WriteInvocation(methodName, args);
|
||||
|
||||
foreach (var connectionId in connectionIds)
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channels.Connection(connectionId), payload));
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args)
|
||||
{
|
||||
if (groupNames == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(groupNames));
|
||||
}
|
||||
var publishTasks = new List<Task>(groupNames.Count);
|
||||
var payload = _protocol.WriteInvocation(methodName, args);
|
||||
|
||||
foreach (var groupName in groupNames)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(groupName))
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channels.Group(groupName), payload));
|
||||
}
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args)
|
||||
{
|
||||
if (userIds.Count > 0)
|
||||
{
|
||||
var payload = _protocol.WriteInvocation(methodName, args);
|
||||
var publishTasks = new List<Task>(userIds.Count);
|
||||
foreach (var userId in userIds)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(userId))
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channels.User(userId), payload));
|
||||
}
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
private Task PublishAsync(string channel, byte[] payload)
|
||||
{
|
||||
RedisLog.PublishToChannel(_logger, channel);
|
||||
return _bus.PublishAsync(channel, payload);
|
||||
}
|
||||
|
||||
private async Task AddGroupAsyncCore(HubConnectionContext connection, string groupName)
|
||||
{
|
||||
var feature = connection.Features.Get<IRedisFeature>();
|
||||
|
|
@ -263,7 +334,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
var groupChannel = _channelNamePrefix + ".group." + groupName;
|
||||
var groupChannel = _channels.Group(groupName);
|
||||
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
|
||||
|
||||
await group.Lock.WaitAsync();
|
||||
|
|
@ -285,37 +356,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
public override async Task RemoveGroupAsync(string connectionId, string groupName)
|
||||
{
|
||||
if (connectionId == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(connectionId));
|
||||
}
|
||||
|
||||
if (groupName == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(groupName));
|
||||
}
|
||||
|
||||
|
||||
var connection = _connections[connectionId];
|
||||
if (connection != null)
|
||||
{
|
||||
// short circuit if connection is on this server
|
||||
await RemoveGroupAsyncCore(connection, groupName);
|
||||
return;
|
||||
}
|
||||
|
||||
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// This takes <see cref="HubConnectionContext"/> because we want to remove the connection from the
|
||||
/// _connections list in OnDisconnectedAsync and still be able to remove groups with this method.
|
||||
/// </summary>
|
||||
private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string groupName)
|
||||
{
|
||||
var groupChannel = _channelNamePrefix + ".group." + groupName;
|
||||
var groupChannel = _channels.Group(groupName);
|
||||
|
||||
if (!_groups.TryGetValue(groupChannel, out var group))
|
||||
{
|
||||
|
|
@ -341,7 +388,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
if (group.Connections.Count == 0)
|
||||
{
|
||||
_logger.Unsubscribe(groupChannel);
|
||||
RedisLog.Unsubscribe(_logger, groupChannel);
|
||||
await _bus.UnsubscribeAsync(groupChannel);
|
||||
}
|
||||
}
|
||||
|
|
@ -350,8 +397,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
group.Lock.Release();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action)
|
||||
|
|
@ -359,14 +404,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var id = Interlocked.Increment(ref _internalId);
|
||||
var ack = _ackHandler.CreateAck(id);
|
||||
// Send Add/Remove Group to other servers and wait for an ack or timeout
|
||||
await PublishAsync(_channelNamePrefix + ".internal.group", new RedisGroupMessage
|
||||
{
|
||||
Action = action,
|
||||
ConnectionId = connectionId,
|
||||
Group = groupName,
|
||||
Id = id,
|
||||
Server = _serverName
|
||||
});
|
||||
var message = _protocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId));
|
||||
await PublishAsync(_channels.GroupManagement, message);
|
||||
|
||||
await ack;
|
||||
}
|
||||
|
|
@ -378,63 +417,24 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
_ackHandler.Dispose();
|
||||
}
|
||||
|
||||
private T DeserializeMessage<T>(RedisValue data)
|
||||
private void SubscribeToAll()
|
||||
{
|
||||
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data))))
|
||||
{
|
||||
return _serializer.Deserialize<T>(reader);
|
||||
}
|
||||
}
|
||||
|
||||
private void SubscribeToHub()
|
||||
{
|
||||
_logger.Subscribing(_channelNamePrefix);
|
||||
_bus.Subscribe(_channelNamePrefix, async (c, data) =>
|
||||
RedisLog.Subscribing(_logger, _channels.All);
|
||||
_bus.Subscribe(_channels.All, async (c, data) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
_logger.ReceivedFromChannel(_channelNamePrefix);
|
||||
RedisLog.ReceivedFromChannel(_logger, _channels.All);
|
||||
|
||||
var message = DeserializeMessage<RedisInvocationMessage>(data);
|
||||
var invocation = _protocol.ReadInvocation(data);
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
var invocation = message.CreateInvocation();
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
tasks.Add(SafeWriteAsync(connection, invocation));
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void SubscribeToAllExcept()
|
||||
{
|
||||
var channelName = _channelNamePrefix + ".AllExcept";
|
||||
_logger.Subscribing(channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
_logger.ReceivedFromChannel(channelName);
|
||||
|
||||
var message = DeserializeMessage<RedisInvocationMessage>(data);
|
||||
var excludedIds = message.ExcludedIds ?? Array.Empty<string>();
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
var invocation = message.CreateInvocation();
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
if (!excludedIds.Contains(connection.ConnectionId))
|
||||
if (invocation.ExcludedIds == null || !invocation.ExcludedIds.Contains(connection.ConnectionId))
|
||||
{
|
||||
tasks.Add(SafeWriteAsync(connection, invocation));
|
||||
tasks.Add(connection.WriteAsync(invocation.Message).AsTask());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -442,19 +442,18 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
RedisLog.FailedWritingMessage(_logger, ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void SubscribeToInternalGroup()
|
||||
private void SubscribeToGroupManagementChannel()
|
||||
{
|
||||
var channelName = _channelNamePrefix + ".internal.group";
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
_bus.Subscribe(_channels.GroupManagement, async (c, data) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var groupMessage = DeserializeMessage<RedisGroupMessage>(data);
|
||||
var groupMessage = _protocol.ReadGroupCommand(data);
|
||||
|
||||
var connection = _connections[groupMessage.ConnectionId];
|
||||
if (connection == null)
|
||||
|
|
@ -465,179 +464,95 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
if (groupMessage.Action == GroupAction.Remove)
|
||||
{
|
||||
await RemoveGroupAsyncCore(connection, groupMessage.Group);
|
||||
await RemoveGroupAsyncCore(connection, groupMessage.GroupName);
|
||||
}
|
||||
|
||||
if (groupMessage.Action == GroupAction.Add)
|
||||
{
|
||||
await AddGroupAsyncCore(connection, groupMessage.Group);
|
||||
await AddGroupAsyncCore(connection, groupMessage.GroupName);
|
||||
}
|
||||
|
||||
// Sending ack to server that sent the original add/remove
|
||||
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new RedisGroupMessage
|
||||
{
|
||||
Action = GroupAction.Ack,
|
||||
Id = groupMessage.Id
|
||||
});
|
||||
// Send an ack to the server that sent the original command.
|
||||
await PublishAsync(_channels.Ack(groupMessage.ServerName), _protocol.WriteAck(groupMessage.Id));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.InternalMessageFailed(ex);
|
||||
RedisLog.InternalMessageFailed(_logger, ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void SubscribeToInternalServerName()
|
||||
private void SubscribeToAckChannel()
|
||||
{
|
||||
// Create server specific channel in order to send an ack to a single server
|
||||
var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}";
|
||||
_bus.Subscribe(serverChannel, (c, data) =>
|
||||
_bus.Subscribe(_channels.Ack(_serverName), (c, data) =>
|
||||
{
|
||||
var groupMessage = DeserializeMessage<RedisGroupMessage>(data);
|
||||
var ackId = _protocol.ReadAck(data);
|
||||
|
||||
if (groupMessage.Action == GroupAction.Ack)
|
||||
{
|
||||
_ackHandler.TriggerAck(groupMessage.Id);
|
||||
}
|
||||
_ackHandler.TriggerAck(ackId);
|
||||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToConnection(HubConnectionContext connection, HashSet<string> redisSubscriptions)
|
||||
{
|
||||
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
|
||||
var connectionChannel = _channels.Connection(connection.ConnectionId);
|
||||
redisSubscriptions.Add(connectionChannel);
|
||||
|
||||
_logger.Subscribing(connectionChannel);
|
||||
RedisLog.Subscribing(_logger, connectionChannel);
|
||||
return _bus.SubscribeAsync(connectionChannel, async (c, data) =>
|
||||
{
|
||||
var message = DeserializeMessage<RedisInvocationMessage>(data);
|
||||
|
||||
await SafeWriteAsync(connection, message.CreateInvocation());
|
||||
var invocation = _protocol.ReadInvocation(data);
|
||||
await connection.WriteAsync(invocation.Message);
|
||||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToUser(HubConnectionContext connection, HashSet<string> redisSubscriptions)
|
||||
{
|
||||
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
|
||||
var userChannel = _channels.User(connection.UserIdentifier);
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
// TODO: Look at optimizing (looping over connections checking for Name)
|
||||
return _bus.SubscribeAsync(userChannel, async (c, data) =>
|
||||
{
|
||||
var message = DeserializeMessage<RedisInvocationMessage>(data);
|
||||
|
||||
await SafeWriteAsync(connection, message.CreateInvocation());
|
||||
var invocation = _protocol.ReadInvocation(data);
|
||||
await connection.WriteAsync(invocation.Message);
|
||||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToGroup(string groupChannel, GroupData group)
|
||||
{
|
||||
_logger.Subscribing(groupChannel);
|
||||
RedisLog.Subscribing(_logger, groupChannel);
|
||||
return _bus.SubscribeAsync(groupChannel, async (c, data) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var message = DeserializeMessage<RedisInvocationMessage>(data);
|
||||
var invocation = _protocol.ReadInvocation(data);
|
||||
|
||||
var tasks = new List<Task>();
|
||||
var invocation = message.CreateInvocation();
|
||||
foreach (var groupConnection in group.Connections)
|
||||
{
|
||||
if (message.ExcludedIds?.Contains(groupConnection.ConnectionId) == true)
|
||||
if (invocation.ExcludedIds?.Contains(groupConnection.ConnectionId) == true)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
tasks.Add(SafeWriteAsync(groupConnection, invocation));
|
||||
tasks.Add(groupConnection.WriteAsync(invocation.Message).AsTask());
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
RedisLog.FailedWritingMessage(_logger, ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args)
|
||||
private static string GenerateServerName()
|
||||
{
|
||||
if (connectionIds == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(connectionIds));
|
||||
}
|
||||
var publishTasks = new List<Task>(connectionIds.Count);
|
||||
var message = new RedisInvocationMessage(target: methodName, arguments: args);
|
||||
|
||||
foreach (var connectionId in connectionIds)
|
||||
{
|
||||
var connection = _connections[connectionId];
|
||||
// If the connection is local we can skip sending the message through the bus since we require sticky connections.
|
||||
// This also saves serializing and deserializing the message!
|
||||
if (connection != null)
|
||||
{
|
||||
publishTasks.Add(SafeWriteAsync(connection, message.CreateInvocation()));
|
||||
}
|
||||
else
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channelNamePrefix + "." + connectionId, message));
|
||||
}
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args)
|
||||
{
|
||||
if (groupNames == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(groupNames));
|
||||
}
|
||||
var publishTasks = new List<Task>(groupNames.Count);
|
||||
var message = new RedisInvocationMessage(target: methodName, arguments: args);
|
||||
|
||||
foreach (var groupName in groupNames)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(groupName))
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channelNamePrefix + "." + groupName, message));
|
||||
}
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args)
|
||||
{
|
||||
if (userIds.Count > 0)
|
||||
{
|
||||
var message = new RedisInvocationMessage(methodName, args);
|
||||
var publishTasks = new List<Task>(userIds.Count);
|
||||
foreach (var userId in userIds)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(userId))
|
||||
{
|
||||
publishTasks.Add(PublishAsync(_channelNamePrefix + ".user." + userId, message));
|
||||
}
|
||||
}
|
||||
|
||||
return Task.WhenAll(publishTasks);
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
// This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to
|
||||
private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message)
|
||||
{
|
||||
try
|
||||
{
|
||||
await connection.WriteAsync(message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
// Use the machine name for convenient diagnostics, but add a guid to make it unique.
|
||||
// Example: MyServerName_02db60e5fab243b890a847fa5c4dcb29
|
||||
return $"{Environment.MachineName}_{Guid.NewGuid():N}";
|
||||
}
|
||||
|
||||
private class LoggerTextWriter : TextWriter
|
||||
|
|
@ -658,7 +573,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
public override void WriteLine(string value)
|
||||
{
|
||||
_logger.LogDebug(value);
|
||||
RedisLog.ConnectionMultiplexerMessage(_logger, value);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -679,53 +594,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
|
||||
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
|
||||
}
|
||||
|
||||
private enum GroupAction
|
||||
{
|
||||
Remove,
|
||||
Add,
|
||||
Ack
|
||||
}
|
||||
|
||||
// Marker interface to represent the messages that can be sent over Redis.
|
||||
private interface IRedisMessage { }
|
||||
|
||||
private class RedisGroupMessage : IRedisMessage
|
||||
{
|
||||
public string ConnectionId { get; set; }
|
||||
public string Group { get; set; }
|
||||
public int Id { get; set; }
|
||||
public GroupAction Action { get; set; }
|
||||
public string Server { get; set; }
|
||||
}
|
||||
|
||||
// Represents a message published to the Redis bus
|
||||
private class RedisInvocationMessage : IRedisMessage
|
||||
{
|
||||
public string Target { get; set; }
|
||||
public IReadOnlyList<string> ExcludedIds { get; set; }
|
||||
public object[] Arguments { get; set; }
|
||||
|
||||
public RedisInvocationMessage()
|
||||
{
|
||||
}
|
||||
|
||||
public RedisInvocationMessage(string target, object[] arguments)
|
||||
: this(target, excludedIds: null, arguments: arguments)
|
||||
{
|
||||
}
|
||||
|
||||
public RedisInvocationMessage(string target, IReadOnlyList<string> excludedIds, object[] arguments)
|
||||
{
|
||||
Target = target;
|
||||
ExcludedIds = excludedIds;
|
||||
Arguments = arguments;
|
||||
}
|
||||
|
||||
public InvocationMessage CreateInvocation()
|
||||
{
|
||||
return new InvocationMessage(Target, argumentBindingException: null, arguments: Arguments);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,13 +6,14 @@ using System.Linq;
|
|||
using Microsoft.Extensions.Logging;
|
||||
using StackExchange.Redis;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis
|
||||
{
|
||||
internal static class RedisLoggerExtensions
|
||||
// We don't want to use our nested static class here because RedisHubLifetimeManager is generic.
|
||||
// We'd end up creating separate instances of all the LoggerMessage.Define values for each Hub.
|
||||
internal static class RedisLog
|
||||
{
|
||||
// Category: RedisHubLifetimeManager<THub>
|
||||
private static readonly Action<ILogger, string, Exception> _connectingToEndpoints =
|
||||
LoggerMessage.Define<string>(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}.");
|
||||
private static readonly Action<ILogger, string, string, Exception> _connectingToEndpoints =
|
||||
LoggerMessage.Define<string, string>(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}. Using Server Name: {ServerName}");
|
||||
|
||||
private static readonly Action<ILogger, Exception> _connected =
|
||||
LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis.");
|
||||
|
|
@ -44,65 +45,75 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
|||
private static readonly Action<ILogger, Exception> _internalMessageFailed =
|
||||
LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message.");
|
||||
|
||||
public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints)
|
||||
public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endpoints, string serverName)
|
||||
{
|
||||
if (logger.IsEnabled(LogLevel.Information))
|
||||
{
|
||||
if (endpoints.Count > 0)
|
||||
{
|
||||
_connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), null);
|
||||
_connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), serverName, null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static void Connected(this ILogger logger)
|
||||
public static void Connected(ILogger logger)
|
||||
{
|
||||
_connected(logger, null);
|
||||
}
|
||||
|
||||
public static void Subscribing(this ILogger logger, string channelName)
|
||||
public static void Subscribing(ILogger logger, string channelName)
|
||||
{
|
||||
_subscribing(logger, channelName, null);
|
||||
}
|
||||
|
||||
public static void ReceivedFromChannel(this ILogger logger, string channelName)
|
||||
public static void ReceivedFromChannel(ILogger logger, string channelName)
|
||||
{
|
||||
_receivedFromChannel(logger, channelName, null);
|
||||
}
|
||||
|
||||
public static void PublishToChannel(this ILogger logger, string channelName)
|
||||
public static void PublishToChannel(ILogger logger, string channelName)
|
||||
{
|
||||
_publishToChannel(logger, channelName, null);
|
||||
}
|
||||
|
||||
public static void Unsubscribe(this ILogger logger, string channelName)
|
||||
public static void Unsubscribe(ILogger logger, string channelName)
|
||||
{
|
||||
_unsubscribe(logger, channelName, null);
|
||||
}
|
||||
|
||||
public static void NotConnected(this ILogger logger)
|
||||
public static void NotConnected(ILogger logger)
|
||||
{
|
||||
_notConnected(logger, null);
|
||||
}
|
||||
|
||||
public static void ConnectionRestored(this ILogger logger)
|
||||
public static void ConnectionRestored(ILogger logger)
|
||||
{
|
||||
_connectionRestored(logger, null);
|
||||
}
|
||||
|
||||
public static void ConnectionFailed(this ILogger logger, Exception exception)
|
||||
public static void ConnectionFailed(ILogger logger, Exception exception)
|
||||
{
|
||||
_connectionFailed(logger, exception);
|
||||
}
|
||||
|
||||
public static void FailedWritingMessage(this ILogger logger, Exception exception)
|
||||
public static void FailedWritingMessage(ILogger logger, Exception exception)
|
||||
{
|
||||
_failedWritingMessage(logger, exception);
|
||||
}
|
||||
|
||||
public static void InternalMessageFailed(this ILogger logger, Exception exception)
|
||||
public static void InternalMessageFailed(ILogger logger, Exception exception)
|
||||
{
|
||||
_internalMessageFailed(logger, exception);
|
||||
}
|
||||
|
||||
// This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer
|
||||
public static void ConnectionMultiplexerMessage(ILogger logger, string message)
|
||||
{
|
||||
if (logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
// We tag it with EventId 100 though so it can be pulled out of logs easily.
|
||||
logger.LogDebug(new EventId(100, "RedisConnectionLog"), message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -5,11 +5,15 @@ using System;
|
|||
using System.Collections.Generic;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using MsgPack.Serialization;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using Newtonsoft.Json.Serialization;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||
|
|
@ -19,14 +23,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeAllAsyncWritesToAllConnectionsOutput()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager = CreateLifetimeManager(server);
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
|
||||
|
|
@ -40,17 +42,44 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAllExceptAsyncExcludesSpecifiedConnections()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
using (var client3 = new TestClient())
|
||||
{
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
var manager3 = CreateLifetimeManager(server);
|
||||
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
var connection3 = HubConnectionContextUtils.Create(client3.Connection);
|
||||
|
||||
await manager1.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager2.OnConnectedAsync(connection2).OrTimeout();
|
||||
await manager3.OnConnectedAsync(connection3).OrTimeout();
|
||||
|
||||
await manager1.SendAllExceptAsync("Hello", new object[] { "World" }, new [] { client3.Connection.ConnectionId }).OrTimeout();
|
||||
|
||||
await AssertMessageAsync(client1);
|
||||
await AssertMessageAsync(client2);
|
||||
Assert.Null(client3.TryRead());
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager = CreateLifetimeManager(server);
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
|
||||
|
|
@ -70,14 +99,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager = CreateLifetimeManager(server);
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
|
||||
|
|
@ -96,14 +123,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeGroupExceptAsyncWritesToAllValidConnectionsInGroupOutput()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager = CreateLifetimeManager(server);
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
|
||||
|
|
@ -124,13 +149,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeConnectionAsyncWritesToConnectionOutput()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager = CreateLifetimeManager(server);
|
||||
var connection = HubConnectionContextUtils.Create(client.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection).OrTimeout();
|
||||
|
|
@ -144,27 +167,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeConnectionAsyncOnNonExistentConnectionDoesNotThrow()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
await manager.SendConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeAllAsyncWithMultipleServersWritesToAllConnectionsOutput()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
|
|
@ -185,16 +200,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeAllAsyncWithMultipleServersDoesNotWriteToDisconnectedConnectionsOutput()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
|
|
@ -218,16 +227,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeConnectionAsyncOnServerWithoutConnectionWritesOutputToConnection()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -244,16 +247,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeGroupAsyncOnServerWithoutConnectionWritesOutputToGroupConnection()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -272,11 +269,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task DisconnectConnectionRemovesConnectionFromGroup()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -297,11 +292,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task RemoveGroupFromLocalConnectionNotInGroupDoesNothing()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -316,16 +309,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task RemoveGroupFromConnectionOnDifferentServerNotInGroupDoesNothing()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
|
||||
Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -340,14 +327,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task AddGroupAsyncForConnectionOnDifferentServerWorks()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -366,10 +349,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task AddGroupAsyncForLocalConnectionAlreadyInGroupDoesNothing()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -382,7 +364,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
|
||||
await manager.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout();
|
||||
|
||||
|
||||
await AssertMessageAsync(client);
|
||||
Assert.Null(client.TryRead());
|
||||
}
|
||||
|
|
@ -391,14 +372,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task AddGroupAsyncForConnectionOnDifferentServerAlreadyInGroupDoesNothing()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -419,14 +396,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task RemoveGroupAsyncForConnectionOnDifferentServerWorks()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -451,14 +424,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task InvokeConnectionAsyncForLocalConnectionDoesNotPublishToRedis()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -478,14 +447,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
[Fact]
|
||||
public async Task WritingToRemoteConnectionThatFailsDoesNotThrow()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager1 = CreateLifetimeManager(server);
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
|
|
@ -502,34 +467,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritingToLocalConnectionThatFailsDoesNotThrowException()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
// Force an exception when writing to connection
|
||||
var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection);
|
||||
connectionMock.Setup(m => m.WriteAsync(It.IsAny<HubMessage>())).Throws(new Exception("Message"));
|
||||
var connection = connectionMock.Object;
|
||||
|
||||
await manager.OnConnectedAsync(connection).OrTimeout();
|
||||
|
||||
await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
|
|
@ -557,6 +500,72 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var messagePackOptions = new MessagePackHubProtocolOptions();
|
||||
messagePackOptions.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel;
|
||||
|
||||
var jsonOptions = new JsonHubProtocolOptions();
|
||||
jsonOptions.PayloadSerializerSettings.ContractResolver = new CamelCasePropertyNamesContractResolver();
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
// The sending manager has serializer settings
|
||||
var manager1 = CreateLifetimeManager(server, messagePackOptions, jsonOptions);
|
||||
|
||||
// The receiving one doesn't matter because of how we serialize!
|
||||
var manager2 = CreateLifetimeManager(server);
|
||||
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
|
||||
|
||||
await manager1.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager2.OnConnectedAsync(connection2).OrTimeout();
|
||||
|
||||
await manager1.SendAllAsync("Hello", new object[] { new TestObject { TestProperty = "Foo" } });
|
||||
|
||||
var message = Assert.IsType<InvocationMessage>(await client2.ReadAsync().OrTimeout());
|
||||
Assert.Equal("Hello", message.Target);
|
||||
Assert.Collection(
|
||||
message.Arguments,
|
||||
arg0 =>
|
||||
{
|
||||
var dict = Assert.IsType<JObject>(arg0);
|
||||
Assert.Collection(dict.Properties(),
|
||||
prop =>
|
||||
{
|
||||
Assert.Equal("testProperty", prop.Name);
|
||||
Assert.Equal("Foo", prop.Value.Value<string>());
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public class TestObject
|
||||
{
|
||||
public string TestProperty { get; set; }
|
||||
}
|
||||
|
||||
private RedisHubLifetimeManager<MyHub> CreateLifetimeManager(TestRedisServer server, MessagePackHubProtocolOptions messagePackOptions = null, JsonHubProtocolOptions jsonOptions = null)
|
||||
{
|
||||
var options = new RedisOptions() { Factory = t => new TestConnectionMultiplexer(server) };
|
||||
messagePackOptions = messagePackOptions ?? new MessagePackHubProtocolOptions();
|
||||
jsonOptions = jsonOptions ?? new JsonHubProtocolOptions();
|
||||
|
||||
return new RedisHubLifetimeManager<MyHub>(
|
||||
NullLogger<RedisHubLifetimeManager<MyHub>>.Instance,
|
||||
Options.Create(options),
|
||||
new DefaultHubProtocolResolver(new IHubProtocol[]
|
||||
{
|
||||
new JsonHubProtocol(Options.Create(jsonOptions)),
|
||||
new MessagePackHubProtocol(Options.Create(messagePackOptions)),
|
||||
}, NullLogger<DefaultHubProtocolResolver>.Instance));
|
||||
}
|
||||
|
||||
private async Task AssertMessageAsync(TestClient client)
|
||||
{
|
||||
var message = Assert.IsType<InvocationMessage>(await client.ReadAsync().OrTimeout());
|
||||
|
|
|
|||
|
|
@ -11,11 +11,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
public static class HubConnectionContextUtils
|
||||
{
|
||||
public static HubConnectionContext Create(ConnectionContext connection)
|
||||
public static HubConnectionContext Create(ConnectionContext connection, IHubProtocol protocol = null, string userIdentifier = null)
|
||||
{
|
||||
return new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance)
|
||||
{
|
||||
Protocol = new JsonHubProtocol()
|
||||
Protocol = protocol ?? new JsonHubProtocol(),
|
||||
UserIdentifier = userIdentifier,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@
|
|||
<PackageReference Include="Microsoft.Extensions.ValueStopwatch.Sources" Version="$(MicrosoftExtensionsValueStopwatchSourcesPackageVersion)" PrivateAssets="All" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Hosting" Version="$(MicrosoftAspNetCoreHostingPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(MicrosoftAspNetCoreServerKestrelPackageVersion)" />
|
||||
<PackageReference Include="StackExchange.Redis.StrongName" Version="$(StackExchangeRedisStrongNamePackageVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
|
|
@ -9,7 +9,7 @@ using System.Net;
|
|||
using System.Threading.Tasks;
|
||||
using StackExchange.Redis;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public class TestConnectionMultiplexer : IConnectionMultiplexer
|
||||
{
|
||||
|
|
@ -70,7 +70,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
remove { }
|
||||
}
|
||||
|
||||
private ISubscriber _subscriber = new TestSubscriber();
|
||||
private ISubscriber _subscriber;
|
||||
|
||||
public TestConnectionMultiplexer(TestRedisServer server)
|
||||
{
|
||||
_subscriber = new TestSubscriber(server);
|
||||
}
|
||||
|
||||
public void BeginProfiling(object forContext)
|
||||
{
|
||||
|
|
@ -203,19 +208,52 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
public class TestSubscriber : ISubscriber
|
||||
public class TestRedisServer
|
||||
{
|
||||
// _globalSubscriptions represents the Redis Server you are connected to.
|
||||
// So when publishing from a TestSubscriber you fake sending through the server by grabbing the callbacks
|
||||
// from the _globalSubscriptions and inoking them inplace.
|
||||
private static ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>> _globalSubscriptions =
|
||||
private ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>> _subscriptions =
|
||||
new ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>>();
|
||||
|
||||
private ConcurrentDictionary<RedisChannel, Action<RedisChannel, RedisValue>> _subscriptions =
|
||||
new ConcurrentDictionary<RedisChannel, Action<RedisChannel, RedisValue>>();
|
||||
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
if (_subscriptions.TryGetValue(channel, out var handlers))
|
||||
{
|
||||
foreach (var handler in handlers)
|
||||
{
|
||||
handler(channel, message);
|
||||
}
|
||||
}
|
||||
|
||||
return handlers != null ? handlers.Count : 0;
|
||||
}
|
||||
|
||||
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
_subscriptions.AddOrUpdate(channel, _ => new List<Action<RedisChannel, RedisValue>> { handler }, (_, list) =>
|
||||
{
|
||||
list.Add(handler);
|
||||
return list;
|
||||
});
|
||||
}
|
||||
|
||||
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
if (_subscriptions.TryGetValue(channel, out var list))
|
||||
{
|
||||
list.Remove(handler);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class TestSubscriber : ISubscriber
|
||||
{
|
||||
private readonly TestRedisServer _server;
|
||||
public ConnectionMultiplexer Multiplexer => throw new NotImplementedException();
|
||||
|
||||
public TestSubscriber(TestRedisServer server)
|
||||
{
|
||||
_server = server;
|
||||
}
|
||||
|
||||
public EndPoint IdentifyEndpoint(RedisChannel channel, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
|
|
@ -243,15 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
|
||||
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
if (_globalSubscriptions.TryGetValue(channel, out var handlers))
|
||||
{
|
||||
foreach (var handler in handlers)
|
||||
{
|
||||
handler(channel, message);
|
||||
}
|
||||
}
|
||||
|
||||
return handlers != null ? handlers.Count : 0;
|
||||
return _server.Publish(channel, message, flags);
|
||||
}
|
||||
|
||||
public async Task<long> PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
|
||||
|
|
@ -262,12 +292,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
|
||||
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
_globalSubscriptions.AddOrUpdate(channel, _ => new List<Action<RedisChannel, RedisValue>> { handler }, (_, list) =>
|
||||
{
|
||||
list.Add(handler);
|
||||
return list;
|
||||
});
|
||||
_subscriptions.AddOrUpdate(channel, handler, (_, __) => handler);
|
||||
_server.Subscribe(channel, handler, flags);
|
||||
}
|
||||
|
||||
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
|
||||
|
|
@ -288,11 +313,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
|
||||
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
|
||||
{
|
||||
_subscriptions.TryRemove(channel, out var handle);
|
||||
if (_globalSubscriptions.TryGetValue(channel, out var list))
|
||||
{
|
||||
list.Remove(handle);
|
||||
}
|
||||
_server.Unsubscribe(channel, handler, flags);
|
||||
}
|
||||
|
||||
public void UnsubscribeAll(CommandFlags flags = CommandFlags.None)
|
||||
Loading…
Reference in New Issue