Update MessagePack to v2.0 for SignalR (#18133)

This commit is contained in:
TeBeCo 2020-01-24 22:43:26 +01:00 committed by Brennan
parent b375410e64
commit 0ace9d91b6
6 changed files with 292 additions and 418 deletions

View File

@ -199,7 +199,7 @@
<SystemCommandlineExperimentalPackageVersion>0.3.0-alpha.19317.1</SystemCommandlineExperimentalPackageVersion>
<SystemComponentModelPackageVersion>4.3.0</SystemComponentModelPackageVersion>
<SystemNetHttpPackageVersion>4.3.2</SystemNetHttpPackageVersion>
<SystemThreadingTasksExtensionsPackageVersion>4.5.2</SystemThreadingTasksExtensionsPackageVersion>
<SystemThreadingTasksExtensionsPackageVersion>4.5.3</SystemThreadingTasksExtensionsPackageVersion>
<!-- Packages developed by @aspnet, but manually updated as necessary. -->
<LibuvPackageVersion>1.10.0</LibuvPackageVersion>
<MicrosoftAspNetWebApiClientPackageVersion>5.2.6</MicrosoftAspNetWebApiClientPackageVersion>
@ -242,7 +242,7 @@
<IdentityServer4PackageVersion>3.0.0</IdentityServer4PackageVersion>
<IdentityServer4StoragePackageVersion>3.0.0</IdentityServer4StoragePackageVersion>
<IdentityServer4EntityFrameworkStoragePackageVersion>3.0.0</IdentityServer4EntityFrameworkStoragePackageVersion>
<MessagePackPackageVersion>1.7.3.7</MessagePackPackageVersion>
<MessagePackPackageVersion>2.0.335</MessagePackPackageVersion>
<MoqPackageVersion>4.10.0</MoqPackageVersion>
<MonoCecilPackageVersion>0.10.1</MonoCecilPackageVersion>
<NewtonsoftJsonBsonPackageVersion>1.0.2</NewtonsoftJsonBsonPackageVersion>

View File

@ -6,10 +6,11 @@ using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using MessagePack;
using MessagePack.Formatters;
using MessagePack.Resolvers;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.Options;
@ -25,8 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
private const int VoidResult = 2;
private const int NonVoidResult = 3;
private IFormatterResolver _resolver;
private MessagePackSerializerOptions _msgPackSerializerOptions;
private static readonly string ProtocolName = "messagepack";
private static readonly int ProtocolVersion = 1;
@ -62,7 +62,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
// with the provided resolvers
if (options.FormatterResolvers.Count != SignalRResolver.Resolvers.Count)
{
_resolver = new CombinedResolvers(options.FormatterResolvers);
var resolver = CompositeResolver.Create(Array.Empty<IMessagePackFormatter>(), (IReadOnlyList<IFormatterResolver>)options.FormatterResolvers);
_msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver);
return;
}
@ -71,13 +73,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
// check if the user customized the resolvers
if (options.FormatterResolvers[i] != SignalRResolver.Resolvers[i])
{
_resolver = new CombinedResolvers(options.FormatterResolvers);
var resolver = CompositeResolver.Create(Array.Empty<IMessagePackFormatter>(), (IReadOnlyList<IFormatterResolver>)options.FormatterResolvers);
_msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver);
return;
}
}
// Use optimized cached resolver if the default is chosen
_resolver = SignalRResolver.Instance;
_msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(SignalRResolver.Instance);
}
/// <inheritdoc />
@ -95,59 +98,43 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return false;
}
var arraySegment = GetArraySegment(payload);
message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder, _resolver);
var reader = new MessagePackReader(payload);
message = ParseMessage(ref reader, binder, _msgPackSerializerOptions);
return true;
}
private static ArraySegment<byte> GetArraySegment(in ReadOnlySequence<byte> input)
private static HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
if (input.IsSingleSegment)
{
var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment);
// This will never be false unless we started using un-managed buffers
Debug.Assert(isArray);
return arraySegment;
}
var itemCount = reader.ReadArrayHeader();
// Should be rare
return new ArraySegment<byte>(input.ToArray());
}
private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver)
{
var itemCount = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize);
startOffset += readSize;
var messageType = ReadInt32(input, ref startOffset, "messageType");
var messageType = ReadInt32(ref reader, "messageType");
switch (messageType)
{
case HubProtocolConstants.InvocationMessageType:
return CreateInvocationMessage(input, ref startOffset, binder, resolver, itemCount);
return CreateInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
case HubProtocolConstants.StreamInvocationMessageType:
return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver, itemCount);
return CreateStreamInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount);
case HubProtocolConstants.StreamItemMessageType:
return CreateStreamItemMessage(input, ref startOffset, binder, resolver);
return CreateStreamItemMessage(ref reader, binder, msgPackSerializerOptions);
case HubProtocolConstants.CompletionMessageType:
return CreateCompletionMessage(input, ref startOffset, binder, resolver);
return CreateCompletionMessage(ref reader, binder, msgPackSerializerOptions);
case HubProtocolConstants.CancelInvocationMessageType:
return CreateCancelInvocationMessage(input, ref startOffset);
return CreateCancelInvocationMessage(ref reader);
case HubProtocolConstants.PingMessageType:
return PingMessage.Instance;
case HubProtocolConstants.CloseMessageType:
return CreateCloseMessage(input, ref startOffset, itemCount);
return CreateCloseMessage(ref reader, itemCount);
default:
// Future protocol changes can add message types, old clients can ignore them
return null;
}
}
private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount)
private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var headers = ReadHeaders(ref reader);
var invocationId = ReadInvocationId(ref reader);
// For MsgPack, we represent an empty invocation ID as an empty string,
// so we need to normalize that to "null", which is what indicates a non-blocking invocation.
@ -156,13 +143,13 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
invocationId = null;
}
var target = ReadString(input, ref offset, "target");
var target = ReadString(ref reader, "target");
object[] arguments = null;
try
{
var parameterTypes = binder.GetParameterTypes(target);
arguments = BindArguments(input, ref offset, parameterTypes, resolver);
arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
}
catch (Exception ex)
{
@ -173,23 +160,23 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
// Previous clients will send 5 items, so we check if they sent a stream array or not
if (itemCount > 5)
{
streams = ReadStreamIds(input, ref offset);
streams = ReadStreamIds(ref reader);
}
return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams));
}
private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount)
private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var target = ReadString(input, ref offset, "target");
var headers = ReadHeaders(ref reader);
var invocationId = ReadInvocationId(ref reader);
var target = ReadString(ref reader, "target");
object[] arguments = null;
try
{
var parameterTypes = binder.GetParameterTypes(target);
arguments = BindArguments(input, ref offset, parameterTypes, resolver);
arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions);
}
catch (Exception ex)
{
@ -200,21 +187,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
// Previous clients will send 5 items, so we check if they sent a stream array or not
if (itemCount > 5)
{
streams = ReadStreamIds(input, ref offset);
streams = ReadStreamIds(ref reader);
}
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
}
private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
private static HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var headers = ReadHeaders(ref reader);
var invocationId = ReadInvocationId(ref reader);
object value;
try
{
var itemType = binder.GetStreamItemType(invocationId);
value = DeserializeObject(input, ref offset, itemType, "item", resolver);
value = DeserializeObject(ref reader, itemType, "item", msgPackSerializerOptions);
}
catch (Exception ex)
{
@ -224,11 +211,11 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
}
private static CompletionMessage CreateCompletionMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var resultKind = ReadInt32(input, ref offset, "resultKind");
var headers = ReadHeaders(ref reader);
var invocationId = ReadInvocationId(ref reader);
var resultKind = ReadInt32(ref reader, "resultKind");
string error = null;
object result = null;
@ -237,11 +224,11 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
switch (resultKind)
{
case ErrorResult:
error = ReadString(input, ref offset, "error");
error = ReadString(ref reader, "error");
break;
case NonVoidResult:
var itemType = binder.GetReturnType(invocationId);
result = DeserializeObject(input, ref offset, itemType, "argument", resolver);
result = DeserializeObject(ref reader, itemType, "argument", msgPackSerializerOptions);
hasResult = true;
break;
case VoidResult:
@ -254,21 +241,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult));
}
private static CancelInvocationMessage CreateCancelInvocationMessage(byte[] input, ref int offset)
private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var headers = ReadHeaders(ref reader);
var invocationId = ReadInvocationId(ref reader);
return ApplyHeaders(headers, new CancelInvocationMessage(invocationId));
}
private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int itemCount)
private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount)
{
var error = ReadString(input, ref offset, "error");
var error = ReadString(ref reader, "error");
var allowReconnect = false;
if (itemCount > 2)
{
allowReconnect = ReadBoolean(input, ref offset, "allowReconnect");
allowReconnect = ReadBoolean(ref reader, "allowReconnect");
}
// An empty string is still an error
@ -280,17 +267,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return new CloseMessage(error, allowReconnect);
}
private static Dictionary<string, string> ReadHeaders(byte[] input, ref int offset)
private static Dictionary<string, string> ReadHeaders(ref MessagePackReader reader)
{
var headerCount = ReadMapLength(input, ref offset, "headers");
var headerCount = ReadMapLength(ref reader, "headers");
if (headerCount > 0)
{
var headers = new Dictionary<string, string>(StringComparer.Ordinal);
for (var i = 0; i < headerCount; i++)
{
var key = ReadString(input, ref offset, $"headers[{i}].Key");
var value = ReadString(input, ref offset, $"headers[{i}].Value");
var key = ReadString(ref reader, $"headers[{i}].Key");
var value = ReadString(ref reader, $"headers[{i}].Value");
headers.Add(key, value);
}
return headers;
@ -301,9 +288,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
}
}
private static string[] ReadStreamIds(byte[] input, ref int offset)
private static string[] ReadStreamIds(ref MessagePackReader reader)
{
var streamIdCount = ReadArrayLength(input, ref offset, "streamIds");
var streamIdCount = ReadArrayLength(ref reader, "streamIds");
List<string> streams = null;
if (streamIdCount > 0)
@ -311,17 +298,16 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
streams = new List<string>();
for (var i = 0; i < streamIdCount; i++)
{
streams.Add(MessagePackBinary.ReadString(input, offset, out var read));
offset += read;
streams.Add(reader.ReadString());
}
}
return streams?.ToArray();
}
private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList<Type> parameterTypes, IFormatterResolver resolver)
private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList<Type> parameterTypes, MessagePackSerializerOptions msgPackSerializerOptions)
{
var argumentCount = ReadArrayLength(input, ref offset, "arguments");
var argumentCount = ReadArrayLength(ref reader, "arguments");
if (parameterTypes.Count != argumentCount)
{
@ -334,7 +320,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
var arguments = new object[argumentCount];
for (var i = 0; i < argumentCount; i++)
{
arguments[i] = DeserializeObject(input, ref offset, parameterTypes[i], "argument", resolver);
arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument", msgPackSerializerOptions);
}
return arguments;
@ -358,339 +344,314 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
/// <inheritdoc />
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
var writer = MemoryBufferWriter.Get();
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
var writer = new MessagePackWriter(memoryBufferWriter);
// Write message to a buffer so we can get its length
WriteMessageCore(message, writer);
WriteMessageCore(message, ref writer);
// Write length then message to output
BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output);
writer.CopyTo(output);
BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output);
memoryBufferWriter.CopyTo(output);
}
finally
{
MemoryBufferWriter.Return(writer);
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
/// <inheritdoc />
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
var writer = MemoryBufferWriter.Get();
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
// Write message to a buffer so we can get its length
WriteMessageCore(message, writer);
var writer = new MessagePackWriter(memoryBufferWriter);
var dataLength = writer.Length;
var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length);
// Write message to a buffer so we can get its length
WriteMessageCore(message, ref writer);
var dataLength = memoryBufferWriter.Length;
var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length);
var array = new byte[dataLength + prefixLength];
var span = array.AsSpan();
// Write length then message to output
var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span);
var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span);
Debug.Assert(written == prefixLength);
writer.CopyTo(span.Slice(prefixLength));
memoryBufferWriter.CopyTo(span.Slice(prefixLength));
return array;
}
finally
{
MemoryBufferWriter.Return(writer);
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
private void WriteMessageCore(HubMessage message, Stream packer)
private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer)
{
switch (message)
{
case InvocationMessage invocationMessage:
WriteInvocationMessage(invocationMessage, packer);
WriteInvocationMessage(invocationMessage, ref writer);
break;
case StreamInvocationMessage streamInvocationMessage:
WriteStreamInvocationMessage(streamInvocationMessage, packer);
WriteStreamInvocationMessage(streamInvocationMessage, ref writer);
break;
case StreamItemMessage streamItemMessage:
WriteStreamingItemMessage(streamItemMessage, packer);
WriteStreamingItemMessage(streamItemMessage, ref writer);
break;
case CompletionMessage completionMessage:
WriteCompletionMessage(completionMessage, packer);
WriteCompletionMessage(completionMessage, ref writer);
break;
case CancelInvocationMessage cancelInvocationMessage:
WriteCancelInvocationMessage(cancelInvocationMessage, packer);
WriteCancelInvocationMessage(cancelInvocationMessage, ref writer);
break;
case PingMessage pingMessage:
WritePingMessage(pingMessage, packer);
WritePingMessage(pingMessage, ref writer);
break;
case CloseMessage closeMessage:
WriteCloseMessage(closeMessage, packer);
WriteCloseMessage(closeMessage, ref writer);
break;
default:
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
}
writer.Flush();
}
private void WriteInvocationMessage(InvocationMessage message, Stream packer)
private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 6);
writer.WriteArrayHeader(6);
MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType);
PackHeaders(packer, message.Headers);
writer.Write(HubProtocolConstants.InvocationMessageType);
PackHeaders(message.Headers, ref writer);
if (string.IsNullOrEmpty(message.InvocationId))
{
MessagePackBinary.WriteNil(packer);
writer.WriteNil();
}
else
{
MessagePackBinary.WriteString(packer, message.InvocationId);
writer.Write(message.InvocationId);
}
MessagePackBinary.WriteString(packer, message.Target);
MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length);
writer.Write(message.Target);
writer.WriteArrayHeader(message.Arguments.Length);
foreach (var arg in message.Arguments)
{
WriteArgument(arg, packer);
WriteArgument(arg, ref writer);
}
WriteStreamIds(message.StreamIds, packer);
WriteStreamIds(message.StreamIds, ref writer);
}
private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer)
private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 6);
writer.WriteArrayHeader(6);
MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType);
PackHeaders(packer, message.Headers);
MessagePackBinary.WriteString(packer, message.InvocationId);
MessagePackBinary.WriteString(packer, message.Target);
writer.Write(HubProtocolConstants.StreamInvocationMessageType);
PackHeaders(message.Headers, ref writer);
writer.Write(message.InvocationId);
writer.Write(message.Target);
MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length);
writer.WriteArrayHeader(message.Arguments.Length);
foreach (var arg in message.Arguments)
{
WriteArgument(arg, packer);
WriteArgument(arg, ref writer);
}
WriteStreamIds(message.StreamIds, packer);
WriteStreamIds(message.StreamIds, ref writer);
}
private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer)
private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 4);
MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamItemMessageType);
PackHeaders(packer, message.Headers);
MessagePackBinary.WriteString(packer, message.InvocationId);
WriteArgument(message.Item, packer);
writer.WriteArrayHeader(4);
writer.Write(HubProtocolConstants.StreamItemMessageType);
PackHeaders(message.Headers, ref writer);
writer.Write(message.InvocationId);
WriteArgument(message.Item, ref writer);
}
private void WriteArgument(object argument, Stream stream)
private void WriteArgument(object argument, ref MessagePackWriter writer)
{
if (argument == null)
{
MessagePackBinary.WriteNil(stream);
writer.WriteNil();
}
else
{
MessagePackSerializer.NonGeneric.Serialize(argument.GetType(), stream, argument, _resolver);
MessagePackSerializer.Serialize(argument.GetType(), ref writer, argument, _msgPackSerializerOptions);
}
}
private void WriteStreamIds(string[] streamIds, Stream packer)
private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer)
{
if (streamIds != null)
{
MessagePackBinary.WriteArrayHeader(packer, streamIds.Length);
writer.WriteArrayHeader(streamIds.Length);
foreach (var streamId in streamIds)
{
MessagePackBinary.WriteString(packer, streamId);
writer.Write(streamId);
}
}
else
{
MessagePackBinary.WriteArrayHeader(packer, 0);
writer.WriteArrayHeader(0);
}
}
private void WriteCompletionMessage(CompletionMessage message, Stream packer)
private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer)
{
var resultKind =
message.Error != null ? ErrorResult :
message.HasResult ? NonVoidResult :
VoidResult;
MessagePackBinary.WriteArrayHeader(packer, 4 + (resultKind != VoidResult ? 1 : 0));
MessagePackBinary.WriteInt32(packer, HubProtocolConstants.CompletionMessageType);
PackHeaders(packer, message.Headers);
MessagePackBinary.WriteString(packer, message.InvocationId);
MessagePackBinary.WriteInt32(packer, resultKind);
writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0));
writer.Write(HubProtocolConstants.CompletionMessageType);
PackHeaders(message.Headers, ref writer);
writer.Write(message.InvocationId);
writer.Write(resultKind);
switch (resultKind)
{
case ErrorResult:
MessagePackBinary.WriteString(packer, message.Error);
writer.Write(message.Error);
break;
case NonVoidResult:
WriteArgument(message.Result, packer);
WriteArgument(message.Result, ref writer);
break;
}
}
private void WriteCancelInvocationMessage(CancelInvocationMessage message, Stream packer)
private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 3);
MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CancelInvocationMessageType);
PackHeaders(packer, message.Headers);
MessagePackBinary.WriteString(packer, message.InvocationId);
writer.WriteArrayHeader(3);
writer.Write(HubProtocolConstants.CancelInvocationMessageType);
PackHeaders(message.Headers, ref writer);
writer.Write(message.InvocationId);
}
private void WriteCloseMessage(CloseMessage message, Stream packer)
private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 3);
MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CloseMessageType);
writer.WriteArrayHeader(3);
writer.Write(HubProtocolConstants.CloseMessageType);
if (string.IsNullOrEmpty(message.Error))
{
MessagePackBinary.WriteNil(packer);
writer.WriteNil();
}
else
{
MessagePackBinary.WriteString(packer, message.Error);
writer.Write(message.Error);
}
MessagePackBinary.WriteBoolean(packer, message.AllowReconnect);
writer.Write(message.AllowReconnect);
}
private void WritePingMessage(PingMessage pingMessage, Stream packer)
private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer)
{
MessagePackBinary.WriteArrayHeader(packer, 1);
MessagePackBinary.WriteInt32(packer, HubProtocolConstants.PingMessageType);
writer.WriteArrayHeader(1);
writer.Write(HubProtocolConstants.PingMessageType);
}
private void PackHeaders(Stream packer, IDictionary<string, string> headers)
private void PackHeaders(IDictionary<string, string> headers, ref MessagePackWriter writer)
{
if (headers != null)
{
MessagePackBinary.WriteMapHeader(packer, headers.Count);
writer.WriteMapHeader(headers.Count);
if (headers.Count > 0)
{
foreach (var header in headers)
{
MessagePackBinary.WriteString(packer, header.Key);
MessagePackBinary.WriteString(packer, header.Value);
writer.Write(header.Key);
writer.Write(header.Value);
}
}
}
else
{
MessagePackBinary.WriteMapHeader(packer, 0);
writer.WriteMapHeader(0);
}
}
private static string ReadInvocationId(byte[] input, ref int offset)
{
return ReadString(input, ref offset, "invocationId");
}
private static string ReadInvocationId(ref MessagePackReader reader) =>
ReadString(ref reader, "invocationId");
private static bool ReadBoolean(byte[] input, ref int offset, string field)
private static bool ReadBoolean(ref MessagePackReader reader, string field)
{
Exception msgPackException = null;
try
{
var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize);
offset += readSize;
return readBool;
}
catch (Exception e)
{
msgPackException = e;
}
throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException);
}
private static int ReadInt32(byte[] input, ref int offset, string field)
{
Exception msgPackException = null;
try
{
var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize);
offset += readSize;
return readInt;
}
catch (Exception e)
{
msgPackException = e;
}
throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException);
}
private static string ReadString(byte[] input, ref int offset, string field)
{
Exception msgPackException = null;
try
{
var readString = MessagePackBinary.ReadString(input, offset, out var readSize);
offset += readSize;
return readString;
}
catch (Exception e)
{
msgPackException = e;
}
throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException);
}
private static long ReadMapLength(byte[] input, ref int offset, string field)
{
Exception msgPackException = null;
try
{
var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize);
offset += readSize;
return readMap;
}
catch (Exception e)
{
msgPackException = e;
}
throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException);
}
private static long ReadArrayLength(byte[] input, ref int offset, string field)
{
Exception msgPackException = null;
try
{
var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize);
offset += readSize;
return readArray;
}
catch (Exception e)
{
msgPackException = e;
}
throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException);
}
private static object DeserializeObject(byte[] input, ref int offset, Type type, string field, IFormatterResolver resolver)
{
Exception msgPackException = null;
try
{
var obj = MessagePackSerializer.NonGeneric.Deserialize(type, new ArraySegment<byte>(input, offset, input.Length - offset), resolver);
offset += MessagePackBinary.ReadNextBlock(input, offset);
return obj;
return reader.ReadBoolean();
}
catch (Exception ex)
{
msgPackException = ex;
throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex);
}
}
private static int ReadInt32(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadInt32();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex);
}
}
private static string ReadString(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadString();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading '{field}' as String failed.", ex);
}
}
private static long ReadMapLength(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadMapHeader();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading map length for '{field}' failed.", ex);
}
throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException);
}
private static long ReadArrayLength(ref MessagePackReader reader, string field)
{
try
{
return reader.ReadArrayHeader();
}
catch (Exception ex)
{
throw new InvalidDataException($"Reading array length for '{field}' failed.", ex);
}
}
private static object DeserializeObject(ref MessagePackReader reader, Type type, string field, MessagePackSerializerOptions msgPackSerializerOptions)
{
try
{
return MessagePackSerializer.Deserialize(type, ref reader, msgPackSerializerOptions);
}
catch (Exception ex)
{
throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex);
}
}
internal static List<IFormatterResolver> CreateDefaultFormatterResolvers()
@ -703,10 +664,10 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
{
public static readonly IFormatterResolver Instance = new SignalRResolver();
public static readonly IList<IFormatterResolver> Resolvers = new[]
public static readonly IList<IFormatterResolver> Resolvers = new IFormatterResolver[]
{
MessagePack.Resolvers.DynamicEnumAsStringResolver.Instance,
MessagePack.Resolvers.ContractlessStandardResolver.Instance,
DynamicEnumAsStringResolver.Instance,
ContractlessStandardResolver.Instance,
};
public IMessagePackFormatter<T> GetFormatter<T>()
@ -731,30 +692,5 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
}
}
}
// Support for users making their own Formatter lists
internal class CombinedResolvers : IFormatterResolver
{
private readonly IList<IFormatterResolver> _resolvers;
public CombinedResolvers(IList<IFormatterResolver> resolvers)
{
_resolvers = resolvers;
}
public IMessagePackFormatter<T> GetFormatter<T>()
{
foreach (var resolver in _resolvers)
{
var formatter = resolver.GetFormatter<T>();
if (formatter != null)
{
return formatter;
}
}
return null;
}
}
}
}

View File

@ -26,15 +26,20 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
AssertMessages(new byte[] { ArrayBytes(5), 3, 0x80, StringBytes(1), (byte)'0', 0x03, ArrayBytes(1), 42 }, result);
}
[Fact]
public void WriteAndParseDateTimeConvertsToUTC()
[Theory]
[InlineData(DateTimeKind.Utc)]
[InlineData(DateTimeKind.Local)]
[InlineData(DateTimeKind.Unspecified)]
public void WriteAndParseDateTimeConvertsToUTC(DateTimeKind dateTimeKind)
{
var dateTime = new DateTime(2018, 4, 9);
// The messagepack Timestamp format always converts input DateTime to Utc if they are passed as "DateTimeKind.Local" :
// https://github.com/neuecc/MessagePack-CSharp/pull/520/files#diff-ed970b3daebc708ce49f55d418075979
var originalDateTime = new DateTime(2018, 4, 9, 0, 0, 0, dateTimeKind);
var writer = MemoryBufferWriter.Get();
try
{
HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", dateTime), writer);
HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", originalDateTime), writer);
var bytes = new ReadOnlySequence<byte>(writer.ToArray());
HubProtocol.TryParseMessage(ref bytes, new TestBinder(typeof(DateTime)), out var hubMessage);
@ -44,7 +49,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
// The messagepack Timestamp format specifies that time is stored as seconds since 1970-01-01 UTC
// so the library has no choice but to store the time as UTC
// https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type
Assert.Equal(dateTime.ToUniversalTime(), resultDateTime);
// So If the original DateTiem was a "Local" one, we create a new DateTime equivalent to the original one but converted to Utc
var expectedUtcDateTime = (originalDateTime.Kind == DateTimeKind.Local) ? originalDateTime.ToUniversalTime() : originalDateTime;
Assert.Equal(expectedUtcDateTime, resultDateTime);
}
finally
{

View File

@ -2512,33 +2512,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private class StringFormatter<T> : IMessagePackFormatter<T>
{
public T Deserialize(byte[] bytes, int offset, IFormatterResolver formatterResolver, out int readSize)
public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options)
{
// this method isn't used in our tests
readSize = 0;
return default;
}
public int Serialize(ref byte[] bytes, int offset, T value, IFormatterResolver formatterResolver)
public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializerOptions options)
{
// string of size 15
bytes[offset] = 0xAF;
bytes[offset + 1] = (byte)'f';
bytes[offset + 2] = (byte)'o';
bytes[offset + 3] = (byte)'r';
bytes[offset + 4] = (byte)'m';
bytes[offset + 5] = (byte)'a';
bytes[offset + 6] = (byte)'t';
bytes[offset + 7] = (byte)'t';
bytes[offset + 8] = (byte)'e';
bytes[offset + 9] = (byte)'d';
bytes[offset + 10] = (byte)'S';
bytes[offset + 11] = (byte)'t';
bytes[offset + 12] = (byte)'r';
bytes[offset + 13] = (byte)'i';
bytes[offset + 14] = (byte)'n';
bytes[offset + 15] = (byte)'g';
return 16;
writer.Write("formattedString");
}
}
}

View File

@ -1,68 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
using MessagePack;
namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
{
internal static class MessagePackUtil
{
public static int ReadArrayHeader(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
public static int ReadMapHeader(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
public static string ReadString(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
public static byte[] ReadBytes(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
public static int ReadInt32(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
public static byte ReadByte(ref ReadOnlyMemory<byte> data)
{
var arr = GetArray(data);
var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize);
data = data.Slice(readSize);
return val;
}
private static ArraySegment<byte> GetArray(ReadOnlyMemory<byte> data)
{
var isArray = MemoryMarshal.TryGetArray(data, out var array);
Debug.Assert(isArray);
return array;
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
@ -43,30 +44,33 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
// * [The output of WriteSerializedHubMessage, which is an 'arr']
// Any additional items are discarded.
var writer = MemoryBufferWriter.Get();
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
MessagePackBinary.WriteArrayHeader(writer, 2);
var writer = new MessagePackWriter(memoryBufferWriter);
writer.WriteArrayHeader(2);
if (excludedConnectionIds != null && excludedConnectionIds.Count > 0)
{
MessagePackBinary.WriteArrayHeader(writer, excludedConnectionIds.Count);
writer.WriteArrayHeader(excludedConnectionIds.Count);
foreach (var id in excludedConnectionIds)
{
MessagePackBinary.WriteString(writer, id);
writer.Write(id);
}
}
else
{
MessagePackBinary.WriteArrayHeader(writer, 0);
writer.WriteArrayHeader(0);
}
WriteHubMessage(writer, new InvocationMessage(methodName, args));
return writer.ToArray();
WriteHubMessage(ref writer, new InvocationMessage(methodName, args));
writer.Flush();
return memoryBufferWriter.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
@ -80,21 +84,24 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
// * A 'str': The connection Id
// Any additional items are discarded.
var writer = MemoryBufferWriter.Get();
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
MessagePackBinary.WriteArrayHeader(writer, 5);
MessagePackBinary.WriteInt32(writer, command.Id);
MessagePackBinary.WriteString(writer, command.ServerName);
MessagePackBinary.WriteByte(writer, (byte)command.Action);
MessagePackBinary.WriteString(writer, command.GroupName);
MessagePackBinary.WriteString(writer, command.ConnectionId);
var writer = new MessagePackWriter(memoryBufferWriter);
return writer.ToArray();
writer.WriteArrayHeader(5);
writer.Write(command.Id);
writer.Write(command.ServerName);
writer.Write((byte)command.Action);
writer.Write(command.GroupName);
writer.Write(command.ConnectionId);
writer.Flush();
return memoryBufferWriter.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
@ -104,101 +111,110 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal
// * An 'int': The Id of the command being acknowledged.
// Any additional items are discarded.
var writer = MemoryBufferWriter.Get();
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
MessagePackBinary.WriteArrayHeader(writer, 1);
MessagePackBinary.WriteInt32(writer, messageId);
var writer = new MessagePackWriter(memoryBufferWriter);
return writer.ToArray();
writer.WriteArrayHeader(1);
writer.Write(messageId);
writer.Flush();
return memoryBufferWriter.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
public RedisInvocation ReadInvocation(ReadOnlyMemory<byte> data)
{
// See WriteInvocation for the format
ValidateArraySize(ref data, 2, "Invocation");
var reader = new MessagePackReader(data);
ValidateArraySize(ref reader, 2, "Invocation");
// Read excluded Ids
IReadOnlyList<string> excludedConnectionIds = null;
var idCount = MessagePackUtil.ReadArrayHeader(ref data);
var idCount = reader.ReadArrayHeader();
if (idCount > 0)
{
var ids = new string[idCount];
for (var i = 0; i < idCount; i++)
{
ids[i] = MessagePackUtil.ReadString(ref data);
ids[i] = reader.ReadString();
}
excludedConnectionIds = ids;
}
// Read payload
var message = ReadSerializedHubMessage(ref data);
var message = ReadSerializedHubMessage(ref reader);
return new RedisInvocation(message, excludedConnectionIds);
}
public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory<byte> data)
{
// See WriteGroupCommand for format.
ValidateArraySize(ref data, 5, "GroupCommand");
var reader = new MessagePackReader(data);
var id = MessagePackUtil.ReadInt32(ref data);
var serverName = MessagePackUtil.ReadString(ref data);
var action = (GroupAction)MessagePackUtil.ReadByte(ref data);
var groupName = MessagePackUtil.ReadString(ref data);
var connectionId = MessagePackUtil.ReadString(ref data);
// See WriteGroupCommand for format.
ValidateArraySize(ref reader, 5, "GroupCommand");
var id = reader.ReadInt32();
var serverName = reader.ReadString();
var action = (GroupAction)reader.ReadByte();
var groupName = reader.ReadString();
var connectionId = reader.ReadString();
return new RedisGroupCommand(id, serverName, action, groupName, connectionId);
}
public int ReadAck(ReadOnlyMemory<byte> data)
{
var reader = new MessagePackReader(data);
// See WriteAck for format
ValidateArraySize(ref data, 1, "Ack");
return MessagePackUtil.ReadInt32(ref data);
ValidateArraySize(ref reader, 1, "Ack");
return reader.ReadInt32();
}
private void WriteHubMessage(Stream stream, HubMessage message)
private void WriteHubMessage(ref MessagePackWriter writer, HubMessage message)
{
// Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str')
// and the values are the serialized blob (as a MessagePack 'bin').
var serializedHubMessages = _messageSerializer.SerializeMessage(message);
MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count);
writer.WriteMapHeader(serializedHubMessages.Count);
foreach (var serializedMessage in serializedHubMessages)
{
MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName);
writer.Write(serializedMessage.ProtocolName);
var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array);
Debug.Assert(isArray);
MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count);
writer.Write(array);
}
}
public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory<byte> data)
public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReader reader)
{
var count = MessagePackUtil.ReadMapHeader(ref data);
var count = reader.ReadMapHeader();
var serializations = new SerializedMessage[count];
for (var i = 0; i < count; i++)
{
var protocol = MessagePackUtil.ReadString(ref data);
var serialized = MessagePackUtil.ReadBytes(ref data);
var protocol = reader.ReadString();
var serialized = reader.ReadBytes()?.ToArray() ?? Array.Empty<byte>();
serializations[i] = new SerializedMessage(protocol, serialized);
}
return new SerializedHubMessage(serializations);
}
private static void ValidateArraySize(ref ReadOnlyMemory<byte> data, int expectedLength, string messageType)
private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType)
{
var length = MessagePackUtil.ReadArrayHeader(ref data);
var length = reader.ReadArrayHeader();
if (length < expectedLength)
{