Change IHubProtocol interface to support partial parsing (#1745)

- These are the finishing touches before we disable batching on the
C# client and on the server. We're changing the IHubProtocol interface to
modify the input buffer with what was consumed. We're also changing it
to parse a single message at a time to be match what output writing does.
- Added TryParseResponseMessage and made it look like TryParseRequestMessage
This commit is contained in:
David Fowler 2018-03-28 12:08:16 -07:00 committed by GitHub
parent 7a428534c3
commit 19b9dca268
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 265 additions and 312 deletions

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.IO.Pipelines; using System.IO.Pipelines;
@ -58,8 +59,9 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
return true; return true;
} }
public bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages) public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{ {
message = null;
return false; return false;
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Attributes;
@ -59,8 +60,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark] [Benchmark]
public void ReadSingleMessage() public void ReadSingleMessage()
{ {
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(_binaryInput);
if (!_hubProtocol.TryParseMessages(_binaryInput, _binder, messages)) if (!_hubProtocol.TryParseMessage(ref data, _binder, out _))
{ {
throw new InvalidOperationException("Failed to read message"); throw new InvalidOperationException("Failed to read message");
} }

View File

@ -1,4 +1,5 @@
using System; using System;
using System.Buffers;
using System.IO; using System.IO;
using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
@ -40,8 +41,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark] [Benchmark]
public void SingleBinaryMessage() public void SingleBinaryMessage()
{ {
ReadOnlyMemory<byte> buffer = _binaryInput; var data = new ReadOnlySequence<byte>(_binaryInput);
if (!BinaryMessageParser.TryParseMessage(ref buffer, out _)) if (!BinaryMessageParser.TryParseMessage(ref data, out _))
{ {
throw new InvalidOperationException("Failed to parse"); throw new InvalidOperationException("Failed to parse");
} }
@ -50,8 +51,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[Benchmark] [Benchmark]
public void SingleTextMessage() public void SingleTextMessage()
{ {
ReadOnlyMemory<byte> buffer = _textInput; var data = new ReadOnlySequence<byte>(_textInput);
if (!TextMessageParser.TryParseMessage(ref buffer, out _)) if (!TextMessageParser.TryParseMessage(ref data, out _))
{ {
throw new InvalidOperationException("Failed to parse"); throw new InvalidOperationException("Failed to parse");
} }

View File

@ -96,21 +96,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
private static readonly Action<ILogger, Exception> _sendingHubHandshake = private static readonly Action<ILogger, Exception> _sendingHubHandshake =
LoggerMessage.Define(LogLevel.Debug, new EventId(28, "SendingHubHandshake"), "Sending Hub Handshake."); LoggerMessage.Define(LogLevel.Debug, new EventId(28, "SendingHubHandshake"), "Sending Hub Handshake.");
private static readonly Action<ILogger, int, Exception> _parsingMessages =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(29, "ParsingMessages"), "Received {Count} bytes. Parsing message(s).");
private static readonly Action<ILogger, int, Exception> _receivingMessages =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(30, "ReceivingMessages"), "Received {MessageCount} message(s).");
private static readonly Action<ILogger, Exception> _receivedPing = private static readonly Action<ILogger, Exception> _receivedPing =
LoggerMessage.Define(LogLevel.Trace, new EventId(31, "ReceivedPing"), "Received a ping message."); LoggerMessage.Define(LogLevel.Trace, new EventId(31, "ReceivedPing"), "Received a ping message.");
private static readonly Action<ILogger, int, Exception> _processedMessages =
LoggerMessage.Define<int>(LogLevel.Debug, new EventId(32, "ProcessedMessages"), "Finished processing {MessageCount} message(s).");
private static readonly Action<ILogger, int, Exception> _failedParsing =
LoggerMessage.Define<int>(LogLevel.Warning, new EventId(33, "FailedParsing"), "No messages parsed from {Count} byte(s).");
private static readonly Action<ILogger, string, Exception> _errorInvokingClientSideMethod = private static readonly Action<ILogger, string, Exception> _errorInvokingClientSideMethod =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(34, "ErrorInvokingClientSideMethod"), "Invoking client side method '{MethodName}' failed."); LoggerMessage.Define<string>(LogLevel.Error, new EventId(34, "ErrorInvokingClientSideMethod"), "Invoking client side method '{MethodName}' failed.");
@ -329,31 +317,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
_sendingHubHandshake(logger, null); _sendingHubHandshake(logger, null);
} }
public static void ParsingMessages(ILogger logger, int byteCount)
{
_parsingMessages(logger, byteCount, null);
}
public static void ReceivingMessages(ILogger logger, int messageCount)
{
_receivingMessages(logger, messageCount, null);
}
public static void ReceivedPing(ILogger logger) public static void ReceivedPing(ILogger logger)
{ {
_receivedPing(logger, null); _receivedPing(logger, null);
} }
public static void ProcessedMessages(ILogger logger, int messageCount)
{
_processedMessages(logger, messageCount, null);
}
public static void FailedParsing(ILogger logger, int byteCount)
{
_failedParsing(logger, byteCount, null);
}
public static void ErrorInvokingClientSideMethod(ILogger logger, string methodName, Exception exception) public static void ErrorInvokingClientSideMethod(ILogger logger, string methodName, Exception exception)
{ {
_errorInvokingClientSideMethod(logger, methodName, exception); _errorInvokingClientSideMethod(logger, methodName, exception);

View File

@ -178,7 +178,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
CheckDisposed(); CheckDisposed();
connectionState = _connectionState; connectionState = _connectionState;
// Set the stopping flag so that any invocations after this get a useful error message instead of // Set the stopping flag so that any invocations after this get a useful error message instead of
// silently failing or throwing an error about the pipe being completed. // silently failing or throwing an error about the pipe being completed.
if (connectionState != null) if (connectionState != null)
@ -374,74 +374,53 @@ namespace Microsoft.AspNetCore.SignalR.Client
} }
} }
private async Task<(bool close, Exception exception)> ProcessMessagesAsync(ReadOnlySequence<byte> buffer, ConnectionState connectionState) private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState)
{ {
Log.ProcessingMessage(_logger, buffer.Length); InvocationRequest irq;
switch (message)
// TODO: Don't ToArray it :)
var data = buffer.ToArray();
var currentData = new ReadOnlyMemory<byte>(data);
Log.ParsingMessages(_logger, currentData.Length);
var messages = new List<HubMessage>();
if (_protocol.TryParseMessages(currentData, connectionState, messages))
{ {
Log.ReceivingMessages(_logger, messages.Count); case InvocationMessage invocation:
foreach (var message in messages) Log.ReceivedInvocation(_logger, invocation.InvocationId, invocation.Target,
{ invocation.ArgumentBindingException != null ? null : invocation.Arguments);
InvocationRequest irq; await DispatchInvocationAsync(invocation);
switch (message) break;
case CompletionMessage completion:
if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq))
{ {
case InvocationMessage invocation: Log.DroppedCompletionMessage(_logger, completion.InvocationId);
Log.ReceivedInvocation(_logger, invocation.InvocationId, invocation.Target,
invocation.ArgumentBindingException != null ? null : invocation.Arguments);
await DispatchInvocationAsync(invocation);
break;
case CompletionMessage completion:
if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq))
{
Log.DroppedCompletionMessage(_logger, completion.InvocationId);
}
else
{
DispatchInvocationCompletion(completion, irq);
irq.Dispose();
}
break;
case StreamItemMessage streamItem:
// Complete the invocation with an error, we don't support streaming (yet)
if (!connectionState.TryGetInvocation(streamItem.InvocationId, out irq))
{
Log.DroppedStreamMessage(_logger, streamItem.InvocationId);
return (close: false, exception: null);
}
await DispatchInvocationStreamItemAsync(streamItem, irq);
break;
case CloseMessage close:
if (string.IsNullOrEmpty(close.Error))
{
Log.ReceivedClose(_logger);
return (close: true, exception: null);
}
else
{
Log.ReceivedCloseWithError(_logger, close.Error);
return (close: true, exception: new HubException($"The server closed the connection with the following error: {close.Error}"));
}
case PingMessage _:
Log.ReceivedPing(_logger);
// Nothing to do on receipt of a ping.
break;
default:
throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}");
} }
} else
Log.ProcessedMessages(_logger, messages.Count); {
} DispatchInvocationCompletion(completion, irq);
else irq.Dispose();
{ }
Log.FailedParsing(_logger, data.Length); break;
case StreamItemMessage streamItem:
// Complete the invocation with an error, we don't support streaming (yet)
if (!connectionState.TryGetInvocation(streamItem.InvocationId, out irq))
{
Log.DroppedStreamMessage(_logger, streamItem.InvocationId);
return (close: false, exception: null);
}
await DispatchInvocationStreamItemAsync(streamItem, irq);
break;
case CloseMessage close:
if (string.IsNullOrEmpty(close.Error))
{
Log.ReceivedClose(_logger);
return (close: true, exception: null);
}
else
{
Log.ReceivedCloseWithError(_logger, close.Error);
return (close: true, exception: new HubException($"The server closed the connection with the following error: {close.Error}"));
}
case PingMessage _:
Log.ReceivedPing(_logger);
// Nothing to do on receipt of a ping.
break;
default:
throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}");
} }
return (close: false, exception: null); return (close: false, exception: null);
@ -536,25 +515,23 @@ namespace Microsoft.AspNetCore.SignalR.Client
{ {
var result = await _connectionState.Connection.Transport.Input.ReadAsync(); var result = await _connectionState.Connection.Transport.Input.ReadAsync();
var buffer = result.Buffer; var buffer = result.Buffer;
var consumed = buffer.Start;
try try
{ {
// Read first message out of the incoming data // Read first message out of the incoming data
if (!buffer.IsEmpty && TextMessageParser.TryParseMessage(ref buffer, out var payload)) if (!buffer.IsEmpty)
{ {
// Buffer was advanced to the end of the message by TryParseMessage if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message))
consumed = buffer.Start;
var message = HandshakeProtocol.ParseResponseMessage(payload.ToArray());
if (!string.IsNullOrEmpty(message.Error))
{ {
Log.HandshakeServerError(_logger, message.Error); if (!string.IsNullOrEmpty(message.Error))
throw new HubException( {
$"Unable to complete handshake with the server due to an error: {message.Error}"); Log.HandshakeServerError(_logger, message.Error);
} throw new HubException(
$"Unable to complete handshake with the server due to an error: {message.Error}");
}
break; break;
}
} }
else if (result.IsCompleted) else if (result.IsCompleted)
{ {
@ -565,7 +542,10 @@ namespace Microsoft.AspNetCore.SignalR.Client
} }
finally finally
{ {
_connectionState.Connection.Transport.Input.AdvanceTo(consumed); // The buffer was sliced up to where it was consumed, so we can just advance to the start.
// We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data
// before yielding the read again.
_connectionState.Connection.Transport.Input.AdvanceTo(buffer.Start, buffer.End);
} }
} }
} }
@ -594,8 +574,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
{ {
var result = await connectionState.Connection.Transport.Input.ReadAsync(); var result = await connectionState.Connection.Transport.Input.ReadAsync();
var buffer = result.Buffer; var buffer = result.Buffer;
var consumed = buffer.End; // TODO: Support partial messages
var examined = buffer.End;
try try
{ {
@ -608,12 +586,27 @@ namespace Microsoft.AspNetCore.SignalR.Client
{ {
ResetTimeoutTimer(timeoutTimer); ResetTimeoutTimer(timeoutTimer);
// We have data, process it Log.ProcessingMessage(_logger, buffer.Length);
var (close, exception) = await ProcessMessagesAsync(buffer, connectionState);
var close = false;
while (_protocol.TryParseMessage(ref buffer, connectionState, out var message))
{
Exception exception;
// We have data, process it
(close, exception) = await ProcessMessagesAsync(message, connectionState);
if (close)
{
// Closing because we got a close frame, possibly with an error in it.
connectionState.CloseException = exception;
break;
}
}
// If we're closing stop everything
if (close) if (close)
{ {
// Closing because we got a close frame, possibly with an error in it.
connectionState.CloseException = exception;
break; break;
} }
} }
@ -624,7 +617,10 @@ namespace Microsoft.AspNetCore.SignalR.Client
} }
finally finally
{ {
connectionState.Connection.Transport.Input.AdvanceTo(consumed, examined); // The buffer was sliced up to where it was consumed, so we can just advance to the start.
// We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data
// before yielding the read again.
connectionState.Connection.Transport.Input.AdvanceTo(buffer.Start, buffer.End);
} }
} }
} }
@ -633,7 +629,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
Log.ServerDisconnectedWithError(_logger, ex); Log.ServerDisconnectedWithError(_logger, ex);
connectionState.CloseException = ex; connectionState.CloseException = ex;
} }
// Clear the connectionState field // Clear the connectionState field
await WaitConnectionLockAsync(); await WaitConnectionLockAsync();
try try

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
namespace Microsoft.AspNetCore.SignalR.Internal.Formatters namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
{ {
@ -9,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
{ {
private const int MaxLengthPrefixSize = 5; private const int MaxLengthPrefixSize = 5;
public static bool TryParseMessage(ref ReadOnlyMemory<byte> buffer, out ReadOnlyMemory<byte> payload) public static bool TryParseMessage(ref ReadOnlySequence<byte> buffer, out ReadOnlySequence<byte> payload)
{ {
if (buffer.IsEmpty) if (buffer.IsEmpty)
{ {
@ -33,7 +34,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
var numBytes = 0; var numBytes = 0;
var lengthPrefixBuffer = buffer.Slice(0, Math.Min(MaxLengthPrefixSize, buffer.Length)); var lengthPrefixBuffer = buffer.Slice(0, Math.Min(MaxLengthPrefixSize, buffer.Length));
var span = lengthPrefixBuffer.Span; var span = GetSpan(lengthPrefixBuffer);
byte byteRead; byte byteRead;
do do
@ -70,5 +71,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
buffer = buffer.Slice(numBytes + (int)length); buffer = buffer.Slice(numBytes + (int)length);
return true; return true;
} }
private static ReadOnlySpan<byte> GetSpan(in ReadOnlySequence<byte> lengthPrefixBuffer)
{
if (lengthPrefixBuffer.IsSingleSegment)
{
return lengthPrefixBuffer.First.Span;
}
// Should be rare
return lengthPrefixBuffer.ToArray();
}
} }
} }

View File

@ -24,22 +24,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
return true; return true;
} }
public static bool TryParseMessage(ref ReadOnlyMemory<byte> buffer, out ReadOnlyMemory<byte> payload)
{
var index = buffer.Span.IndexOf(TextMessageFormatter.RecordSeparator);
if (index == -1)
{
payload = default;
return false;
}
payload = buffer.Slice(0, index);
// Skip record separator
buffer = buffer.Slice(index + 1);
return true;
}
} }
} }

View File

@ -56,8 +56,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return new JsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true)); return new JsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true));
} }
public static HandshakeResponseMessage ParseResponseMessage(ReadOnlyMemory<byte> payload) public static bool TryParseResponseMessage(ref ReadOnlySequence<byte> buffer, out HandshakeResponseMessage responseMessage)
{ {
if (!TextMessageParser.TryParseMessage(ref buffer, out var payload))
{
responseMessage = null;
return false;
}
var textReader = Utf8BufferTextReader.Get(payload); var textReader = Utf8BufferTextReader.Get(payload);
try try
@ -76,7 +82,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
} }
var error = JsonUtils.GetOptionalProperty<string>(handshakeJObject, ErrorPropertyName); var error = JsonUtils.GetOptionalProperty<string>(handshakeJObject, ErrorPropertyName);
return new HandshakeResponseMessage(error); responseMessage = new HandshakeResponseMessage(error);
return true;
} }
} }
finally finally
@ -85,19 +92,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
} }
} }
public static bool TryParseRequestMessage(ReadOnlySequence<byte> buffer, out HandshakeRequestMessage requestMessage, out SequencePosition consumed, out SequencePosition examined) public static bool TryParseRequestMessage(ref ReadOnlySequence<byte> buffer, out HandshakeRequestMessage requestMessage)
{ {
if (!TryReadMessageIntoSingleMemory(buffer, out consumed, out examined, out var memory)) if (!TextMessageParser.TryParseMessage(ref buffer, out var payload))
{ {
requestMessage = null; requestMessage = null;
return false; return false;
} }
if (!TextMessageParser.TryParseMessage(ref memory, out var payload))
{
throw new InvalidDataException("Unable to parse payload as a handshake request message.");
}
var textReader = Utf8BufferTextReader.Get(payload); var textReader = Utf8BufferTextReader.Get(payload);
try try
{ {
@ -117,23 +119,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return true; return true;
} }
internal static bool TryReadMessageIntoSingleMemory(ReadOnlySequence<byte> buffer, out SequencePosition consumed, out SequencePosition examined, out ReadOnlyMemory<byte> memory)
{
var separator = buffer.PositionOf(TextMessageFormatter.RecordSeparator);
if (separator == null)
{
// Haven't seen the entire message so bail
consumed = buffer.Start;
examined = buffer.End;
memory = null;
return false;
}
consumed = buffer.GetPosition(1, separator.Value);
examined = consumed;
memory = buffer.IsSingleSegment ? buffer.First : buffer.ToArray();
return true;
}
} }
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
@ -16,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
TransferFormat TransferFormat { get; } TransferFormat TransferFormat { get; }
bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages); bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message);
void WriteMessage(HubMessage message, Stream output); void WriteMessage(HubMessage message, Stream output);

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Runtime.ExceptionServices; using System.Runtime.ExceptionServices;
@ -54,27 +55,26 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return version == Version; return version == Version;
} }
public bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages) public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{ {
while (TextMessageParser.TryParseMessage(ref input, out var payload)) if (!TextMessageParser.TryParseMessage(ref input, out var payload))
{ {
var textReader = Utf8BufferTextReader.Get(payload); message = null;
return false;
try
{
var message = ParseMessage(textReader, binder);
if (message != null)
{
messages.Add(message);
}
}
finally
{
Utf8BufferTextReader.Return(textReader);
}
} }
return messages.Count > 0; var textReader = Utf8BufferTextReader.Get(payload);
try
{
message = ParseMessage(textReader, binder);
}
finally
{
Utf8BufferTextReader.Return(textReader);
}
return message != null;
} }
public void WriteMessage(HubMessage message, Stream output) public void WriteMessage(HubMessage message, Stream output)

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.IO; using System.IO;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
@ -10,7 +11,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{ {
internal class Utf8BufferTextReader : TextReader internal class Utf8BufferTextReader : TextReader
{ {
private ReadOnlyMemory<byte> _utf8Buffer; private ReadOnlySequence<byte> _utf8Buffer;
private Decoder _decoder; private Decoder _decoder;
[ThreadStatic] [ThreadStatic]
@ -25,7 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
_decoder = Encoding.UTF8.GetDecoder(); _decoder = Encoding.UTF8.GetDecoder();
} }
public static Utf8BufferTextReader Get(ReadOnlyMemory<byte> utf8Buffer) public static Utf8BufferTextReader Get(in ReadOnlySequence<byte> utf8Buffer)
{ {
var reader = _cachedInstance; var reader = _cachedInstance;
if (reader == null) if (reader == null)
@ -55,7 +56,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
#endif #endif
} }
public void SetBuffer(ReadOnlyMemory<byte> utf8Buffer) public void SetBuffer(in ReadOnlySequence<byte> utf8Buffer)
{ {
_utf8Buffer = utf8Buffer; _utf8Buffer = utf8Buffer;
_decoder.Reset(); _decoder.Reset();
@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return 0; return 0;
} }
var source = _utf8Buffer.Span; var source = _utf8Buffer.First.Span;
var bytesUsed = 0; var bytesUsed = 0;
var charsUsed = 0; var charsUsed = 0;
#if NETCOREAPP2_1 #if NETCOREAPP2_1

View File

@ -221,15 +221,19 @@ namespace Microsoft.AspNetCore.SignalR
{ {
var result = await _connectionContext.Transport.Input.ReadAsync(cts.Token); var result = await _connectionContext.Transport.Input.ReadAsync(cts.Token);
var buffer = result.Buffer; var buffer = result.Buffer;
var consumed = buffer.End; var consumed = buffer.Start;
var examined = buffer.End; var examined = buffer.End;
try try
{ {
if (!buffer.IsEmpty) if (!buffer.IsEmpty)
{ {
if (HandshakeProtocol.TryParseRequestMessage(buffer, out var handshakeRequestMessage, out consumed, out examined)) if (HandshakeProtocol.TryParseRequestMessage(ref buffer, out var handshakeRequestMessage))
{ {
// We parsed the handshake
consumed = buffer.Start;
examined = consumed;
Protocol = protocolResolver.GetProtocol(handshakeRequestMessage.Protocol, supportedProtocols); Protocol = protocolResolver.GetProtocol(handshakeRequestMessage.Protocol, supportedProtocols);
if (Protocol == null) if (Protocol == null)
{ {
@ -277,6 +281,10 @@ namespace Microsoft.AspNetCore.SignalR
await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty); await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty);
return true; return true;
} }
else
{
_logger.LogInformation("Didn't parse the handshake");
}
} }
else if (result.IsCompleted) else if (result.IsCompleted)
{ {

View File

@ -164,24 +164,16 @@ namespace Microsoft.AspNetCore.SignalR
{ {
var result = await connection.Input.ReadAsync(connection.ConnectionAborted); var result = await connection.Input.ReadAsync(connection.ConnectionAborted);
var buffer = result.Buffer; var buffer = result.Buffer;
var consumed = buffer.End;
var examined = buffer.End;
try try
{ {
if (!buffer.IsEmpty) if (!buffer.IsEmpty)
{ {
var hubMessages = new List<HubMessage>(); while (connection.Protocol.TryParseMessage(ref buffer, _dispatcher, out var message))
// TODO: Make this incremental
if (connection.Protocol.TryParseMessages(buffer.ToArray(), _dispatcher, hubMessages))
{ {
foreach (var hubMessage in hubMessages) // Don't wait on the result of execution, continue processing other
{ // incoming messages on this connection.
// Don't wait on the result of execution, continue processing other _ = _dispatcher.DispatchMessageAsync(connection, message);
// incoming messages on this connection.
_ = _dispatcher.DispatchMessageAsync(connection, hubMessage);
}
} }
} }
else if (result.IsCompleted) else if (result.IsCompleted)
@ -191,7 +183,10 @@ namespace Microsoft.AspNetCore.SignalR
} }
finally finally
{ {
connection.Input.AdvanceTo(consumed, examined); // The buffer was sliced up to where it was consumed, so we can just advance to the start.
// We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data
// before yielding the read again.
connection.Input.AdvanceTo(buffer.Start, buffer.End);
} }
} }
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
@ -46,21 +47,33 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return version == Version; return version == Version;
} }
public bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages) public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{ {
while (BinaryMessageParser.TryParseMessage(ref input, out var payload)) if (!BinaryMessageParser.TryParseMessage(ref input, out var payload))
{ {
var isArray = MemoryMarshal.TryGetArray(payload, out var arraySegment); message = null;
// This will never be false unless we started using un-managed buffers return false;
Debug.Assert(isArray);
var message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder);
if (message != null)
{
messages.Add(message);
}
} }
return messages.Count > 0; var arraySegment = GetArraySegment(payload);
message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder);
return message != null;
}
private static ArraySegment<byte> GetArraySegment(ReadOnlySequence<byte> input)
{
if (input.IsSingleSegment)
{
var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment);
// This will never be false unless we started using un-managed buffers
Debug.Assert(isArray);
return arraySegment;
}
// Should be rare
return new ArraySegment<byte>(input.ToArray());
} }
private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder) private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder)

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -144,7 +145,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return true; return true;
} }
public bool TryParseMessages(ReadOnlyMemory<byte> input, IInvocationBinder binder, IList<HubMessage> messages) public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{ {
if (_error != null) if (_error != null)
{ {
@ -152,7 +153,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
} }
if (_parsed != null) if (_parsed != null)
{ {
messages.Add(_parsed); message = _parsed;
return true; return true;
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@ -113,7 +114,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters
{ {
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, ms); BinaryMessageFormatter.WriteLengthPrefix(payload.Length, ms);
ms.Write(payload, 0, payload.Length); ms.Write(payload, 0, payload.Length);
var buffer = new ReadOnlyMemory<byte>(ms.ToArray()); var buffer = new ReadOnlySequence<byte>(ms.ToArray());
Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped)); Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped));
Assert.Equal(payload, roundtripped.ToArray()); Assert.Equal(payload, roundtripped.ToArray());
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
@ -17,7 +18,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")] [InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")]
public void ReadMessage(byte[] encoded, string payload) public void ReadMessage(byte[] encoded, string payload)
{ {
ReadOnlyMemory<byte> span = encoded; var span = new ReadOnlySequence<byte>(encoded);
Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message)); Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message));
Assert.Equal(0, span.Length); Assert.Equal(0, span.Length);
@ -52,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
})] })]
public void ReadBinaryMessage(byte[] encoded, byte[] payload) public void ReadBinaryMessage(byte[] encoded, byte[] payload)
{ {
ReadOnlyMemory< byte> span = encoded; var span = new ReadOnlySequence<byte>(encoded);
Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message)); Assert.True(BinaryMessageParser.TryParseMessage(ref span, out var message));
Assert.Equal(0, span.Length); Assert.Equal(0, span.Length);
Assert.Equal(payload, message.ToArray()); Assert.Equal(payload, message.ToArray());
@ -66,7 +67,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
{ {
var ex = Assert.Throws<FormatException>(() => var ex = Assert.Throws<FormatException>(() =>
{ {
var buffer = new ReadOnlyMemory<byte>(payload); var buffer = new ReadOnlySequence<byte>(payload);;
BinaryMessageParser.TryParseMessage(ref buffer, out var message); BinaryMessageParser.TryParseMessage(ref buffer, out var message);
}); });
Assert.Equal("Messages over 2GB in size are not supported.", ex.Message); Assert.Equal("Messages over 2GB in size are not supported.", ex.Message);
@ -79,7 +80,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[InlineData(new byte[] { 0x80 })] // size is cut [InlineData(new byte[] { 0x80 })] // size is cut
public void BinaryMessageParserReturnsFalseForPartialPayloads(byte[] payload) public void BinaryMessageParserReturnsFalseForPartialPayloads(byte[] payload)
{ {
var buffer = new ReadOnlyMemory<byte>(payload); var buffer = new ReadOnlySequence<byte>(payload);
Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message)); Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message));
} }
@ -94,7 +95,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
/* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21,
}; };
ReadOnlyMemory<byte> buffer = encoded; var buffer = new ReadOnlySequence<byte>(encoded);
var messages = new List<byte[]>(); var messages = new List<byte[]>();
while (BinaryMessageParser.TryParseMessage(ref buffer, out var message)) while (BinaryMessageParser.TryParseMessage(ref buffer, out var message))
{ {
@ -113,7 +114,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[InlineData(new byte[] { 0x09, 0x00, 0x00 })] // Not enough data for payload [InlineData(new byte[] { 0x09, 0x00, 0x00 })] // Not enough data for payload
public void ReadIncompleteMessages(byte[] encoded) public void ReadIncompleteMessages(byte[] encoded)
{ {
ReadOnlyMemory<byte> buffer = encoded; var buffer = new ReadOnlySequence<byte>(encoded);
Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message)); Assert.False(BinaryMessageParser.TryParseMessage(ref buffer, out var message));
Assert.Equal(encoded.Length, buffer.Length); Assert.Equal(encoded.Length, buffer.Length);
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Text; using System.Text;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Xunit; using Xunit;
@ -13,7 +14,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[Fact] [Fact]
public void ReadMessage() public void ReadMessage()
{ {
var message = new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("ABC\u001e")); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("ABC\u001e"));
Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload));
Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray()));
@ -23,14 +24,14 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[Fact] [Fact]
public void TryReadingIncompleteMessage() public void TryReadingIncompleteMessage()
{ {
var message = new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("ABC")); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("ABC"));
Assert.False(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.False(TextMessageParser.TryParseMessage(ref message, out var payload));
} }
[Fact] [Fact]
public void TryReadingMultipleMessages() public void TryReadingMultipleMessages()
{ {
var message = new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e")); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e"));
Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload));
Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray()));
Assert.True(TextMessageParser.TryParseMessage(ref message, out payload)); Assert.True(TextMessageParser.TryParseMessage(ref message, out payload));
@ -40,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[Fact] [Fact]
public void IncompleteTrailingMessage() public void IncompleteTrailingMessage()
{ {
var message = new ReadOnlyMemory<byte>(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e123")); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("ABC\u001eXYZ\u001e123"));
Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload)); Assert.True(TextMessageParser.TryParseMessage(ref message, out var payload));
Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray())); Assert.Equal("ABC", Encoding.UTF8.GetString(payload.ToArray()));
Assert.True(TextMessageParser.TryParseMessage(ref message, out payload)); Assert.True(TextMessageParser.TryParseMessage(ref message, out payload));

View File

@ -18,9 +18,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{\"protocol\":null,\"version\":123}\u001e", null, 123)] [InlineData("{\"protocol\":null,\"version\":123}\u001e", null, 123)]
public void ParsingHandshakeRequestMessageSuccessForValidMessages(string json, string protocol, int version) public void ParsingHandshakeRequestMessageSuccessForValidMessages(string json, string protocol, int version)
{ {
var message = Encoding.UTF8.GetBytes(json); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(json));
Assert.True(HandshakeProtocol.TryParseRequestMessage(new ReadOnlySequence<byte>(message), out var deserializedMessage, out _, out _)); Assert.True(HandshakeProtocol.TryParseRequestMessage(ref message, out var deserializedMessage));
Assert.Equal(protocol, deserializedMessage.Protocol); Assert.Equal(protocol, deserializedMessage.Protocol);
Assert.Equal(version, deserializedMessage.Version); Assert.Equal(version, deserializedMessage.Version);
@ -33,19 +33,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{}\u001e", null)] [InlineData("{}\u001e", null)]
public void ParsingHandshakeResponseMessageSuccessForValidMessages(string json, string error) public void ParsingHandshakeResponseMessageSuccessForValidMessages(string json, string error)
{ {
var message = Encoding.UTF8.GetBytes(json); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(json));
var response = HandshakeProtocol.ParseResponseMessage(message);
Assert.True(HandshakeProtocol.TryParseResponseMessage(ref message, out var response));
Assert.Equal(error, response.Error); Assert.Equal(error, response.Error);
} }
[Fact] [Fact]
public void ParsingHandshakeRequestNotCompleteReturnsFalse() public void ParsingHandshakeRequestNotCompleteReturnsFalse()
{ {
var message = Encoding.UTF8.GetBytes("42"); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("42"));
Assert.False(HandshakeProtocol.TryParseRequestMessage(new ReadOnlySequence<byte>(message), out _, out _, out _)); Assert.False(HandshakeProtocol.TryParseRequestMessage(ref message, out _));
} }
[Theory] [Theory]
@ -59,25 +58,25 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{\"protocol\":null,\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")] [InlineData("{\"protocol\":null,\"version\":\"123\"}\u001e", "Expected 'version' to be of type Integer.")]
public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage) public void ParsingHandshakeRequestMessageThrowsForInvalidMessages(string payload, string expectedMessage)
{ {
var message = Encoding.UTF8.GetBytes(payload); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(payload));
var exception = Assert.Throws<InvalidDataException>(() => var exception = Assert.Throws<InvalidDataException>(() =>
Assert.True(HandshakeProtocol.TryParseRequestMessage(new ReadOnlySequence<byte>(message), out _, out _, out _))); Assert.True(HandshakeProtocol.TryParseRequestMessage(ref message, out _)));
Assert.Equal(expectedMessage, exception.Message); Assert.Equal(expectedMessage, exception.Message);
} }
[Theory] [Theory]
[InlineData("42", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")] [InlineData("42\u001e", "Unexpected JSON Token Type 'Integer'. Expected a JSON Object.")]
[InlineData("\"42\"", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")] [InlineData("\"42\"\u001e", "Unexpected JSON Token Type 'String'. Expected a JSON Object.")]
[InlineData("null", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")] [InlineData("null\u001e", "Unexpected JSON Token Type 'Null'. Expected a JSON Object.")]
[InlineData("[]", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")] [InlineData("[]\u001e", "Unexpected JSON Token Type 'Array'. Expected a JSON Object.")]
public void ParsingHandshakeResponseMessageThrowsForInvalidMessages(string payload, string expectedMessage) public void ParsingHandshakeResponseMessageThrowsForInvalidMessages(string payload, string expectedMessage)
{ {
var message = Encoding.UTF8.GetBytes(payload); var message = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(payload));
var exception = Assert.Throws<InvalidDataException>(() => var exception = Assert.Throws<InvalidDataException>(() =>
HandshakeProtocol.ParseResponseMessage(message)); HandshakeProtocol.TryParseRequestMessage(ref message, out _));
Assert.Equal(expectedMessage, exception.Message); Assert.Equal(expectedMessage, exception.Message);
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Text; using System.Text;
@ -133,10 +134,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(expectedMessage); var binder = new TestBinder(expectedMessage);
var protocol = new JsonHubProtocol(Options.Create(protocolOptions)); var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(input));
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages); protocol.TryParseMessage(ref data, binder, out var message);
Assert.Equal(expectedMessage, messages[0], TestHubMessageEqualityComparer.Instance); Assert.Equal(expectedMessage, message, TestHubMessageEqualityComparer.Instance);
} }
[Theory] [Theory]
@ -183,8 +184,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(Array.Empty<Type>(), typeof(object)); var binder = new TestBinder(Array.Empty<Type>(), typeof(object));
var protocol = new JsonHubProtocol(); var protocol = new JsonHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(input));
var ex = Assert.Throws<InvalidDataException>(() => protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages)); var ex = Assert.Throws<InvalidDataException>(() => protocol.TryParseMessage(ref data, binder, out var _));
Assert.Equal(expectedMessage, ex.Message); Assert.Equal(expectedMessage, ex.Message);
} }
@ -196,10 +197,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(expectedMessage); var binder = new TestBinder(expectedMessage);
var protocol = new JsonHubProtocol(); var protocol = new JsonHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(input));
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages); protocol.TryParseMessage(ref data, binder, out var message);
Assert.Equal(expectedMessage, messages[0], TestHubMessageEqualityComparer.Instance); Assert.Equal(expectedMessage, message, TestHubMessageEqualityComparer.Instance);
} }
[Theory] [Theory]
@ -210,9 +211,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool)); var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
var protocol = new JsonHubProtocol(); var protocol = new JsonHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(input));
Assert.True(protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages)); Assert.True(protocol.TryParseMessage(ref data, binder, out var message));
Assert.Single(messages); Assert.NotNull(message);
} }
[Theory] [Theory]
@ -228,9 +229,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool)); var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
var protocol = new JsonHubProtocol(); var protocol = new JsonHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(input));
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages); protocol.TryParseMessage(ref data, binder, out var message);
var ex = Assert.Throws<InvalidDataException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments); var ex = Assert.Throws<InvalidDataException>(() => ((HubMethodInvocationMessage)message).Arguments);
Assert.Equal(expectedMessage, ex.Message); Assert.Equal(expectedMessage, ex.Message);
} }

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@ -286,11 +287,11 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
// Parse the input fully now. // Parse the input fully now.
bytes = Frame(bytes); bytes = Frame(bytes);
var protocol = new MessagePackHubProtocol(); var protocol = new MessagePackHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(bytes);
Assert.True(protocol.TryParseMessages(bytes, new TestBinder(testData.Message), messages)); Assert.True(protocol.TryParseMessage(ref data, new TestBinder(testData.Message), out var message));
Assert.Single(messages); Assert.NotNull(message);
Assert.Equal(testData.Message, messages[0], TestHubMessageEqualityComparer.Instance); Assert.Equal(testData.Message, message, TestHubMessageEqualityComparer.Instance);
} }
[Fact] [Fact]
@ -308,11 +309,11 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
// Parse the input fully now. // Parse the input fully now.
bytes = Frame(bytes); bytes = Frame(bytes);
var protocol = new MessagePackHubProtocol(); var protocol = new MessagePackHubProtocol();
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(bytes);
Assert.True(protocol.TryParseMessages(bytes, new TestBinder(expectedMessage), messages)); Assert.True(protocol.TryParseMessage(ref data, new TestBinder(expectedMessage), out var message));
Assert.Single(messages); Assert.NotNull(message);
Assert.Equal(expectedMessage, messages[0], TestHubMessageEqualityComparer.Instance); Assert.Equal(expectedMessage, message, TestHubMessageEqualityComparer.Instance);
} }
[Theory] [Theory]
@ -325,7 +326,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
AssertMessages(testData.Encoded, bytes); AssertMessages(testData.Encoded, bytes);
// Unframe the message to check the binary encoding // Unframe the message to check the binary encoding
ReadOnlyMemory<byte> byteSpan = bytes; var byteSpan = new ReadOnlySequence<byte>(bytes);
Assert.True(BinaryMessageParser.TryParseMessage(ref byteSpan, out var unframed)); Assert.True(BinaryMessageParser.TryParseMessage(ref byteSpan, out var unframed));
// Check the baseline binary encoding, use Assert.True in order to configure the error message // Check the baseline binary encoding, use Assert.True in order to configure the error message
@ -380,8 +381,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{ {
var buffer = Frame(Pack(testData.Encoded)); var buffer = Frame(Pack(testData.Encoded));
var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(buffer);
var exception = Assert.Throws<FormatException>(() => _hubProtocol.TryParseMessages(buffer, binder, messages)); var exception = Assert.Throws<FormatException>(() => _hubProtocol.TryParseMessage(ref data, binder, out _));
Assert.Equal(testData.ErrorMessage, exception.Message); Assert.Equal(testData.ErrorMessage, exception.Message);
} }
@ -409,22 +410,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{ {
var buffer = Frame(Pack(testData.Encoded)); var buffer = Frame(Pack(testData.Encoded));
var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(buffer);
_hubProtocol.TryParseMessages(buffer, binder, messages); _hubProtocol.TryParseMessage(ref data, binder, out var message);
var exception = Assert.Throws<FormatException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments); var exception = Assert.Throws<FormatException>(() => ((HubMethodInvocationMessage)message).Arguments);
Assert.Equal(testData.ErrorMessage, exception.Message); Assert.Equal(testData.ErrorMessage, exception.Message);
} }
[Theory] [Theory]
[InlineData(new object[] { new byte[] { 0x05, 0x01 }, 0 })] [InlineData(new byte[] { 0x05, 0x01 })]
public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount) public void ParserDoesNotConsumePartialData(byte[] payload)
{ {
var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
var messages = new List<HubMessage>(); var data = new ReadOnlySequence<byte>(payload);
var result = _hubProtocol.TryParseMessages(payload, binder, messages); var result = _hubProtocol.TryParseMessage(ref data, binder, out var message);
Assert.True(result || messages.Count == 0); Assert.Null(message);
Assert.Equal(expectedMessagesCount, messages.Count);
} }
[Fact] [Fact]
@ -434,9 +434,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
AssertMessages(Array(HubProtocolConstants.CompletionMessageType, Map(), "0", 3, Array(42)), result); AssertMessages(Array(HubProtocolConstants.CompletionMessageType, Map(), "0", 3, Array(42)), result);
} }
private static void AssertMessages(MessagePackObject expectedOutput, ReadOnlyMemory<byte> bytes) private static void AssertMessages(MessagePackObject expectedOutput, byte[] bytes)
{ {
Assert.True(BinaryMessageParser.TryParseMessage(ref bytes, out var message)); var data = new ReadOnlySequence<byte>(bytes);
Assert.True(BinaryMessageParser.TryParseMessage(ref data, out var message));
var obj = Unpack(message.ToArray()); var obj = Unpack(message.ToArray());
Assert.Equal(expectedOutput, obj); Assert.Equal(expectedOutput, obj);
} }

View File

@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void ReadingWhenCharBufferBigEnough() public void ReadingWhenCharBufferBigEnough()
{ {
var buffer = Encoding.UTF8.GetBytes("Hello World"); var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("Hello World"));
var reader = new Utf8BufferTextReader(); var reader = new Utf8BufferTextReader();
reader.SetBuffer(buffer); reader.SetBuffer(buffer);
@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void ReadingUnicodeWhenCharBufferBigEnough() public void ReadingUnicodeWhenCharBufferBigEnough()
{ {
var buffer = Encoding.UTF8.GetBytes("a\u00E4\u00E4\u00a9o"); var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("a\u00E4\u00E4\u00a9o"));
var reader = new Utf8BufferTextReader(); var reader = new Utf8BufferTextReader();
reader.SetBuffer(buffer); reader.SetBuffer(buffer);
@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void ReadingWhenCharBufferBigEnoughAndNotStartingFromZero() public void ReadingWhenCharBufferBigEnoughAndNotStartingFromZero()
{ {
var buffer = Encoding.UTF8.GetBytes("Hello World"); var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("Hello World"));
var reader = new Utf8BufferTextReader(); var reader = new Utf8BufferTextReader();
reader.SetBuffer(buffer); reader.SetBuffer(buffer);
@ -60,7 +60,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void ReadingWhenBufferTooSmall() public void ReadingWhenBufferTooSmall()
{ {
var buffer = Encoding.UTF8.GetBytes("Hello World"); var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("Hello World"));
var reader = new Utf8BufferTextReader(); var reader = new Utf8BufferTextReader();
reader.SetBuffer(buffer); reader.SetBuffer(buffer);
@ -92,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void ReadingUnicodeWhenBufferTooSmall() public void ReadingUnicodeWhenBufferTooSmall()
{ {
var buffer = Encoding.UTF8.GetBytes("\u00E4\u00E4\u00E5"); var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes("\u00E4\u00E4\u00E5"));
var reader = new Utf8BufferTextReader(); var reader = new Utf8BufferTextReader();
reader.SetBuffer(buffer); reader.SetBuffer(buffer);

View File

@ -27,7 +27,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
private readonly IHubProtocol _protocol; private readonly IHubProtocol _protocol;
private readonly IInvocationBinder _invocationBinder; private readonly IInvocationBinder _invocationBinder;
private readonly CancellationTokenSource _cts; private readonly CancellationTokenSource _cts;
private readonly Queue<HubMessage> _messages = new Queue<HubMessage>();
public DefaultConnectionContext Connection { get; } public DefaultConnectionContext Connection { get; }
public Task Connected => ((TaskCompletionSource<bool>)Connection.Items["ConnectedTask"]).Task; public Task Connected => ((TaskCompletionSource<bool>)Connection.Items["ConnectedTask"]).Task;
@ -84,7 +83,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
// note that the handshake response might not immediately be readable // note that the handshake response might not immediately be readable
// e.g. server is waiting for request, times out after configured duration, // e.g. server is waiting for request, times out after configured duration,
// and sends response with timeout error // and sends response with timeout error
HandshakeResponseMessage = (HandshakeResponseMessage) await ReadAsync(true).OrTimeout(); HandshakeResponseMessage = (HandshakeResponseMessage)await ReadAsync(true).OrTimeout();
} }
return connection; return connection;
@ -220,51 +219,35 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public HubMessage TryRead(bool isHandshake = false) public HubMessage TryRead(bool isHandshake = false)
{ {
if (_messages.Count > 0)
{
return _messages.Dequeue();
}
if (!Connection.Application.Input.TryRead(out var result)) if (!Connection.Application.Input.TryRead(out var result))
{ {
return null; return null;
} }
var buffer = result.Buffer; var buffer = result.Buffer;
var consumed = buffer.End;
var examined = consumed;
try try
{ {
if (!isHandshake) if (!isHandshake)
{ {
var messages = new List<HubMessage>(); if (_protocol.TryParseMessage(ref buffer, _invocationBinder, out var message))
if (_protocol.TryParseMessages(result.Buffer.ToArray(), _invocationBinder, messages))
{ {
foreach (var m in messages) return message;
{
_messages.Enqueue(m);
}
return _messages.Dequeue();
} }
} }
else else
{ {
HandshakeProtocol.TryReadMessageIntoSingleMemory(buffer, out consumed, out examined, out var data); // read first message out of the incoming data
if (!HandshakeProtocol.TryParseResponseMessage(ref buffer, out var responseMessage))
// read first message out of the incoming data
if (!TextMessageParser.TryParseMessage(ref data, out var payload))
{ {
throw new InvalidDataException("Unable to parse payload as a handshake response message."); throw new InvalidDataException("Unable to parse payload as a handshake response message.");
} }
return responseMessage;
return HandshakeProtocol.ParseResponseMessage(payload);
} }
} }
finally finally
{ {
Connection.Application.Input.AdvanceTo(consumed, examined); Connection.Application.Input.AdvanceTo(buffer.Start);
} }
return null; return null;