From 3d93e095dbf2297fabf595099341c4cce673f32d Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 22 Oct 2019 12:35:12 -0700 Subject: [PATCH] Add support for hub specific IHubProtocols that don't affect other hubs (#15177) --- .../src/BlazorPack/BlazorPackHubProtocol.cs | 2 + .../NonDefaultHubProtocolAttribute.cs | 13 ++ .../Microbenchmarks/RedisProtocolBenchmark.cs | 12 +- .../server/Core/src/HubOptionsSetup.cs | 5 + .../server/Core/src/HubOptionsSetup`T.cs | 1 + .../server/Core/src/SerializedHubMessage.cs | 12 +- .../server/SignalR/test/AddSignalRTests.cs | 65 +++++++ ...e.SignalR.StackExchangeRedis.netcoreapp.cs | 1 + .../Internal/DefaultHubMessageSerializer.cs | 39 ++++ .../src/Internal/RedisProtocol.cs | 23 +-- .../src/RedisHubLifetimeManager.cs | 20 ++- .../test/DefaultHubMessageSerializerTests.cs | 166 ++++++++++++++++++ .../test/RedisProtocolTests.cs | 39 +++- 13 files changed, 367 insertions(+), 31 deletions(-) create mode 100644 src/Components/Server/src/BlazorPack/NonDefaultHubProtocolAttribute.cs create mode 100644 src/SignalR/server/StackExchangeRedis/src/Internal/DefaultHubMessageSerializer.cs create mode 100644 src/SignalR/server/StackExchangeRedis/test/DefaultHubMessageSerializerTests.cs diff --git a/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs index bc5b7825df..6d726262a4 100644 --- a/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs +++ b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs @@ -12,6 +12,7 @@ using MessagePack; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.Components.Server.BlazorPack @@ -19,6 +20,7 @@ namespace Microsoft.AspNetCore.Components.Server.BlazorPack /// /// Implements the SignalR Hub Protocol using MessagePack with limited type support. /// + [NonDefaultHubProtocol] internal sealed class BlazorPackHubProtocol : IHubProtocol { internal const string ProtocolName = "blazorpack"; diff --git a/src/Components/Server/src/BlazorPack/NonDefaultHubProtocolAttribute.cs b/src/Components/Server/src/BlazorPack/NonDefaultHubProtocolAttribute.cs new file mode 100644 index 0000000000..141d8b3194 --- /dev/null +++ b/src/Components/Server/src/BlazorPack/NonDefaultHubProtocolAttribute.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + // Tells SignalR not to add the IHubProtocol with this attribute to all hubs by default + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)] + internal class NonDefaultHubProtocolAttribute : Attribute + { + } +} diff --git a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs index c87d0e5226..25380dfbe9 100644 --- a/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/RedisProtocolBenchmark.cs @@ -6,8 +6,10 @@ using System.Buffers; using System.Collections.Generic; using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; +using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { @@ -28,10 +30,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - _protocol = new RedisProtocol(new [] { - new DummyProtocol("protocol1"), - new DummyProtocol("protocol2") - }); + var resolver = new DefaultHubProtocolResolver(new List { new DummyProtocol("protocol1"), + new DummyProtocol("protocol2") }, NullLogger.Instance); + + _protocol = new RedisProtocol(new DefaultHubMessageSerializer(resolver, new List() { "protocol1", "protocol2" }, hubSupportedProtocols: null)); _groupCommand = new RedisGroupCommand(id: 42, serverName: "Server", GroupAction.Add, groupName: "group", connectionId: "connection"); @@ -119,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return ids; } - private class DummyProtocol: IHubProtocol + private class DummyProtocol : IHubProtocol { private static readonly byte[] _fixedOutput = new byte[] { 0x68, 0x68, 0x6C, 0x6C, 0x6F }; diff --git a/src/SignalR/server/Core/src/HubOptionsSetup.cs b/src/SignalR/server/Core/src/HubOptionsSetup.cs index 8f1affe718..6b6a08abab 100644 --- a/src/SignalR/server/Core/src/HubOptionsSetup.cs +++ b/src/SignalR/server/Core/src/HubOptionsSetup.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Options; @@ -26,6 +27,10 @@ namespace Microsoft.AspNetCore.SignalR { foreach (var hubProtocol in protocols) { + if (hubProtocol.GetType().CustomAttributes.Where(a => a.AttributeType.FullName == "Microsoft.AspNetCore.SignalR.Internal.NonDefaultHubProtocolAttribute").Any()) + { + continue; + } _defaultProtocols.Add(hubProtocol.Name); } } diff --git a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs index ee1ccdf1da..9f4fb17c0a 100644 --- a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs +++ b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.SignalR public void Configure(HubOptions options) { + // Do a deep copy, otherwise users modifying the HubOptions list would be changing the global options list options.SupportedProtocols = new List(_hubOptions.SupportedProtocols.Count); foreach (var protocol in _hubOptions.SupportedProtocols) { diff --git a/src/SignalR/server/Core/src/SerializedHubMessage.cs b/src/SignalR/server/Core/src/SerializedHubMessage.cs index 99a969a3a9..8cd578b6a0 100644 --- a/src/SignalR/server/Core/src/SerializedHubMessage.cs +++ b/src/SignalR/server/Core/src/SerializedHubMessage.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.SignalR { private SerializedMessage _cachedItem1; private SerializedMessage _cachedItem2; - private IList _cachedItems; + private List _cachedItems; private readonly object _lock = new object(); /// @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR for (var i = 0; i < messages.Count; i++) { var message = messages[i]; - SetCache(message.ProtocolName, message.Serialized); + SetCacheUnsynchronized(message.ProtocolName, message.Serialized); } } @@ -54,7 +54,7 @@ namespace Microsoft.AspNetCore.SignalR { lock (_lock) { - if (!TryGetCached(protocol.Name, out var serialized)) + if (!TryGetCachedUnsynchronized(protocol.Name, out var serialized)) { if (Message == null) { @@ -63,7 +63,7 @@ namespace Microsoft.AspNetCore.SignalR } serialized = protocol.GetMessageBytes(Message); - SetCache(protocol.Name, serialized); + SetCacheUnsynchronized(protocol.Name, serialized); } return serialized; @@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private void SetCache(string protocolName, ReadOnlyMemory serialized) + private void SetCacheUnsynchronized(string protocolName, ReadOnlyMemory serialized) { // We set the fields before moving on to the list, if we need it to hold more than 2 items. // We have to read/write these fields under the lock because the structs might tear and another @@ -132,7 +132,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private bool TryGetCached(string protocolName, out ReadOnlyMemory result) + private bool TryGetCachedUnsynchronized(string protocolName, out ReadOnlyMemory result) { if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal)) { diff --git a/src/SignalR/server/SignalR/test/AddSignalRTests.cs b/src/SignalR/server/SignalR/test/AddSignalRTests.cs index d711f242b5..a8cd5a9342 100644 --- a/src/SignalR/server/SignalR/test/AddSignalRTests.cs +++ b/src/SignalR/server/SignalR/test/AddSignalRTests.cs @@ -2,12 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using Xunit; @@ -148,6 +151,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Null(globalOptions.SupportedProtocols); Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval); } + + [Fact] + public void HubProtocolsWithNonDefaultAttributeNotAddedToSupportedProtocols() + { + var serviceCollection = new ServiceCollection(); + + serviceCollection.AddSignalR().AddHubOptions(options => + { + }); + + serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton()); + serviceCollection.TryAddEnumerable(ServiceDescriptor.Singleton()); + + var serviceProvider = serviceCollection.BuildServiceProvider(); + Assert.Collection(serviceProvider.GetRequiredService>>().Value.SupportedProtocols, + p => + { + Assert.Equal("json", p); + }, + p => + { + Assert.Equal("messagepack", p); + }); + } } public class CustomHub : Hub @@ -276,4 +303,42 @@ namespace Microsoft.AspNetCore.SignalR.Tests throw new System.NotImplementedException(); } } + + [NonDefaultHubProtocol] + internal class CustomHubProtocol : IHubProtocol + { + public string Name => "custom"; + + public int Version => throw new NotImplementedException(); + + public TransferFormat TransferFormat => throw new NotImplementedException(); + + public ReadOnlyMemory GetMessageBytes(HubMessage message) + { + throw new NotImplementedException(); + } + + public bool IsVersionSupported(int version) + { + throw new NotImplementedException(); + } + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + throw new NotImplementedException(); + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + throw new NotImplementedException(); + } + } +} + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + [AttributeUsage(AttributeTargets.Class, AllowMultiple = false, Inherited = true)] + internal class NonDefaultHubProtocolAttribute : Attribute + { + } } diff --git a/src/SignalR/server/StackExchangeRedis/ref/Microsoft.AspNetCore.SignalR.StackExchangeRedis.netcoreapp.cs b/src/SignalR/server/StackExchangeRedis/ref/Microsoft.AspNetCore.SignalR.StackExchangeRedis.netcoreapp.cs index 5610a42318..8de5c13646 100644 --- a/src/SignalR/server/StackExchangeRedis/ref/Microsoft.AspNetCore.SignalR.StackExchangeRedis.netcoreapp.cs +++ b/src/SignalR/server/StackExchangeRedis/ref/Microsoft.AspNetCore.SignalR.StackExchangeRedis.netcoreapp.cs @@ -6,6 +6,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis public partial class RedisHubLifetimeManager : Microsoft.AspNetCore.SignalR.HubLifetimeManager, System.IDisposable where THub : Microsoft.AspNetCore.SignalR.Hub { public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger> logger, Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver) { } + public RedisHubLifetimeManager(Microsoft.Extensions.Logging.ILogger> logger, Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.SignalR.IHubProtocolResolver hubProtocolResolver, Microsoft.Extensions.Options.IOptions globalHubOptions, Microsoft.Extensions.Options.IOptions> hubOptions) { } public override System.Threading.Tasks.Task AddToGroupAsync(string connectionId, string groupName, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public void Dispose() { } [System.Diagnostics.DebuggerStepThroughAttribute] diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/DefaultHubMessageSerializer.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/DefaultHubMessageSerializer.cs new file mode 100644 index 0000000000..7bcd4089e8 --- /dev/null +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/DefaultHubMessageSerializer.cs @@ -0,0 +1,39 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + internal class DefaultHubMessageSerializer + { + private readonly List _hubProtocols = new List(); + + public DefaultHubMessageSerializer(IHubProtocolResolver hubProtocolResolver, IList globalSupportedProtocols, IList hubSupportedProtocols) + { + var supportedProtocols = hubSupportedProtocols ?? globalSupportedProtocols ?? Array.Empty(); + foreach (var protocolName in supportedProtocols) + { + var protocol = hubProtocolResolver.GetProtocol(protocolName, (supportedProtocols as IReadOnlyList) ?? supportedProtocols.ToList()); + if (protocol != null) + { + _hubProtocols.Add(protocol); + } + } + } + + public IReadOnlyList SerializeMessage(HubMessage message) + { + var list = new List(_hubProtocols.Count); + foreach (var protocol in _hubProtocols) + { + list.Add(new SerializedMessage(protocol.Name, protocol.GetMessageBytes(message))); + } + + return list; + } + } +} diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index a1594b0fd3..76b5ed1b6e 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -8,17 +8,18 @@ using System.IO; using System.Runtime.InteropServices; using MessagePack; using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal { internal class RedisProtocol { - private readonly IReadOnlyList _protocols; + private readonly DefaultHubMessageSerializer _messageSerializer; - public RedisProtocol(IReadOnlyList protocols) + public RedisProtocol(DefaultHubMessageSerializer messageSerializer) { - _protocols = protocols; + _messageSerializer = messageSerializer; } // The Redis Protocol: @@ -60,8 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal MessagePackBinary.WriteArrayHeader(writer, 0); } - WriteSerializedHubMessage(writer, - new SerializedHubMessage(new InvocationMessage(methodName, args))); + WriteHubMessage(writer, new InvocationMessage(methodName, args)); return writer.ToArray(); } finally @@ -163,19 +163,20 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal return MessagePackUtil.ReadInt32(ref data); } - private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message) + private void WriteHubMessage(Stream stream, HubMessage message) { // Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str') // and the values are the serialized blob (as a MessagePack 'bin'). - MessagePackBinary.WriteMapHeader(stream, _protocols.Count); + var serializedHubMessages = _messageSerializer.SerializeMessage(message); - foreach (var protocol in _protocols) + MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count); + + foreach (var serializedMessage in serializedHubMessages) { - MessagePackBinary.WriteString(stream, protocol.Name); + MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName); - var serialized = message.GetSerializedMessage(protocol); - var isArray = MemoryMarshal.TryGetArray(serialized, out var array); + var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array); Debug.Assert(isArray); MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count); } diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index 17b462bfd0..df81ed61b3 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -8,6 +8,7 @@ using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; using Microsoft.Extensions.Logging; @@ -36,12 +37,29 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis public RedisHubLifetimeManager(ILogger> logger, IOptions options, IHubProtocolResolver hubProtocolResolver) + : this(logger, options, hubProtocolResolver, globalHubOptions: null, hubOptions: null) + { + } + + public RedisHubLifetimeManager(ILogger> logger, + IOptions options, + IHubProtocolResolver hubProtocolResolver, + IOptions globalHubOptions, + IOptions> hubOptions) { _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); _channels = new RedisChannels(typeof(THub).FullName); - _protocol = new RedisProtocol(hubProtocolResolver.AllProtocols); + if (globalHubOptions != null && hubOptions != null) + { + _protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols)); + } + else + { + var supportedProtocols = hubProtocolResolver.AllProtocols.Select(p => p.Name).ToList(); + _protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, supportedProtocols, null)); + } RedisLog.ConnectingToEndpoints(_logger, options.Value.Configuration.EndPoints, _serverName); _ = EnsureRedisServerConnection(); diff --git a/src/SignalR/server/StackExchangeRedis/test/DefaultHubMessageSerializerTests.cs b/src/SignalR/server/StackExchangeRedis/test/DefaultHubMessageSerializerTests.cs new file mode 100644 index 0000000000..0c1ac1f62d --- /dev/null +++ b/src/SignalR/server/StackExchangeRedis/test/DefaultHubMessageSerializerTests.cs @@ -0,0 +1,166 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal +{ + public class DefaultHubMessageSerializerTests + { + [Theory] + [MemberData(nameof(InvocationTestData))] + public void SerializeMessages(string testName) + { + var testData = _invocationTestData[testName]; + + var resolver = CreateHubProtocolResolver(new List { new MessagePackHubProtocol(), new JsonHubProtocol() }); + var protocolNames = testData.SupportedHubProtocols.ConvertAll(p => p.Name); + var serializer = new DefaultHubMessageSerializer(resolver, protocolNames, hubSupportedProtocols: null); + var serializedHubMessage = serializer.SerializeMessage(_testMessage); + + var allBytes = new List(); + Assert.Equal(testData.SerializedCount, serializedHubMessage.Count); + foreach (var message in serializedHubMessage) + { + allBytes.AddRange(message.Serialized.ToArray()); + } + + Assert.Equal(testData.Encoded, allBytes); + } + + [Fact] + public void GlobalSupportedProtocolsOverriddenByHubSupportedProtocols() + { + var testData = _invocationTestData["Single supported protocol"]; + + var resolver = CreateHubProtocolResolver(new List { new MessagePackHubProtocol(), new JsonHubProtocol() }); + + var serializer = new DefaultHubMessageSerializer(resolver, new List() { "json" }, new List() { "messagepack" }); + var serializedHubMessage = serializer.SerializeMessage(_testMessage); + + Assert.Equal(1, serializedHubMessage.Count); + + Assert.Equal(new List() { 0x0D, + 0x96, + 0x01, + 0x80, + 0xC0, + 0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', + 0x90, + 0x90 }, + serializedHubMessage[0].Serialized.ToArray()); + } + + private IHubProtocolResolver CreateHubProtocolResolver(List hubProtocols) + { + return new DefaultHubProtocolResolver(hubProtocols, NullLogger.Instance); + } + + private static Dictionary _invocationTestData = new[] + { + new ProtocolTestData( + "Single supported protocol", + new List() { new MessagePackHubProtocol() }, + 1, + 0x0D, + 0x96, + 0x01, + 0x80, + 0xC0, + 0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', + 0x90, + 0x90), + new ProtocolTestData( + "Multiple supported protocols", + new List() { new MessagePackHubProtocol(), new JsonHubProtocol() }, + 2, + 0x0D, + 0x96, + 0x01, + 0x80, + 0xC0, + 0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', + 0x90, + 0x90, + (byte)'{', (byte)'"', (byte)'t', (byte)'y', (byte)'p', (byte)'e', (byte)'"', (byte)':', (byte)'1', + (byte)',',(byte)'"', (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', (byte)'"', (byte)':', + (byte)'"', (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', (byte)'"', + (byte)',', (byte)'"', (byte)'a', (byte)'r', (byte)'g', (byte)'u', (byte)'m', (byte)'e', (byte)'n', (byte)'t', (byte)'s', (byte)'"', + (byte)':', (byte)'[', (byte)']', (byte)'}', 0x1e), + new ProtocolTestData( + "Multiple protocols, one not in hub protocol resolver", + new List() { new MessagePackHubProtocol(), new TestHubProtocol() }, + 1, + 0x0D, + 0x96, + 0x01, + 0x80, + 0xC0, + 0xA6, (byte)'t', (byte)'a', (byte)'r', (byte)'g', (byte)'e', (byte)'t', + 0x90, + 0x90), + new ProtocolTestData( + "No protocols", + new List(), + 0) + }.ToDictionary(t => t.Name); + + public static IEnumerable InvocationTestData = _invocationTestData.Keys.Select(k => new object[] { k }); + + public class ProtocolTestData + { + public string Name { get; } + public byte[] Encoded { get; } + public int SerializedCount { get; } + public List SupportedHubProtocols { get; } + + public ProtocolTestData(string name, List supportedHubProtocols, int serializedCount, params byte[] encoded) + { + Name = name; + Encoded = encoded; + SerializedCount = serializedCount; + SupportedHubProtocols = supportedHubProtocols; + } + } + + // The actual invocation message doesn't matter + private static InvocationMessage _testMessage = new InvocationMessage("target", Array.Empty()); + + internal class TestHubProtocol : IHubProtocol + { + public string Name => "test"; + + public int Version => throw new NotImplementedException(); + + public TransferFormat TransferFormat => throw new NotImplementedException(); + + public ReadOnlyMemory GetMessageBytes(HubMessage message) + { + throw new NotImplementedException(); + } + + public bool IsVersionSupported(int version) + { + throw new NotImplementedException(); + } + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + throw new NotImplementedException(); + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + throw new NotImplementedException(); + } + } + } +} diff --git a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs index b8b3a0bca9..1afca99c94 100644 --- a/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs +++ b/src/SignalR/server/StackExchangeRedis/test/RedisProtocolTests.cs @@ -2,14 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.Linq; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal; using Microsoft.AspNetCore.SignalR.Tests; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests public void ParseAck(string testName) { var testData = _ackTestData[testName]; - var protocol = new RedisProtocol(Array.Empty()); + var protocol = new RedisProtocol(CreateHubMessageSerializer(new List())); var decoded = protocol.ReadAck(testData.Encoded); @@ -44,7 +44,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests public void WriteAck(string testName) { var testData = _ackTestData[testName]; - var protocol = new RedisProtocol(Array.Empty()); + var protocol = new RedisProtocol(CreateHubMessageSerializer(new List())); var encoded = protocol.WriteAck(testData.Decoded); @@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests public void ParseGroupCommand(string testName) { var testData = _groupCommandTestData[testName]; - var protocol = new RedisProtocol(Array.Empty()); + var protocol = new RedisProtocol(CreateHubMessageSerializer(new List())); var decoded = protocol.ReadGroupCommand(testData.Encoded); @@ -80,7 +80,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests public void WriteGroupCommand(string testName) { var testData = _groupCommandTestData[testName]; - var protocol = new RedisProtocol(Array.Empty()); + var protocol = new RedisProtocol(CreateHubMessageSerializer(new List())); var encoded = protocol.WriteGroupCommand(testData.Decoded); @@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests { var testData = _invocationTestData[testName]; var hubProtocols = new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }; - var protocol = new RedisProtocol(hubProtocols); + var protocol = new RedisProtocol(CreateHubMessageSerializer(hubProtocols.Cast().ToList())); var expected = testData.Decoded(); @@ -171,7 +171,23 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests public void WriteInvocation(string testName) { var testData = _invocationTestData[testName]; - var protocol = new RedisProtocol(new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }); + var protocol = new RedisProtocol(CreateHubMessageSerializer(new List() { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") })); + + // Actual invocation doesn't matter because we're using a dummy hub protocol. + // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. + var expected = testData.Decoded(); + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, expected.ExcludedConnectionIds); + + Assert.Equal(testData.Encoded, encoded); + } + + [Theory] + [MemberData(nameof(InvocationTestData))] + public void WriteInvocationWithHubMessageSerializer(string testName) + { + var testData = _invocationTestData[testName]; + var hubMessageSerializer = CreateHubMessageSerializer(new List() { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }); + var protocol = new RedisProtocol(hubMessageSerializer); // Actual invocation doesn't matter because we're using a dummy hub protocol. // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. @@ -198,5 +214,12 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Tests Encoded = encoded; } } + + private DefaultHubMessageSerializer CreateHubMessageSerializer(List protocols) + { + var protocolResolver = new DefaultHubProtocolResolver(protocols, NullLogger.Instance); + + return new DefaultHubMessageSerializer(protocolResolver, protocols.ConvertAll(p => p.Name), hubSupportedProtocols: null); + } } }