308 lines
11 KiB
C#
308 lines
11 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.IO;
|
|
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
|
|
using MsgPack;
|
|
using MsgPack.Serialization;
|
|
|
|
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|
{
|
|
public class MessagePackHubProtocol : IHubProtocol
|
|
{
|
|
private const int InvocationMessageType = 1;
|
|
private const int StreamItemMessageType = 2;
|
|
private const int CompletionMessageType = 3;
|
|
|
|
private const int ErrorResult = 1;
|
|
private const int VoidResult = 2;
|
|
private const int NonVoidResult = 3;
|
|
|
|
private readonly SerializationContext _serializationContext;
|
|
|
|
public string Name => "messagepack";
|
|
|
|
public ProtocolType Type => ProtocolType.Binary;
|
|
|
|
public MessagePackHubProtocol()
|
|
: this(CreateDefaultSerializationContext())
|
|
{ }
|
|
|
|
public MessagePackHubProtocol(SerializationContext serializationContext)
|
|
{
|
|
_serializationContext = serializationContext;
|
|
}
|
|
|
|
public bool TryParseMessages(ReadOnlyBuffer<byte> input, IInvocationBinder binder, out IList<HubMessage> messages)
|
|
{
|
|
messages = new List<HubMessage>();
|
|
|
|
while (BinaryMessageParser.TryParseMessage(ref input, out var payload))
|
|
{
|
|
using (var memoryStream = new MemoryStream(payload.ToArray()))
|
|
{
|
|
messages.Add(ParseMessage(memoryStream, binder));
|
|
}
|
|
}
|
|
|
|
return messages.Count > 0;
|
|
}
|
|
|
|
private static HubMessage ParseMessage(Stream input, IInvocationBinder binder)
|
|
{
|
|
var unpacker = Unpacker.Create(input);
|
|
_ = ReadArrayLength(unpacker, "elementCount");
|
|
var messageType = ReadInt32(unpacker, "messageType");
|
|
|
|
switch (messageType)
|
|
{
|
|
case InvocationMessageType:
|
|
return CreateInvocationMessage(unpacker, binder);
|
|
case StreamItemMessageType:
|
|
return CreateStreamItemMessage(unpacker, binder);
|
|
case CompletionMessageType:
|
|
return CreateCompletionMessage(unpacker, binder);
|
|
default:
|
|
throw new FormatException($"Invalid message type: {messageType}.");
|
|
}
|
|
}
|
|
|
|
private static InvocationMessage CreateInvocationMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var nonBlocking = ReadBoolean(unpacker, "nonBlocking");
|
|
var target = ReadString(unpacker, "target");
|
|
var argumentCount = ReadArrayLength(unpacker, "arguments");
|
|
var parameterTypes = binder.GetParameterTypes(target);
|
|
|
|
if (parameterTypes.Length != argumentCount)
|
|
{
|
|
throw new FormatException(
|
|
$"Target method expects {parameterTypes.Length} arguments(s) but invocation has {argumentCount} argument(s).");
|
|
}
|
|
|
|
var arguments = new object[argumentCount];
|
|
for (var i = 0; i < argumentCount; i++)
|
|
{
|
|
arguments[i] = DeserializeObject(unpacker, parameterTypes[i], "argument");
|
|
}
|
|
|
|
return new InvocationMessage(invocationId, nonBlocking, target, arguments);
|
|
}
|
|
|
|
private static StreamItemMessage CreateStreamItemMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var itemType = binder.GetReturnType(invocationId);
|
|
var value = DeserializeObject(unpacker, itemType, "item");
|
|
return new StreamItemMessage(invocationId, value);
|
|
}
|
|
|
|
private static CompletionMessage CreateCompletionMessage(Unpacker unpacker, IInvocationBinder binder)
|
|
{
|
|
var invocationId = ReadInvocationId(unpacker);
|
|
var resultKind = ReadInt32(unpacker, "resultKind");
|
|
|
|
string error = null;
|
|
object result = null;
|
|
var hasResult = false;
|
|
|
|
switch(resultKind)
|
|
{
|
|
case ErrorResult:
|
|
error = ReadString(unpacker, "error");
|
|
break;
|
|
case NonVoidResult:
|
|
var itemType = binder.GetReturnType(invocationId);
|
|
result = DeserializeObject(unpacker, itemType, "argument");
|
|
hasResult = true;
|
|
break;
|
|
case VoidResult:
|
|
hasResult = false;
|
|
break;
|
|
default:
|
|
throw new FormatException("Invalid invocation result kind.");
|
|
}
|
|
|
|
return new CompletionMessage(invocationId, error, result, hasResult);
|
|
}
|
|
|
|
public void WriteMessage(HubMessage message, Stream output)
|
|
{
|
|
using (var memoryStream = new MemoryStream())
|
|
{
|
|
WriteMessageCore(message, memoryStream);
|
|
BinaryMessageFormatter.WriteMessage(new ReadOnlySpan<byte>(memoryStream.ToArray()), output);
|
|
}
|
|
}
|
|
|
|
private void WriteMessageCore(HubMessage message, Stream output)
|
|
{
|
|
// PackerCompatibilityOptions.None prevents from serializing byte[] as strings
|
|
// and allows extended objects
|
|
var packer = Packer.Create(output, PackerCompatibilityOptions.None);
|
|
switch (message)
|
|
{
|
|
case InvocationMessage invocationMessage:
|
|
WriteInvocationMessage(invocationMessage, packer, output);
|
|
break;
|
|
case StreamItemMessage streamItemMessage:
|
|
WriteStreamingItemMessage(streamItemMessage, packer, output);
|
|
break;
|
|
case CompletionMessage completionMessage:
|
|
WriteCompletionMessage(completionMessage, packer, output);
|
|
break;
|
|
default:
|
|
throw new FormatException($"Unexpected message type: {message.GetType().Name}");
|
|
}
|
|
}
|
|
|
|
private void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer, Stream output)
|
|
{
|
|
packer.PackArrayHeader(5);
|
|
packer.Pack(InvocationMessageType);
|
|
packer.PackString(invocationMessage.InvocationId);
|
|
packer.Pack(invocationMessage.NonBlocking);
|
|
packer.PackString(invocationMessage.Target);
|
|
packer.PackObject(invocationMessage.Arguments, _serializationContext);
|
|
}
|
|
|
|
private void WriteStreamingItemMessage(StreamItemMessage streamItemMessage, Packer packer, Stream output)
|
|
{
|
|
packer.PackArrayHeader(3);
|
|
packer.Pack(StreamItemMessageType);
|
|
packer.PackString(streamItemMessage.InvocationId);
|
|
packer.PackObject(streamItemMessage.Item, _serializationContext);
|
|
}
|
|
|
|
private void WriteCompletionMessage(CompletionMessage completionMessage, Packer packer, Stream output)
|
|
{
|
|
var resultKind =
|
|
completionMessage.Error != null ? ErrorResult :
|
|
completionMessage.HasResult ? NonVoidResult :
|
|
VoidResult;
|
|
|
|
packer.PackArrayHeader(3 + (resultKind != VoidResult ? 1 : 0));
|
|
packer.Pack(CompletionMessageType);
|
|
packer.PackString(completionMessage.InvocationId);
|
|
packer.Pack(resultKind);
|
|
switch (resultKind)
|
|
{
|
|
case ErrorResult:
|
|
packer.PackString(completionMessage.Error);
|
|
break;
|
|
case NonVoidResult:
|
|
packer.PackObject(completionMessage.Result, _serializationContext);
|
|
break;
|
|
}
|
|
}
|
|
|
|
private static string ReadInvocationId(Unpacker unpacker)
|
|
{
|
|
return ReadString(unpacker, "invocationId");
|
|
}
|
|
|
|
private static int ReadInt32(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadInt32(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as Int32 failed.", msgPackException);
|
|
}
|
|
|
|
private static string ReadString(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadString(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as String failed.", msgPackException);
|
|
}
|
|
|
|
private static bool ReadBoolean(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadBoolean(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading '{field}' as Boolean failed.", msgPackException);
|
|
}
|
|
|
|
private static long ReadArrayLength(Unpacker unpacker, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.ReadArrayLength(out var value))
|
|
{
|
|
return value;
|
|
}
|
|
}
|
|
catch (Exception e)
|
|
{
|
|
msgPackException = e;
|
|
}
|
|
|
|
throw new FormatException($"Reading array length for '{field}' failed.", msgPackException);
|
|
}
|
|
|
|
private static object DeserializeObject(Unpacker unpacker, Type type, string field)
|
|
{
|
|
Exception msgPackException = null;
|
|
try
|
|
{
|
|
if (unpacker.Read())
|
|
{
|
|
var serializer = MessagePackSerializer.Get(type);
|
|
return serializer.UnpackFrom(unpacker);
|
|
}
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
msgPackException = ex;
|
|
}
|
|
|
|
throw new FormatException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException);
|
|
}
|
|
|
|
public static SerializationContext CreateDefaultSerializationContext()
|
|
{
|
|
// serializes objects (here: arguments and results) as maps so that property names are preserved
|
|
var serializationContext = new SerializationContext { SerializationMethod = SerializationMethod.Map };
|
|
// allows for serializing objects that cannot be deserialized due to the lack of the default ctor etc.
|
|
serializationContext.CompatibilityOptions.AllowAsymmetricSerializer = true;
|
|
return serializationContext;
|
|
}
|
|
}
|
|
}
|