diff --git a/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs
index 6d726262a4..3fca949e08 100644
--- a/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs
+++ b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocol.cs
@@ -3,14 +3,7 @@
using System;
using System.Buffers;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.IO;
-using System.Runtime.CompilerServices;
-using System.Runtime.ExceptionServices;
-using MessagePack;
using Microsoft.AspNetCore.Connections;
-using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
@@ -24,12 +17,10 @@ namespace Microsoft.AspNetCore.Components.Server.BlazorPack
internal sealed class BlazorPackHubProtocol : IHubProtocol
{
internal const string ProtocolName = "blazorpack";
- private const int ErrorResult = 1;
- private const int VoidResult = 2;
- private const int NonVoidResult = 3;
-
private static readonly int ProtocolVersion = 1;
+ private readonly BlazorPackHubProtocolWorker _worker = new BlazorPackHubProtocolWorker();
+
///
public string Name => ProtocolName;
@@ -47,609 +38,14 @@ namespace Microsoft.AspNetCore.Components.Server.BlazorPack
///
public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message)
- {
- if (!BinaryMessageParser.TryParseMessage(ref input, out var payload))
- {
- message = null;
- return false;
- }
-
- var reader = new MessagePackReader(payload);
-
- var itemCount = reader.ReadArrayHeader();
- var messageType = ReadInt32(ref reader, "messageType");
-
- switch (messageType)
- {
- case HubProtocolConstants.InvocationMessageType:
- message = CreateInvocationMessage(ref reader, binder, itemCount);
- return true;
- case HubProtocolConstants.StreamInvocationMessageType:
- message = CreateStreamInvocationMessage(ref reader, binder, itemCount);
- return true;
- case HubProtocolConstants.StreamItemMessageType:
- message = CreateStreamItemMessage(ref reader, binder);
- return true;
- case HubProtocolConstants.CompletionMessageType:
- message = CreateCompletionMessage(ref reader, binder);
- return true;
- case HubProtocolConstants.CancelInvocationMessageType:
- message = CreateCancelInvocationMessage(ref reader);
- return true;
- case HubProtocolConstants.PingMessageType:
- message = PingMessage.Instance;
- return true;
- case HubProtocolConstants.CloseMessageType:
- message = CreateCloseMessage(ref reader, itemCount);
- return true;
- default:
- // Future protocol changes can add message types, old clients can ignore them
- message = null;
- return false;
- }
- }
-
- private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadString(ref reader, "invocationId");
-
- // For MsgPack, we represent an empty invocation ID as an empty string,
- // so we need to normalize that to "null", which is what indicates a non-blocking invocation.
- if (string.IsNullOrEmpty(invocationId))
- {
- invocationId = null;
- }
-
- var target = ReadString(ref reader, "target");
-
- object[] arguments;
- try
- {
- var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(ref reader, parameterTypes);
- }
- catch (Exception ex)
- {
- return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
- }
-
- string[] streams = null;
- // Previous clients will send 5 items, so we check if they sent a stream array or not
- if (itemCount > 5)
- {
- streams = ReadStreamIds(ref reader);
- }
-
- return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
- }
-
- private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadString(ref reader, "invocationId");
- var target = ReadString(ref reader, "target"); ;
-
- object[] arguments;
- try
- {
- var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(ref reader, parameterTypes);
- }
- catch (Exception ex)
- {
- return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
- }
-
- string[] streams = null;
- // Previous clients will send 5 items, so we check if they sent a stream array or not
- if (itemCount > 5)
- {
- streams = ReadStreamIds(ref reader);
- }
-
- return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
- }
-
- private static StreamItemMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadString(ref reader, "invocationId");
-
- var itemType = binder.GetStreamItemType(invocationId);
- var value = DeserializeObject(ref reader, itemType, "item");
- return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
- }
-
- private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadString(ref reader, "invocationId");
- var resultKind = ReadInt32(ref reader, "resultKind");
-
- string error = null;
- object result = null;
- var hasResult = false;
-
- switch (resultKind)
- {
- case ErrorResult:
- error = ReadString(ref reader, "error");
- break;
- case NonVoidResult:
- var itemType = binder.GetReturnType(invocationId);
- result = DeserializeObject(ref reader, itemType, "argument");
- hasResult = true;
- break;
- case VoidResult:
- hasResult = false;
- break;
- default:
- throw new InvalidDataException("Invalid invocation result kind.");
- }
-
- return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
- }
-
- private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadString(ref reader, "invocationId");
- return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
- }
-
- private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
- {
- var error = ReadString(ref reader, "error");
- var allowReconnect = false;
-
- if (itemCount > 2)
- {
- allowReconnect = ReadBoolean(ref reader, "allowReconnect");
- }
-
- // An empty string is still an error
- if (error == null && !allowReconnect)
- {
- return CloseMessage.Empty;
- }
-
- return new CloseMessage(error, allowReconnect);
- }
-
- private static Dictionary ReadHeaders(ref MessagePackReader reader)
- {
- var headerCount = ReadMapHeader(ref reader, "headers");
- if (headerCount == 0)
- {
- return null;
- }
-
- var headers = new Dictionary(StringComparer.Ordinal);
- for (var i = 0; i < headerCount; i++)
- {
- var key = ReadString(ref reader, $"headers[{i}].Key");
- var value = ReadString(ref reader, $"headers[{i}].Value");
-
- headers[key] = value;
- }
-
- return headers;
- }
-
- private static string[] ReadStreamIds(ref MessagePackReader reader)
- {
- var streamIdCount = ReadArrayHeader(ref reader, "streamIds");
-
- if (streamIdCount == 0)
- {
- return null;
- }
-
- var streams = new List();
- for (var i = 0; i < streamIdCount; i++)
- {
- streams.Add(reader.ReadString());
- }
-
- return streams.ToArray();
- }
-
- private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes)
- {
- var argumentCount = ReadArrayHeader(ref reader, "arguments");
-
- if (parameterTypes.Count != argumentCount)
- {
- throw new InvalidDataException(
- $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}.");
- }
-
- try
- {
- var arguments = new object[argumentCount];
- for (var i = 0; i < argumentCount; i++)
- {
- arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument");
- }
-
- return arguments;
- }
- catch (Exception ex)
- {
- throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex);
- }
- }
+ => _worker.TryParseMessage(ref input, binder, out message);
///
public void WriteMessage(HubMessage message, IBufferWriter output)
- {
- var writer = MemoryBufferWriter.Get();
-
- try
- {
- // Write message to a buffer so we can get its length
- WriteMessageCore(message, writer);
-
- // Write length then message to output
- BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output);
- writer.CopyTo(output);
- }
- finally
- {
- MemoryBufferWriter.Return(writer);
- }
- }
+ => _worker.WriteMessage(message, output);
/////
public ReadOnlyMemory GetMessageBytes(HubMessage message)
- {
- using var writer = new ArrayBufferWriter();
-
- // Write message to a buffer so we can get its length
- WriteMessageCore(message, writer);
-
- var memory = writer.WrittenMemory;
-
- var dataLength = memory.Length;
- var prefixLength = BinaryMessageFormatter.LengthPrefixLength(dataLength);
-
- var array = new byte[dataLength + prefixLength];
- var span = array.AsSpan();
-
- // Write length then message to output
- var written = BinaryMessageFormatter.WriteLengthPrefix(dataLength, span);
- Debug.Assert(written == prefixLength);
-
- memory.Span.CopyTo(span.Slice(prefixLength));
-
- return array;
- }
-
- private void WriteMessageCore(HubMessage message, IBufferWriter bufferWriter)
- {
- var writer = new MessagePackWriter(bufferWriter);
-
- switch (message)
- {
- case InvocationMessage invocationMessage:
- WriteInvocationMessage(invocationMessage, ref writer);
- break;
- case StreamInvocationMessage streamInvocationMessage:
- WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
- break;
- case StreamItemMessage streamItemMessage:
- WriteStreamingItemMessage(streamItemMessage, ref writer);
- break;
- case CompletionMessage completionMessage:
- WriteCompletionMessage(completionMessage, ref writer);
- break;
- case CancelInvocationMessage cancelInvocationMessage:
- WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
- break;
- case PingMessage pingMessage:
- WritePingMessage(pingMessage, ref writer);
- break;
- case CloseMessage closeMessage:
- WriteCloseMessage(closeMessage, ref writer);
- break;
- default:
- throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
- }
-
- writer.Flush();
- }
-
- private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(6);
-
- writer.Write(HubProtocolConstants.InvocationMessageType);
- PackHeaders(ref writer, message.Headers);
- if (string.IsNullOrEmpty(message.InvocationId))
- {
- writer.WriteNil();
- }
- else
- {
- writer.Write(message.InvocationId);
- }
- writer.Write(message.Target);
- writer.WriteArrayHeader(message.Arguments.Length);
- foreach (var arg in message.Arguments)
- {
- SerializeArgument(ref writer, arg);
- }
-
- WriteStreamIds(message.StreamIds, ref writer);
- }
-
- private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(6);
-
- writer.Write(HubProtocolConstants.StreamInvocationMessageType);
- PackHeaders(ref writer, message.Headers);
- writer.Write(message.InvocationId);
- writer.Write(message.Target);
-
- writer.WriteArrayHeader(message.Arguments.Length);
- foreach (var arg in message.Arguments)
- {
- SerializeArgument(ref writer, arg);
- }
-
- WriteStreamIds(message.StreamIds, ref writer);
- }
-
- private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(4);
- writer.Write(HubProtocolConstants.StreamItemMessageType);
- PackHeaders(ref writer, message.Headers);
- writer.Write(message.InvocationId);
- SerializeArgument(ref writer, message.Item);
- }
-
- private void SerializeArgument(ref MessagePackWriter writer, object argument)
- {
- switch (argument)
- {
- case null:
- writer.WriteNil();
- break;
-
- case bool boolValue:
- writer.Write(boolValue);
- break;
-
- case string stringValue:
- writer.Write(stringValue);
- break;
-
- case int intValue:
- writer.Write(intValue);
- break;
-
- case long longValue:
- writer.Write(longValue);
- break;
-
- case float floatValue:
- writer.Write(floatValue);
- break;
-
- case ArraySegment bytes:
- writer.Write(bytes);
- break;
-
- default:
- throw new FormatException($"Unsupported argument type {argument.GetType()}");
- }
- }
-
- private static object DeserializeObject(ref MessagePackReader reader, Type type, string field)
- {
- try
- {
- if (type == typeof(string))
- {
- return ReadString(ref reader, "argument");
- }
- else if (type == typeof(bool))
- {
- return reader.ReadBoolean();
- }
- else if (type == typeof(int))
- {
- return reader.ReadInt32();
- }
- else if (type == typeof(long))
- {
- return reader.ReadInt64();
- }
- else if (type == typeof(float))
- {
- return reader.ReadSingle();
- }
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
- }
-
- throw new FormatException($"Type {type} is not supported");
- }
-
- private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer)
- {
- if (streamIds != null)
- {
- writer.WriteArrayHeader(streamIds.Length);
- foreach (var streamId in streamIds)
- {
- writer.Write(streamId);
- }
- }
- else
- {
- writer.WriteArrayHeader(0);
- }
- }
-
- private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
- {
- var resultKind =
- message.Error != null ? ErrorResult :
- message.HasResult ? NonVoidResult :
- VoidResult;
-
- writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
- writer.Write(HubProtocolConstants.CompletionMessageType);
- PackHeaders(ref writer, message.Headers);
- writer.Write(message.InvocationId);
- writer.Write(resultKind);
- switch (resultKind)
- {
- case ErrorResult:
- writer.Write(message.Error);
- break;
- case NonVoidResult:
- SerializeArgument(ref writer, message.Result);
- break;
- }
- }
-
- private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(3);
- writer.Write(HubProtocolConstants.CancelInvocationMessageType);
- PackHeaders(ref writer, message.Headers);
- writer.Write(message.InvocationId);
- }
-
- private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(3);
- writer.Write(HubProtocolConstants.CloseMessageType);
- if (string.IsNullOrEmpty(message.Error))
- {
- writer.WriteNil();
- }
- else
- {
- writer.Write(message.Error);
- }
-
- writer.Write(message.AllowReconnect);
- }
-
- private void WritePingMessage(PingMessage _, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(1);
- writer.Write(HubProtocolConstants.PingMessageType);
- }
-
- private void PackHeaders(ref MessagePackWriter writer, IDictionary headers)
- {
- if (headers == null)
- {
- writer.WriteMapHeader(0);
- return;
- }
-
- writer.WriteMapHeader(headers.Count);
- foreach (var header in headers)
- {
- writer.Write(header.Key);
- writer.Write(header.Value);
- }
- }
-
- private static T ApplyHeaders(IDictionary source, T destination) where T : HubInvocationMessage
- {
- if (source != null && source.Count > 0)
- {
- destination.Headers = source;
- }
-
- return destination;
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private static bool ReadBoolean(ref MessagePackReader reader, string field)
- {
- if (reader.End || reader.NextMessagePackType != MessagePackType.Boolean)
- {
- ThrowInvalidDataException(field, "Boolean");
- }
-
- return reader.ReadBoolean();
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private static int ReadInt32(ref MessagePackReader reader, string field)
- {
- if (reader.End || reader.NextMessagePackType != MessagePackType.Integer)
- {
- ThrowInvalidDataException(field, "Int32");
- }
-
- return reader.ReadInt32();
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private static string ReadString(ref MessagePackReader reader, string field)
- {
- if (reader.End)
- {
- ThrowInvalidDataException(field, "String");
- }
-
- if (reader.IsNil)
- {
- reader.ReadNil();
- return null;
- }
- else if (reader.NextMessagePackType == MessagePackType.String)
- {
- return reader.ReadString();
- }
-
- ThrowInvalidDataException(field, "String");
- return null; //This should never be reached.
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private static int ReadArrayHeader(ref MessagePackReader reader, string field)
- {
- if (reader.End || reader.NextMessagePackType != MessagePackType.Array)
- {
- ThrowInvalidCollectionLengthException(field, "array");
- }
-
- return reader.ReadArrayHeader();
- }
-
- [MethodImpl(MethodImplOptions.AggressiveInlining)]
- private static int ReadMapHeader(ref MessagePackReader reader, string field)
- {
- if (reader.End || reader.NextMessagePackType != MessagePackType.Map)
- {
- ThrowInvalidCollectionLengthException(field, "map");
- }
-
- return reader.ReadMapHeader();
- }
-
- private static void ThrowInvalidDataException(string field, string targetType)
- {
- throw new InvalidDataException($"Reading '{field}' as {targetType} failed.");
- }
-
- private static void ThrowInvalidCollectionLengthException(string field, string collection)
- {
- throw new InvalidDataException($"Reading {collection} length for '{field}' failed.");
- }
+ => _worker.GetMessageBytes(message);
}
}
diff --git a/src/Components/Server/src/BlazorPack/BlazorPackHubProtocolWorker.cs b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocolWorker.cs
new file mode 100644
index 0000000000..721eaedf2b
--- /dev/null
+++ b/src/Components/Server/src/BlazorPack/BlazorPackHubProtocolWorker.cs
@@ -0,0 +1,83 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.IO;
+using MessagePack;
+using Microsoft.AspNetCore.SignalR.Protocol;
+
+namespace Microsoft.AspNetCore.Components.Server.BlazorPack
+{
+ internal sealed class BlazorPackHubProtocolWorker : MessagePackHubProtocolWorker
+ {
+ protected override object DeserializeObject(ref MessagePackReader reader, Type type, string field)
+ {
+ try
+ {
+ if (type == typeof(string))
+ {
+ return ReadString(ref reader, "argument");
+ }
+ else if (type == typeof(bool))
+ {
+ return reader.ReadBoolean();
+ }
+ else if (type == typeof(int))
+ {
+ return reader.ReadInt32();
+ }
+ else if (type == typeof(long))
+ {
+ return reader.ReadInt64();
+ }
+ else if (type == typeof(float))
+ {
+ return reader.ReadSingle();
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
+ }
+
+ throw new FormatException($"Type {type} is not supported");
+ }
+
+ protected override void Serialize(ref MessagePackWriter writer, Type type, object value)
+ {
+ switch (value)
+ {
+ case null:
+ writer.WriteNil();
+ break;
+
+ case bool boolValue:
+ writer.Write(boolValue);
+ break;
+
+ case string stringValue:
+ writer.Write(stringValue);
+ break;
+
+ case int intValue:
+ writer.Write(intValue);
+ break;
+
+ case long longValue:
+ writer.Write(longValue);
+ break;
+
+ case float floatValue:
+ writer.Write(floatValue);
+ break;
+
+ case ArraySegment bytes:
+ writer.Write(bytes);
+ break;
+
+ default:
+ throw new FormatException($"Unsupported argument type {type}");
+ }
+ }
+ }
+}
diff --git a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj
index 98fa33ca3e..f176a30378 100644
--- a/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj
+++ b/src/Components/Server/src/Microsoft.AspNetCore.Components.Server.csproj
@@ -61,6 +61,7 @@
+
diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/DefaultMessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/DefaultMessagePackHubProtocolWorker.cs
new file mode 100644
index 0000000000..de61d27c8d
--- /dev/null
+++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/DefaultMessagePackHubProtocolWorker.cs
@@ -0,0 +1,36 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.IO;
+using MessagePack;
+
+namespace Microsoft.AspNetCore.SignalR.Protocol
+{
+ internal sealed class DefaultMessagePackHubProtocolWorker : MessagePackHubProtocolWorker
+ {
+ private readonly MessagePackSerializerOptions _messagePackSerializerOptions;
+
+ public DefaultMessagePackHubProtocolWorker(MessagePackSerializerOptions messagePackSerializerOptions)
+ {
+ _messagePackSerializerOptions = messagePackSerializerOptions;
+ }
+
+ protected override object DeserializeObject(ref MessagePackReader reader, Type type, string field)
+ {
+ try
+ {
+ return MessagePackSerializer.Deserialize(type, ref reader, _messagePackSerializerOptions);
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
+ }
+ }
+
+ protected override void Serialize(ref MessagePackWriter writer, Type type, object value)
+ {
+ MessagePackSerializer.Serialize(type, ref writer, value, _messagePackSerializerOptions);
+ }
+ }
+}
diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
index ffa814da22..d9a8ea817b 100644
--- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
+++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs
@@ -4,15 +4,10 @@
using System;
using System.Buffers;
using System.Collections.Generic;
-using System.Diagnostics;
-using System.IO;
-using System.Linq;
-using System.Runtime.ExceptionServices;
using MessagePack;
using MessagePack.Formatters;
using MessagePack.Resolvers;
using Microsoft.AspNetCore.Connections;
-using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR.Protocol
@@ -22,14 +17,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
///
public class MessagePackHubProtocol : IHubProtocol
{
- private const int ErrorResult = 1;
- private const int VoidResult = 2;
- private const int NonVoidResult = 3;
-
- private readonly MessagePackSerializerOptions _msgPackSerializerOptions;
-
private static readonly string ProtocolName = "messagepack";
private static readonly int ProtocolVersion = 1;
+ private readonly DefaultMessagePackHubProtocolWorker _worker;
///
public string Name => ProtocolName;
@@ -53,7 +43,12 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
/// The options used to initialize the protocol.
public MessagePackHubProtocol(IOptions options)
{
- _msgPackSerializerOptions = options.Value.SerializerOptions;
+ if (options is null)
+ {
+ throw new ArgumentNullException(nameof(options));
+ }
+
+ _worker = new DefaultMessagePackHubProtocolWorker(options.Value.SerializerOptions);
}
///
@@ -64,568 +59,16 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
///
public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message)
- {
- if (!BinaryMessageParser.TryParseMessage(ref input, out var payload))
- {
- message = null;
- return false;
- }
-
- var reader = new MessagePackReader(payload);
- message = ParseMessage(ref reader, binder, _msgPackSerializerOptions);
- return true;
- }
-
- private static HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
- {
- var itemCount = reader.ReadArrayHeader();
-
- var messageType = ReadInt32(ref reader, "messageType");
-
- switch (messageType)
- {
- case HubProtocolConstants.InvocationMessageType:
- return CreateInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
- case HubProtocolConstants.StreamInvocationMessageType:
- return CreateStreamInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
- case HubProtocolConstants.StreamItemMessageType:
- return CreateStreamItemMessage(ref reader, binder, msgPackSerializerOptions);
- case HubProtocolConstants.CompletionMessageType:
- return CreateCompletionMessage(ref reader, binder, msgPackSerializerOptions);
- case HubProtocolConstants.CancelInvocationMessageType:
- return CreateCancelInvocationMessage(ref reader);
- case HubProtocolConstants.PingMessageType:
- return PingMessage.Instance;
- case HubProtocolConstants.CloseMessageType:
- return CreateCloseMessage(ref reader, itemCount);
- default:
- // Future protocol changes can add message types, old clients can ignore them
- return null;
- }
- }
-
- private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadInvocationId(ref reader);
-
- // For MsgPack, we represent an empty invocation ID as an empty string,
- // so we need to normalize that to "null", which is what indicates a non-blocking invocation.
- if (string.IsNullOrEmpty(invocationId))
- {
- invocationId = null;
- }
-
- var target = ReadString(ref reader, "target");
-
- object[] arguments = null;
- try
- {
- var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
- }
- catch (Exception ex)
- {
- return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
- }
-
- string[] streams = null;
- // Previous clients will send 5 items, so we check if they sent a stream array or not
- if (itemCount > 5)
- {
- streams = ReadStreamIds(ref reader);
- }
-
- return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
- }
-
- private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadInvocationId(ref reader);
- var target = ReadString(ref reader, "target");
-
- object[] arguments = null;
- try
- {
- var parameterTypes = binder.GetParameterTypes(target);
- arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
- }
- catch (Exception ex)
- {
- return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
- }
-
- string[] streams = null;
- // Previous clients will send 5 items, so we check if they sent a stream array or not
- if (itemCount > 5)
- {
- streams = ReadStreamIds(ref reader);
- }
-
- return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
- }
-
- private static HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadInvocationId(ref reader);
- object value;
- try
- {
- var itemType = binder.GetStreamItemType(invocationId);
- value = DeserializeObject(ref reader, itemType, "item", msgPackSerializerOptions);
- }
- catch (Exception ex)
- {
- return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
- }
-
- return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
- }
-
- private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadInvocationId(ref reader);
- var resultKind = ReadInt32(ref reader, "resultKind");
-
- string error = null;
- object result = null;
- var hasResult = false;
-
- switch (resultKind)
- {
- case ErrorResult:
- error = ReadString(ref reader, "error");
- break;
- case NonVoidResult:
- var itemType = binder.GetReturnType(invocationId);
- result = DeserializeObject(ref reader, itemType, "argument", msgPackSerializerOptions);
- hasResult = true;
- break;
- case VoidResult:
- hasResult = false;
- break;
- default:
- throw new InvalidDataException("Invalid invocation result kind.");
- }
-
- return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
- }
-
- private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
- {
- var headers = ReadHeaders(ref reader);
- var invocationId = ReadInvocationId(ref reader);
- return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
- }
-
- private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
- {
- var error = ReadString(ref reader, "error");
- var allowReconnect = false;
-
- if (itemCount > 2)
- {
- allowReconnect = ReadBoolean(ref reader, "allowReconnect");
- }
-
- // An empty string is still an error
- if (error == null && !allowReconnect)
- {
- return CloseMessage.Empty;
- }
-
- return new CloseMessage(error, allowReconnect);
- }
-
- private static Dictionary ReadHeaders(ref MessagePackReader reader)
- {
- var headerCount = ReadMapLength(ref reader, "headers");
- if (headerCount > 0)
- {
- var headers = new Dictionary(StringComparer.Ordinal);
-
- for (var i = 0; i < headerCount; i++)
- {
- var key = ReadString(ref reader, $"headers[{i}].Key");
- var value = ReadString(ref reader, $"headers[{i}].Value");
- headers.Add(key, value);
- }
- return headers;
- }
- else
- {
- return null;
- }
- }
-
- private static string[] ReadStreamIds(ref MessagePackReader reader)
- {
- var streamIdCount = ReadArrayLength(ref reader, "streamIds");
- List streams = null;
-
- if (streamIdCount > 0)
- {
- streams = new List();
- for (var i = 0; i < streamIdCount; i++)
- {
- streams.Add(reader.ReadString());
- }
- }
-
- return streams?.ToArray();
- }
-
- private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes, MessagePackSerializerOptions msgPackSerializerOptions)
- {
- var argumentCount = ReadArrayLength(ref reader, "arguments");
-
- if (parameterTypes.Count != argumentCount)
- {
- throw new InvalidDataException(
- $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}.");
- }
-
- try
- {
- var arguments = new object[argumentCount];
- for (var i = 0; i < argumentCount; i++)
- {
- arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument", msgPackSerializerOptions);
- }
-
- return arguments;
- }
- catch (Exception ex)
- {
- throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex);
- }
- }
-
- private static T ApplyHeaders(IDictionary source, T destination) where T : HubInvocationMessage
- {
- if (source != null && source.Count > 0)
- {
- destination.Headers = source;
- }
-
- return destination;
- }
+ => _worker.TryParseMessage(ref input, binder, out message);
///
public void WriteMessage(HubMessage message, IBufferWriter output)
- {
- var memoryBufferWriter = MemoryBufferWriter.Get();
+ => _worker.WriteMessage(message, output);
- try
- {
- var writer = new MessagePackWriter(memoryBufferWriter);
-
- // Write message to a buffer so we can get its length
- WriteMessageCore(message, ref writer);
-
- // Write length then message to output
- BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output);
- memoryBufferWriter.CopyTo(output);
- }
- finally
- {
- MemoryBufferWriter.Return(memoryBufferWriter);
- }
- }
///
public ReadOnlyMemory GetMessageBytes(HubMessage message)
- {
- var memoryBufferWriter = MemoryBufferWriter.Get();
-
- try
- {
- var writer = new MessagePackWriter(memoryBufferWriter);
-
- // Write message to a buffer so we can get its length
- WriteMessageCore(message, ref writer);
-
- var dataLength = memoryBufferWriter.Length;
- var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length);
-
- var array = new byte[dataLength + prefixLength];
- var span = array.AsSpan();
-
- // Write length then message to output
- var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span);
- Debug.Assert(written == prefixLength);
- memoryBufferWriter.CopyTo(span.Slice(prefixLength));
-
- return array;
- }
- finally
- {
- MemoryBufferWriter.Return(memoryBufferWriter);
- }
- }
-
- private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
- {
- switch (message)
- {
- case InvocationMessage invocationMessage:
- WriteInvocationMessage(invocationMessage, ref writer);
- break;
- case StreamInvocationMessage streamInvocationMessage:
- WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
- break;
- case StreamItemMessage streamItemMessage:
- WriteStreamingItemMessage(streamItemMessage, ref writer);
- break;
- case CompletionMessage completionMessage:
- WriteCompletionMessage(completionMessage, ref writer);
- break;
- case CancelInvocationMessage cancelInvocationMessage:
- WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
- break;
- case PingMessage pingMessage:
- WritePingMessage(pingMessage, ref writer);
- break;
- case CloseMessage closeMessage:
- WriteCloseMessage(closeMessage, ref writer);
- break;
- default:
- throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
- }
-
- writer.Flush();
- }
-
- private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(6);
-
- writer.Write(HubProtocolConstants.InvocationMessageType);
- PackHeaders(message.Headers, ref writer);
- if (string.IsNullOrEmpty(message.InvocationId))
- {
- writer.WriteNil();
- }
- else
- {
- writer.Write(message.InvocationId);
- }
- writer.Write(message.Target);
- writer.WriteArrayHeader(message.Arguments.Length);
- foreach (var arg in message.Arguments)
- {
- WriteArgument(arg, ref writer);
- }
-
- WriteStreamIds(message.StreamIds, ref writer);
- }
-
- private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(6);
-
- writer.Write(HubProtocolConstants.StreamInvocationMessageType);
- PackHeaders(message.Headers, ref writer);
- writer.Write(message.InvocationId);
- writer.Write(message.Target);
-
- writer.WriteArrayHeader(message.Arguments.Length);
- foreach (var arg in message.Arguments)
- {
- WriteArgument(arg, ref writer);
- }
-
- WriteStreamIds(message.StreamIds, ref writer);
- }
-
- private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(4);
- writer.Write(HubProtocolConstants.StreamItemMessageType);
- PackHeaders(message.Headers, ref writer);
- writer.Write(message.InvocationId);
- WriteArgument(message.Item, ref writer);
- }
-
- private void WriteArgument(object argument, ref MessagePackWriter writer)
- {
- if (argument == null)
- {
- writer.WriteNil();
- }
- else
- {
- MessagePackSerializer.Serialize(argument.GetType(), ref writer, argument, _msgPackSerializerOptions);
- }
- }
-
- private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer)
- {
- if (streamIds != null)
- {
- writer.WriteArrayHeader(streamIds.Length);
- foreach (var streamId in streamIds)
- {
- writer.Write(streamId);
- }
- }
- else
- {
- writer.WriteArrayHeader(0);
- }
- }
-
- private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
- {
- var resultKind =
- message.Error != null ? ErrorResult :
- message.HasResult ? NonVoidResult :
- VoidResult;
-
- writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
- writer.Write(HubProtocolConstants.CompletionMessageType);
- PackHeaders(message.Headers, ref writer);
- writer.Write(message.InvocationId);
- writer.Write(resultKind);
- switch (resultKind)
- {
- case ErrorResult:
- writer.Write(message.Error);
- break;
- case NonVoidResult:
- WriteArgument(message.Result, ref writer);
- break;
- }
- }
-
- private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(3);
- writer.Write(HubProtocolConstants.CancelInvocationMessageType);
- PackHeaders(message.Headers, ref writer);
- writer.Write(message.InvocationId);
- }
-
- private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(3);
- writer.Write(HubProtocolConstants.CloseMessageType);
- if (string.IsNullOrEmpty(message.Error))
- {
- writer.WriteNil();
- }
- else
- {
- writer.Write(message.Error);
- }
-
- writer.Write(message.AllowReconnect);
- }
-
- private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer)
- {
- writer.WriteArrayHeader(1);
- writer.Write(HubProtocolConstants.PingMessageType);
- }
-
- private void PackHeaders(IDictionary headers, ref MessagePackWriter writer)
- {
- if (headers != null)
- {
- writer.WriteMapHeader(headers.Count);
- if (headers.Count > 0)
- {
- foreach (var header in headers)
- {
- writer.Write(header.Key);
- writer.Write(header.Value);
- }
- }
- }
- else
- {
- writer.WriteMapHeader(0);
- }
- }
-
- private static string ReadInvocationId(ref MessagePackReader reader) =>
- ReadString(ref reader, "invocationId");
-
- private static bool ReadBoolean(ref MessagePackReader reader, string field)
- {
- try
- {
- return reader.ReadBoolean();
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
- }
- }
-
- private static int ReadInt32(ref MessagePackReader reader, string field)
- {
- try
- {
- return reader.ReadInt32();
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
- }
- }
-
- private static string ReadString(ref MessagePackReader reader, string field)
- {
- try
- {
- return reader.ReadString();
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
- }
- }
-
- private static long ReadMapLength(ref MessagePackReader reader, string field)
- {
- try
- {
- return reader.ReadMapHeader();
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
- }
-
- }
-
- private static long ReadArrayLength(ref MessagePackReader reader, string field)
- {
- try
- {
- return reader.ReadArrayHeader();
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
- }
- }
-
- private static object DeserializeObject(ref MessagePackReader reader, Type type, string field, MessagePackSerializerOptions msgPackSerializerOptions)
- {
- try
- {
- return MessagePackSerializer.Deserialize(type, ref reader, msgPackSerializerOptions);
- }
- catch (Exception ex)
- {
- throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
- }
- }
+ => _worker.GetMessageBytes(message);
internal static MessagePackSerializerOptions CreateDefaultMessagePackSerializerOptions() =>
MessagePackSerializerOptions
diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs
new file mode 100644
index 0000000000..43c9a83929
--- /dev/null
+++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocolWorker.cs
@@ -0,0 +1,580 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Buffers;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Runtime.ExceptionServices;
+using MessagePack;
+using Microsoft.AspNetCore.Internal;
+
+namespace Microsoft.AspNetCore.SignalR.Protocol
+{
+ ///
+ /// Implements support for MessagePackHubProtocol. This code is shared between SignalR and Blazor.
+ ///
+ internal abstract class MessagePackHubProtocolWorker
+ {
+ private const int ErrorResult = 1;
+ private const int VoidResult = 2;
+ private const int NonVoidResult = 3;
+
+ public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message)
+ {
+ if (!BinaryMessageParser.TryParseMessage(ref input, out var payload))
+ {
+ message = null;
+ return false;
+ }
+
+ var reader = new MessagePackReader(payload);
+ message = ParseMessage(ref reader, binder);
+ return true;
+ }
+
+ private HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder)
+ {
+ var itemCount = reader.ReadArrayHeader();
+
+ var messageType = ReadInt32(ref reader, "messageType");
+
+ switch (messageType)
+ {
+ case HubProtocolConstants.InvocationMessageType:
+ return CreateInvocationMessage(ref reader, binder, itemCount);
+ case HubProtocolConstants.StreamInvocationMessageType:
+ return CreateStreamInvocationMessage(ref reader, binder, itemCount);
+ case HubProtocolConstants.StreamItemMessageType:
+ return CreateStreamItemMessage(ref reader, binder);
+ case HubProtocolConstants.CompletionMessageType:
+ return CreateCompletionMessage(ref reader, binder);
+ case HubProtocolConstants.CancelInvocationMessageType:
+ return CreateCancelInvocationMessage(ref reader);
+ case HubProtocolConstants.PingMessageType:
+ return PingMessage.Instance;
+ case HubProtocolConstants.CloseMessageType:
+ return CreateCloseMessage(ref reader, itemCount);
+ default:
+ // Future protocol changes can add message types, old clients can ignore them
+ return null;
+ }
+ }
+
+ private HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
+ {
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+
+ // For MsgPack, we represent an empty invocation ID as an empty string,
+ // so we need to normalize that to "null", which is what indicates a non-blocking invocation.
+ if (string.IsNullOrEmpty(invocationId))
+ {
+ invocationId = null;
+ }
+
+ var target = ReadString(ref reader, "target");
+
+ object[] arguments = null;
+ try
+ {
+ var parameterTypes = binder.GetParameterTypes(target);
+ arguments = BindArguments(ref reader, parameterTypes);
+ }
+ catch (Exception ex)
+ {
+ return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
+ }
+
+ string[] streams = null;
+ // Previous clients will send 5 items, so we check if they sent a stream array or not
+ if (itemCount > 5)
+ {
+ streams = ReadStreamIds(ref reader);
+ }
+
+ return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
+ }
+
+ private HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, int itemCount)
+ {
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ var target = ReadString(ref reader, "target");
+
+ object[] arguments = null;
+ try
+ {
+ var parameterTypes = binder.GetParameterTypes(target);
+ arguments = BindArguments(ref reader, parameterTypes);
+ }
+ catch (Exception ex)
+ {
+ return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
+ }
+
+ string[] streams = null;
+ // Previous clients will send 5 items, so we check if they sent a stream array or not
+ if (itemCount > 5)
+ {
+ streams = ReadStreamIds(ref reader);
+ }
+
+ return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
+ }
+
+ private HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder)
+ {
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ object value;
+ try
+ {
+ var itemType = binder.GetStreamItemType(invocationId);
+ value = DeserializeObject(ref reader, itemType, "item");
+ }
+ catch (Exception ex)
+ {
+ return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
+ }
+
+ return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
+ }
+
+ private CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder)
+ {
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ var resultKind = ReadInt32(ref reader, "resultKind");
+
+ string error = null;
+ object result = null;
+ var hasResult = false;
+
+ switch (resultKind)
+ {
+ case ErrorResult:
+ error = ReadString(ref reader, "error");
+ break;
+ case NonVoidResult:
+ var itemType = binder.GetReturnType(invocationId);
+ result = DeserializeObject(ref reader, itemType, "argument");
+ hasResult = true;
+ break;
+ case VoidResult:
+ hasResult = false;
+ break;
+ default:
+ throw new InvalidDataException("Invalid invocation result kind.");
+ }
+
+ return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
+ }
+
+ private CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
+ {
+ var headers = ReadHeaders(ref reader);
+ var invocationId = ReadInvocationId(ref reader);
+ return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
+ }
+
+ private CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
+ {
+ var error = ReadString(ref reader, "error");
+ var allowReconnect = false;
+
+ if (itemCount > 2)
+ {
+ allowReconnect = ReadBoolean(ref reader, "allowReconnect");
+ }
+
+ // An empty string is still an error
+ if (error == null && !allowReconnect)
+ {
+ return CloseMessage.Empty;
+ }
+
+ return new CloseMessage(error, allowReconnect);
+ }
+
+ private Dictionary ReadHeaders(ref MessagePackReader reader)
+ {
+ var headerCount = ReadMapLength(ref reader, "headers");
+ if (headerCount > 0)
+ {
+ var headers = new Dictionary(StringComparer.Ordinal);
+
+ for (var i = 0; i < headerCount; i++)
+ {
+ var key = ReadString(ref reader, $"headers[{i}].Key");
+ var value = ReadString(ref reader, $"headers[{i}].Value");
+ headers.Add(key, value);
+ }
+ return headers;
+ }
+ else
+ {
+ return null;
+ }
+ }
+
+ private string[] ReadStreamIds(ref MessagePackReader reader)
+ {
+ var streamIdCount = ReadArrayLength(ref reader, "streamIds");
+ List streams = null;
+
+ if (streamIdCount > 0)
+ {
+ streams = new List();
+ for (var i = 0; i < streamIdCount; i++)
+ {
+ streams.Add(reader.ReadString());
+ }
+ }
+
+ return streams?.ToArray();
+ }
+
+ private object[] BindArguments(ref MessagePackReader reader, IReadOnlyList parameterTypes)
+ {
+ var argumentCount = ReadArrayLength(ref reader, "arguments");
+
+ if (parameterTypes.Count != argumentCount)
+ {
+ throw new InvalidDataException(
+ $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}.");
+ }
+
+ try
+ {
+ var arguments = new object[argumentCount];
+ for (var i = 0; i < argumentCount; i++)
+ {
+ arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument");
+ }
+
+ return arguments;
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex);
+ }
+ }
+
+ protected abstract object DeserializeObject(ref MessagePackReader reader, Type type, string field);
+
+ private T ApplyHeaders(IDictionary source, T destination) where T : HubInvocationMessage
+ {
+ if (source != null && source.Count > 0)
+ {
+ destination.Headers = source;
+ }
+
+ return destination;
+ }
+
+ ///
+ public void WriteMessage(HubMessage message, IBufferWriter output)
+ {
+ var memoryBufferWriter = MemoryBufferWriter.Get();
+
+ try
+ {
+ var writer = new MessagePackWriter(memoryBufferWriter);
+
+ // Write message to a buffer so we can get its length
+ WriteMessageCore(message, ref writer);
+
+ // Write length then message to output
+ BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output);
+ memoryBufferWriter.CopyTo(output);
+ }
+ finally
+ {
+ MemoryBufferWriter.Return(memoryBufferWriter);
+ }
+ }
+
+ ///
+ public ReadOnlyMemory GetMessageBytes(HubMessage message)
+ {
+ var memoryBufferWriter = MemoryBufferWriter.Get();
+
+ try
+ {
+ var writer = new MessagePackWriter(memoryBufferWriter);
+
+ // Write message to a buffer so we can get its length
+ WriteMessageCore(message, ref writer);
+
+ var dataLength = memoryBufferWriter.Length;
+ var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length);
+
+ var array = new byte[dataLength + prefixLength];
+ var span = array.AsSpan();
+
+ // Write length then message to output
+ var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span);
+ Debug.Assert(written == prefixLength);
+ memoryBufferWriter.CopyTo(span.Slice(prefixLength));
+
+ return array;
+ }
+ finally
+ {
+ MemoryBufferWriter.Return(memoryBufferWriter);
+ }
+ }
+
+ private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
+ {
+ switch (message)
+ {
+ case InvocationMessage invocationMessage:
+ WriteInvocationMessage(invocationMessage, ref writer);
+ break;
+ case StreamInvocationMessage streamInvocationMessage:
+ WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
+ break;
+ case StreamItemMessage streamItemMessage:
+ WriteStreamingItemMessage(streamItemMessage, ref writer);
+ break;
+ case CompletionMessage completionMessage:
+ WriteCompletionMessage(completionMessage, ref writer);
+ break;
+ case CancelInvocationMessage cancelInvocationMessage:
+ WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
+ break;
+ case PingMessage pingMessage:
+ WritePingMessage(pingMessage, ref writer);
+ break;
+ case CloseMessage closeMessage:
+ WriteCloseMessage(closeMessage, ref writer);
+ break;
+ default:
+ throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
+ }
+
+ writer.Flush();
+ }
+
+ private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(6);
+
+ writer.Write(HubProtocolConstants.InvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
+ if (string.IsNullOrEmpty(message.InvocationId))
+ {
+ writer.WriteNil();
+ }
+ else
+ {
+ writer.Write(message.InvocationId);
+ }
+ writer.Write(message.Target);
+ writer.WriteArrayHeader(message.Arguments.Length);
+ foreach (var arg in message.Arguments)
+ {
+ WriteArgument(arg, ref writer);
+ }
+
+ WriteStreamIds(message.StreamIds, ref writer);
+ }
+
+ private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(6);
+
+ writer.Write(HubProtocolConstants.StreamInvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ writer.Write(message.Target);
+
+ writer.WriteArrayHeader(message.Arguments.Length);
+ foreach (var arg in message.Arguments)
+ {
+ WriteArgument(arg, ref writer);
+ }
+
+ WriteStreamIds(message.StreamIds, ref writer);
+ }
+
+ private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(4);
+ writer.Write(HubProtocolConstants.StreamItemMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ WriteArgument(message.Item, ref writer);
+ }
+
+ private void WriteArgument(object argument, ref MessagePackWriter writer)
+ {
+ if (argument == null)
+ {
+ writer.WriteNil();
+ }
+ else
+ {
+ Serialize(ref writer, argument.GetType(), argument);
+ }
+ }
+
+ protected abstract void Serialize(ref MessagePackWriter writer, Type type, object value);
+
+ private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer)
+ {
+ if (streamIds != null)
+ {
+ writer.WriteArrayHeader(streamIds.Length);
+ foreach (var streamId in streamIds)
+ {
+ writer.Write(streamId);
+ }
+ }
+ else
+ {
+ writer.WriteArrayHeader(0);
+ }
+ }
+
+ private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
+ {
+ var resultKind =
+ message.Error != null ? ErrorResult :
+ message.HasResult ? NonVoidResult :
+ VoidResult;
+
+ writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
+ writer.Write(HubProtocolConstants.CompletionMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ writer.Write(resultKind);
+ switch (resultKind)
+ {
+ case ErrorResult:
+ writer.Write(message.Error);
+ break;
+ case NonVoidResult:
+ WriteArgument(message.Result, ref writer);
+ break;
+ }
+ }
+
+ private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(3);
+ writer.Write(HubProtocolConstants.CancelInvocationMessageType);
+ PackHeaders(message.Headers, ref writer);
+ writer.Write(message.InvocationId);
+ }
+
+ private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(3);
+ writer.Write(HubProtocolConstants.CloseMessageType);
+ if (string.IsNullOrEmpty(message.Error))
+ {
+ writer.WriteNil();
+ }
+ else
+ {
+ writer.Write(message.Error);
+ }
+
+ writer.Write(message.AllowReconnect);
+ }
+
+ private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer)
+ {
+ writer.WriteArrayHeader(1);
+ writer.Write(HubProtocolConstants.PingMessageType);
+ }
+
+ private void PackHeaders(IDictionary headers, ref MessagePackWriter writer)
+ {
+ if (headers != null)
+ {
+ writer.WriteMapHeader(headers.Count);
+ if (headers.Count > 0)
+ {
+ foreach (var header in headers)
+ {
+ writer.Write(header.Key);
+ writer.Write(header.Value);
+ }
+ }
+ }
+ else
+ {
+ writer.WriteMapHeader(0);
+ }
+ }
+
+ private string ReadInvocationId(ref MessagePackReader reader) =>
+ ReadString(ref reader, "invocationId");
+
+ private bool ReadBoolean(ref MessagePackReader reader, string field)
+ {
+ try
+ {
+ return reader.ReadBoolean();
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
+ }
+ }
+
+ private int ReadInt32(ref MessagePackReader reader, string field)
+ {
+ try
+ {
+ return reader.ReadInt32();
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
+ }
+ }
+
+ protected string ReadString(ref MessagePackReader reader, string field)
+ {
+ try
+ {
+ return reader.ReadString();
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
+ }
+ }
+
+ private long ReadMapLength(ref MessagePackReader reader, string field)
+ {
+ try
+ {
+ return reader.ReadMapHeader();
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
+ }
+
+ }
+
+ private long ReadArrayLength(ref MessagePackReader reader, string field)
+ {
+ try
+ {
+ return reader.ReadArrayHeader();
+ }
+ catch (Exception ex)
+ {
+ throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
+ }
+ }
+ }
+}