diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs index dcff449701..e13f8b68c3 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs @@ -3,7 +3,6 @@ using System; using System.Collections.Generic; -using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR.Internal @@ -20,8 +19,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal public HubMessage Message { get; } - private SerializedHubMessage() + public SerializedHubMessage(IReadOnlyList messages) { + for (var i = 0; i < messages.Count; i++) + { + var message = messages[i]; + SetCache(message.ProtocolName, message.Serialized); + } } public SerializedHubMessage(HubMessage message) @@ -46,46 +50,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal return serialized; } - public static void WriteAllSerializedVersions(BinaryWriter writer, HubMessage message, IReadOnlyList protocols) - { - // The serialization format is based on BinaryWriter - // * 1 byte number of protocols - // * For each protocol: - // * Length-prefixed string using 7-bit variable length encoding (length depends on BinaryWriter's encoding) - // * 4 byte length of the buffer - // * N byte buffer - - if (protocols.Count > byte.MaxValue) - { - throw new InvalidOperationException($"Can't serialize cache containing more than {byte.MaxValue} entries"); - } - - writer.Write((byte)protocols.Count); - foreach (var protocol in protocols) - { - writer.Write(protocol.Name); - - var buffer = protocol.GetMessageBytes(message); - writer.Write(buffer.Length); - writer.Write(buffer); - } - } - - public static SerializedHubMessage ReadAllSerializedVersions(BinaryReader reader) - { - var cache = new SerializedHubMessage(); - var count = reader.ReadByte(); - for (var i = 0; i < count; i++) - { - var protocol = reader.ReadString(); - var length = reader.ReadInt32(); - var serialized = reader.ReadBytes(length); - cache.SetCache(protocol, serialized); - } - - return cache; - } - private void SetCache(string protocolName, byte[] serialized) { if (_cachedItem1.ProtocolName == null) @@ -145,17 +109,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal result = default; return false; } - - private readonly struct SerializedMessage - { - public string ProtocolName { get; } - public byte[] Serialized { get; } - - public SerializedMessage(string protocolName, byte[] serialized) - { - ProtocolName = protocolName; - Serialized = serialized; - } - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedMessage.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedMessage.cs new file mode 100644 index 0000000000..4d2d80fda0 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedMessage.cs @@ -0,0 +1,14 @@ +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public readonly struct SerializedMessage + { + public string ProtocolName { get; } + public byte[] Serialized { get; } + + public SerializedMessage(string protocolName, byte[] serialized) + { + ProtocolName = protocolName; + Serialized = serialized; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MsgPackUtil.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MsgPackUtil.cs new file mode 100644 index 0000000000..a1254b58aa --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/MsgPackUtil.cs @@ -0,0 +1,65 @@ +using System; +using System.Diagnostics; +using System.Runtime.InteropServices; +using MessagePack; + +namespace Microsoft.AspNetCore.SignalR.Redis.Internal +{ + internal static class MsgPackUtil + { + public static int ReadArrayHeader(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static int ReadMapHeader(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static string ReadString(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static byte[] ReadBytes(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static int ReadInt32(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + public static byte ReadByte(ref ReadOnlyMemory data) + { + var arr = GetArray(data); + var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize); + data = data.Slice(readSize); + return val; + } + + private static ArraySegment GetArray(ReadOnlyMemory data) + { + var isArray = MemoryMarshal.TryGetArray(data, out var array); + Debug.Assert(isArray); + return array; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs index 1728fc39a4..babd1835fe 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisProtocol.cs @@ -1,18 +1,22 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; -using System.Text; +using System.Runtime.InteropServices; +using MessagePack; +using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using StackExchange.Redis; namespace Microsoft.AspNetCore.SignalR.Redis.Internal { public class RedisProtocol { private readonly IReadOnlyList _protocols; - private static readonly Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); public RedisProtocol(IReadOnlyList protocols) { @@ -35,136 +39,172 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal public byte[] WriteInvocation(string methodName, object[] args, IReadOnlyList excludedIds) { - // Redis Invocation Format: - // * Variable length integer: Number of excluded Ids - // * For each excluded Id: - // * Length prefixed string: ID - // * SerializedHubMessage encoded by the format described by that type. + // Written as a MessagePack 'arr' containing at least these items: + // * A MessagePack 'arr' of 'str's representing the excluded ids + // * [The output of WriteSerializedHubMessage, which is an 'arr'] + // Any additional items are discarded. - using (var stream = new MemoryStream()) - using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + var writer = MemoryBufferWriter.Get(); + + try { - if (excludedIds != null) + MessagePackBinary.WriteArrayHeader(writer, 2); + if (excludedIds != null && excludedIds.Count > 0) { - writer.WriteVarInt(excludedIds.Count); + MessagePackBinary.WriteArrayHeader(writer, excludedIds.Count); foreach (var id in excludedIds) { - writer.Write(id); + MessagePackBinary.WriteString(writer, id); } } else { - writer.WriteVarInt(0); + MessagePackBinary.WriteArrayHeader(writer, 0); } - SerializedHubMessage.WriteAllSerializedVersions(writer, new InvocationMessage(methodName, null, args), _protocols); - return stream.ToArray(); + WriteSerializedHubMessage(writer, + new SerializedHubMessage(new InvocationMessage(methodName, null, args))); + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); } } public byte[] WriteGroupCommand(RedisGroupCommand command) { - // Group Command Format: - // * Variable length integer: Id - // * Length prefixed string: ServerName - // * 1 byte: Action - // * Length prefixed string: GroupName - // * Length prefixed string: ConnectionId + // Written as a MessagePack 'arr' containing at least these items: + // * An 'int': the Id of the command + // * A 'str': The server name + // * An 'int': The action (likely less than 0x7F and thus a single-byte fixnum) + // * A 'str': The group name + // * A 'str': The connection Id + // Any additional items are discarded. - using (var stream = new MemoryStream()) - using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + var writer = MemoryBufferWriter.Get(); + try { - writer.WriteVarInt(command.Id); - writer.Write(command.ServerName); - writer.Write((byte)command.Action); - writer.Write(command.GroupName); - writer.Write(command.ConnectionId); - return stream.ToArray(); + MessagePackBinary.WriteArrayHeader(writer, 5); + MessagePackBinary.WriteInt32(writer, command.Id); + MessagePackBinary.WriteString(writer, command.ServerName); + MessagePackBinary.WriteByte(writer, (byte)command.Action); + MessagePackBinary.WriteString(writer, command.GroupName); + MessagePackBinary.WriteString(writer, command.ConnectionId); + + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); } } public byte[] WriteAck(int messageId) { - // Acknowledgement Format: - // * Variable length integer: Id + // Written as a MessagePack 'arr' containing at least these items: + // * An 'int': The Id of the command being acknowledged. + // Any additional items are discarded. - using (var stream = new MemoryStream()) - using (var writer = new BinaryWriterWithVarInt(stream, _utf8NoBom)) + var writer = MemoryBufferWriter.Get(); + try { - writer.WriteVarInt(messageId); - return stream.ToArray(); + MessagePackBinary.WriteArrayHeader(writer, 1); + MessagePackBinary.WriteInt32(writer, messageId); + + return writer.ToArray(); + } + finally + { + MemoryBufferWriter.Return(writer); } } - public RedisInvocation ReadInvocation(byte[] data) + public RedisInvocation ReadInvocation(ReadOnlyMemory data) { - // See WriteInvocation for format. + // See WriteInvocation for the format + ValidateArraySize(ref data, 2, "Invocation"); - using (var stream = new MemoryStream(data)) - using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) + // Read excluded Ids + IReadOnlyList excludedIds = null; + var idCount = MsgPackUtil.ReadArrayHeader(ref data); + if (idCount > 0) { - IReadOnlyList excludedIds = null; - - var idCount = reader.ReadVarInt(); - if (idCount > 0) + var ids = new string[idCount]; + for (var i = 0; i < idCount; i++) { - var ids = new string[idCount]; - for (var i = 0; i < idCount; i++) - { - ids[i] = reader.ReadString(); - } - - excludedIds = ids; + ids[i] = MsgPackUtil.ReadString(ref data); } - var message = SerializedHubMessage.ReadAllSerializedVersions(reader); - return new RedisInvocation(message, excludedIds); + excludedIds = ids; } + + // Read payload + var message = ReadSerializedHubMessage(ref data); + return new RedisInvocation(message, excludedIds); } - public RedisGroupCommand ReadGroupCommand(byte[] data) + public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory data) { // See WriteGroupCommand for format. - using (var stream = new MemoryStream(data)) - using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) - { - var id = reader.ReadVarInt(); - var serverName = reader.ReadString(); - var action = (GroupAction)reader.ReadByte(); - var groupName = reader.ReadString(); - var connectionId = reader.ReadString(); + ValidateArraySize(ref data, 5, "GroupCommand"); - return new RedisGroupCommand(id, serverName, action, groupName, connectionId); - } + var id = MsgPackUtil.ReadInt32(ref data); + var serverName = MsgPackUtil.ReadString(ref data); + var action = (GroupAction)MsgPackUtil.ReadByte(ref data); + var groupName = MsgPackUtil.ReadString(ref data); + var connectionId = MsgPackUtil.ReadString(ref data); + + return new RedisGroupCommand(id, serverName, action, groupName, connectionId); } - public int ReadAck(byte[] data) + public int ReadAck(ReadOnlyMemory data) { // See WriteAck for format - using (var stream = new MemoryStream(data)) - using (var reader = new BinaryReaderWithVarInt(stream, _utf8NoBom)) + ValidateArraySize(ref data, 1, "Ack"); + return MsgPackUtil.ReadInt32(ref data); + } + + private void WriteSerializedHubMessage(Stream stream, SerializedHubMessage message) + { + // Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str') + // and the values are the serialized blob (as a MessagePack 'bin'). + + MessagePackBinary.WriteMapHeader(stream, _protocols.Count); + + foreach (var protocol in _protocols) { - return reader.ReadVarInt(); + MessagePackBinary.WriteString(stream, protocol.Name); + + var serialized = message.GetSerializedMessage(protocol); + var isArray = MemoryMarshal.TryGetArray(serialized, out var array); + Debug.Assert(isArray); + MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count); } } - // Kinda cheaty way to get access to write the 7-bit varint format directly - private class BinaryWriterWithVarInt : BinaryWriter + public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory data) { - public BinaryWriterWithVarInt(Stream output, Encoding encoding) : base(output, encoding) + var count = MsgPackUtil.ReadMapHeader(ref data); + var serializations = new SerializedMessage[count]; + for (var i = 0; i < count; i++) { + var protocol = MsgPackUtil.ReadString(ref data); + var serialized = MsgPackUtil.ReadBytes(ref data); + serializations[i] = new SerializedMessage(protocol, serialized); } - public void WriteVarInt(int value) => Write7BitEncodedInt(value); + return new SerializedHubMessage(serializations); } - private class BinaryReaderWithVarInt : BinaryReader + private static void ValidateArraySize(ref ReadOnlyMemory data, int expectedLength, string messageType) { - public BinaryReaderWithVarInt(Stream input, Encoding encoding) : base(input, encoding) - { - } + var length = MsgPackUtil.ReadArrayHeader(ref data); - public int ReadVarInt() => Read7BitEncodedInt(); + if (length < expectedLength) + { + throw new InvalidDataException($"Insufficient items in {messageType} array."); + } } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj index e8b3c9f741..8862770b75 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Microsoft.AspNetCore.SignalR.Redis.csproj @@ -7,11 +7,13 @@ + + diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 807bdb55eb..15e438b154 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -383,7 +383,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { RedisLog.ReceivedFromChannel(_logger, _channels.All); - var invocation = _protocol.ReadInvocation(data); + var invocation = _protocol.ReadInvocation((byte[])data); var tasks = new List(_connections.Count); @@ -410,7 +410,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { try { - var groupMessage = _protocol.ReadGroupCommand(data); + var groupMessage = _protocol.ReadGroupCommand((byte[])data); var connection = _connections[groupMessage.ConnectionId]; if (connection == null) @@ -444,7 +444,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // Create server specific channel in order to send an ack to a single server _bus.Subscribe(_channels.Ack(_serverName), (c, data) => { - var ackId = _protocol.ReadAck(data); + var ackId = _protocol.ReadAck((byte[])data); _ackHandler.TriggerAck(ackId); }); @@ -458,7 +458,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis RedisLog.Subscribing(_logger, connectionChannel); return _bus.SubscribeAsync(connectionChannel, async (c, data) => { - var invocation = _protocol.ReadInvocation(data); + var invocation = _protocol.ReadInvocation((byte[])data); await connection.WriteAsync(invocation.Message); }); } @@ -471,7 +471,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // TODO: Look at optimizing (looping over connections checking for Name) return _bus.SubscribeAsync(userChannel, async (c, data) => { - var invocation = _protocol.ReadInvocation(data); + var invocation = _protocol.ReadInvocation((byte[])data); await connection.WriteAsync(invocation.Message); }); } @@ -483,7 +483,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { try { - var invocation = _protocol.ReadInvocation(data); + var invocation = _protocol.ReadInvocation((byte[])data); var tasks = new List(); foreach (var groupConnection in group.Connections) diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs new file mode 100644 index 0000000000..3f6459239e --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs @@ -0,0 +1,226 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Redis.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Redis.Tests +{ + public class RedisProtocolTests + { + private static Dictionary> _ackTestData = new[] + { + CreateTestData("Zero", 0, 0x91, 0x00), + CreateTestData("Fixnum", 42, 0x91, 0x2A), + CreateTestData("Uint8", 180, 0x91, 0xCC, 0xB4), + CreateTestData("Uint16", 384, 0x91, 0xCD, 0x01, 0x80), + CreateTestData("Uint32", 70_000, 0x91, 0xCE, 0x00, 0x01, 0x11, 0x70), + }.ToDictionary(t => t.Name); + + public static IEnumerable AckTestData = _ackTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(AckTestData))] + public void ParseAck(string testName) + { + var testData = _ackTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var decoded = protocol.ReadAck(testData.Encoded); + + Assert.Equal(testData.Decoded, decoded); + } + + [Theory] + [MemberData(nameof(AckTestData))] + public void WriteAck(string testName) + { + var testData = _ackTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var encoded = protocol.WriteAck(testData.Decoded); + + Assert.Equal(testData.Encoded, encoded); + } + + private static Dictionary> _groupCommandTestData = new[] + { + CreateTestData("GroupAdd", new RedisGroupCommand(42, "S", GroupAction.Add, "G", "C" ), 0x95, 0x2A, 0xA1, (byte)'S', 0x01, 0xA1, (byte)'G', 0xA1, (byte)'C'), + CreateTestData("GroupRemove", new RedisGroupCommand(42, "S", GroupAction.Remove, "G", "C" ), 0x95, 0x2A, 0xA1, (byte)'S', 0x02, 0xA1, (byte)'G', 0xA1, (byte)'C'), + }.ToDictionary(t => t.Name); + + public static IEnumerable GroupCommandTestData = _groupCommandTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(GroupCommandTestData))] + public void ParseGroupCommand(string testName) + { + var testData = _groupCommandTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var decoded = protocol.ReadGroupCommand(testData.Encoded); + + Assert.Equal(testData.Decoded.Id, decoded.Id); + Assert.Equal(testData.Decoded.ServerName, decoded.ServerName); + Assert.Equal(testData.Decoded.Action, decoded.Action); + Assert.Equal(testData.Decoded.GroupName, decoded.GroupName); + Assert.Equal(testData.Decoded.ConnectionId, decoded.ConnectionId); + } + + [Theory] + [MemberData(nameof(GroupCommandTestData))] + public void WriteGroupCommand(string testName) + { + var testData = _groupCommandTestData[testName]; + var protocol = new RedisProtocol(Array.Empty()); + + var encoded = protocol.WriteGroupCommand(testData.Decoded); + + Assert.Equal(testData.Encoded, encoded); + } + + // The actual invocation message doesn't matter + private static InvocationMessage _testMessage = new InvocationMessage("target", null, Array.Empty()); + private static Dictionary> _invocationTestData = new[] + { + CreateTestData( + "NoExcludedIds", + new RedisInvocation(new SerializedHubMessage(_testMessage), null), + 0x92, + 0x90, + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + CreateTestData( + "OneExcludedId", + new RedisInvocation(new SerializedHubMessage(_testMessage), new [] { "a" }), + 0x92, + 0x91, + 0xA1, (byte)'a', + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + CreateTestData( + "ManyExcludedIds", + new RedisInvocation(new SerializedHubMessage(_testMessage), new [] { "a", "b", "c", "d", "e", "f" }), + 0x92, + 0x96, + 0xA1, (byte)'a', + 0xA1, (byte)'b', + 0xA1, (byte)'c', + 0xA1, (byte)'d', + 0xA1, (byte)'e', + 0xA1, (byte)'f', + 0x82, + 0xA2, (byte)'p', (byte)'1', + 0xC4, 0x01, 0x2A, + 0xA2, (byte)'p', (byte)'2', + 0xC4, 0x01, 0x2A), + }.ToDictionary(t => t.Name); + + public static IEnumerable InvocationTestData = _invocationTestData.Keys.Select(k => new object[] { k }); + + [Theory] + [MemberData(nameof(InvocationTestData))] + public void ParseInvocation(string testName) + { + var testData = _invocationTestData[testName]; + var hubProtocols = new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }; + var protocol = new RedisProtocol(hubProtocols); + + var decoded = protocol.ReadInvocation(testData.Encoded); + + Assert.Equal(testData.Decoded.ExcludedIds, decoded.ExcludedIds); + + // Verify the deserialized object has the necessary serialized forms + foreach (var hubProtocol in hubProtocols) + { + Assert.Equal( + testData.Decoded.Message.GetSerializedMessage(hubProtocol).ToArray(), + decoded.Message.GetSerializedMessage(hubProtocol).ToArray()); + Assert.Equal(1, hubProtocol.SerializationCount); + } + } + + [Theory] + [MemberData(nameof(InvocationTestData))] + public void WriteInvocation(string testName) + { + var testData = _invocationTestData[testName]; + var protocol = new RedisProtocol(new[] { new DummyHubProtocol("p1"), new DummyHubProtocol("p2") }); + + // Actual invocation doesn't matter because we're using a dummy hub protocol. + // But the dummy protocol will check that we gave it the test message to make sure everything flows through properly. + var encoded = protocol.WriteInvocation(_testMessage.Target, _testMessage.Arguments, testData.Decoded.ExcludedIds); + + Assert.Equal(testData.Encoded, encoded); + } + + // Create ProtocolTestData using the Power of Type Inference(TM). + private static ProtocolTestData CreateTestData(string name, T decoded, params byte[] encoded) + => new ProtocolTestData(name, decoded, encoded); + + public class ProtocolTestData + { + public string Name { get; } + public T Decoded { get; } + public byte[] Encoded { get; } + + public ProtocolTestData(string name, T decoded, byte[] encoded) + { + Name = name; + Decoded = decoded; + Encoded = encoded; + } + } + + public class DummyHubProtocol : IHubProtocol + { + public int SerializationCount { get; private set; } + + public string Name { get; } + public int Version => 1; + public TransferFormat TransferFormat => TransferFormat.Text; + + public DummyHubProtocol(string name) + { + Name = name; + } + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + throw new NotSupportedException(); + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + output.Write(GetMessageBytes(message)); + } + + public byte[] GetMessageBytes(HubMessage message) + { + SerializationCount += 1; + + // Assert that we got the test message + var invocation = Assert.IsType(message); + Assert.Same(_testMessage.Target, invocation.Target); + Assert.Same(_testMessage.Arguments, invocation.Arguments); + + return new byte[] { 0x2A }; + } + + public bool IsVersionSupported(int version) + { + throw new NotSupportedException(); + } + } + } +}