diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs index 4a9b3ecab8..4b4fa91740 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs @@ -10,8 +10,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters { public static class BinaryMessageFormatter { - public static bool TryWriteMessage(ReadOnlySpan payload, MemoryStream output) + public static bool TryWriteMessage(ReadOnlySpan payload, Stream output) { + // TODO: Optimize for size - (e.g. use Varints) var length = sizeof(long); var buffer = ArrayPool.Shared.Rent(length); BufferWriter.WriteBigEndian(buffer, payload.Length); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index b2892914f9..8ba9abb8bc 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.IO; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; using MsgPack; using MsgPack.Serialization; @@ -19,9 +20,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { messages = new List(); - using (var memoryStream = new MemoryStream(input.ToArray())) + var messageParser = new BinaryMessageParser(); + + while (messageParser.TryParseMessage(ref input, out var payload)) { - messages.Add(ParseMessage(memoryStream, binder)); + using (var memoryStream = new MemoryStream(payload.ToArray())) + { + messages.Add(ParseMessage(memoryStream, binder)); + } } return messages.Count > 0; @@ -99,8 +105,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return new CompletionMessage(invocationId, error, result, hasResult); } - // TODO: when to return false? public bool TryWriteMessage(HubMessage message, Stream output) + { + using (var memoryStream = new MemoryStream()) + { + WriteMessage(message, memoryStream); + return BinaryMessageFormatter.TryWriteMessage(new ReadOnlySpan(memoryStream.ToArray()), output); + } + } + + private void WriteMessage(HubMessage message, Stream output) { var packer = Packer.Create(output); switch (message) @@ -117,8 +131,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol default: throw new FormatException($"Unexpected message type: {message.GetType().Name}"); } - - return true; } private static void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer, Stream output) diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs new file mode 100644 index 0000000000..e786540c13 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/CompositeTestBinder.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol +{ + public class CompositeTestBinder : IInvocationBinder + { + private readonly HubMessage[] _hubMessages; + private int index = 0; + + public CompositeTestBinder(HubMessage[] hubMessages) + { + _hubMessages = hubMessages; + } + + public Type[] GetParameterTypes(string methodName) + { + index++; + return new TestBinder(_hubMessages[index - 1]).GetParameterTypes(methodName); + } + + public Type GetReturnType(string invocationId) + { + index++; + return new TestBinder(_hubMessages[index - 1]).GetReturnType(invocationId); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index aaed947e00..c83cf388ed 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -3,9 +3,9 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using MsgPack; using Xunit; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol @@ -16,45 +16,59 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol public static IEnumerable TestMessages => new[] { - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ false, "method") }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method") }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null } ) }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) }, - new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ false, "method") } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method") } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null }) } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) } }, + new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) } }, - new object[]{ new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) }, - new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: false) }, - new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: 42, hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: "string", hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: true, hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) }, - new object[]{ new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) }, + new object[] { new[] { new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: false) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42, hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: "string", hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: true, hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) } }, + new object[] { new[] { new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) } }, - new object[]{ new StreamItemMessage("xyz", null)}, - new object[]{ new StreamItemMessage("xyz", 42)}, - new object[]{ new StreamItemMessage("xyz", 42.0f)}, - new object[]{ new StreamItemMessage("xyz", "string")}, - new object[]{ new StreamItemMessage("xyz", true)}, - new object[]{ new StreamItemMessage("xyz", new CustomObject())}, - new object[]{ new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() })} + new object[] { new[] { new StreamItemMessage("xyz", null) } }, + new object[] { new[] { new StreamItemMessage("xyz", 42) } }, + new object[] { new[] { new StreamItemMessage("xyz", 42.0f) } }, + new object[] { new[] { new StreamItemMessage("xyz", "string") } }, + new object[] { new[] { new StreamItemMessage("xyz", true) } }, + new object[] { new[] { new StreamItemMessage("xyz", new CustomObject()) } }, + new object[] { new[] { new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() }) } }, + + new object[] + { + new HubMessage[] + { + new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()), + new CompletionMessage("xyz", error: null, result: 42, hasResult: true), + new StreamItemMessage("xyz", null), + new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) + } + } }; [Theory] [MemberData(nameof(TestMessages))] - public void CanRoundTripInvocationMessage(HubMessage hubMessage) + public void CanRoundTripInvocationMessage(HubMessage[] hubMessages) { using (var memoryStream = new MemoryStream()) { - _hubProtocol.TryWriteMessage(hubMessage, memoryStream); - _hubProtocol.TryParseMessages( - new ReadOnlySpan(memoryStream.ToArray()), new TestBinder(hubMessage), out var messages); + foreach (var hubMessage in hubMessages) + { + _hubProtocol.TryWriteMessage(hubMessage, memoryStream); + } - Assert.Equal(1, messages.Count); - Assert.Equal(hubMessage, messages[0], TestHubMessageEqualityComparer.Instance); + _hubProtocol.TryParseMessages( + new ReadOnlySpan(memoryStream.ToArray()), new CompositeTestBinder(hubMessages), out var messages); + + Assert.Equal(hubMessages, messages, TestHubMessageEqualityComparer.Instance); } } @@ -100,11 +114,36 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [MemberData(nameof(InvalidPayloads))] public void ParserThrowsForInvalidMessages(byte[] payload, string expectedExceptionMessage) { + var payloadSize = payload.Length; + Debug.Assert(payloadSize <= 0xff, "This test does not support payloads larger than 255"); + + // prefix payload with the size + var buffer = new byte[8 + payloadSize]; + buffer[7] = (byte)(payloadSize & 0xff); + Array.Copy(payload, 0, buffer, 8, payloadSize); + var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); var exception = Assert.Throws(() => - _hubProtocol.TryParseMessages(new ReadOnlySpan(payload), binder, out var messages)); + _hubProtocol.TryParseMessages(new ReadOnlySpan(buffer), binder, out var messages)); Assert.Equal(expectedExceptionMessage, exception.Message); } + + [Theory] + [InlineData(new object[] { new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x01 }, 0 })] + [InlineData(new object[] { + new byte[] + { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1 + }, 2 })] + public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount) + { + var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); + var result = _hubProtocol.TryParseMessages(payload, binder, out var messages); + Assert.True(result || messages.Count == 0); + Assert.Equal(expectedMessagesCount, messages.Count); + } } }