502 lines
18 KiB
C#
502 lines
18 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using System.Runtime.ExceptionServices;
|
|
using Microsoft.AspNetCore.Protocols;
|
|
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
|
|
using Microsoft.AspNetCore.Sockets;
|
|
using Microsoft.Extensions.Options;
|
|
using MsgPack;
|
|
using MsgPack.Serialization;
|
|
|
|
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|
{
|
|
public class MessagePackHubProtocol : IHubProtocol
|
|
{
|
|
private const int ErrorResult = 1;
|
|
private const int VoidResult = 2;
|
|
private const int NonVoidResult = 3;
|
|
|
|
public static readonly string ProtocolName = "messagepack";
|
|
|
|
public SerializationContext SerializationContext { get; }
|
|
|
|
public string Name => ProtocolName;
|
|
|
|
public TransferFormat TransferFormat => TransferFormat.Binary;
|
|
|
|
public MessagePackHubProtocol()
|
|
: this(Options.Create(new MessagePackHubProtocolOptions()))
|
|
{ }
|
|
|
|
public MessagePackHubProtocol(IOptions<MessagePackHubProtocolOptions> options)
|
|
{
|
|
SerializationContext = options.Value.SerializationContext;
|
|
}
|
|
|
|
public bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
|
|
{
|
|
while (BinaryMessageParser.TryParseMessage(ref input, out var payload))
|
|
{
|
|
messages.Add(ParseMessage(payload.ToArray(), binder));
|
|
}
|
|
|
|
return messages.Count > 0;
|
|
}
|
|
|
|
private static HubMessage ParseMessage(byte[] input, IInvocationBinder binder)
|
|
{
|
|
using (var unpacker = Unpacker.Create(input))
|
|
{
|
|
_ = ReadArrayLength(unpacker, "elementCount");
|
|
|
|
var messageType = ReadInt32(unpacker, "messageType");
|
|
|
|
switch (messageType)
|
|
{
|
|
case HubProtocolConstants.InvocationMessageType:
|
|
return CreateInvocationMessage(unpacker, binder);
|
|
case HubProtocolConstants.StreamInvocationMessageType:
|
|
return CreateStreamInvocationMessage(unpacker, binder);
|
|
case HubProtocolConstants.StreamItemMessageType:
|
|
return CreateStreamItemMessage(unpacker, binder);
|
|
case HubProtocolConstants.CompletionMessageType:
|
|
return CreateCompletionMessage(unpacker, binder);
|
|
case HubProtocolConstants.CancelInvocationMessageType:
|
|
return CreateCancelInvocationMessage(unpacker);
|
|
case HubProtocolConstants.PingMessageType:
|
|
return PingMessage.Instance;
|
|
default:
|
|
throw new FormatException($"Invalid message type: {messageType}.");
|
|
}
|
|
}
|
|
}
|
|
|
|
private static InvocationMessage CreateInvocationMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var headers = ReadHeaders(unpacker);
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
|
|
// For MsgPack, we represent an empty invocation ID as an empty string,
|
|
// so we need to normalize that to "null", which is what indicates a non-blocking invocation.
|
|
if (string.IsNullOrEmpty(invocationId))
|
|
{
|
|
invocationId = null;
|
|
}
|
|
|
|
var target = ReadString(unpacker, "target");
|
|
var parameterTypes = binder.GetParameterTypes(target);
|
|
|
|
try
|
|
{
|
|
var arguments = BindArguments(unpacker, parameterTypes);
|
|
return ApplyHeaders(headers, new InvocationMessage(invocationId, target, argumentBindingException: null, arguments: arguments));
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return ApplyHeaders(headers, new InvocationMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)));
|
|
}
|
|
}
|
|
|
|
private static StreamInvocationMessage CreateStreamInvocationMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var headers = ReadHeaders(unpacker);
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var target = ReadString(unpacker, "target");
|
|
var parameterTypes = binder.GetParameterTypes(target);
|
|
|
|
try
|
|
{
|
|
var arguments = BindArguments(unpacker, parameterTypes);
|
|
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, argumentBindingException: null, arguments: arguments));
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)));
|
|
}
|
|
}
|
|
|
|
private static StreamItemMessage CreateStreamItemMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var headers = ReadHeaders(unpacker);
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var itemType = binder.GetReturnType(invocationId);
|
|
var value = DeserializeObject(unpacker, itemType, "item");
|
|
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
|
|
}
|
|
|
|
private static CompletionMessage CreateCompletionMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var headers = ReadHeaders(unpacker);
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var resultKind = ReadInt32(unpacker, "resultKind");
|
|
|
|
string error = null;
|
|
object result = null;
|
|
var hasResult = false;
|
|
|
|
switch (resultKind)
|
|
{
|
|
case ErrorResult:
|
|
error = ReadString(unpacker, "error");
|
|
break;
|
|
case NonVoidResult:
|
|
var itemType = binder.GetReturnType(invocationId);
|
|
result = DeserializeObject(unpacker, itemType, "argument");
|
|
hasResult = true;
|
|
break;
|
|
case VoidResult:
|
|
hasResult = false;
|
|
break;
|
|
default:
|
|
throw new FormatException("Invalid invocation result kind.");
|
|
}
|
|
|
|
return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
|
|
}
|
|
|
|
private static CancelInvocationMessage CreateCancelInvocationMessage(Unpacker unpacker)
|
|
{
|
|
var headers = ReadHeaders(unpacker);
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
|
|
}
|
|
|
|
private static Dictionary<string, string> ReadHeaders(Unpacker unpacker)
|
|
{
|
|
var headerCount = ReadMapLength(unpacker, "headers");
|
|
if (headerCount > 0)
|
|
{
|
|
// If headerCount is larger than int.MaxValue, things are going to go horribly wrong anyway :)
|
|
var headers = new Dictionary<string, string>((int)headerCount);
|
|
|
|
for (var i = 0; i < headerCount; i++)
|
|
{
|
|
var key = ReadString(unpacker, $"headers[{i}].Key");
|
|
var value = ReadString(unpacker, $"headers[{i}].Value");
|
|
headers[key] = value;
|
|
}
|
|
return headers;
|
|
}
|
|
else
|
|
{
|
|
return null;
|
|
}
|
|
}
|
|
|
|
private static object[] BindArguments(Unpacker unpacker, IReadOnlyList<Type> parameterTypes)
|
|
{
|
|
var argumentCount = ReadArrayLength(unpacker, "arguments");
|
|
|
|
if (parameterTypes.Count != argumentCount)
|
|
{
|
|
throw new FormatException(
|
|
$"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}.");
|
|
}
|
|
|
|
try
|
|
{
|
|
var arguments = new object[argumentCount];
|
|
for (var i = 0; i < argumentCount; i++)
|
|
{
|
|
arguments[i] = DeserializeObject(unpacker, parameterTypes[i], "argument");
|
|
}
|
|
|
|
return arguments;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
throw new FormatException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex);
|
|
}
|
|
}
|
|
|
|
private static T ApplyHeaders<T>(IDictionary<string, string> source, T destination) where T : HubInvocationMessage
|
|
{
|
|
if (source != null && source.Count > 0)
|
|
{
|
|
destination.Headers = source;
|
|
}
|
|
|
|
return destination;
|
|
}
|
|
|
|
public void WriteMessage(HubMessage message, Stream output)
|
|
{
|
|
// We're writing data into the memoryStream so that we can get the length prefix
|
|
using (var memoryStream = new MemoryStream())
|
|
{
|
|
WriteMessageCore(message, memoryStream);
|
|
if (memoryStream.TryGetBuffer(out var buffer))
|
|
{
|
|
// Write the buffer directly
|
|
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
|
|
output.Write(buffer.Array, buffer.Offset, buffer.Count);
|
|
}
|
|
else
|
|
{
|
|
BinaryMessageFormatter.WriteLengthPrefix(memoryStream.Length, output);
|
|
memoryStream.Position = 0;
|
|
memoryStream.CopyTo(output);
|
|
}
|
|
}
|
|
}
|
|
|
|
private void WriteMessageCore(HubMessage message, Stream output)
|
|
{
|
|
// PackerCompatibilityOptions.None prevents from serializing byte[] as strings
|
|
// and allows extended objects
|
|
var packer = Packer.Create(output, PackerCompatibilityOptions.None);
|
|
switch (message)
|
|
{
|
|
case InvocationMessage invocationMessage:
|
|
WriteInvocationMessage(invocationMessage, packer);
|
|
break;
|
|
case StreamInvocationMessage streamInvocationMessage:
|
|
WriteStreamInvocationMessage(streamInvocationMessage, packer);
|
|
break;
|
|
case StreamItemMessage streamItemMessage:
|
|
WriteStreamingItemMessage(streamItemMessage, packer);
|
|
break;
|
|
case CompletionMessage completionMessage:
|
|
WriteCompletionMessage(completionMessage, packer);
|
|
break;
|
|
case CancelInvocationMessage cancelInvocationMessage:
|
|
WriteCancelInvocationMessage(cancelInvocationMessage, packer);
|
|
break;
|
|
case PingMessage pingMessage:
|
|
WritePingMessage(pingMessage, packer);
|
|
break;
|
|
default:
|
|
throw new FormatException($"Unexpected message type: {message.GetType().Name}");
|
|
}
|
|
}
|
|
|
|
private void WriteInvocationMessage(InvocationMessage message, Packer packer)
|
|
{
|
|
packer.PackArrayHeader(5);
|
|
packer.Pack(HubProtocolConstants.InvocationMessageType);
|
|
PackHeaders(packer, message.Headers);
|
|
if (string.IsNullOrEmpty(message.InvocationId))
|
|
{
|
|
packer.PackNull();
|
|
}
|
|
else
|
|
{
|
|
packer.PackString(message.InvocationId);
|
|
}
|
|
packer.PackString(message.Target);
|
|
packer.PackObject(message.Arguments, SerializationContext);
|
|
}
|
|
|
|
private void WriteStreamInvocationMessage(StreamInvocationMessage message, Packer packer)
|
|
{
|
|
packer.PackArrayHeader(5);
|
|
packer.Pack(HubProtocolConstants.StreamInvocationMessageType);
|
|
PackHeaders(packer, message.Headers);
|
|
packer.PackString(message.InvocationId);
|
|
packer.PackString(message.Target);
|
|
packer.PackObject(message.Arguments, SerializationContext);
|
|
}
|
|
|
|
private void WriteStreamingItemMessage(StreamItemMessage message, Packer packer)
|
|
{
|
|
packer.PackArrayHeader(4);
|
|
packer.Pack(HubProtocolConstants.StreamItemMessageType);
|
|
PackHeaders(packer, message.Headers);
|
|
packer.PackString(message.InvocationId);
|
|
packer.PackObject(message.Item, SerializationContext);
|
|
}
|
|
|
|
private void WriteCompletionMessage(CompletionMessage message, Packer packer)
|
|
{
|
|
var resultKind =
|
|
message.Error != null ? ErrorResult :
|
|
message.HasResult ? NonVoidResult :
|
|
VoidResult;
|
|
|
|
packer.PackArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
|
|
packer.Pack(HubProtocolConstants.CompletionMessageType);
|
|
PackHeaders(packer, message.Headers);
|
|
packer.PackString(message.InvocationId);
|
|
packer.Pack(resultKind);
|
|
switch (resultKind)
|
|
{
|
|
case ErrorResult:
|
|
packer.PackString(message.Error);
|
|
break;
|
|
case NonVoidResult:
|
|
packer.PackObject(message.Result, SerializationContext);
|
|
break;
|
|
}
|
|
}
|
|
|
|
private void WriteCancelInvocationMessage(CancelInvocationMessage message, Packer packer)
|
|
{
|
|
packer.PackArrayHeader(3);
|
|
packer.Pack(HubProtocolConstants.CancelInvocationMessageType);
|
|
PackHeaders(packer, message.Headers);
|
|
packer.PackString(message.InvocationId);
|
|
}
|
|
|
|
private void WritePingMessage(PingMessage pingMessage, Packer packer)
|
|
{
|
|
packer.PackArrayHeader(1);
|
|
packer.Pack(HubProtocolConstants.PingMessageType);
|
|
}
|
|
|
|
private void PackHeaders(Packer packer, IDictionary<string, string> headers)
|
|
{
|
|
if (headers != null)
|
|
{
|
|
packer.PackMapHeader(headers.Count);
|
|
if (headers.Count > 0)
|
|
{
|
|
foreach (var header in headers)
|
|
{
|
|
packer.PackString(header.Key);
|
|
packer.PackString(header.Value);
|
|
}
|
|
}
|
|
}
|
|
else
|
|
{
|
|
packer.PackMapHeader(0);
|
|
}
|
|
}
|
|
|
|
private static string ReadInvocationId(Unpacker unpacker)
|
|
{
|
|
return ReadString(unpacker, "invocationId");
|
|
}
|
|
|
|
private static int ReadInt32(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadInt32(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as Int32 failed.", msgPackException);
|
|
}
|
|
|
|
private static string ReadString(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.Read())
|
|
{
|
|
if (unpacker.LastReadData.IsNil)
|
|
{
|
|
return null;
|
|
}
|
|
else
|
|
{
|
|
return unpacker.LastReadData.AsString();
|
|
}
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as String failed.", msgPackException);
|
|
}
|
|
|
|
private static bool ReadBoolean(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadBoolean(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as Boolean failed.", msgPackException);
|
|
}
|
|
|
|
private static long ReadMapLength(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadMapLength(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading map length for '{field}' failed.", msgPackException);
|
|
}
|
|
|
|
private static long ReadArrayLength(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadArrayLength(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading array length for '{field}' failed.", msgPackException);
|
|
}
|
|
|
|
private static object DeserializeObject(Unpacker unpacker, Type type, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.Read())
|
|
{
|
|
var serializer = MessagePackSerializer.Get(type);
|
|
return serializer.UnpackFrom(unpacker);
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
msgPackException = ex;
|
|
}
|
|
|
|
throw new FormatException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException);
|
|
}
|
|
|
|
internal static SerializationContext CreateDefaultSerializationContext()
|
|
{
|
|
// serializes objects (here: arguments and results) as maps so that property names are preserved
|
|
var serializationContext = new SerializationContext { SerializationMethod = SerializationMethod.Map };
|
|
|
|
// allows for serializing objects that cannot be deserialized due to the lack of the default ctor etc.
|
|
serializationContext.CompatibilityOptions.AllowAsymmetricSerializer = true;
|
|
return serializationContext;
|
|
}
|
|
}
|
|
}
|