// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Runtime.ExceptionServices; using System.Runtime.InteropServices; using MessagePack; using MessagePack.Formatters; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR.Protocol { /// /// Implements the SignalR Hub Protocol using MessagePack. /// public class MessagePackHubProtocol : IHubProtocol { private const int ErrorResult = 1; private const int VoidResult = 2; private const int NonVoidResult = 3; private IFormatterResolver _resolver; private static readonly string ProtocolName = "messagepack"; private static readonly int ProtocolVersion = 1; private static readonly int ProtocolMinorVersion = 0; /// public string Name => ProtocolName; /// public int Version => ProtocolVersion; /// public int MinorVersion => ProtocolMinorVersion; /// public TransferFormat TransferFormat => TransferFormat.Binary; /// /// Initializes a new instance of the class. /// public MessagePackHubProtocol() : this(Options.Create(new MessagePackHubProtocolOptions())) { } /// /// Initializes a new instance of the class. /// /// The options used to initialize the protocol. public MessagePackHubProtocol(IOptions options) { var msgPackOptions = options.Value; SetupResolver(msgPackOptions); } private void SetupResolver(MessagePackHubProtocolOptions options) { // if counts don't match then we know users customized resolvers so we set up the options // with the provided resolvers if (options.FormatterResolvers.Count != SignalRResolver.Resolvers.Count) { _resolver = new CombinedResolvers(options.FormatterResolvers); return; } for (var i = 0; i < options.FormatterResolvers.Count; i++) { // check if the user customized the resolvers if (options.FormatterResolvers[i] != SignalRResolver.Resolvers[i]) { _resolver = new CombinedResolvers(options.FormatterResolvers); return; } } // Use optimized cached resolver if the default is chosen _resolver = SignalRResolver.Instance; } /// public bool IsVersionSupported(int version) { return version == Version; } /// public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) { if (!BinaryMessageParser.TryParseMessage(ref input, out var payload)) { message = null; return false; } var arraySegment = GetArraySegment(payload); message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder, _resolver); return true; } private static ArraySegment GetArraySegment(in ReadOnlySequence input) { if (input.IsSingleSegment) { var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment); // This will never be false unless we started using un-managed buffers Debug.Assert(isArray); return arraySegment; } // Should be rare return new ArraySegment(input.ToArray()); } private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver) { _ = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize); startOffset += readSize; var messageType = ReadInt32(input, ref startOffset, "messageType"); switch (messageType) { case HubProtocolConstants.InvocationMessageType: return CreateInvocationMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.StreamInvocationMessageType: return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.StreamItemMessageType: return CreateStreamItemMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.CompletionMessageType: return CreateCompletionMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.CancelInvocationMessageType: return CreateCancelInvocationMessage(input, ref startOffset); case HubProtocolConstants.PingMessageType: return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: return CreateCloseMessage(input, ref startOffset); default: // Future protocol changes can add message types, old clients can ignore them return null; } } private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); // For MsgPack, we represent an empty invocation ID as an empty string, // so we need to normalize that to "null", which is what indicates a non-blocking invocation. if (string.IsNullOrEmpty(invocationId)) { invocationId = null; } var target = ReadString(input, ref offset, "target"); var parameterTypes = binder.GetParameterTypes(target); try { var arguments = BindArguments(input, ref offset, parameterTypes, resolver); return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments)); } catch (Exception ex) { return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); } } private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); var target = ReadString(input, ref offset, "target"); var parameterTypes = binder.GetParameterTypes(target); try { var arguments = BindArguments(input, ref offset, parameterTypes, resolver); return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments)); } catch (Exception ex) { return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); } } private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); var itemType = binder.GetReturnType(invocationId); var value = DeserializeObject(input, ref offset, itemType, "item", resolver); return ApplyHeaders(headers, new StreamItemMessage(invocationId, value)); } private static CompletionMessage CreateCompletionMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); var resultKind = ReadInt32(input, ref offset, "resultKind"); string error = null; object result = null; var hasResult = false; switch (resultKind) { case ErrorResult: error = ReadString(input, ref offset, "error"); break; case NonVoidResult: var itemType = binder.GetReturnType(invocationId); result = DeserializeObject(input, ref offset, itemType, "argument", resolver); hasResult = true; break; case VoidResult: hasResult = false; break; default: throw new InvalidDataException("Invalid invocation result kind."); } return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult)); } private static CancelInvocationMessage CreateCancelInvocationMessage(byte[] input, ref int offset) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); return ApplyHeaders(headers, new CancelInvocationMessage(invocationId)); } private static CloseMessage CreateCloseMessage(byte[] input, ref int offset) { var error = ReadString(input, ref offset, "error"); return new CloseMessage(error); } private static Dictionary ReadHeaders(byte[] input, ref int offset) { var headerCount = ReadMapLength(input, ref offset, "headers"); if (headerCount > 0) { // If headerCount is larger than int.MaxValue, things are going to go horribly wrong anyway :) var headers = new Dictionary((int)headerCount, StringComparer.Ordinal); for (var i = 0; i < headerCount; i++) { var key = ReadString(input, ref offset, $"headers[{i}].Key"); var value = ReadString(input, ref offset, $"headers[{i}].Value"); headers[key] = value; } return headers; } else { return null; } } private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList parameterTypes, IFormatterResolver resolver) { var argumentCount = ReadArrayLength(input, ref offset, "arguments"); if (parameterTypes.Count != argumentCount) { throw new InvalidDataException( $"Invocation provides {argumentCount} argument(s) but target expects {parameterTypes.Count}."); } try { var arguments = new object[argumentCount]; for (var i = 0; i < argumentCount; i++) { arguments[i] = DeserializeObject(input, ref offset, parameterTypes[i], "argument", resolver); } return arguments; } catch (Exception ex) { throw new InvalidDataException("Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", ex); } } private static T ApplyHeaders(IDictionary source, T destination) where T : HubInvocationMessage { if (source != null && source.Count > 0) { destination.Headers = source; } return destination; } /// public void WriteMessage(HubMessage message, IBufferWriter output) { var writer = MemoryBufferWriter.Get(); try { // Write message to a buffer so we can get its length WriteMessageCore(message, writer); // Write length then message to output BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output); writer.CopyTo(output); } finally { MemoryBufferWriter.Return(writer); } } /// public ReadOnlyMemory GetMessageBytes(HubMessage message) { var writer = MemoryBufferWriter.Get(); try { // Write message to a buffer so we can get its length WriteMessageCore(message, writer); var dataLength = writer.Length; var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length); var array = new byte[dataLength + prefixLength]; var span = array.AsSpan(); // Write length then message to output var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span); Debug.Assert(written == prefixLength); writer.CopyTo(span.Slice(prefixLength)); return array; } finally { MemoryBufferWriter.Return(writer); } } private void WriteMessageCore(HubMessage message, Stream packer) { switch (message) { case InvocationMessage invocationMessage: WriteInvocationMessage(invocationMessage, packer); break; case StreamInvocationMessage streamInvocationMessage: WriteStreamInvocationMessage(streamInvocationMessage, packer); break; case StreamItemMessage streamItemMessage: WriteStreamingItemMessage(streamItemMessage, packer); break; case CompletionMessage completionMessage: WriteCompletionMessage(completionMessage, packer); break; case CancelInvocationMessage cancelInvocationMessage: WriteCancelInvocationMessage(cancelInvocationMessage, packer); break; case PingMessage pingMessage: WritePingMessage(pingMessage, packer); break; case CloseMessage closeMessage: WriteCloseMessage(closeMessage, packer); break; default: throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); } } private void WriteInvocationMessage(InvocationMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 5); MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType); PackHeaders(packer, message.Headers); if (string.IsNullOrEmpty(message.InvocationId)) { MessagePackBinary.WriteNil(packer); } else { MessagePackBinary.WriteString(packer, message.InvocationId); } MessagePackBinary.WriteString(packer, message.Target); MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); foreach (var arg in message.Arguments) { WriteArgument(arg, packer); } } private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 5); MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType); PackHeaders(packer, message.Headers); MessagePackBinary.WriteString(packer, message.InvocationId); MessagePackBinary.WriteString(packer, message.Target); MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); foreach (var arg in message.Arguments) { WriteArgument(arg, packer); } } private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 4); MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamItemMessageType); PackHeaders(packer, message.Headers); MessagePackBinary.WriteString(packer, message.InvocationId); WriteArgument(message.Item, packer); } private void WriteArgument(object argument, Stream stream) { if (argument == null) { MessagePackBinary.WriteNil(stream); } else { MessagePackSerializer.NonGeneric.Serialize(argument.GetType(), stream, argument, _resolver); } } private void WriteCompletionMessage(CompletionMessage message, Stream packer) { var resultKind = message.Error != null ? ErrorResult : message.HasResult ? NonVoidResult : VoidResult; MessagePackBinary.WriteArrayHeader(packer, 4 + (resultKind != VoidResult ? 1 : 0)); MessagePackBinary.WriteInt32(packer, HubProtocolConstants.CompletionMessageType); PackHeaders(packer, message.Headers); MessagePackBinary.WriteString(packer, message.InvocationId); MessagePackBinary.WriteInt32(packer, resultKind); switch (resultKind) { case ErrorResult: MessagePackBinary.WriteString(packer, message.Error); break; case NonVoidResult: WriteArgument(message.Result, packer); break; } } private void WriteCancelInvocationMessage(CancelInvocationMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 3); MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CancelInvocationMessageType); PackHeaders(packer, message.Headers); MessagePackBinary.WriteString(packer, message.InvocationId); } private void WriteCloseMessage(CloseMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 2); MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CloseMessageType); if (string.IsNullOrEmpty(message.Error)) { MessagePackBinary.WriteNil(packer); } else { MessagePackBinary.WriteString(packer, message.Error); } } private void WritePingMessage(PingMessage pingMessage, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 1); MessagePackBinary.WriteInt32(packer, HubProtocolConstants.PingMessageType); } private void PackHeaders(Stream packer, IDictionary headers) { if (headers != null) { MessagePackBinary.WriteMapHeader(packer, headers.Count); if (headers.Count > 0) { foreach (var header in headers) { MessagePackBinary.WriteString(packer, header.Key); MessagePackBinary.WriteString(packer, header.Value); } } } else { MessagePackBinary.WriteMapHeader(packer, 0); } } private static string ReadInvocationId(byte[] input, ref int offset) { return ReadString(input, ref offset, "invocationId"); } private static int ReadInt32(byte[] input, ref int offset, string field) { Exception msgPackException = null; try { var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize); offset += readSize; return readInt; } catch (Exception e) { msgPackException = e; } throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException); } private static string ReadString(byte[] input, ref int offset, string field) { Exception msgPackException = null; try { var readString = MessagePackBinary.ReadString(input, offset, out var readSize); offset += readSize; return readString; } catch (Exception e) { msgPackException = e; } throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException); } private static bool ReadBoolean(byte[] input, ref int offset, string field) { Exception msgPackException = null; try { var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize); offset += readSize; return readBool; } catch (Exception e) { msgPackException = e; } throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException); } private static long ReadMapLength(byte[] input, ref int offset, string field) { Exception msgPackException = null; try { var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize); offset += readSize; return readMap; } catch (Exception e) { msgPackException = e; } throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException); } private static long ReadArrayLength(byte[] input, ref int offset, string field) { Exception msgPackException = null; try { var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize); offset += readSize; return readArray; } catch (Exception e) { msgPackException = e; } throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException); } private static object DeserializeObject(byte[] input, ref int offset, Type type, string field, IFormatterResolver resolver) { Exception msgPackException = null; try { var obj = MessagePackSerializer.NonGeneric.Deserialize(type, new ArraySegment(input, offset, input.Length - offset), resolver); offset += MessagePackBinary.ReadNextBlock(input, offset); return obj; } catch (Exception ex) { msgPackException = ex; } throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException); } internal static List CreateDefaultFormatterResolvers() { // Copy to allow users to add/remove resolvers without changing the static SignalRResolver list return new List(SignalRResolver.Resolvers); } internal class SignalRResolver : IFormatterResolver { public static readonly IFormatterResolver Instance = new SignalRResolver(); public static readonly IList Resolvers = new[] { MessagePack.Resolvers.DynamicEnumAsStringResolver.Instance, MessagePack.Resolvers.ContractlessStandardResolver.Instance, }; public IMessagePackFormatter GetFormatter() { return Cache.Formatter; } private static class Cache { public static readonly IMessagePackFormatter Formatter; static Cache() { foreach (var resolver in Resolvers) { Formatter = resolver.GetFormatter(); if (Formatter != null) { return; } } } } } // Support for users making their own Formatter lists internal class CombinedResolvers : IFormatterResolver { private readonly IList _resolvers; public CombinedResolvers(IList resolvers) { _resolvers = resolvers; } public IMessagePackFormatter GetFormatter() { foreach (var resolver in _resolvers) { var formatter = resolver.GetFormatter(); if (formatter != null) { return formatter; } } return null; } } } }