Add support for hub specific IHubProtocols that don't affect other hubs (#15177)

This commit is contained in:
Brennan 2019-10-22 12:35:12 -07:00 committed by GitHub
parent 99e79a0bb3
commit 3d93e095db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 367 additions and 31 deletions

View File

@ -12,6 +12,7 @@ using MessagePack;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.Components.Server.BlazorPack
@ -19,6 +20,7 @@ namespace Microsoft.AspNetCore.Components.Server.BlazorPack
/// <summary>
/// Implements the SignalR Hub Protocol using MessagePack with limited type support.
/// </summary>
[NonDefaultHubProtocol]
internal sealed class BlazorPackHubProtocol : IHubProtocol
{
internal const string ProtocolName = "blazorpack";

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.SignalR.Internal
{
// Tells SignalR not to add the IHubProtocol with this attribute to all hubs by default
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
internal class NonDefaultHubProtocolAttribute : Attribute
{
}
}

View File

@ -6,8 +6,10 @@ using System.Buffers;
using System.Collections.Generic;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
@ -28,10 +30,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
_protocol = new RedisProtocol(new [] {
new DummyProtocol("protocol1"),
new DummyProtocol("protocol2")
});
var resolver = new DefaultHubProtocolResolver(new List<IHubProtocol> { new DummyProtocol("protocol1"),
new DummyProtocol("protocol2") }, NullLogger<DefaultHubProtocolResolver>.Instance);
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(resolver, new List<string>() { "protocol1", "protocol2" }, hubSupportedProtocols: null));
_groupCommand = new RedisGroupCommand(id: 42, serverName: "Server", GroupAction.Add, groupName: "group", connectionId: "connection");
@ -119,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
return ids;
}
private class DummyProtocol: IHubProtocol
private class DummyProtocol : IHubProtocol
{
private static readonly byte[] _fixedOutput = new byte[] { 0x68, 0x68, 0x6C, 0x6C, 0x6F };

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Options;
@ -26,6 +27,10 @@ namespace Microsoft.AspNetCore.SignalR
{
foreach (var hubProtocol in protocols)
{
if (hubProtocol.GetType().CustomAttributes.Where(a => a.AttributeType.FullName == "Microsoft.AspNetCore.SignalR.Internal.NonDefaultHubProtocolAttribute").Any())
{
continue;
}
_defaultProtocols.Add(hubProtocol.Name);
}
}

View File

@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.SignalR
public void Configure(HubOptions<THub> options)
{
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols.Count);
foreach (var protocol in _hubOptions.SupportedProtocols)
{

View File

@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.SignalR
{
private SerializedMessage _cachedItem1;
private SerializedMessage _cachedItem2;
private IList<SerializedMessage> _cachedItems;
private List<SerializedMessage> _cachedItems;
private readonly object _lock = new object();
/// <summary>
@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR
for (var i = 0; i < messages.Count; i++)
{
var message = messages[i];
SetCache(message.ProtocolName, message.Serialized);
SetCacheUnsynchronized(message.ProtocolName, message.Serialized);
}
}
@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR
{
lock (_lock)
{
if (!TryGetCached(protocol.Name, out var serialized))
if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized))
{
if (Message == null)
{
@ -63,7 +63,7 @@ namespace Microsoft.AspNetCore.SignalR
}
serialized = protocol.GetMessageBytes(Message);
SetCache(protocol.Name, serialized);
SetCacheUnsynchronized(protocol.Name, serialized);
}
return serialized;
@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory<byte> serialized)
{
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
// We have to read/write these fields under the lock because the structs might tear and another
@ -132,7 +132,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private bool TryGetCached(string protocolName, out ReadOnlyMemory<byte> result)
private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory<byte> result)
{
if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal))
{

View File

@ -2,12 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Xunit;
@ -148,6 +151,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Null(globalOptions.SupportedProtocols);
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
}
[Fact]
public void HubProtocolsWithNonDefaultAttributeNotAddedToSupportedProtocols()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSignalR().AddHubOptions<CustomHub>(options =>
{
});
serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, CustomHubProtocol>());
serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton<IHubProtocol, MessagePackHubProtocol>());
var serviceProvider = serviceCollection.BuildServiceProvider();
Assert.Collection(serviceProvider.GetRequiredService<IOptions<HubOptions<CustomHub>>>().Value.SupportedProtocols,
p =>
{
Assert.Equal("json", p);
},
p =>
{
Assert.Equal("messagepack", p);
});
}
}
public class CustomHub : Hub
@ -276,4 +303,42 @@ namespace Microsoft.AspNetCore.SignalR.Tests
throw new System.NotImplementedException();
}
}
[NonDefaultHubProtocol]
internal class CustomHubProtocol : IHubProtocol
{
public string Name => "custom";
public int Version => throw new NotImplementedException();
public TransferFormat TransferFormat => throw new NotImplementedException();
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
throw new NotImplementedException();
}
public bool IsVersionSupported(int version)
{
throw new NotImplementedException();
}
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
throw new NotImplementedException();
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
throw new NotImplementedException();
}
}
}
namespace Microsoft.AspNetCore.SignalR.Internal
{
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)]
internal class NonDefaultHubProtocolAttribute : Attribute
{
}
}

View File

@ -6,6 +6,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis
public partial class RedisHubLifetimeManager<THub> : Microsoft.AspNetCore.SignalR.HubLifetimeManager<THub>, System.IDisposable where THub : Microsoft.AspNetCore.SignalR.Hub
{
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver) { }
public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisHubLifetimeManager<THub>> logger, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.StackExchangeRedis.RedisOptions> options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions> globalHubOptions, Microsoft.Extensions.Options.IOptions<Microsoft.AspNetCore.SignalR.HubOptions<THub>> hubOptions) { }
public override System.Threading.Tasks.Task AddToGroupAsync(string connectionId, string groupName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; }
public void Dispose() { }
[System.Diagnostics.DebuggerStepThroughAttribute]

View File

@ -0,0 +1,39 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class DefaultHubMessageSerializer
{
private readonly List<IHubProtocol> _hubProtocols = new List<IHubProtocol>();
public DefaultHubMessageSerializer(IHubProtocolResolver hubProtocolResolver, IList<string> globalSupportedProtocols, IList<string> hubSupportedProtocols)
{
var supportedProtocols = hubSupportedProtocols ?? globalSupportedProtocols ?? Array.Empty<string>();
foreach (var protocolName in supportedProtocols)
{
var protocol = hubProtocolResolver.GetProtocol(protocolName, (supportedProtocols as IReadOnlyList<string>) ?? supportedProtocols.ToList());
if (protocol != null)
{
_hubProtocols.Add(protocol);
}
}
}
public IReadOnlyList<SerializedMessage> SerializeMessage(HubMessage message)
{
var list = new List<SerializedMessage>(_hubProtocols.Count);
foreach (var protocol in _hubProtocols)
{
list.Add(new SerializedMessage(protocol.Name, protocol.GetMessageBytes(message)));
}
return list;
}
}
}

View File

@ -8,17 +8,18 @@ using System.IO;
using System.Runtime.InteropServices;
using MessagePack;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
{
internal class RedisProtocol
{
private readonly IReadOnlyList<IHubProtocol> _protocols;
private readonly DefaultHubMessageSerializer _messageSerializer;
public RedisProtocol(IReadOnlyList<IHubProtocol> protocols)
public RedisProtocol(DefaultHubMessageSerializer messageSerializer)
{
_protocols = protocols;
_messageSerializer = messageSerializer;
}
// The Redis Protocol:
@ -60,8 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
MessagePackBinary.WriteArrayHeader(writer, 0);
}
WriteSerializedHubMessage(writer,
new SerializedHubMessage(new InvocationMessage(methodName, args)));
WriteHubMessage(writer, new InvocationMessage(methodName, args));
return writer.ToArray();
}
finally
@ -163,19 +163,20 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
return MessagePackUtil.ReadInt32(ref data);
}
private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message)
private void WriteHubMessage(Stream stream, HubMessage message)
{
// Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str')
// and the values are the serialized blob (as a MessagePack 'bin').
MessagePackBinary.WriteMapHeader(stream, _protocols.Count);
var serializedHubMessages = _messageSerializer.SerializeMessage(message);
foreach (var protocol in _protocols)
MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count);
foreach (var serializedMessage in serializedHubMessages)
{
MessagePackBinary.WriteString(stream, protocol.Name);
MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName);
var serialized = message.GetSerializedMessage(protocol);
var isArray = MemoryMarshal.TryGetArray(serialized, out var array);
var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array);
Debug.Assert(isArray);
MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count);
}

View File

@ -8,6 +8,7 @@ using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
using Microsoft.Extensions.Logging;
@ -36,12 +37,29 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options,
IHubProtocolResolver hubProtocolResolver)
: this(logger, options, hubProtocolResolver, globalHubOptions: null, hubOptions: null)
{
}
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options,
IHubProtocolResolver hubProtocolResolver,
IOptions<HubOptions> globalHubOptions,
IOptions<HubOptions<THub>> hubOptions)
{
_logger = logger;
_options = options.Value;
_ackHandler = new AckHandler();
_channels = new RedisChannels(typeof(THub).FullName);
_protocol = new RedisProtocol(hubProtocolResolver.AllProtocols);
if (globalHubOptions != null && hubOptions != null)
{
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols));
}
else
{
var supportedProtocols = hubProtocolResolver.AllProtocols.Select(p => p.Name).ToList();
_protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, supportedProtocols, null));
}
RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName);
_ = EnsureRedisServerConnection();

View File

@ -0,0 +1,166 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging.Abstractions;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests.Internal
{
public class DefaultHubMessageSerializerTests
{
[Theory]
[MemberData(nameof(InvocationTestData))]
public void SerializeMessages(string testName)
{
var testData = _invocationTestData[testName];
var resolver = CreateHubProtocolResolver(new List<IHubProtocol> { new MessagePackHubProtocol(), new JsonHubProtocol() });
var protocolNames = testData.SupportedHubProtocols.ConvertAll(p => p.Name);
var serializer = new DefaultHubMessageSerializer(resolver, protocolNames, hubSupportedProtocols: null);
var serializedHubMessage = serializer.SerializeMessage(_testMessage);
var allBytes = new List<byte>();
Assert.Equal(testData.SerializedCount, serializedHubMessage.Count);
foreach (var message in serializedHubMessage)
{
allBytes.AddRange(message.Serialized.ToArray());
}
Assert.Equal(testData.Encoded, allBytes);
}
[Fact]
public void GlobalSupportedProtocolsOverriddenByHubSupportedProtocols()
{
var testData = _invocationTestData["Single supported protocol"];
var resolver = CreateHubProtocolResolver(new List<IHubProtocol> { new MessagePackHubProtocol(), new JsonHubProtocol() });
var serializer = new DefaultHubMessageSerializer(resolver, new List<string>() { "json" }, new List<string>() { "messagepack" });
var serializedHubMessage = serializer.SerializeMessage(_testMessage);
Assert.Equal(1, serializedHubMessage.Count);
Assert.Equal(new List<byte>() { 0x0D,
0x96,
0x01,
0x80,
0xC0,
0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t',
0x90,
0x90 },
serializedHubMessage[0].Serialized.ToArray());
}
private IHubProtocolResolver CreateHubProtocolResolver(List<IHubProtocol> hubProtocols)
{
return new DefaultHubProtocolResolver(hubProtocols, NullLogger<DefaultHubProtocolResolver>.Instance);
}
private static Dictionary<string, ProtocolTestData> _invocationTestData = new[]
{
new ProtocolTestData(
"Single supported protocol",
new List<IHubProtocol>() { new MessagePackHubProtocol() },
1,
0x0D,
0x96,
0x01,
0x80,
0xC0,
0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t',
0x90,
0x90),
new ProtocolTestData(
"Multiple supported protocols",
new List<IHubProtocol>() { new MessagePackHubProtocol(), new JsonHubProtocol() },
2,
0x0D,
0x96,
0x01,
0x80,
0xC0,
0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t',
0x90,
0x90,
(byte)'{', (byte)'"', (byte)'t', (byte)'y', (byte)'p', (byte)'e', (byte)'"', (byte)':', (byte)'1',
(byte)',',(byte)'"', (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', (byte)'"', (byte)':',
(byte)'"', (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', (byte)'"',
(byte)',', (byte)'"', (byte)'a', (byte)'r', (byte)'g', (byte)'u', (byte)'m', (byte)'e', (byte)'n', (byte)'t', (byte)'s', (byte)'"',
(byte)':', (byte)'[', (byte)']', (byte)'}', 0x1e),
new ProtocolTestData(
"Multiple protocols, one not in hub protocol resolver",
new List<IHubProtocol>() { new MessagePackHubProtocol(), new TestHubProtocol() },
1,
0x0D,
0x96,
0x01,
0x80,
0xC0,
0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t',
0x90,
0x90),
new ProtocolTestData(
"No protocols",
new List<IHubProtocol>(),
0)
}.ToDictionary(t => t.Name);
public static IEnumerable<object[]> InvocationTestData = _invocationTestData.Keys.Select(k => new object[] { k });
public class ProtocolTestData
{
public string Name { get; }
public byte[] Encoded { get; }
public int SerializedCount { get; }
public List<IHubProtocol> SupportedHubProtocols { get; }
public ProtocolTestData(string name, List<IHubProtocol> supportedHubProtocols, int serializedCount, params byte[] encoded)
{
Name = name;
Encoded = encoded;
SerializedCount = serializedCount;
SupportedHubProtocols = supportedHubProtocols;
}
}
// The actual invocation message doesn't matter
private static InvocationMessage _testMessage = new InvocationMessage("target", Array.Empty<object>());
internal class TestHubProtocol : IHubProtocol
{
public string Name => "test";
public int Version => throw new NotImplementedException();
public TransferFormat TransferFormat => throw new NotImplementedException();
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
throw new NotImplementedException();
}
public bool IsVersionSupported(int version)
{
throw new NotImplementedException();
}
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
throw new NotImplementedException();
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
throw new NotImplementedException();
}
}
}
}

View File

@ -2,14 +2,14 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal;
using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
public void ParseAck(string testName)
{
var testData = _ackTestData[testName];
var protocol = new RedisProtocol(Array.Empty<IHubProtocol>());
var protocol = new RedisProtocol(CreateHubMessageSerializer(new List<IHubProtocol>()));
var decoded = protocol.ReadAck(testData.Encoded);
@ -44,7 +44,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
public void WriteAck(string testName)
{
var testData = _ackTestData[testName];
var protocol = new RedisProtocol(Array.Empty<IHubProtocol>());
var protocol = new RedisProtocol(CreateHubMessageSerializer(new List<IHubProtocol>()));
var encoded = protocol.WriteAck(testData.Decoded);
@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
public void ParseGroupCommand(string testName)
{
var testData = _groupCommandTestData[testName];
var protocol = new RedisProtocol(Array.Empty<IHubProtocol>());
var protocol = new RedisProtocol(CreateHubMessageSerializer(new List<IHubProtocol>()));
var decoded = protocol.ReadGroupCommand(testData.Encoded);
@ -80,7 +80,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
public void WriteGroupCommand(string testName)
{
var testData = _groupCommandTestData[testName];
var protocol = new RedisProtocol(Array.Empty<IHubProtocol>());
var protocol = new RedisProtocol(CreateHubMessageSerializer(new List<IHubProtocol>()));
var encoded = protocol.WriteGroupCommand(testData.Decoded);
@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
{
var testData = _invocationTestData[testName];
var hubProtocols = new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") };
var protocol = new RedisProtocol(hubProtocols);
var protocol = new RedisProtocol(CreateHubMessageSerializer(hubProtocols.Cast<IHubProtocol>().ToList()));
var expected = testData.Decoded();
@ -171,7 +171,23 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
public void WriteInvocation(string testName)
{
var testData = _invocationTestData[testName];
var protocol = new RedisProtocol(new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") });
var protocol = new RedisProtocol(CreateHubMessageSerializer(new List<IHubProtocol>() { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }));
// Actual invocation doesn't matter because we're using a dummy hub protocol.
// But the dummy protocol will check that we gave it the test message to make sure everything flows through properly.
var expected = testData.Decoded();
var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds);
Assert.Equal(testData.Encoded, encoded);
}
[Theory]
[MemberData(nameof(InvocationTestData))]
public void WriteInvocationWithHubMessageSerializer(string testName)
{
var testData = _invocationTestData[testName];
var hubMessageSerializer = CreateHubMessageSerializer(new List<IHubProtocol>() { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") });
var protocol = new RedisProtocol(hubMessageSerializer);
// Actual invocation doesn't matter because we're using a dummy hub protocol.
// But the dummy protocol will check that we gave it the test message to make sure everything flows through properly.
@ -198,5 +214,12 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests
Encoded = encoded;
}
}
private DefaultHubMessageSerializer CreateHubMessageSerializer(List<IHubProtocol> protocols)
{
var protocolResolver = new DefaultHubProtocolResolver(protocols, NullLogger<DefaultHubProtocolResolver>.Instance);
return new DefaultHubMessageSerializer(protocolResolver, protocols.ConvertAll(p => p.Name), hubSupportedProtocols: null);
}
}
}