Preserialize for all formats when sending through Redis (#1843)

This commit is contained in:
Andrew Stanton-Nurse 2018-04-05 13:48:14 -07:00 committed by GitHub
parent 39f693b9ed
commit 19b2fea0d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1200 additions and 602 deletions

View File

@ -49,7 +49,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
_successHubProtocolResolver = new TestHubProtocolResolver(new JsonHubProtocol());
_failureHubProtocolResolver = new TestHubProtocolResolver(null);
_userIdProvider = new TestUserIdProvider();
_supportedProtocols = new List<string> {"json"};
_supportedProtocols = new List<string> { "json" };
}
[Benchmark]
@ -83,8 +83,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
private readonly IHubProtocol _instance;
public IReadOnlyList<IHubProtocol> AllProtocols { get; }
public TestHubProtocolResolver(IHubProtocol instance)
{
AllProtocols = new[] { instance };
_instance = instance;
}

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
@ -12,6 +12,8 @@
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Client.Core\Microsoft.AspNetCore.SignalR.Client.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Protocols.MsgPack\Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Redis\Microsoft.AspNetCore.SignalR.Redis.csproj" />
<ProjectReference Include="..\..\test\Microsoft.AspNetCore.SignalR.Tests.Utils\Microsoft.AspNetCore.SignalR.Tests.Utils.csproj" />
</ItemGroup>
<ItemGroup>

View File

@ -0,0 +1,203 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Redis;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
public class RedisHubLifetimeManagerBenchmark
{
private RedisHubLifetimeManager<TestHub> _manager1;
private RedisHubLifetimeManager<TestHub> _manager2;
private TestClient[] _clients;
private object[] _args;
private List<string> _excludedIds = new List<string>();
private List<string> _sendIds = new List<string>();
private List<string> _groups = new List<string>();
private List<string> _users = new List<string>();
private const int ClientCount = 20;
[Params(2, 20)]
public int ProtocolCount { get; set; }
[GlobalSetup]
public void GlobalSetup()
{
var server = new TestRedisServer();
var logger = NullLogger<RedisHubLifetimeManager<TestHub>>.Instance;
var protocols = GenerateProtocols(ProtocolCount).ToArray();
var options = Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer(server)
});
var resolver = new DefaultHubProtocolResolver(protocols, NullLogger<DefaultHubProtocolResolver>.Instance);
_manager1 = new RedisHubLifetimeManager<TestHub>(logger, options, resolver);
_manager2 = new RedisHubLifetimeManager<TestHub>(logger, options, resolver);
async Task ConnectClient(TestClient client, IHubProtocol protocol, string userId, string group)
{
await _manager2.OnConnectedAsync(HubConnectionContextUtils.Create(client.Connection, protocol, userId));
await _manager2.AddGroupAsync(client.Connection.ConnectionId, "Everyone");
await _manager2.AddGroupAsync(client.Connection.ConnectionId, group);
}
// Connect clients
_clients = new TestClient[ClientCount];
var tasks = new Task[ClientCount];
for (var i = 0; i < _clients.Length; i++)
{
var protocol = protocols[i % ProtocolCount];
_clients[i] = new TestClient(protocol: protocol);
string group;
string user;
if ((i % 2) == 0)
{
group = "Evens";
user = "EvenUser";
_excludedIds.Add(_clients[i].Connection.ConnectionId);
}
else
{
group = "Odds";
user = "OddUser";
_sendIds.Add(_clients[i].Connection.ConnectionId);
}
tasks[i] = ConnectClient(_clients[i], protocol, user, group);
_ = ConsumeAsync(_clients[i]);
}
Task.WaitAll(tasks);
_groups.Add("Evens");
_groups.Add("Odds");
_users.Add("EvenUser");
_users.Add("OddUser");
_args = new object[] {"Foo"};
}
private IEnumerable<IHubProtocol> GenerateProtocols(int protocolCount)
{
for (var i = 0; i < protocolCount; i++)
{
yield return ((i % 2) == 0)
? new WrappedHubProtocol($"json_{i}", new JsonHubProtocol())
: new WrappedHubProtocol($"msgpack_{i}", new MessagePackHubProtocol());
}
}
private async Task ConsumeAsync(TestClient testClient)
{
while (await testClient.ReadAsync() != null)
{
// Just dump the message
}
}
[Benchmark]
public async Task SendAll()
{
await _manager1.SendAllAsync("Test", _args);
}
[Benchmark]
public async Task SendGroup()
{
await _manager1.SendGroupAsync("Everyone", "Test", _args);
}
[Benchmark]
public async Task SendUser()
{
await _manager1.SendUserAsync("EvenUser", "Test", _args);
}
[Benchmark]
public async Task SendConnection()
{
await _manager1.SendConnectionAsync(_clients[0].Connection.ConnectionId, "Test", _args);
}
[Benchmark]
public async Task SendConnections()
{
await _manager1.SendConnectionsAsync(_sendIds, "Test", _args);
}
[Benchmark]
public async Task SendAllExcept()
{
await _manager1.SendAllExceptAsync("Test", _args, _excludedIds);
}
[Benchmark]
public async Task SendGroupExcept()
{
await _manager1.SendGroupExceptAsync("Everyone", "Test", _args, _excludedIds);
}
[Benchmark]
public async Task SendGroups()
{
await _manager1.SendGroupsAsync(_groups, "Test", _args);
}
[Benchmark]
public async Task SendUsers()
{
await _manager1.SendUsersAsync(_users, "Test", _args);
}
public class TestHub : Hub
{
}
private class WrappedHubProtocol : IHubProtocol
{
private readonly string _name;
private readonly IHubProtocol _innerProtocol;
public string Name => _name;
public int Version => _innerProtocol.Version;
public TransferFormat TransferFormat => _innerProtocol.TransferFormat;
public WrappedHubProtocol(string name, IHubProtocol innerProtocol)
{
_name = name;
_innerProtocol = innerProtocol;
}
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
return _innerProtocol.TryParseMessage(ref input, binder, out message);
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
_innerProtocol.WriteMessage(message, output);
}
public bool IsVersionSupported(int version)
{
return _innerProtocol.IsVersionSupported(version);
}
}
}
}

View File

@ -28,7 +28,7 @@ namespace SignalRSamples
{
options.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel;
});
// .AddRedis();
//.AddRedis();
services.AddCors(o =>
{

View File

@ -1,87 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public abstract class HubMessage
{
protected HubMessage()
{
}
private object _lock = new object();
private List<SerializedMessage> _serializedMessages;
private SerializedMessage _message1;
private SerializedMessage _message2;
public byte[] WriteMessage(IHubProtocol protocol)
{
// REVIEW: Revisit lock
// Could use a reader/writer lock to allow the loop to take place in "unlocked" code
// Or, could use a fixed size array and Interlocked to manage it.
// Or, Immutable *ducks*
lock (_lock)
{
if (ReferenceEquals(_message1.Protocol, protocol))
{
return _message1.Message;
}
if (ReferenceEquals(_message2.Protocol, protocol))
{
return _message2.Message;
}
for (var i = 0; i < _serializedMessages?.Count; i++)
{
if (ReferenceEquals(_serializedMessages[i].Protocol, protocol))
{
return _serializedMessages[i].Message;
}
}
var bytes = protocol.WriteToArray(this);
if (_message1.Protocol == null)
{
_message1 = new SerializedMessage(protocol, bytes);
}
else if (_message2.Protocol == null)
{
_message2 = new SerializedMessage(protocol, bytes);
}
else
{
if (_serializedMessages == null)
{
_serializedMessages = new List<SerializedMessage>();
}
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
// So we cap how many caches we store and worst case just serialize the message for every connection
if (_serializedMessages.Count < 10)
{
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
}
}
return bytes;
}
}
private readonly struct SerializedMessage
{
public readonly IHubProtocol Protocol;
public readonly byte[] Message;
public SerializedMessage(IHubProtocol protocol, byte[] message)
{
Protocol = protocol;
Message = message;
}
}
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
@ -210,9 +211,9 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
private InvocationMessage CreateInvocationMessage(string methodName, object[] args)
private SerializedHubMessage CreateInvocationMessage(string methodName, object[] args)
{
return new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args);
return new SerializedHubMessage(new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args));
}
public override Task SendUserAsync(string userId, string methodName, object[] args)

View File

@ -6,19 +6,16 @@ using System.Buffers;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
using System.Net;
using System.Runtime.ExceptionServices;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
@ -73,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR
public virtual PipeReader Input => _connectionContext.Transport.Input;
public string UserIdentifier { get; private set; }
public string UserIdentifier { get; set; }
internal virtual IHubProtocol Protocol { get; set; }
@ -84,7 +81,36 @@ namespace Microsoft.AspNetCore.SignalR
public virtual ValueTask WriteAsync(HubMessage message)
{
// We were unable to get the lock so take the slow async path of waiting for the semaphore
// Try to grab the lock synchronously, if we fail, go to the slower path
if (!_writeLock.Wait(0))
{
return new ValueTask(WriteSlowAsync(message));
}
// This method should never throw synchronously
var task = WriteCore(message);
// The write didn't complete synchronously so await completion
if (!task.IsCompletedSuccessfully)
{
return new ValueTask(CompleteWriteAsync(task));
}
// Otherwise, release the lock acquired when entering WriteAsync
_writeLock.Release();
return default;
}
/// <summary>
/// This method is designed to support the framework and is not intended to be used by application code. Writes a pre-serialized message to the
/// connection.
/// </summary>
/// <param name="message">The serialization cache to use.</param>
/// <returns></returns>
public virtual ValueTask WriteAsync(SerializedHubMessage message)
{
// Try to grab the lock synchronously, if we fail, go to the slower path
if (!_writeLock.Wait(0))
{
return new ValueTask(WriteSlowAsync(message));
@ -109,14 +135,28 @@ namespace Microsoft.AspNetCore.SignalR
{
try
{
// This will internally cache the buffer for each unique HubProtocol
// So that we don't serialize the HubMessage for every single connection
var buffer = message.WriteMessage(Protocol);
// We know that we are only writing this message to one receiver, so we can
// write it without caching.
Protocol.WriteMessage(message, _connectionContext.Transport.Output);
var output = _connectionContext.Transport.Output;
output.Write(buffer);
return _connectionContext.Transport.Output.FlushAsync();
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
return output.FlushAsync();
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
}
}
private ValueTask<FlushResult> WriteCore(SerializedHubMessage message)
{
try
{
// Grab a preserialized buffer for this protocol.
var buffer = message.GetSerializedMessage(Protocol);
return _connectionContext.Transport.Output.WriteAsync(buffer);
}
catch (Exception ex)
{
@ -162,6 +202,25 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task WriteSlowAsync(SerializedHubMessage message)
{
try
{
// Failed to get the lock immediately when entering WriteAsync so await until it is available
await _writeLock.WaitAsync();
await WriteCore(message);
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
}
finally
{
_writeLock.Release();
}
}
private ValueTask TryWritePingAsync()
{
// Don't wait for the lock, if it returns false that means someone wrote to the connection

View File

@ -7,21 +7,25 @@ using System.Linq;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public class DefaultHubProtocolResolver : IHubProtocolResolver
{
private readonly ILogger<DefaultHubProtocolResolver> _logger;
private readonly List<IHubProtocol> _hubProtocols;
private readonly Dictionary<string, IHubProtocol> _availableProtocols;
public IReadOnlyList<IHubProtocol> AllProtocols => _hubProtocols;
public DefaultHubProtocolResolver(IEnumerable<IHubProtocol> availableProtocols, ILogger<DefaultHubProtocolResolver> logger)
{
_logger = logger ?? NullLogger<DefaultHubProtocolResolver>.Instance;
_availableProtocols = new Dictionary<string, IHubProtocol>(StringComparer.OrdinalIgnoreCase);
foreach (var protocol in availableProtocols)
// We might get duplicates in _hubProtocols, but we're going to check it and throw in just a sec.
_hubProtocols = availableProtocols.ToList();
foreach (var protocol in _hubProtocols)
{
if (_availableProtocols.ContainsKey(protocol.Name))
{

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
@ -8,6 +8,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
public interface IHubProtocolResolver
{
IReadOnlyList<IHubProtocol> AllProtocols { get; }
IHubProtocol GetProtocol(string protocolName, IList<string> supportedProtocols);
}
}

View File

@ -0,0 +1,161 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Internal
{
/// <summary>
/// This class is designed to support the framework. The API is subject to breaking changes.
/// Represents a serialization cache for a single message.
/// </summary>
public class SerializedHubMessage
{
private SerializedMessage _cachedItem1;
private SerializedMessage _cachedItem2;
private IList<SerializedMessage> _cachedItems;
public HubMessage Message { get; }
private SerializedHubMessage()
{
}
public SerializedHubMessage(HubMessage message)
{
Message = message;
}
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
{
if (!TryGetCached(protocol.Name, out var serialized))
{
if (Message == null)
{
throw new InvalidOperationException(
"This message was received from another server that did not have the requested protocol available.");
}
serialized = protocol.WriteToArray(Message);
SetCache(protocol.Name, serialized);
}
return serialized;
}
public static void WriteAllSerializedVersions(BinaryWriter writer, HubMessage message, IReadOnlyList<IHubProtocol> protocols)
{
// The serialization format is based on BinaryWriter
// * 1 byte number of protocols
// * For each protocol:
// * Length-prefixed string using 7-bit variable length encoding (length depends on BinaryWriter's encoding)
// * 4 byte length of the buffer
// * N byte buffer
if (protocols.Count > byte.MaxValue)
{
throw new InvalidOperationException($"Can't serialize cache containing more than {byte.MaxValue} entries");
}
writer.Write((byte)protocols.Count);
foreach (var protocol in protocols)
{
writer.Write(protocol.Name);
var buffer = protocol.WriteToArray(message);
writer.Write(buffer.Length);
writer.Write(buffer);
}
}
public static SerializedHubMessage ReadAllSerializedVersions(BinaryReader reader)
{
var cache = new SerializedHubMessage();
var count = reader.ReadByte();
for (var i = 0; i < count; i++)
{
var protocol = reader.ReadString();
var length = reader.ReadInt32();
var serialized = reader.ReadBytes(length);
cache.SetCache(protocol, serialized);
}
return cache;
}
private void SetCache(string protocolName, byte[] serialized)
{
if (_cachedItem1.ProtocolName == null)
{
_cachedItem1 = new SerializedMessage(protocolName, serialized);
}
else if (_cachedItem2.ProtocolName == null)
{
_cachedItem2 = new SerializedMessage(protocolName, serialized);
}
else
{
if (_cachedItems == null)
{
_cachedItems = new List<SerializedMessage>();
}
foreach (var item in _cachedItems)
{
if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal))
{
// No need to add
return;
}
}
_cachedItems.Add(new SerializedMessage(protocolName, serialized));
}
}
private bool TryGetCached(string protocolName, out byte[] result)
{
if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal))
{
result = _cachedItem1.Serialized;
return true;
}
if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal))
{
result = _cachedItem2.Serialized;
return true;
}
if (_cachedItems != null)
{
foreach (var serializedMessage in _cachedItems)
{
if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal))
{
result = serializedMessage.Serialized;
return true;
}
}
}
result = default;
return false;
}
private readonly struct SerializedMessage
{
public string ProtocolName { get; }
public byte[] Serialized { get; }
public SerializedMessage(string protocolName, byte[] serialized)
{
ProtocolName = protocolName;
Serialized = serialized;
}
}
}
}

View File

@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
// The size of the enum is defined by the protocol. Do not change it. If you need more than 255 items,
// add an additional enum.
public enum GroupAction : byte
{
// These numbers are used by the protocol, do not change them and always use explicit assignment
// when adding new items to this enum. 0 is intentionally omitted
Add = 1,
Remove = 2,
}
}

View File

@ -0,0 +1,75 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Runtime.CompilerServices;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
internal class RedisChannels
{
private readonly string _prefix;
/// <summary>
/// Gets the name of the channel for sending to all connections.
/// </summary>
/// <remarks>
/// The payload on this channel is <see cref="RedisInvocation"/> objects containing
/// invocations to be sent to all connections
/// </remarks>
public string All { get; }
/// <summary>
/// Gets the name of the internal channel for group management messages.
/// </summary>
public string GroupManagement { get; }
public RedisChannels(string prefix)
{
_prefix = prefix;
All = prefix + ":all";
GroupManagement = prefix + ":internal:groups";
}
/// <summary>
/// Gets the name of the channel for sending a message to a specific connection.
/// </summary>
/// <param name="connectionId">The ID of the connection to get the channel for.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public string Connection(string connectionId)
{
return _prefix + ":connection:" + connectionId;
}
/// <summary>
/// Gets the name of the channel for sending a message to a named group of connections.
/// </summary>
/// <param name="groupName">The name of the group to get the channel for.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public string Group(string groupName)
{
return _prefix + ":group:" + groupName;
}
/// <summary>
/// Gets the name of the channel for sending a message to all collections associated with a user.
/// </summary>
/// <param name="userId">The ID of the user to get the channel for.</param>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public string User(string userId)
{
return _prefix + ":user:" + userId;
}
/// <summary>
/// Gets the name of the acknowledgement channel for the specified server.
/// </summary>
/// <param name="serverName">The name of the server to get the acknowledgement channel for.</param>
/// <returns></returns>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public string Ack(string serverName)
{
return _prefix + ":internal:ack:" + serverName;
}
}
}

View File

@ -0,0 +1,39 @@
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
public readonly struct RedisGroupCommand
{
/// <summary>
/// Gets the ID of the group command.
/// </summary>
public int Id { get; }
/// <summary>
/// Gets the name of the server that sent the command.
/// </summary>
public string ServerName { get; }
/// <summary>
/// Gets the action to be performed on the group.
/// </summary>
public GroupAction Action { get; }
/// <summary>
/// Gets the group on which the action is performed.
/// </summary>
public string GroupName { get; }
/// <summary>
/// Gets the ID of the connection to be added or removed from the group.
/// </summary>
public string ConnectionId { get; }
public RedisGroupCommand(int id, string serverName, GroupAction action, string groupName, string connectionId)
{
Id = id;
ServerName = serverName;
Action = action;
GroupName = groupName;
ConnectionId = connectionId;
}
}
}

View File

@ -0,0 +1,33 @@
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
public readonly struct RedisInvocation
{
/// <summary>
/// Gets a list of connections that should be excluded from this invocation.
/// May be null to indicate that no connections are to be excluded.
/// </summary>
public IReadOnlyList<string> ExcludedIds { get; }
/// <summary>
/// Gets the message serialization cache containing serialized payloads for the message.
/// </summary>
public SerializedHubMessage Message { get; }
public RedisInvocation(SerializedHubMessage message, IReadOnlyList<string> excludedIds)
{
Message = message;
ExcludedIds = excludedIds;
}
public static RedisInvocation Create(string target, object[] arguments, IReadOnlyList<string> excludedIds = null)
{
return new RedisInvocation(
new SerializedHubMessage(new InvocationMessage(target, argumentBindingException: null, arguments)),
excludedIds);
}
}
}

View File

@ -0,0 +1,170 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
public class RedisProtocol
{
private readonly IReadOnlyList<IHubProtocol> _protocols;
private static readonly Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
public RedisProtocol(IReadOnlyList<IHubProtocol> protocols)
{
_protocols = protocols;
}
// The Redis Protocol:
// * The message type is known in advance because messages are sent to different channels based on type
// * Invocations are sent to the All, Group, Connection and User channels
// * Group Commands are sent to the GroupManagement channel
// * Acks are sent to the Acknowledgement channel.
// * See the Write[type] methods for a description of the protocol for each in-depth.
// * The "Variable length integer" is the length-prefixing format used by BinaryReader/BinaryWriter:
// * https://docs.microsoft.com/en-us/dotnet/api/system.io.binarywriter.write?view=netstandard-2.0
// * The "Length prefixed string" is the string format used by BinaryReader/BinaryWriter:
// * A 7-bit variable length integer encodes the length in bytes, followed by the encoded string in UTF-8.
public byte[] WriteInvocation(string methodName, object[] args) =>
WriteInvocation(methodName, args, excludedIds: null);
public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
// Redis Invocation Format:
// * Variable length integer: Number of excluded Ids
// * For each excluded Id:
// * Length prefixed string: ID
// * SerializedHubMessage encoded by the format described by that type.
using (var stream = new MemoryStream())
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
{
if (excludedIds != null)
{
writer.WriteVarInt(excludedIds.Count);
foreach (var id in excludedIds)
{
writer.Write(id);
}
}
else
{
writer.WriteVarInt(0);
}
SerializedHubMessage.WriteAllSerializedVersions(writer, new InvocationMessage(methodName, argumentBindingException: null, args), _protocols);
return stream.ToArray();
}
}
public byte[] WriteGroupCommand(RedisGroupCommand command)
{
// Group Command Format:
// * Variable length integer: Id
// * Length prefixed string: ServerName
// * 1 byte: Action
// * Length prefixed string: GroupName
// * Length prefixed string: ConnectionId
using (var stream = new MemoryStream())
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
{
writer.WriteVarInt(command.Id);
writer.Write(command.ServerName);
writer.Write((byte)command.Action);
writer.Write(command.GroupName);
writer.Write(command.ConnectionId);
return stream.ToArray();
}
}
public byte[] WriteAck(int messageId)
{
// Acknowledgement Format:
// * Variable length integer: Id
using (var stream = new MemoryStream())
using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom))
{
writer.WriteVarInt(messageId);
return stream.ToArray();
}
}
public RedisInvocation ReadInvocation(byte[] data)
{
// See WriteInvocation for format.
using (var stream = new MemoryStream(data))
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
{
IReadOnlyList<string> excludedIds = null;
var idCount = reader.ReadVarInt();
if (idCount > 0)
{
var ids = new string[idCount];
for (var i = 0; i < idCount; i++)
{
ids[i] = reader.ReadString();
}
excludedIds = ids;
}
var message = SerializedHubMessage.ReadAllSerializedVersions(reader);
return new RedisInvocation(message, excludedIds);
}
}
public RedisGroupCommand ReadGroupCommand(byte[] data)
{
// See WriteGroupCommand for format.
using (var stream = new MemoryStream(data))
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
{
var id = reader.ReadVarInt();
var serverName = reader.ReadString();
var action = (GroupAction)reader.ReadByte();
var groupName = reader.ReadString();
var connectionId = reader.ReadString();
return new RedisGroupCommand(id, serverName, action, groupName, connectionId);
}
}
public int ReadAck(byte[] data)
{
// See WriteAck for format
using (var stream = new MemoryStream(data))
using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom))
{
return reader.ReadVarInt();
}
}
// Kinda cheaty way to get access to write the 7-bit varint format directly
private class BinaryWriterWithVarInt : BinaryWriter
{
public BinaryWriterWithVarInt(Stream output, Encoding encoding) : base(output, encoding)
{
}
public void WriteVarInt(int value) => Write7BitEncodedInt(value);
}
private class BinaryReaderWithVarInt : BinaryReader
{
public BinaryReaderWithVarInt(Stream input, Encoding encoding) : base(input, encoding)
{
}
public int ReadVarInt() => Read7BitEncodedInt();
}
}
}

View File

@ -9,7 +9,7 @@ using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Redis.Internal;
using Microsoft.Extensions.Logging;
@ -28,8 +28,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private readonly ISubscriber _bus;
private readonly ILogger _logger;
private readonly RedisOptions _options;
private readonly string _channelNamePrefix = typeof(THub).FullName;
private readonly string _serverName = Guid.NewGuid().ToString();
private readonly RedisChannels _channels;
private readonly string _serverName = GenerateServerName();
private readonly RedisProtocol _protocol;
private readonly AckHandler _ackHandler;
private int _internalId;
@ -41,14 +43,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
};
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options)
IOptions<RedisOptions> options,
IHubProtocolResolver hubProtocolResolver)
{
_logger = logger;
_options = options.Value;
_ackHandler = new AckHandler();
_channels = new RedisChannels(typeof(THub).FullName);
_protocol = new RedisProtocol(hubProtocolResolver.AllProtocols);
var writer = new LoggerTextWriter(logger);
_logger.ConnectingToEndpoints(options.Value.Options.EndPoints);
RedisLog.ConnectingToEndpoints(_logger, options.Value.Options.EndPoints, _serverName);
_redisServerConnection = _options.Connect(writer);
_redisServerConnection.ConnectionRestored += (_, e) =>
@ -60,7 +65,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
_logger.ConnectionRestored();
RedisLog.ConnectionRestored(_logger);
};
_redisServerConnection.ConnectionFailed += (_, e) =>
@ -72,23 +77,22 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
_logger.ConnectionFailed(e.Exception);
RedisLog.ConnectionFailed(_logger, e.Exception);
};
if (_redisServerConnection.IsConnected)
{
_logger.Connected();
RedisLog.Connected(_logger);
}
else
{
_logger.NotConnected();
RedisLog.NotConnected(_logger);
}
_bus = _redisServerConnection.GetSubscriber();
SubscribeToHub();
SubscribeToAllExcept();
SubscribeToInternalGroup();
SubscribeToInternalServerName();
SubscribeToAll();
SubscribeToGroupManagementChannel();
SubscribeToAckChannel();
}
public override Task OnConnectedAsync(HubConnectionContext connection)
@ -125,7 +129,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
foreach (var subscription in redisSubscriptions)
{
_logger.Unsubscribe(subscription);
RedisLog.Unsubscribe(_logger, subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
@ -149,15 +153,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override Task SendAllAsync(string methodName, object[] args)
{
var message = new RedisInvocationMessage(target: methodName, arguments: args);
return PublishAsync(_channelNamePrefix, message);
var message = _protocol.WriteInvocation(methodName, args);
return PublishAsync(_channels.All, message);
}
public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
var message = new RedisInvocationMessage(target: methodName, excludedIds: excludedIds, arguments: args);
return PublishAsync(_channelNamePrefix + ".AllExcept", message);
var message = _protocol.WriteInvocation(methodName, args, excludedIds);
return PublishAsync(_channels.All, message);
}
public override Task SendConnectionAsync(string connectionId, string methodName, object[] args)
@ -167,17 +170,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
throw new ArgumentNullException(nameof(connectionId));
}
var message = new RedisInvocationMessage(target: methodName, arguments: args);
// If the connection is local we can skip sending the message through the bus since we require sticky connections.
// This also saves serializing and deserializing the message!
var connection = _connections[connectionId];
if (connection != null)
{
return SafeWriteAsync(connection, message.CreateInvocation());
return connection.WriteAsync(new InvocationMessage(methodName, argumentBindingException: null, args)).AsTask();
}
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
var message = _protocol.WriteInvocation(methodName, args);
return PublishAsync(_channels.Connection(connectionId), message);
}
public override Task SendGroupAsync(string groupName, string methodName, object[] args)
@ -187,9 +189,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
throw new ArgumentNullException(nameof(groupName));
}
var message = new RedisInvocationMessage(target: methodName, excludedIds: null, arguments: args);
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
var message = _protocol.WriteInvocation(methodName, args);
return PublishAsync(_channels.Group(groupName), message);
}
public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList<string> excludedIds)
@ -199,31 +200,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
throw new ArgumentNullException(nameof(groupName));
}
var message = new RedisInvocationMessage(methodName, excludedIds, args);
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
var message = _protocol.WriteInvocation(methodName, args, excludedIds);
return PublishAsync(_channels.Group(groupName), message);
}
public override Task SendUserAsync(string userId, string methodName, object[] args)
{
var message = new RedisInvocationMessage(methodName, args);
return PublishAsync(_channelNamePrefix + ".user." + userId, message);
}
private async Task PublishAsync(string channel, IRedisMessage message)
{
byte[] payload;
using (var stream = new LimitArrayPoolWriteStream())
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream)))
{
_serializer.Serialize(writer, message);
writer.Flush();
payload = stream.ToArray();
}
_logger.PublishToChannel(channel);
await _bus.PublishAsync(channel, payload);
var message = _protocol.WriteInvocation(methodName, args);
return PublishAsync(_channels.User(userId), message);
}
public override async Task AddGroupAsync(string connectionId, string groupName)
@ -249,6 +233,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Add);
}
public override async Task RemoveGroupAsync(string connectionId, string groupName)
{
if (connectionId == null)
{
throw new ArgumentNullException(nameof(connectionId));
}
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var connection = _connections[connectionId];
if (connection != null)
{
// short circuit if connection is on this server
await RemoveGroupAsyncCore(connection, groupName);
return;
}
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove);
}
public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args)
{
if (connectionIds == null)
{
throw new ArgumentNullException(nameof(connectionIds));
}
var publishTasks = new List<Task>(connectionIds.Count);
var payload = _protocol.WriteInvocation(methodName, args);
foreach (var connectionId in connectionIds)
{
publishTasks.Add(PublishAsync(_channels.Connection(connectionId), payload));
}
return Task.WhenAll(publishTasks);
}
public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args)
{
if (groupNames == null)
{
throw new ArgumentNullException(nameof(groupNames));
}
var publishTasks = new List<Task>(groupNames.Count);
var payload = _protocol.WriteInvocation(methodName, args);
foreach (var groupName in groupNames)
{
if (!string.IsNullOrEmpty(groupName))
{
publishTasks.Add(PublishAsync(_channels.Group(groupName), payload));
}
}
return Task.WhenAll(publishTasks);
}
public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args)
{
if (userIds.Count > 0)
{
var payload = _protocol.WriteInvocation(methodName, args);
var publishTasks = new List<Task>(userIds.Count);
foreach (var userId in userIds)
{
if (!string.IsNullOrEmpty(userId))
{
publishTasks.Add(PublishAsync(_channels.User(userId), payload));
}
}
return Task.WhenAll(publishTasks);
}
return Task.CompletedTask;
}
private Task PublishAsync(string channel, byte[] payload)
{
RedisLog.PublishToChannel(_logger, channel);
return _bus.PublishAsync(channel, payload);
}
private async Task AddGroupAsyncCore(HubConnectionContext connection, string groupName)
{
var feature = connection.Features.Get<IRedisFeature>();
@ -263,7 +334,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
var groupChannel = _channelNamePrefix + ".group." + groupName;
var groupChannel = _channels.Group(groupName);
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
await group.Lock.WaitAsync();
@ -285,37 +356,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override async Task RemoveGroupAsync(string connectionId, string groupName)
{
if (connectionId == null)
{
throw new ArgumentNullException(nameof(connectionId));
}
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var connection = _connections[connectionId];
if (connection != null)
{
// short circuit if connection is on this server
await RemoveGroupAsyncCore(connection, groupName);
return;
}
await SendGroupActionAndWaitForAck(connectionId, groupName, GroupAction.Remove);
}
/// <summary>
/// This takes <see cref="HubConnectionContext"/> because we want to remove the connection from the
/// _connections list in OnDisconnectedAsync and still be able to remove groups with this method.
/// </summary>
private async Task RemoveGroupAsyncCore(HubConnectionContext connection, string groupName)
{
var groupChannel = _channelNamePrefix + ".group." + groupName;
var groupChannel = _channels.Group(groupName);
if (!_groups.TryGetValue(groupChannel, out var group))
{
@ -341,7 +388,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
if (group.Connections.Count == 0)
{
_logger.Unsubscribe(groupChannel);
RedisLog.Unsubscribe(_logger, groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
@ -350,8 +397,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
group.Lock.Release();
}
return;
}
private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action)
@ -359,14 +404,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var id = Interlocked.Increment(ref _internalId);
var ack = _ackHandler.CreateAck(id);
// Send Add/Remove Group to other servers and wait for an ack or timeout
await PublishAsync(_channelNamePrefix + ".internal.group", new RedisGroupMessage
{
Action = action,
ConnectionId = connectionId,
Group = groupName,
Id = id,
Server = _serverName
});
var message = _protocol.WriteGroupCommand(new RedisGroupCommand(id, _serverName, action, groupName, connectionId));
await PublishAsync(_channels.GroupManagement, message);
await ack;
}
@ -378,63 +417,24 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_ackHandler.Dispose();
}
private T DeserializeMessage<T>(RedisValue data)
private void SubscribeToAll()
{
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data))))
{
return _serializer.Deserialize<T>(reader);
}
}
private void SubscribeToHub()
{
_logger.Subscribing(_channelNamePrefix);
_bus.Subscribe(_channelNamePrefix, async (c, data) =>
RedisLog.Subscribing(_logger, _channels.All);
_bus.Subscribe(_channels.All, async (c, data) =>
{
try
{
_logger.ReceivedFromChannel(_channelNamePrefix);
RedisLog.ReceivedFromChannel(_logger, _channels.All);
var message = DeserializeMessage<RedisInvocationMessage>(data);
var invocation = _protocol.ReadInvocation(data);
var tasks = new List<Task>(_connections.Count);
var invocation = message.CreateInvocation();
foreach (var connection in _connections)
{
tasks.Add(SafeWriteAsync(connection, invocation));
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
}
private void SubscribeToAllExcept()
{
var channelName = _channelNamePrefix + ".AllExcept";
_logger.Subscribing(channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
try
{
_logger.ReceivedFromChannel(channelName);
var message = DeserializeMessage<RedisInvocationMessage>(data);
var excludedIds = message.ExcludedIds ?? Array.Empty<string>();
var tasks = new List<Task>(_connections.Count);
var invocation = message.CreateInvocation();
foreach (var connection in _connections)
{
if (!excludedIds.Contains(connection.ConnectionId))
if (invocation.ExcludedIds == null || !invocation.ExcludedIds.Contains(connection.ConnectionId))
{
tasks.Add(SafeWriteAsync(connection, invocation));
tasks.Add(connection.WriteAsync(invocation.Message).AsTask());
}
}
@ -442,19 +442,18 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
RedisLog.FailedWritingMessage(_logger, ex);
}
});
}
private void SubscribeToInternalGroup()
private void SubscribeToGroupManagementChannel()
{
var channelName = _channelNamePrefix + ".internal.group";
_bus.Subscribe(channelName, async (c, data) =>
_bus.Subscribe(_channels.GroupManagement, async (c, data) =>
{
try
{
var groupMessage = DeserializeMessage<RedisGroupMessage>(data);
var groupMessage = _protocol.ReadGroupCommand(data);
var connection = _connections[groupMessage.ConnectionId];
if (connection == null)
@ -465,179 +464,95 @@ namespace Microsoft.AspNetCore.SignalR.Redis
if (groupMessage.Action == GroupAction.Remove)
{
await RemoveGroupAsyncCore(connection, groupMessage.Group);
await RemoveGroupAsyncCore(connection, groupMessage.GroupName);
}
if (groupMessage.Action == GroupAction.Add)
{
await AddGroupAsyncCore(connection, groupMessage.Group);
await AddGroupAsyncCore(connection, groupMessage.GroupName);
}
// Sending ack to server that sent the original add/remove
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new RedisGroupMessage
{
Action = GroupAction.Ack,
Id = groupMessage.Id
});
// Send an ack to the server that sent the original command.
await PublishAsync(_channels.Ack(groupMessage.ServerName), _protocol.WriteAck(groupMessage.Id));
}
catch (Exception ex)
{
_logger.InternalMessageFailed(ex);
RedisLog.InternalMessageFailed(_logger, ex);
}
});
}
private void SubscribeToInternalServerName()
private void SubscribeToAckChannel()
{
// Create server specific channel in order to send an ack to a single server
var serverChannel = $"{_channelNamePrefix}.internal.{_serverName}";
_bus.Subscribe(serverChannel, (c, data) =>
_bus.Subscribe(_channels.Ack(_serverName), (c, data) =>
{
var groupMessage = DeserializeMessage<RedisGroupMessage>(data);
var ackId = _protocol.ReadAck(data);
if (groupMessage.Action == GroupAction.Ack)
{
_ackHandler.TriggerAck(groupMessage.Id);
}
_ackHandler.TriggerAck(ackId);
});
}
private Task SubscribeToConnection(HubConnectionContext connection, HashSet<string> redisSubscriptions)
{
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
var connectionChannel = _channels.Connection(connection.ConnectionId);
redisSubscriptions.Add(connectionChannel);
_logger.Subscribing(connectionChannel);
RedisLog.Subscribing(_logger, connectionChannel);
return _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
var message = DeserializeMessage<RedisInvocationMessage>(data);
await SafeWriteAsync(connection, message.CreateInvocation());
var invocation = _protocol.ReadInvocation(data);
await connection.WriteAsync(invocation.Message);
});
}
private Task SubscribeToUser(HubConnectionContext connection, HashSet<string> redisSubscriptions)
{
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
var userChannel = _channels.User(connection.UserIdentifier);
redisSubscriptions.Add(userChannel);
// TODO: Look at optimizing (looping over connections checking for Name)
return _bus.SubscribeAsync(userChannel, async (c, data) =>
{
var message = DeserializeMessage<RedisInvocationMessage>(data);
await SafeWriteAsync(connection, message.CreateInvocation());
var invocation = _protocol.ReadInvocation(data);
await connection.WriteAsync(invocation.Message);
});
}
private Task SubscribeToGroup(string groupChannel, GroupData group)
{
_logger.Subscribing(groupChannel);
RedisLog.Subscribing(_logger, groupChannel);
return _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
try
{
var message = DeserializeMessage<RedisInvocationMessage>(data);
var invocation = _protocol.ReadInvocation(data);
var tasks = new List<Task>();
var invocation = message.CreateInvocation();
foreach (var groupConnection in group.Connections)
{
if (message.ExcludedIds?.Contains(groupConnection.ConnectionId) == true)
if (invocation.ExcludedIds?.Contains(groupConnection.ConnectionId) == true)
{
continue;
}
tasks.Add(SafeWriteAsync(groupConnection, invocation));
tasks.Add(groupConnection.WriteAsync(invocation.Message).AsTask());
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
RedisLog.FailedWritingMessage(_logger, ex);
}
});
}
public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args)
private static string GenerateServerName()
{
if (connectionIds == null)
{
throw new ArgumentNullException(nameof(connectionIds));
}
var publishTasks = new List<Task>(connectionIds.Count);
var message = new RedisInvocationMessage(target: methodName, arguments: args);
foreach (var connectionId in connectionIds)
{
var connection = _connections[connectionId];
// If the connection is local we can skip sending the message through the bus since we require sticky connections.
// This also saves serializing and deserializing the message!
if (connection != null)
{
publishTasks.Add(SafeWriteAsync(connection, message.CreateInvocation()));
}
else
{
publishTasks.Add(PublishAsync(_channelNamePrefix + "." + connectionId, message));
}
}
return Task.WhenAll(publishTasks);
}
public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args)
{
if (groupNames == null)
{
throw new ArgumentNullException(nameof(groupNames));
}
var publishTasks = new List<Task>(groupNames.Count);
var message = new RedisInvocationMessage(target: methodName, arguments: args);
foreach (var groupName in groupNames)
{
if (!string.IsNullOrEmpty(groupName))
{
publishTasks.Add(PublishAsync(_channelNamePrefix + "." + groupName, message));
}
}
return Task.WhenAll(publishTasks);
}
public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args)
{
if (userIds.Count > 0)
{
var message = new RedisInvocationMessage(methodName, args);
var publishTasks = new List<Task>(userIds.Count);
foreach (var userId in userIds)
{
if (!string.IsNullOrEmpty(userId))
{
publishTasks.Add(PublishAsync(_channelNamePrefix + ".user." + userId, message));
}
}
return Task.WhenAll(publishTasks);
}
return Task.CompletedTask;
}
// This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to
private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message)
{
try
{
await connection.WriteAsync(message);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
// Use the machine name for convenient diagnostics, but add a guid to make it unique.
// Example: MyServerName_02db60e5fab243b890a847fa5c4dcb29
return $"{Environment.MachineName}_{Guid.NewGuid():N}";
}
private class LoggerTextWriter : TextWriter
@ -658,7 +573,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override void WriteLine(string value)
{
_logger.LogDebug(value);
RedisLog.ConnectionMultiplexerMessage(_logger, value);
}
}
@ -679,53 +594,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
private enum GroupAction
{
Remove,
Add,
Ack
}
// Marker interface to represent the messages that can be sent over Redis.
private interface IRedisMessage { }
private class RedisGroupMessage : IRedisMessage
{
public string ConnectionId { get; set; }
public string Group { get; set; }
public int Id { get; set; }
public GroupAction Action { get; set; }
public string Server { get; set; }
}
// Represents a message published to the Redis bus
private class RedisInvocationMessage : IRedisMessage
{
public string Target { get; set; }
public IReadOnlyList<string> ExcludedIds { get; set; }
public object[] Arguments { get; set; }
public RedisInvocationMessage()
{
}
public RedisInvocationMessage(string target, object[] arguments)
: this(target, excludedIds: null, arguments: arguments)
{
}
public RedisInvocationMessage(string target, IReadOnlyList<string> excludedIds, object[] arguments)
{
Target = target;
ExcludedIds = excludedIds;
Arguments = arguments;
}
public InvocationMessage CreateInvocation()
{
return new InvocationMessage(Target, argumentBindingException: null, arguments: Arguments);
}
}
}
}

View File

@ -6,13 +6,14 @@ using System.Linq;
using Microsoft.Extensions.Logging;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
namespace Microsoft.AspNetCore.SignalR.Redis
{
internal static class RedisLoggerExtensions
// We don't want to use our nested static class here because RedisHubLifetimeManager is generic.
// We'd end up creating separate instances of all the LoggerMessage.Define values for each Hub.
internal static class RedisLog
{
// Category: RedisHubLifetimeManager<THub>
private static readonly Action<ILogger, string, Exception> _connectingToEndpoints =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}.");
private static readonly Action<ILogger, string, string, Exception> _connectingToEndpoints =
LoggerMessage.Define<string, string>(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}. Using Server Name: {ServerName}");
private static readonly Action<ILogger, Exception> _connected =
LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis.");
@ -44,65 +45,75 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
private static readonly Action<ILogger, Exception> _internalMessageFailed =
LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message.");
public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints)
public static void ConnectingToEndpoints(ILogger logger, EndPointCollection endpoints, string serverName)
{
if (logger.IsEnabled(LogLevel.Information))
{
if (endpoints.Count > 0)
{
_connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), null);
_connectingToEndpoints(logger, string.Join(", ", endpoints.Select(e => EndPointCollection.ToString(e))), serverName, null);
}
}
}
public static void Connected(this ILogger logger)
public static void Connected(ILogger logger)
{
_connected(logger, null);
}
public static void Subscribing(this ILogger logger, string channelName)
public static void Subscribing(ILogger logger, string channelName)
{
_subscribing(logger, channelName, null);
}
public static void ReceivedFromChannel(this ILogger logger, string channelName)
public static void ReceivedFromChannel(ILogger logger, string channelName)
{
_receivedFromChannel(logger, channelName, null);
}
public static void PublishToChannel(this ILogger logger, string channelName)
public static void PublishToChannel(ILogger logger, string channelName)
{
_publishToChannel(logger, channelName, null);
}
public static void Unsubscribe(this ILogger logger, string channelName)
public static void Unsubscribe(ILogger logger, string channelName)
{
_unsubscribe(logger, channelName, null);
}
public static void NotConnected(this ILogger logger)
public static void NotConnected(ILogger logger)
{
_notConnected(logger, null);
}
public static void ConnectionRestored(this ILogger logger)
public static void ConnectionRestored(ILogger logger)
{
_connectionRestored(logger, null);
}
public static void ConnectionFailed(this ILogger logger, Exception exception)
public static void ConnectionFailed(ILogger logger, Exception exception)
{
_connectionFailed(logger, exception);
}
public static void FailedWritingMessage(this ILogger logger, Exception exception)
public static void FailedWritingMessage(ILogger logger, Exception exception)
{
_failedWritingMessage(logger, exception);
}
public static void InternalMessageFailed(this ILogger logger, Exception exception)
public static void InternalMessageFailed(ILogger logger, Exception exception)
{
_internalMessageFailed(logger, exception);
}
// This isn't DefineMessage-based because it's just the simple TextWriter logging from ConnectionMultiplexer
public static void ConnectionMultiplexerMessage(ILogger logger, string message)
{
if (logger.IsEnabled(LogLevel.Debug))
{
// We tag it with EventId 100 though so it can be pulled out of logs easily.
logger.LogDebug(new EventId(100, "RedisConnectionLog"), message);
}
}
}
}
}

View File

@ -5,11 +5,15 @@ using System;
using System.Collections.Generic;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using MsgPack.Serialization;
using Newtonsoft.Json.Linq;
using Newtonsoft.Json.Serialization;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
@ -19,14 +23,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeAllAsyncWritesToAllConnectionsOutput()
{
var server = new TestRedisServer();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
@ -40,17 +42,44 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[Fact]
public async Task InvokeAllExceptAsyncExcludesSpecifiedConnections()
{
var server = new TestRedisServer();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client3 = new TestClient())
{
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
var manager3 = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
var connection3 = HubConnectionContextUtils.Create(client3.Connection);
await manager1.OnConnectedAsync(connection1).OrTimeout();
await manager2.OnConnectedAsync(connection2).OrTimeout();
await manager3.OnConnectedAsync(connection3).OrTimeout();
await manager1.SendAllExceptAsync("Hello", new object[] { "World" }, new [] { client3.Connection.ConnectionId }).OrTimeout();
await AssertMessageAsync(client1);
await AssertMessageAsync(client2);
Assert.Null(client3.TryRead());
}
}
[Fact]
public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput()
{
var server = new TestRedisServer();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
@ -70,14 +99,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput()
{
var server = new TestRedisServer();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
@ -96,14 +123,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeGroupExceptAsyncWritesToAllValidConnectionsInGroupOutput()
{
var server = new TestRedisServer();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
@ -124,13 +149,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeConnectionAsyncWritesToConnectionOutput()
{
var server = new TestRedisServer();
using (var client = new TestClient())
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager = CreateLifetimeManager(server);
var connection = HubConnectionContextUtils.Create(client.Connection);
await manager.OnConnectedAsync(connection).OrTimeout();
@ -144,27 +167,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeConnectionAsyncOnNonExistentConnectionDoesNotThrow()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
await manager.SendConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout();
}
[Fact]
public async Task InvokeAllAsyncWithMultipleServersWritesToAllConnectionsOutput()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
@ -185,16 +200,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeAllAsyncWithMultipleServersDoesNotWriteToDisconnectedConnectionsOutput()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
@ -218,16 +227,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeConnectionAsyncOnServerWithoutConnectionWritesOutputToConnection()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -244,16 +247,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeGroupAsyncOnServerWithoutConnectionWritesOutputToGroupConnection()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -272,11 +269,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task DisconnectConnectionRemovesConnectionFromGroup()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -297,11 +292,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task RemoveGroupFromLocalConnectionNotInGroupDoesNothing()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -316,16 +309,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task RemoveGroupFromConnectionOnDifferentServerNotInGroupDoesNothing()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(),
Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -340,14 +327,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task AddGroupAsyncForConnectionOnDifferentServerWorks()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -366,10 +349,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task AddGroupAsyncForLocalConnectionAlreadyInGroupDoesNothing()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -382,7 +364,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
await manager.SendGroupAsync("name", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client);
Assert.Null(client.TryRead());
}
@ -391,14 +372,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task AddGroupAsyncForConnectionOnDifferentServerAlreadyInGroupDoesNothing()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -419,14 +396,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task RemoveGroupAsyncForConnectionOnDifferentServerWorks()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -451,14 +424,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task InvokeConnectionAsyncForLocalConnectionDoesNotPublishToRedis()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -478,14 +447,10 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
[Fact]
public async Task WritingToRemoteConnectionThatFailsDoesNotThrow()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager1 = CreateLifetimeManager(server);
var manager2 = CreateLifetimeManager(server);
using (var client = new TestClient())
{
@ -502,34 +467,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[Fact]
public async Task WritingToLocalConnectionThatFailsDoesNotThrowException()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client = new TestClient())
{
// Force an exception when writing to connection
var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection);
connectionMock.Setup(m => m.WriteAsync(It.IsAny<HubMessage>())).Throws(new Exception("Message"));
var connection = connectionMock.Object;
await manager.OnConnectedAsync(connection).OrTimeout();
await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
}
}
[Fact]
public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
@ -557,6 +500,72 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[Fact]
public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
{
var server = new TestRedisServer();
var messagePackOptions = new MessagePackHubProtocolOptions();
messagePackOptions.SerializationContext.DictionarySerlaizationOptions.KeyTransformer = DictionaryKeyTransformers.LowerCamel;
var jsonOptions = new JsonHubProtocolOptions();
jsonOptions.PayloadSerializerSettings.ContractResolver = new CamelCasePropertyNamesContractResolver();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
// The sending manager has serializer settings
var manager1 = CreateLifetimeManager(server, messagePackOptions, jsonOptions);
// The receiving one doesn't matter because of how we serialize!
var manager2 = CreateLifetimeManager(server);
var connection1 = HubConnectionContextUtils.Create(client1.Connection);
var connection2 = HubConnectionContextUtils.Create(client2.Connection);
await manager1.OnConnectedAsync(connection1).OrTimeout();
await manager2.OnConnectedAsync(connection2).OrTimeout();
await manager1.SendAllAsync("Hello", new object[] { new TestObject { TestProperty = "Foo" } });
var message = Assert.IsType<InvocationMessage>(await client2.ReadAsync().OrTimeout());
Assert.Equal("Hello", message.Target);
Assert.Collection(
message.Arguments,
arg0 =>
{
var dict = Assert.IsType<JObject>(arg0);
Assert.Collection(dict.Properties(),
prop =>
{
Assert.Equal("testProperty", prop.Name);
Assert.Equal("Foo", prop.Value.Value<string>());
});
});
}
}
public class TestObject
{
public string TestProperty { get; set; }
}
private RedisHubLifetimeManager<MyHub> CreateLifetimeManager(TestRedisServer server, MessagePackHubProtocolOptions messagePackOptions = null, JsonHubProtocolOptions jsonOptions = null)
{
var options = new RedisOptions() { Factory = t => new TestConnectionMultiplexer(server) };
messagePackOptions = messagePackOptions ?? new MessagePackHubProtocolOptions();
jsonOptions = jsonOptions ?? new JsonHubProtocolOptions();
return new RedisHubLifetimeManager<MyHub>(
NullLogger<RedisHubLifetimeManager<MyHub>>.Instance,
Options.Create(options),
new DefaultHubProtocolResolver(new IHubProtocol[]
{
new JsonHubProtocol(Options.Create(jsonOptions)),
new MessagePackHubProtocol(Options.Create(messagePackOptions)),
}, NullLogger<DefaultHubProtocolResolver>.Instance));
}
private async Task AssertMessageAsync(TestClient client)
{
var message = Assert.IsType<InvocationMessage>(await client.ReadAsync().OrTimeout());

View File

@ -11,11 +11,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public static class HubConnectionContextUtils
{
public static HubConnectionContext Create(ConnectionContext connection)
public static HubConnectionContext Create(ConnectionContext connection, IHubProtocol protocol = null, string userIdentifier = null)
{
return new HubConnectionContext(connection, TimeSpan.FromSeconds(15), NullLoggerFactory.Instance)
{
Protocol = new JsonHubProtocol()
Protocol = protocol ?? new JsonHubProtocol(),
UserIdentifier = userIdentifier,
};
}

View File

@ -21,6 +21,7 @@
<PackageReference Include="Microsoft.Extensions.ValueStopwatch.Sources" Version="$(MicrosoftExtensionsValueStopwatchSourcesPackageVersion)" PrivateAssets="All" />
<PackageReference Include="Microsoft.AspNetCore.Hosting" Version="$(MicrosoftAspNetCoreHostingPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Server.Kestrel" Version="$(MicrosoftAspNetCoreServerKestrelPackageVersion)" />
<PackageReference Include="StackExchange.Redis.StrongName" Version="$(StackExchangeRedisStrongNamePackageVersion)" />
</ItemGroup>
<ItemGroup>

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -9,7 +9,7 @@ using System.Net;
using System.Threading.Tasks;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class TestConnectionMultiplexer : IConnectionMultiplexer
{
@ -70,7 +70,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
remove { }
}
private ISubscriber _subscriber = new TestSubscriber();
private ISubscriber _subscriber;
public TestConnectionMultiplexer(TestRedisServer server)
{
_subscriber = new TestSubscriber(server);
}
public void BeginProfiling(object forContext)
{
@ -203,19 +208,52 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
public class TestSubscriber : ISubscriber
public class TestRedisServer
{
// _globalSubscriptions represents the Redis Server you are connected to.
// So when publishing from a TestSubscriber you fake sending through the server by grabbing the callbacks
// from the _globalSubscriptions and inoking them inplace.
private static ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>> _globalSubscriptions =
private ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>> _subscriptions =
new ConcurrentDictionary<RedisChannel, List<Action<RedisChannel, RedisValue>>>();
private ConcurrentDictionary<RedisChannel, Action<RedisChannel, RedisValue>> _subscriptions =
new ConcurrentDictionary<RedisChannel, Action<RedisChannel, RedisValue>>();
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
{
if (_subscriptions.TryGetValue(channel, out var handlers))
{
foreach (var handler in handlers)
{
handler(channel, message);
}
}
return handlers != null ? handlers.Count : 0;
}
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
{
_subscriptions.AddOrUpdate(channel, _ => new List<Action<RedisChannel, RedisValue>> { handler }, (_, list) =>
{
list.Add(handler);
return list;
});
}
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
{
if (_subscriptions.TryGetValue(channel, out var list))
{
list.Remove(handler);
}
}
}
public class TestSubscriber : ISubscriber
{
private readonly TestRedisServer _server;
public ConnectionMultiplexer Multiplexer => throw new NotImplementedException();
public TestSubscriber(TestRedisServer server)
{
_server = server;
}
public EndPoint IdentifyEndpoint(RedisChannel channel, CommandFlags flags = CommandFlags.None)
{
throw new NotImplementedException();
@ -243,15 +281,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
public long Publish(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
{
if (_globalSubscriptions.TryGetValue(channel, out var handlers))
{
foreach (var handler in handlers)
{
handler(channel, message);
}
}
return handlers != null ? handlers.Count : 0;
return _server.Publish(channel, message, flags);
}
public async Task<long> PublishAsync(RedisChannel channel, RedisValue message, CommandFlags flags = CommandFlags.None)
@ -262,12 +292,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
public void Subscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
{
_globalSubscriptions.AddOrUpdate(channel, _ => new List<Action<RedisChannel, RedisValue>> { handler }, (_, list) =>
{
list.Add(handler);
return list;
});
_subscriptions.AddOrUpdate(channel, handler, (_, __) => handler);
_server.Subscribe(channel, handler, flags);
}
public Task SubscribeAsync(RedisChannel channel, Action<RedisChannel, RedisValue> handler, CommandFlags flags = CommandFlags.None)
@ -288,11 +313,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
public void Unsubscribe(RedisChannel channel, Action<RedisChannel, RedisValue> handler = null, CommandFlags flags = CommandFlags.None)
{
_subscriptions.TryRemove(channel, out var handle);
if (_globalSubscriptions.TryGetValue(channel, out var list))
{
list.Remove(handle);
}
_server.Unsubscribe(channel, handler, flags);
}
public void UnsubscribeAll(CommandFlags flags = CommandFlags.None)