Enabling length prefixing, parsing multiple messages
This commit is contained in:
parent
3504337918
commit
71949129ea
|
|
@ -10,8 +10,9 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
{
|
||||
public static class BinaryMessageFormatter
|
||||
{
|
||||
public static bool TryWriteMessage(ReadOnlySpan<byte> payload, MemoryStream output)
|
||||
public static bool TryWriteMessage(ReadOnlySpan<byte> payload, Stream output)
|
||||
{
|
||||
// TODO: Optimize for size - (e.g. use Varints)
|
||||
var length = sizeof(long);
|
||||
var buffer = ArrayPool<byte>.Shared.Rent(length);
|
||||
BufferWriter.WriteBigEndian<long>(buffer, payload.Length);
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using MsgPack;
|
||||
using MsgPack.Serialization;
|
||||
|
||||
|
|
@ -19,9 +20,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
{
|
||||
messages = new List<HubMessage>();
|
||||
|
||||
using (var memoryStream = new MemoryStream(input.ToArray()))
|
||||
var messageParser = new BinaryMessageParser();
|
||||
|
||||
while (messageParser.TryParseMessage(ref input, out var payload))
|
||||
{
|
||||
messages.Add(ParseMessage(memoryStream, binder));
|
||||
using (var memoryStream = new MemoryStream(payload.ToArray()))
|
||||
{
|
||||
messages.Add(ParseMessage(memoryStream, binder));
|
||||
}
|
||||
}
|
||||
|
||||
return messages.Count > 0;
|
||||
|
|
@ -99,8 +105,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
return new CompletionMessage(invocationId, error, result, hasResult);
|
||||
}
|
||||
|
||||
// TODO: when to return false?
|
||||
public bool TryWriteMessage(HubMessage message, Stream output)
|
||||
{
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
WriteMessage(message, memoryStream);
|
||||
return BinaryMessageFormatter.TryWriteMessage(new ReadOnlySpan<byte>(memoryStream.ToArray()), output);
|
||||
}
|
||||
}
|
||||
|
||||
private void WriteMessage(HubMessage message, Stream output)
|
||||
{
|
||||
var packer = Packer.Create(output);
|
||||
switch (message)
|
||||
|
|
@ -117,8 +131,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
default:
|
||||
throw new FormatException($"Unexpected message type: {message.GetType().Name}");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private static void WriteInvocationMessage(InvocationMessage invocationMessage, Packer packer, Stream output)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Text;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
||||
{
|
||||
public class CompositeTestBinder : IInvocationBinder
|
||||
{
|
||||
private readonly HubMessage[] _hubMessages;
|
||||
private int index = 0;
|
||||
|
||||
public CompositeTestBinder(HubMessage[] hubMessages)
|
||||
{
|
||||
_hubMessages = hubMessages;
|
||||
}
|
||||
|
||||
public Type[] GetParameterTypes(string methodName)
|
||||
{
|
||||
index++;
|
||||
return new TestBinder(_hubMessages[index - 1]).GetParameterTypes(methodName);
|
||||
}
|
||||
|
||||
public Type GetReturnType(string invocationId)
|
||||
{
|
||||
index++;
|
||||
return new TestBinder(_hubMessages[index - 1]).GetReturnType(invocationId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,9 +3,9 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using MsgPack;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
||||
|
|
@ -16,45 +16,59 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
|
||||
public static IEnumerable<object[]> TestMessages => new[]
|
||||
{
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ false, "method") },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method") },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null } ) },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) },
|
||||
new object[]{ new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ false, "method") } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method") } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new object[] { null }) } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42) } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string") } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()) } },
|
||||
new object[] { new[] { new InvocationMessage("xyz", /*nonBlocking*/ true, "method", new[] { new CustomObject(), new CustomObject() }) } },
|
||||
|
||||
new object[]{ new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: false) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: null, hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: 42, hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: "string", hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: true, hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) },
|
||||
new object[]{ new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: "Error not found!", result: null, hasResult: false) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: false) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: null, hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42, hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: 42.0f, hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: "string", hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: true, hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) } },
|
||||
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) } },
|
||||
|
||||
new object[]{ new StreamItemMessage("xyz", null)},
|
||||
new object[]{ new StreamItemMessage("xyz", 42)},
|
||||
new object[]{ new StreamItemMessage("xyz", 42.0f)},
|
||||
new object[]{ new StreamItemMessage("xyz", "string")},
|
||||
new object[]{ new StreamItemMessage("xyz", true)},
|
||||
new object[]{ new StreamItemMessage("xyz", new CustomObject())},
|
||||
new object[]{ new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() })}
|
||||
new object[] { new[] { new StreamItemMessage("xyz", null) } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", 42) } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", 42.0f) } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", "string") } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", true) } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", new CustomObject()) } },
|
||||
new object[] { new[] { new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() }) } },
|
||||
|
||||
new object[]
|
||||
{
|
||||
new HubMessage[]
|
||||
{
|
||||
new InvocationMessage("xyz", /*nonBlocking*/ true, "method", 42, "string", new CustomObject()),
|
||||
new CompletionMessage("xyz", error: null, result: 42, hasResult: true),
|
||||
new StreamItemMessage("xyz", null),
|
||||
new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(TestMessages))]
|
||||
public void CanRoundTripInvocationMessage(HubMessage hubMessage)
|
||||
public void CanRoundTripInvocationMessage(HubMessage[] hubMessages)
|
||||
{
|
||||
using (var memoryStream = new MemoryStream())
|
||||
{
|
||||
_hubProtocol.TryWriteMessage(hubMessage, memoryStream);
|
||||
_hubProtocol.TryParseMessages(
|
||||
new ReadOnlySpan<byte>(memoryStream.ToArray()), new TestBinder(hubMessage), out var messages);
|
||||
foreach (var hubMessage in hubMessages)
|
||||
{
|
||||
_hubProtocol.TryWriteMessage(hubMessage, memoryStream);
|
||||
}
|
||||
|
||||
Assert.Equal(1, messages.Count);
|
||||
Assert.Equal(hubMessage, messages[0], TestHubMessageEqualityComparer.Instance);
|
||||
_hubProtocol.TryParseMessages(
|
||||
new ReadOnlySpan<byte>(memoryStream.ToArray()), new CompositeTestBinder(hubMessages), out var messages);
|
||||
|
||||
Assert.Equal(hubMessages, messages, TestHubMessageEqualityComparer.Instance);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -100,11 +114,36 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
[MemberData(nameof(InvalidPayloads))]
|
||||
public void ParserThrowsForInvalidMessages(byte[] payload, string expectedExceptionMessage)
|
||||
{
|
||||
var payloadSize = payload.Length;
|
||||
Debug.Assert(payloadSize <= 0xff, "This test does not support payloads larger than 255");
|
||||
|
||||
// prefix payload with the size
|
||||
var buffer = new byte[8 + payloadSize];
|
||||
buffer[7] = (byte)(payloadSize & 0xff);
|
||||
Array.Copy(payload, 0, buffer, 8, payloadSize);
|
||||
|
||||
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
|
||||
var exception = Assert.Throws<FormatException>(() =>
|
||||
_hubProtocol.TryParseMessages(new ReadOnlySpan<byte>(payload), binder, out var messages));
|
||||
_hubProtocol.TryParseMessages(new ReadOnlySpan<byte>(buffer), binder, out var messages));
|
||||
|
||||
Assert.Equal(expectedExceptionMessage, exception.Message);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(new object[] { new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x01 }, 0 })]
|
||||
[InlineData(new object[] {
|
||||
new byte[]
|
||||
{
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1, 0x78, 0xa1, 0x45,
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x03, 0xa1
|
||||
}, 2 })]
|
||||
public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount)
|
||||
{
|
||||
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
|
||||
var result = _hubProtocol.TryParseMessages(payload, binder, out var messages);
|
||||
Assert.True(result || messages.Count == 0);
|
||||
Assert.Equal(expectedMessagesCount, messages.Count);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue