diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index 12b9269153..cdfc0f0e16 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -16,6 +16,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private const int StreamItemMessageType = 2; private const int CompletionMessageType = 3; + private const int ErrorResult = 1; + private const int VoidResult = 2; + private const int NonVoidResult = 3; + public string Name => "messagepack"; public ProtocolType Type => ProtocolType.Binary; @@ -38,10 +42,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private static HubMessage ParseMessage(Stream input, IInvocationBinder binder) { var unpacker = Unpacker.Create(input); - if (!unpacker.ReadInt32(out var messageType)) - { - throw new FormatException("Message type is missing."); - } + var messageType = ReadInt32(unpacker, "messageType"); switch (messageType) { @@ -90,18 +91,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private static CompletionMessage CreateCompletionMessage(Unpacker unpacker, IInvocationBinder binder) { var invocationId = ReadInvocationId(unpacker); - var error = ReadString(unpacker, "error"); + var resultKind = ReadInt32(unpacker, "resultKind"); - var hasResult = false; + string error = null; object result = null; - if (error == null) + var hasResult = false; + + switch(resultKind) { - hasResult = ReadBoolean(unpacker, "hasResult"); - if (hasResult) - { + case ErrorResult: + error = ReadString(unpacker, "error"); + break; + case NonVoidResult: var itemType = binder.GetReturnType(invocationId); result = DeserializeObject(unpacker, itemType, "argument"); - } + hasResult = true; + break; + case VoidResult: + hasResult = false; + break; + default: + throw new FormatException("Invalid invocation result kind."); } return new CompletionMessage(invocationId, error, result, hasResult); @@ -153,16 +163,22 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private void WriteCompletionMessage(CompletionMessage completionMessage, Packer packer, Stream output) { + var resultKind = + completionMessage.Error != null ? ErrorResult : + completionMessage.HasResult ? NonVoidResult : + VoidResult; + packer.Pack(CompletionMessageType); packer.PackString(completionMessage.InvocationId); - packer.PackString(completionMessage.Error); - if (completionMessage.Error == null) + packer.Pack(resultKind); + switch (resultKind) { - packer.Pack(completionMessage.HasResult); - if (completionMessage.HasResult) - { + case ErrorResult: + packer.PackString(completionMessage.Error); + break; + case NonVoidResult: packer.PackObject(completionMessage.Result); - } + break; } } @@ -171,6 +187,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return ReadString(unpacker, "invocationId"); } + private static int ReadInt32(Unpacker unpacker, string field) + { + Exception msgPackException = null; + try + { + if (unpacker.ReadInt32(out var value)) + { + return value; + } + } + catch (Exception e) + { + msgPackException = e; + } + + throw new FormatException($"Reading '{field}' as Int32 failed.", msgPackException); + } + private static string ReadString(Unpacker unpacker, string field) { Exception msgPackException = null; diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index fed9cbf6e4..0ce3294aa6 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -73,7 +73,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol public static IEnumerable InvalidPayloads => new[] { - new object[] { new byte[0], "Message type is missing." }, + new object[] { new byte[0], "Reading 'messageType' as Int32 failed." }, + new object[] { new byte[] { 0xc2 } , "Reading 'messageType' as Int32 failed." }, // message type is not int new object[] { new byte[] { 0x0a } , "Invalid message type: 10." }, // InvocationMessage @@ -100,13 +101,15 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol // CompletionMessage new object[] { new byte[] { 0x03 }, "Reading 'invocationId' as String failed." }, // 0xc2 is Bool false new object[] { new byte[] { 0x03, 0xc2 }, "Reading 'invocationId' as String failed." }, // 0xc2 is Bool false - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc2 }, "Reading 'error' as String failed." }, // 0xc2 is Bool false - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xa1 }, "Reading 'error' as String failed." }, // error is cut - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc0 }, "Reading 'hasResult' as Boolean failed." }, // hasResult missing - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc0, 0xa0 }, "Reading 'hasResult' as Boolean failed." }, // 0xa0 is string - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc0, 0xc3 }, "Deserializing object of the `String` type for 'argument' failed." }, // result missing - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc0, 0xc3, 0xa9 }, "Deserializing object of the `String` type for 'argument' failed." }, // result is cut - new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc0, 0xc3, 0x00 }, "Deserializing object of the `String` type for 'argument' failed." } // return type mismatch + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0xc2 }, "Reading 'resultKind' as Int32 failed." }, // result kind is not int + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x0f }, "Invalid invocation result kind." }, // result kind is out of range + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x01 }, "Reading 'error' as String failed." }, // error result but no error + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x01, 0xa1 }, "Reading 'error' as String failed." }, // error is cut + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03 }, "Deserializing object of the `String` type for 'argument' failed." }, // non void result but result missing + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03, 0xa9 }, "Deserializing object of the `String` type for 'argument' failed." }, // result is cut + new object[] { new byte[] { 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03, 0x00 }, "Deserializing object of the `String` type for 'argument' failed." }, // return type mismatch + + // TODO: ReadAsInt32 and no int32 value }; [Theory] @@ -132,8 +135,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [InlineData(new object[] { new byte[] { - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x03, 0xa1, 0x78, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x03, 0xa1, 0x78, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1 }, 2 })] public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount)