Enabling length prefixing, parsing multiple messages

This commit is contained in:
Pawel Kadluczka 2017-06-26 10:56:06 -07:00
parent 3504337918
commit 71949129ea
4 changed files with 120 additions and 37 deletions

View File

@ -10,8 +10,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
{
public static class BinaryMessageFormatter
{
public static bool TryWriteMessage(ReadOnlySpan<byte> payload, MemoryStream output)
public static bool TryWriteMessage(ReadOnlySpan<byte> payload, Stream output)
{
// TODO: Optimize for size - (e.g. use Varints)
var length = sizeof(long);
var buffer = ArrayPool<byte>.Shared.Rent(length);
BufferWriter.WriteBigEndian<long>(buffer, payload.Length);

View File

@ -4,6 +4,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using MsgPack;
using MsgPack.Serialization;
@ -19,9 +20,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
messages = new List<HubMessage>();
using (var memoryStream = new MemoryStream(input.ToArray()))
var messageParser = new BinaryMessageParser();
while (messageParser.TryParseMessage(ref input, out var payload))
{
messages.Add(ParseMessage(memoryStream, binder));
using (var memoryStream = new MemoryStream(payload.ToArray()))
{
messages.Add(ParseMessage(memoryStream, binder));
}
}
return messages.Count > 0;
@ -99,8 +105,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return new CompletionMessage(invocationId, error, result, hasResult);
}
// TODO: when to return false?
public bool TryWriteMessage(HubMessage message, Stream output)
{
using (var memoryStream = new MemoryStream())
{
WriteMessage(message, memoryStream);
return BinaryMessageFormatter.TryWriteMessage(new ReadOnlySpan<byte>(memoryStream.ToArray()), output);
}
}
private void WriteMessage(HubMessage message, Stream output)
{
var packer = Packer.Create(output);
switch (message)
@ -117,8 +131,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
default:
throw new FormatException($"Unexpected message type: {message.GetType().Name}");
}
return true;
}
private static void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer, Stream output)

View File

@ -0,0 +1,31 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
public class CompositeTestBinder : IInvocationBinder
{
private readonly HubMessage[] _hubMessages;
private int index = 0;
public CompositeTestBinder(HubMessage[] hubMessages)
{
_hubMessages = hubMessages;
}
public Type[] GetParameterTypes(string methodName)
{
index++;
return new TestBinder(_hubMessages[index - 1]).GetParameterTypes(methodName);
}
public Type GetReturnType(string invocationId)
{
index++;
return new TestBinder(_hubMessages[index - 1]).GetReturnType(invocationId);
}
}
}

View File

@ -3,9 +3,9 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using MsgPack;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
@ -16,45 +16,59 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
public static IEnumerable<object[]> TestMessages => new[]
{
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ false, "method") },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method") },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null } ) },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) },
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ false, "method") } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method") } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null }) } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) } },
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) } },
new object[]{ new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) },
new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: false) },
new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: 42, hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: "string", hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: true, hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) },
new object[]{ new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) },
new object[] { new[] { new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: false) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42, hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: "string", hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: true, hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) } },
new object[]{ new StreamItemMessage("xyz", null)},
new object[]{ new StreamItemMessage("xyz", 42)},
new object[]{ new StreamItemMessage("xyz", 42.0f)},
new object[]{ new StreamItemMessage("xyz", "string")},
new object[]{ new StreamItemMessage("xyz", true)},
new object[]{ new StreamItemMessage("xyz", new CustomObject())},
new object[]{ new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() })}
new object[] { new[] { new StreamItemMessage("xyz", null) } },
new object[] { new[] { new StreamItemMessage("xyz", 42) } },
new object[] { new[] { new StreamItemMessage("xyz", 42.0f) } },
new object[] { new[] { new StreamItemMessage("xyz", "string") } },
new object[] { new[] { new StreamItemMessage("xyz", true) } },
new object[] { new[] { new StreamItemMessage("xyz", new CustomObject()) } },
new object[] { new[] { new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() }) } },
new object[]
{
new HubMessage[]
{
new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()),
new CompletionMessage("xyz", error: null, result: 42, hasResult: true),
new StreamItemMessage("xyz", null),
new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true)
}
}
};
[Theory]
[MemberData(nameof(TestMessages))]
public void CanRoundTripInvocationMessage(HubMessage hubMessage)
public void CanRoundTripInvocationMessage(HubMessage[] hubMessages)
{
using (var memoryStream = new MemoryStream())
{
_hubProtocol.TryWriteMessage(hubMessage, memoryStream);
_hubProtocol.TryParseMessages(
new ReadOnlySpan<byte>(memoryStream.ToArray()), new TestBinder(hubMessage), out var messages);
foreach (var hubMessage in hubMessages)
{
_hubProtocol.TryWriteMessage(hubMessage, memoryStream);
}
Assert.Equal(1, messages.Count);
Assert.Equal(hubMessage, messages[0], TestHubMessageEqualityComparer.Instance);
_hubProtocol.TryParseMessages(
new ReadOnlySpan<byte>(memoryStream.ToArray()), new CompositeTestBinder(hubMessages), out var messages);
Assert.Equal(hubMessages, messages, TestHubMessageEqualityComparer.Instance);
}
}
@ -100,11 +114,36 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[MemberData(nameof(InvalidPayloads))]
public void ParserThrowsForInvalidMessages(byte[] payload, string expectedExceptionMessage)
{
var payloadSize = payload.Length;
Debug.Assert(payloadSize <= 0xff, "This test does not support payloads larger than 255");
// prefix payload with the size
var buffer = new byte[8 + payloadSize];
buffer[7] = (byte)(payloadSize & 0xff);
Array.Copy(payload, 0, buffer, 8, payloadSize);
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var exception = Assert.Throws<FormatException>(() =>
_hubProtocol.TryParseMessages(new ReadOnlySpan<byte>(payload), binder, out var messages));
_hubProtocol.TryParseMessages(new ReadOnlySpan<byte>(buffer), binder, out var messages));
Assert.Equal(expectedExceptionMessage, exception.Message);
}
[Theory]
[InlineData(new object[] { new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x01 }, 0 })]
[InlineData(new object[] {
new byte[]
{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1
}, 2 })]
public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount)
{
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var result = _hubProtocol.TryParseMessages(payload, binder, out var messages);
Assert.True(result || messages.Count == 0);
Assert.Equal(expectedMessagesCount, messages.Count);
}
}
}