diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index 3d4c99a1a0..c8c36a25e3 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -1,18 +1,18 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using System; +using System.Buffers; using System.Collections.Generic; -using System.IO; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets.Formatters; -using Microsoft.Extensions.Internal; -using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.Sockets.Client { @@ -26,6 +26,8 @@ namespace Microsoft.AspNetCore.Sockets.Client private IChannelConnection _application; private Task _sender; private Task _poller; + private MessageParser _parser = new MessageParser(); + private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); public Task Running { get; private set; } = Task.CompletedTask; @@ -101,18 +103,22 @@ namespace Microsoft.AspNetCore.Sockets.Client } else { - _logger.LogDebug("Receive a message from the server"); + _logger.LogDebug("Received messages from the server"); - // Read the whole payload + // Until Pipeline starts natively supporting BytesReader, this is the easiest way to do this. var payload = await response.Content.ReadAsByteArrayAsync(); - - foreach (var message in ReadMessages(payload)) + if (payload.Length > 0) { - while (!_application.Output.TryWrite(message)) + var messages = ParsePayload(payload); + + foreach (var message in messages) { - if (cancellationToken.IsCancellationRequested || !await _application.Output.WaitToWriteAsync(cancellationToken)) + while (!_application.Output.TryWrite(message)) { - return; + if (cancellationToken.IsCancellationRequested || !await _application.Output.WaitToWriteAsync(cancellationToken)) + { + return; + } } } } @@ -137,26 +143,29 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.LogInformation("Receive loop stopped"); } - private IEnumerable ReadMessages(ReadOnlySpan payload) + private IList ParsePayload(byte[] payload) { - if (payload.Length == 0) + var reader = new BytesReader(payload); + var messageFormat = MessageParser.GetFormat(reader.Unread[0]); + reader.Advance(1); + + _parser.Reset(); + var messages = new List(); + while (_parser.TryParseMessage(ref reader, messageFormat, out var message)) { - yield break; + messages.Add(message); } - var messageFormat = MessageFormatter.GetFormat(payload[0]); - payload = payload.Slice(1); + // Since we pre-read the whole payload, we know that when this fails we have read everything. + // Once Pipelines natively support BytesReader, we could get into situations where the data for + // a message just isn't available yet. - while (payload.Length > 0) + // If there's still data, we hit an incomplete message + if (reader.Unread.Length > 0) { - if (!MessageFormatter.TryParseMessage(payload, messageFormat, out var message, out var consumed)) - { - throw new InvalidDataException("Invalid message payload from server"); - } - - payload = payload.Slice(consumed); - yield return message; + throw new FormatException("Incomplete message"); } + return messages; } private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken) diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/BinaryMessageFormatter.cs deleted file mode 100644 index 50f6dbe861..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/BinaryMessageFormatter.cs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Binary; -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets.Formatters -{ - internal static class BinaryMessageFormatter - { - private const byte TextTypeFlag = 0x00; - private const byte BinaryTypeFlag = 0x01; - private const byte ErrorTypeFlag = 0x02; - private const byte CloseTypeFlag = 0x03; - - internal static bool TryFormatMessage(Message message, Span buffer, out int bytesWritten) - { - // We can check the size needed right up front! - var sizeNeeded = sizeof(long) + 1 + message.Payload.Length; - if (buffer.Length < sizeNeeded) - { - bytesWritten = 0; - return false; - } - - buffer.WriteBigEndian((long)message.Payload.Length); - if (!TryFormatType(message.Type, buffer.Slice(sizeof(long), 1))) - { - bytesWritten = 0; - return false; - } - - buffer = buffer.Slice(sizeof(long) + 1); - - message.Payload.CopyTo(buffer); - bytesWritten = sizeNeeded; - return true; - } - - internal static bool TryParseMessage(ReadOnlySpan buffer, out Message message, out int bytesConsumed) - { - // Check if we have enough to read the size and type flag - if (buffer.Length < sizeof(long) + 1) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // REVIEW: The spec calls for 64-bit length but I'm thinking that's a little ridiculous. - // REVIEW: We don't really have a primitive for storing that much data. For now, I'm using it - // REVIEW: but throwing if the size is over 2GB. - var longLength = buffer.ReadBigEndian(); - if (longLength > Int32.MaxValue) - { - throw new FormatException("Messages over 2GB in size are not supported"); - } - var length = (int)longLength; - - if (!TryParseType(buffer[sizeof(long)], out var messageType)) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // Check if we actually have the whole payload - if (buffer.Length < sizeof(long) + 1 + length) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // Copy the payload into the buffer - // REVIEW: Copy! Noooooooooo! But how can we capture a segment of the span as an "Owned" reference? - // REVIEW: If we do have to copy, we should at least use a pooled buffer - var buf = new byte[length]; - buffer.Slice(sizeof(long) + 1, length).CopyTo(buf); - - message = new Message(buf, messageType, endOfMessage: true); - bytesConsumed = sizeof(long) + 1 + length; - return true; - } - - private static bool TryParseType(byte type, out MessageType messageType) - { - switch (type) - { - case TextTypeFlag: - messageType = MessageType.Text; - return true; - case BinaryTypeFlag: - messageType = MessageType.Binary; - return true; - case CloseTypeFlag: - messageType = MessageType.Close; - return true; - case ErrorTypeFlag: - messageType = MessageType.Error; - return true; - default: - messageType = default(MessageType); - return false; - } - } - - private static bool TryFormatType(MessageType type, Span buffer) - { - switch (type) - { - case MessageType.Text: - buffer[0] = TextTypeFlag; - return true; - case MessageType.Binary: - buffer[0] = BinaryTypeFlag; - return true; - case MessageType.Close: - buffer[0] = CloseTypeFlag; - return true; - case MessageType.Error: - buffer[0] = ErrorTypeFlag; - return true; - default: - return false; - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/MessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/MessageFormatter.cs deleted file mode 100644 index c8a4ee6229..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/MessageFormatter.cs +++ /dev/null @@ -1,79 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; - -namespace Microsoft.AspNetCore.Sockets.Formatters -{ - public static class MessageFormatter - { - public static readonly byte TextFormatIndicator = (byte)'T'; - public static readonly byte BinaryFormatIndicator = (byte)'B'; - - public static readonly string TextContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text"; - public static readonly string BinaryContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+binary"; - - public static bool TryFormatMessage(Message message, Span buffer, MessageFormat format, out int bytesWritten) - { - if (!message.EndOfMessage) - { - // This is truly an exceptional condition since we EXPECT callers to have already - // buffered incomplete messages and synthesized the correct, complete message before - // giving it to us. Hence we throw, instead of returning false. - throw new ArgumentException("Cannot format message where endOfMessage is false using this format", nameof(message)); - } - - return format == MessageFormat.Text ? - TextMessageFormatter.TryFormatMessage(message, buffer, out bytesWritten) : - BinaryMessageFormatter.TryFormatMessage(message, buffer, out bytesWritten); - } - - public static bool TryParseMessage(ReadOnlySpan buffer, MessageFormat format, out Message message, out int bytesConsumed) - { - return format == MessageFormat.Text ? - TextMessageFormatter.TryParseMessage(buffer, out message, out bytesConsumed) : - BinaryMessageFormatter.TryParseMessage(buffer, out message, out bytesConsumed); - } - - public static string GetContentType(MessageFormat messageFormat) - { - switch (messageFormat) - { - case MessageFormat.Text: return TextContentType; - case MessageFormat.Binary: return BinaryContentType; - default: throw new ArgumentException($"Invalid message format: {messageFormat}", nameof(messageFormat)); - } - } - - public static byte GetFormatIndicator(MessageFormat messageFormat) - { - switch (messageFormat) - { - case MessageFormat.Text: return TextFormatIndicator; - case MessageFormat.Binary: return BinaryFormatIndicator; - default: throw new ArgumentException($"Invalid message format: {messageFormat}", nameof(messageFormat)); - } - } - - public static MessageFormat GetFormat(byte formatIndicator) - { - // Can't use switch because our "constants" are not consts, they're "static readonly" (which is good, because they are public) - if (formatIndicator == TextFormatIndicator) - { - return MessageFormat.Text; - } - - if (formatIndicator == BinaryFormatIndicator) - { - return MessageFormat.Binary; - } - - throw new ArgumentException($"Invalid message format: 0x{formatIndicator:X}", nameof(formatIndicator)); - } - - public static bool TryParseMessage(ReadOnlySpan payload, object messageFormat) - { - throw new NotImplementedException(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs deleted file mode 100644 index 1ab081b281..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Binary; - -namespace Microsoft.AspNetCore.Sockets.Formatters -{ - public static class ServerSentEventsMessageFormatter - { - private static readonly Span DataPrefix = new byte[] { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' }; - private static readonly Span Newline = new byte[] { (byte)'\r', (byte)'\n' }; - - private const byte LineFeed = (byte)'\n'; - private const byte TextTypeFlag = (byte)'T'; - private const byte BinaryTypeFlag = (byte)'B'; - private const byte CloseTypeFlag = (byte)'C'; - private const byte ErrorTypeFlag = (byte)'E'; - - public static bool TryFormatMessage(Message message, Span buffer, out int bytesWritten) - { - if (!message.EndOfMessage) - { - // This is a truely exceptional condition since we EXPECT callers to have already - // buffered incomplete messages and synthesized the correct, complete message before - // giving it to us. Hence we throw, instead of returning false. - throw new InvalidOperationException("Cannot format message where endOfMessage is false using this format"); - } - - // Need at least: Length of 'data: ', one character type, one \r\n, and the trailing \r\n - if (buffer.Length < DataPrefix.Length + 1 + Newline.Length + Newline.Length) - { - bytesWritten = 0; - return false; - } - DataPrefix.CopyTo(buffer); - buffer = buffer.Slice(DataPrefix.Length); - if (!TryFormatType(buffer, message.Type)) - { - bytesWritten = 0; - return false; - } - buffer = buffer.Slice(1); - - Newline.CopyTo(buffer); - buffer = buffer.Slice(Newline.Length); - - // Write the payload - if (!TryFormatPayload(message.Payload, message.Type, buffer, out var writtenForPayload)) - { - bytesWritten = 0; - return false; - } - buffer = buffer.Slice(writtenForPayload); - - if (buffer.Length < Newline.Length) - { - bytesWritten = 0; - return false; - } - Newline.CopyTo(buffer); - - bytesWritten = DataPrefix.Length + Newline.Length + 1 + writtenForPayload + Newline.Length; - return true; - } - - private static bool TryFormatPayload(ReadOnlySpan payload, MessageType type, Span buffer, out int bytesWritten) - { - // Short-cut for empty payload - if (payload.Length == 0) - { - bytesWritten = 0; - return true; - } - - var writtenSoFar = 0; - if (type == MessageType.Binary) - { - var encodedSize = DataPrefix.Length + Base64.ComputeEncodedLength(payload.Length) + Newline.Length; - if (buffer.Length < encodedSize) - { - bytesWritten = 0; - return false; - } - DataPrefix.CopyTo(buffer); - buffer = buffer.Slice(DataPrefix.Length); - - var encodedLength = Base64.Encode(payload, buffer); - buffer = buffer.Slice(encodedLength); - - Newline.CopyTo(buffer); - writtenSoFar += encodedSize; - buffer.Slice(Newline.Length); - } - else - { - // We can't just use while(payload.Length > 0) because we need to write a blank final "data: " line - // if the payload ends in a newline. For example, consider the following payload: - // "Hello\n" - // It needs to be written as: - // data: Hello\r\n - // data: \r\n - // \r\n - // Since we slice past the newline when we find it, after writing "Hello" in the previous example, we'll - // end up with an empty payload buffer, BUT we need to write it as an empty 'data:' line, so we need - // to use a condition that ensure the only time we stop writing is when we write the slice after the final - // newline. - var keepWriting = true; - while (keepWriting) - { - // Seek to the end of buffer or newline - var sliceEnd = payload.IndexOf(LineFeed); - var nextSliceStart = sliceEnd + 1; - if (sliceEnd < 0) - { - sliceEnd = payload.Length; - nextSliceStart = sliceEnd + 1; - - // This is the last span - keepWriting = false; - } - if (sliceEnd > 0 && payload[sliceEnd - 1] == '\r') - { - sliceEnd--; - } - - var slice = payload.Slice(0, sliceEnd); - - if (nextSliceStart >= payload.Length) - { - payload = Span.Empty; - } - else - { - payload = payload.Slice(nextSliceStart); - } - - if (!TryFormatLine(slice, buffer, out var writtenByLine)) - { - bytesWritten = 0; - return false; - } - buffer = buffer.Slice(writtenByLine); - writtenSoFar += writtenByLine; - } - } - - bytesWritten = writtenSoFar; - return true; - } - - private static bool TryFormatLine(ReadOnlySpan line, Span buffer, out int bytesWritten) - { - // We're going to write the whole thing. HOWEVER, if the last byte is a '\r', we want to truncate it - // because it was the '\r' in a '\r\n' newline sequence - // This won't require an additional byte in the buffer because after this line we have to write a newline sequence anyway. - var writtenSoFar = 0; - if (buffer.Length < DataPrefix.Length + line.Length) - { - bytesWritten = 0; - return false; - } - DataPrefix.CopyTo(buffer); - writtenSoFar += DataPrefix.Length; - buffer = buffer.Slice(DataPrefix.Length); - - line.CopyTo(buffer); - var sliceTo = line.Length; - if (sliceTo > 0 && buffer[sliceTo - 1] == '\r') - { - sliceTo -= 1; - } - writtenSoFar += sliceTo; - buffer = buffer.Slice(sliceTo); - - if (buffer.Length < Newline.Length) - { - bytesWritten = 0; - return false; - } - writtenSoFar += Newline.Length; - Newline.CopyTo(buffer); - - bytesWritten = writtenSoFar; - return true; - } - - private static bool TryFormatType(Span buffer, MessageType type) - { - switch (type) - { - case MessageType.Text: - buffer[0] = TextTypeFlag; - return true; - case MessageType.Binary: - buffer[0] = BinaryTypeFlag; - return true; - case MessageType.Close: - buffer[0] = CloseTypeFlag; - return true; - case MessageType.Error: - buffer[0] = ErrorTypeFlag; - return true; - default: - return false; - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs deleted file mode 100644 index b2eab314a3..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs +++ /dev/null @@ -1,241 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Binary; -using System.IO.Pipelines; -using System.Text; - -namespace Microsoft.AspNetCore.Sockets.Formatters -{ - internal static class TextMessageFormatter - { - private const byte FieldDelimiter = (byte)':'; - private const byte MessageDelimiter = (byte)';'; - private const byte TextTypeFlag = (byte)'T'; - private const byte BinaryTypeFlag = (byte)'B'; - private const byte CloseTypeFlag = (byte)'C'; - private const byte ErrorTypeFlag = (byte)'E'; - - internal static bool TryFormatMessage(Message message, Span buffer, out int bytesWritten) - { - // Calculate the length, it's the number of characters for text messages, but number of base64 characters for binary - var length = message.Payload.Length; - if (message.Type == MessageType.Binary) - { - length = (int)(4 * Math.Ceiling(((double)message.Payload.Length / 3))); - } - - // Write the length as a string - int written = 0; - if (!length.TryFormat(buffer, out int lengthLen, default(TextFormat), TextEncoder.Utf8)) - { - bytesWritten = 0; - return false; - } - written += lengthLen; - buffer = buffer.Slice(lengthLen); - - // We need at least 4 more characters of space (':', type flag, ':', and eventually the terminating ';') - // We'll still need to double-check that we have space for the terminator after we write the payload, - // but this way we can exit early if the buffer is way too small. - if (buffer.Length < 4 + length) - { - bytesWritten = 0; - return false; - } - buffer[0] = FieldDelimiter; - if (!TryFormatType(message.Type, buffer.Slice(1, 1))) - { - bytesWritten = 0; - return false; - } - buffer[2] = FieldDelimiter; - buffer = buffer.Slice(3); - written += 3; - - // Payload - if (message.Type == MessageType.Binary) - { - // Encode the payload directly into the buffer - var writtenByPayload = Base64.Encode(message.Payload, buffer); - - // Check that we wrote enough. Length was already set (above) to the expected length in base64-encoded bytes - if (writtenByPayload < length) - { - bytesWritten = 0; - return false; - } - - // We did, advance the buffers and continue - buffer = buffer.Slice(writtenByPayload); - written += writtenByPayload; - } - else - { - message.Payload.CopyTo(buffer.Slice(0, message.Payload.Length)); - written += message.Payload.Length; - buffer = buffer.Slice(message.Payload.Length); - } - - // Terminator - if (buffer.Length < 1) - { - bytesWritten = 0; - return false; - } - buffer[0] = MessageDelimiter; - bytesWritten = written + 1; - return true; - } - - internal static bool TryParseMessage(ReadOnlySpan buffer, out Message message, out int bytesConsumed) - { - // Read until the first ':' to find the length - var consumedSoFar = 0; - var colonIndex = buffer.IndexOf(FieldDelimiter); - if (colonIndex < 0) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - consumedSoFar += colonIndex; - var lengthSpan = buffer.Slice(0, colonIndex); - buffer = buffer.Slice(colonIndex); - - // Parse the length - if (!PrimitiveParser.TryParseInt32(lengthSpan, out var length, out var consumedByLength, encoder: TextEncoder.Utf8) || consumedByLength < lengthSpan.Length) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // Check if there's enough space in the buffer to even bother continuing - // There are at least 4 characters we still expect to see: ':', type flag, ':', ';', plus the (encoded) payload length. - if (buffer.Length < 4 + length) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // Verify that we have the ':' after the type flag. - if (buffer[0] != FieldDelimiter) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // We already know that index 0 is the ':', so next is the type flag at index '1'. - if (!TryParseType(buffer[1], out var messageType)) - { - message = default(Message); - bytesConsumed = 0; - } - - // Verify that we have the ':' after the type flag. - if (buffer[2] != FieldDelimiter) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - // Slice off ':[Type]:' and check the remaining length - buffer = buffer.Slice(3); - consumedSoFar += 3; - - // Grab the payload buffer - var payload = buffer.Slice(0, length); - buffer = buffer.Slice(length); - consumedSoFar += length; - - // Parse the payload. For now, we make it an array and use the old-fashioned types. - // I've filed https://github.com/aspnet/SignalR/issues/192 to update this. -anurse - if (messageType == MessageType.Binary && payload.Length > 0) - { - // Determine the output size - // Every 4 Base64 characters represents 3 bytes - var decodedLength = (int)((payload.Length / 4) * 3); - - // Subtract padding bytes - if (payload[payload.Length - 1] == '=') - { - decodedLength -= 1; - } - if (payload.Length > 1 && payload[payload.Length - 2] == '=') - { - decodedLength -= 1; - } - - // Allocate a new buffer to decode to - var decodeBuffer = new byte[decodedLength]; - if (Base64.Decode(payload, decodeBuffer) != decodedLength) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - payload = decodeBuffer; - } - - // Verify the trailer - if (buffer.Length < 1 || buffer[0] != MessageDelimiter) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - - message = new Message(payload.ToArray(), messageType); - bytesConsumed = consumedSoFar + 1; - return true; - } - - private static bool TryParseType(byte type, out MessageType messageType) - { - switch (type) - { - case TextTypeFlag: - messageType = MessageType.Text; - return true; - case BinaryTypeFlag: - messageType = MessageType.Binary; - return true; - case CloseTypeFlag: - messageType = MessageType.Close; - return true; - case ErrorTypeFlag: - messageType = MessageType.Error; - return true; - default: - messageType = default(MessageType); - return false; - } - } - - private static bool TryFormatType(MessageType type, Span buffer) - { - switch (type) - { - case MessageType.Text: - buffer[0] = TextTypeFlag; - return true; - case MessageType.Binary: - buffer[0] = BinaryTypeFlag; - return true; - case MessageType.Close: - buffer[0] = CloseTypeFlag; - return true; - case MessageType.Error: - buffer[0] = ErrorTypeFlag; - return true; - default: - return false; - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageFormatter.cs new file mode 100644 index 0000000000..5a5e3c4372 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageFormatter.cs @@ -0,0 +1,56 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + internal static class BinaryMessageFormatter + { + internal const byte TextTypeFlag = 0x00; + internal const byte BinaryTypeFlag = 0x01; + internal const byte ErrorTypeFlag = 0x02; + internal const byte CloseTypeFlag = 0x03; + + public static bool TryWriteMessage(Message message, IOutput output) + { + var typeIndicator = GetTypeIndicator(message.Type); + + // Try to write the data + if (!output.TryWriteBigEndian((long)message.Payload.Length)) + { + return false; + } + + if (!output.TryWriteBigEndian(typeIndicator)) + { + return false; + } + + if (!output.TryWrite(message.Payload)) + { + return false; + } + + return true; + } + + private static byte GetTypeIndicator(MessageType type) + { + switch (type) + { + case MessageType.Text: + return TextTypeFlag; + case MessageType.Binary: + return BinaryTypeFlag; + case MessageType.Close: + return CloseTypeFlag; + case MessageType.Error: + return ErrorTypeFlag; + default: + throw new FormatException($"Invalid Message Type: {type}"); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageParser.cs new file mode 100644 index 0000000000..3e549d8b19 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BinaryMessageParser.cs @@ -0,0 +1,114 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Binary; +using System.Buffers; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + internal class BinaryMessageParser + { + private ParserState _state; + + public void Reset() + { + _state = default(ParserState); + } + + public bool TryParseMessage(ref BytesReader buffer, out Message message) + { + if (_state.Length == null) + { + var length = buffer.TryReadBytes(sizeof(long))?.ToSingleSpan(); + if (length == null || length.Value.Length < sizeof(long)) + { + message = default(Message); + return false; + } + + var longLength = length.Value.ReadBigEndian(); + if (longLength > Int32.MaxValue) + { + throw new FormatException("Messages over 2GB in size are not supported"); + } + buffer.Advance(length.Value.Length); + _state.Length = (int)longLength; + } + + if (_state.MessageType == null) + { + if (buffer.Unread.Length == 0) + { + message = default(Message); + return false; + } + + var typeByte = buffer.Unread[0]; + + if (!TryParseType(typeByte, out var messageType)) + { + throw new FormatException($"Unknown type value: 0x{typeByte:X}"); + } + + buffer.Advance(1); + _state.MessageType = messageType; + } + + if (_state.Payload == null) + { + _state.Payload = new byte[_state.Length.Value]; + } + + while (_state.Read < _state.Payload.Length && buffer.Unread.Length > 0) + { + // Copy what we can from the current unread segment + var toCopy = Math.Min(_state.Payload.Length - _state.Read, buffer.Unread.Length); + buffer.Unread.Slice(0, toCopy).CopyTo(_state.Payload.Slice(_state.Read)); + _state.Read += toCopy; + buffer.Advance(toCopy); + } + + if (_state.Read == _state.Payload.Length) + { + message = new Message(_state.Payload, _state.MessageType.Value); + Reset(); + return true; + } + + // There's still more to read. + message = default(Message); + return false; + } + + private static bool TryParseType(byte type, out MessageType messageType) + { + switch (type) + { + case BinaryMessageFormatter.TextTypeFlag: + messageType = MessageType.Text; + return true; + case BinaryMessageFormatter.BinaryTypeFlag: + messageType = MessageType.Binary; + return true; + case BinaryMessageFormatter.CloseTypeFlag: + messageType = MessageType.Close; + return true; + case BinaryMessageFormatter.ErrorTypeFlag: + messageType = MessageType.Error; + return true; + default: + messageType = default(MessageType); + return false; + } + } + + private struct ParserState + { + public int? Length; + public MessageType? MessageType; + public byte[] Payload; + public int Read; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BufferExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BufferExtensions.cs new file mode 100644 index 0000000000..c9c1f257dd --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/BufferExtensions.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace System.Buffers +{ + internal static class BufferExtensions + { + public static ReadOnlySpan ToSingleSpan(this ReadOnlyBytes self) + { + if (self.Rest == null) + { + return self.First.Span; + } + else + { + return self.ToSpan(); + } + } + + public static ReadOnlyBytes? TryReadBytes(this BytesReader self, int count) + { + try + { + return self.ReadBytes(count); + } + catch (ArgumentOutOfRangeException) + { + return null; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs new file mode 100644 index 0000000000..19b5d3b1e4 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/IOutputExtensions.cs @@ -0,0 +1,59 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Binary; +using System.Runtime; +using System.Runtime.CompilerServices; + +namespace System.Buffers +{ + internal static class IOutputExtensions + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryWriteBigEndian<[Primitive] T>(this IOutput self, T value) where T : struct + { + var size = Unsafe.SizeOf(); + if (self.Buffer.Length < size) + { + self.Enlarge(size); + if (self.Buffer.Length < size) + { + return false; + } + } + + self.Buffer.WriteBigEndian(value); + self.Advance(size); + return true; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public static bool TryWrite(this IOutput self, ReadOnlySpan data) + { + while (data.Length > 0) + { + if (self.Buffer.Length == 0) + { + self.Enlarge(data.Length); + if (self.Buffer.Length == 0) + { + // Failed to enlarge + return false; + } + } + + var toWrite = Math.Min(self.Buffer.Length, data.Length); + + // Slice based on what we can fit + var chunk = data.Slice(0, toWrite); + data = data.Slice(toWrite); + + // Copy the chunk + chunk.CopyTo(self.Buffer); + self.Advance(chunk.Length); + } + + return true; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatter.cs new file mode 100644 index 0000000000..bad361588c --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatter.cs @@ -0,0 +1,52 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + public class MessageFormatter + { + public static readonly char TextFormatIndicator = 'T'; + public static readonly char BinaryFormatIndicator = 'B'; + + public static readonly string TextContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+text"; + public static readonly string BinaryContentType = "application/vnd.microsoft.aspnetcore.endpoint-messages.v1+binary"; + + public static bool TryWriteMessage(Message message, IOutput output, MessageFormat format) + { + if (!message.EndOfMessage) + { + // This is a truely exceptional condition since we EXPECT callers to have already + // buffered incomplete messages and synthesized the correct, complete message before + // giving it to us. Hence we throw, instead of returning false. + throw new ArgumentException("Cannot format message where endOfMessage is false using this format", nameof(message)); + } + + return format == MessageFormat.Text ? + TextMessageFormatter.TryWriteMessage(message, output) : + BinaryMessageFormatter.TryWriteMessage(message, output); + } + + public static string GetContentType(MessageFormat messageFormat) + { + switch (messageFormat) + { + case MessageFormat.Text: return TextContentType; + case MessageFormat.Binary: return BinaryContentType; + default: throw new ArgumentException($"Invalid message format: {messageFormat}", nameof(messageFormat)); + } + } + + public static char GetFormatIndicator(MessageFormat messageFormat) + { + switch (messageFormat) + { + case MessageFormat.Text: return TextFormatIndicator; + case MessageFormat.Binary: return BinaryFormatIndicator; + default: throw new ArgumentException($"Invalid message format: {messageFormat}", nameof(messageFormat)); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageParser.cs new file mode 100644 index 0000000000..dcd8eef98f --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageParser.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + public class MessageParser + { + private TextMessageParser _textParser = new TextMessageParser(); + private BinaryMessageParser _binaryParser = new BinaryMessageParser(); + + public void Reset() + { + _textParser.Reset(); + _binaryParser.Reset(); + } + + public bool TryParseMessage(ref BytesReader buffer, MessageFormat format, out Message message) + { + return format == MessageFormat.Text ? + _textParser.TryParseMessage(ref buffer, out message) : + _binaryParser.TryParseMessage(ref buffer, out message); + } + + public static MessageFormat GetFormat(byte formatIndicator) + { + // Can't use switch because our "constants" are not consts, they're "static readonly" (which is good, because they are public) + if (formatIndicator == MessageFormatter.TextFormatIndicator) + { + return MessageFormat.Text; + } + + if (formatIndicator == MessageFormatter.BinaryFormatIndicator) + { + return MessageFormat.Binary; + } + + throw new ArgumentException($"Invalid message format: 0x{formatIndicator:X}", nameof(formatIndicator)); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageFormatter.cs new file mode 100644 index 0000000000..5b0fb4c292 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageFormatter.cs @@ -0,0 +1,163 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Binary; +using System.Buffers; +using System.Text; +using System.Text.Formatting; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + public static class ServerSentEventsMessageFormatter + { + private static readonly Span DataPrefix = new byte[] { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' }; + private static readonly Span Newline = new byte[] { (byte)'\r', (byte)'\n' }; + + private const byte LineFeed = (byte)'\n'; + private const char TextTypeFlag = 'T'; + private const char BinaryTypeFlag = 'B'; + private const char CloseTypeFlag = 'C'; + private const char ErrorTypeFlag = 'E'; + + public static bool TryWriteMessage(Message message, IOutput output) + { + if (!message.EndOfMessage) + { + // This is a truely exceptional condition since we EXPECT callers to have already + // buffered incomplete messages and synthesized the correct, complete message before + // giving it to us. Hence we throw, instead of returning false. + throw new InvalidOperationException("Cannot format message where endOfMessage is false using this format"); + } + + var typeIndicator = GetTypeIndicator(message.Type); + + // Write the Data Prefix + if (!output.TryWrite(DataPrefix)) + { + return false; + } + + // Write the type indicator + output.Append(typeIndicator, TextEncoder.Utf8); + + if (!output.TryWrite(Newline)) + { + return false; + } + + // Write the payload + if (!TryWritePayload(message.Payload, message.Type, output)) + { + return false; + } + + if (!output.TryWrite(Newline)) + { + return false; + } + + return true; + } + + private static bool TryWritePayload(ReadOnlySpan payload, MessageType type, IOutput output) + { + // Short-cut for empty payload + if (payload.Length == 0) + { + return true; + } + + if (type == MessageType.Binary) + { + // TODO: Base64 writer that works with IOutput would be amazing! + var arr = new byte[Base64.ComputeEncodedLength(payload.Length)]; + Base64.Encode(payload, arr); + return TryWriteLine(arr, output); + } + else + { + // We can't just use while(payload.Length > 0) because we need to write a blank final "data: " line + // if the payload ends in a newline. For example, consider the following payload: + // "Hello\n" + // It needs to be written as: + // data: Hello\r\n + // data: \r\n + // \r\n + // Since we slice past the newline when we find it, after writing "Hello" in the previous example, we'll + // end up with an empty payload buffer, BUT we need to write it as an empty 'data:' line, so we need + // to use a condition that ensure the only time we stop writing is when we write the slice after the final + // newline. + var keepWriting = true; + while (keepWriting) + { + // Seek to the end of buffer or newline + var sliceEnd = payload.IndexOf(LineFeed); + var nextSliceStart = sliceEnd + 1; + if (sliceEnd < 0) + { + sliceEnd = payload.Length; + nextSliceStart = sliceEnd + 1; + + // This is the last span + keepWriting = false; + } + if (sliceEnd > 0 && payload[sliceEnd - 1] == '\r') + { + sliceEnd--; + } + + var slice = payload.Slice(0, sliceEnd); + + if (nextSliceStart >= payload.Length) + { + payload = Span.Empty; + } + else + { + payload = payload.Slice(nextSliceStart); + } + + if (!TryWriteLine(slice, output)) + { + return false; + } + } + } + + return true; + } + + private static bool TryWriteLine(ReadOnlySpan line, IOutput output) + { + if (!output.TryWrite(DataPrefix)) + { + return false; + } + + if (!output.TryWrite(line)) + { + return false; + } + + if (!output.TryWrite(Newline)) + { + return false; + } + + return true; + } + + private static char GetTypeIndicator(MessageType type) + { + switch (type) + { + case MessageType.Text: return TextTypeFlag; + case MessageType.Binary: return BinaryTypeFlag; + case MessageType.Close: return CloseTypeFlag; + case MessageType.Error: return ErrorTypeFlag; + default: throw new FormatException($"Invalid Message Type: {type}"); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageFormatter.cs new file mode 100644 index 0000000000..a6287ff80f --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageFormatter.cs @@ -0,0 +1,85 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Binary; +using System.Buffers; +using System.Text; +using System.Text.Formatting; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + internal static class TextMessageFormatter + { + internal const char FieldDelimiter = ':'; + internal const char MessageDelimiter = ';'; + internal const char TextTypeFlag = 'T'; + internal const char BinaryTypeFlag = 'B'; + + internal const char CloseTypeFlag = 'C'; + internal const char ErrorTypeFlag = 'E'; + + public static bool TryWriteMessage(Message message, IOutput output) + { + // Calculate the length, it's the number of characters for text messages, but number of base64 characters for binary + var length = message.Payload.Length; + if (message.Type == MessageType.Binary) + { + length = Base64.ComputeEncodedLength(length); + } + + // Get the type indicator + var typeIndicator = GetTypeIndicator(message.Type); + + // Write the length as a string + output.Append(length, TextEncoder.Utf8); + + // Write the field delimiter ':' + output.Append(FieldDelimiter, TextEncoder.Utf8); + + // Write the type + output.Append(typeIndicator, TextEncoder.Utf8); + + // Write the field delimiter ':' + output.Append(FieldDelimiter, TextEncoder.Utf8); + + // Write the payload + if (!TryWritePayload(message, output, length)) + { + return false; + } + + // Terminator + output.Append(MessageDelimiter, TextEncoder.Utf8); + return true; + } + + private static bool TryWritePayload(Message message, IOutput output, int length) + { + // Payload + if (message.Type == MessageType.Binary) + { + // TODO: Base64 writer that works with IOutput would be amazing! + var arr = new byte[Base64.ComputeEncodedLength(message.Payload.Length)]; + Base64.Encode(message.Payload, arr); + return output.TryWrite(arr); + } + else + { + return output.TryWrite(message.Payload); + } + } + + private static char GetTypeIndicator(MessageType type) + { + switch (type) + { + case MessageType.Text: return TextTypeFlag; + case MessageType.Binary: return BinaryTypeFlag; + case MessageType.Close: return CloseTypeFlag; + case MessageType.Error: return ErrorTypeFlag; + default: throw new FormatException($"Invalid message type: {type}"); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs new file mode 100644 index 0000000000..764d4334d0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs @@ -0,0 +1,240 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Binary; +using System.Buffers; +using System.Text; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + internal class TextMessageParser + { + private ParserState _state; + + public void Reset() + { + _state = default(ParserState); + } + + /// + /// Attempts to parse a message from the buffer. Returns 'false' if there is not enough data to complete a message. Throws an + /// exception if there is a format error in the provided data. + /// + public bool TryParseMessage(ref BytesReader buffer, out Message message) + { + while (buffer.Unread.Length > 0) + { + switch (_state.Phase) + { + case ParsePhase.ReadingLength: + if (!TryReadLength(ref buffer)) + { + message = default(Message); + return false; + } + + break; + case ParsePhase.LengthComplete: + if (!TryReadDelimiter(ref buffer, TextMessageFormatter.FieldDelimiter, ParsePhase.ReadingType, "length")) + { + message = default(Message); + return false; + } + + break; + case ParsePhase.ReadingType: + if (!TryReadType(ref buffer)) + { + message = default(Message); + return false; + } + + break; + case ParsePhase.TypeComplete: + if (!TryReadDelimiter(ref buffer, TextMessageFormatter.FieldDelimiter, ParsePhase.ReadingPayload, "type")) + { + message = default(Message); + return false; + } + + break; + case ParsePhase.ReadingPayload: + ReadPayload(ref buffer); + + break; + case ParsePhase.PayloadComplete: + if (!TryReadDelimiter(ref buffer, TextMessageFormatter.MessageDelimiter, ParsePhase.ReadingPayload, "payload")) + { + message = default(Message); + return false; + } + + // We're done! + message = new Message(_state.Payload, _state.MessageType); + Reset(); + return true; + default: + throw new InvalidOperationException($"Invalid parser phase: {_state.Phase}"); + } + } + + message = default(Message); + return false; + } + + private bool TryReadLength(ref BytesReader buffer) + { + // Read until the first ':' to find the length + var lengthSpan = buffer.ReadBytesUntil((byte)TextMessageFormatter.FieldDelimiter)?.ToSingleSpan(); + if (lengthSpan == null) + { + // Insufficient data + return false; + } + + // Parse the length + if (!PrimitiveParser.TryParseInt32(lengthSpan.Value, out var length, out var consumedByLength, encoder: TextEncoder.Utf8) || consumedByLength < lengthSpan.Value.Length) + { + if (TextEncoder.Utf8.TryDecode(lengthSpan.Value, out var lengthString, out _)) + { + throw new FormatException($"Invalid length: '{lengthString}'"); + } + + throw new FormatException("Invalid length"); + } + + _state.Length = length; + _state.Phase = ParsePhase.LengthComplete; + return true; + } + + private bool TryReadDelimiter(ref BytesReader buffer, char delimiter, ParsePhase nextPhase, string field) + { + if (buffer.Unread.Length == 0) + { + return false; + } + + if (buffer.Unread[0] != delimiter) + { + throw new FormatException($"Missing delimiter '{delimiter}' after {field}"); + } + buffer.Advance(1); + + _state.Phase = nextPhase; + return true; + } + + private bool TryReadType(ref BytesReader buffer) + { + if (buffer.Unread.Length == 0) + { + return false; + } + + if (!TryParseType(buffer.Unread[0], out _state.MessageType)) + { + throw new FormatException($"Unknown message type: '{(char)buffer.Unread[0]}'"); + } + + buffer.Advance(1); + _state.Phase = ParsePhase.TypeComplete; + return true; + } + + private void ReadPayload(ref BytesReader buffer) + { + if (_state.Payload == null) + { + _state.Payload = new byte[_state.Length]; + } + + if (_state.Read == _state.Length) + { + _state.Payload = DecodePayload(_state.Payload); + + _state.Phase = ParsePhase.PayloadComplete; + } + else + { + // Copy as much as possible from the Unread buffer + var toCopy = Math.Min(_state.Length, buffer.Unread.Length); + buffer.Unread.Slice(0, toCopy).CopyTo(_state.Payload.Slice(_state.Read)); + _state.Read += toCopy; + buffer.Advance(toCopy); + } + } + + private byte[] DecodePayload(byte[] inputPayload) + { + if (_state.MessageType == MessageType.Binary && inputPayload.Length > 0) + { + // Determine the output size + // Every 4 Base64 characters represents 3 bytes + var decodedLength = (inputPayload.Length / 4) * 3; + + // Subtract padding bytes + if (inputPayload[inputPayload.Length - 1] == '=') + { + decodedLength -= 1; + } + if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=') + { + decodedLength -= 1; + } + + // Allocate a new buffer to decode to + var decodeBuffer = new byte[decodedLength]; + if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength) + { + throw new FormatException("Invalid Base64 payload"); + } + return decodeBuffer; + } + + return inputPayload; + } + + private static bool TryParseType(byte type, out MessageType messageType) + { + switch ((char)type) + { + case TextMessageFormatter.TextTypeFlag: + messageType = MessageType.Text; + return true; + case TextMessageFormatter.BinaryTypeFlag: + messageType = MessageType.Binary; + return true; + case TextMessageFormatter.CloseTypeFlag: + messageType = MessageType.Close; + return true; + case TextMessageFormatter.ErrorTypeFlag: + messageType = MessageType.Error; + return true; + default: + messageType = default(MessageType); + return false; + } + } + + private struct ParserState + { + public ParsePhase Phase; + public int Length; + public MessageType MessageType; + public byte[] Payload; + public int Read; + } + + private enum ParsePhase + { + ReadingLength = 0, + LengthComplete, + ReadingType, + TypeComplete, + ReadingPayload, + PayloadComplete + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj index c9447d5fc0..6026e44f10 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj @@ -14,7 +14,8 @@ - + + diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index 58bd2c4f0b..dc8f7997d8 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -1,14 +1,17 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.Extensions.Logging; using System; using System.IO.Pipelines; +using System.IO.Pipelines.Text.Primitives; +using System.Text; +using System.Text.Formatting; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Sockets.Formatters; -using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Transports { @@ -43,39 +46,28 @@ namespace Microsoft.AspNetCore.Sockets.Transports context.Response.ContentType = MessageFormatter.GetContentType(messageFormat); var writer = context.Response.Body.AsPipelineWriter(); - var alloc = writer.Alloc(minimumSize: 1); - alloc.WriteBigEndian(MessageFormatter.GetFormatIndicator(messageFormat)); + var output = new PipelineTextOutput(writer, TextEncoder.Utf8); // We don't need the Encoder, but it's harmless to set. + + output.Append(MessageFormatter.GetFormatIndicator(messageFormat)); while (_application.TryRead(out var message)) { - var buffer = alloc.Memory.Span; - _logger.LogDebug("Writing {0} byte message to response", message.Payload.Length); - // Try to format the message - if (!MessageFormatter.TryFormatMessage(message, buffer, messageFormat, out var written)) + if (!MessageFormatter.TryWriteMessage(message, output, messageFormat)) { - // We need to expand the buffer - // REVIEW: I'm not sure I fully understand the "right" pattern here... - alloc.Ensure(MaxBufferSize); - buffer = alloc.Memory.Span; + // We ran out of space to write, even after trying to enlarge. + // This should only happen in a significant lack-of-memory scenario. - // Try one more time - if (!MessageFormatter.TryFormatMessage(message, buffer, messageFormat, out written)) - { - // Message too large - throw new InvalidOperationException($"Message is too large to write. Maximum allowed message size is: {MaxBufferSize}"); - } + // IOutput doesn't really have a way to write incremental + + // Throwing InvalidOperationException here, but it's not quite an invalid operation... + throw new InvalidOperationException("Ran out of space to format messages!"); } - // Update the buffer and commit - alloc.Advance(written); - alloc.Commit(); - alloc = writer.Alloc(); - buffer = alloc.Memory.Span; + // REVIEW: Flushing after each message? Good? Bad? We can't access Commit because it's hidden inside PipelineTextOutput + await output.FlushAsync(); } - - await alloc.FlushAsync(); } catch (OperationCanceledException) { diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs index 7ebacb6e56..8c7abff349 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs @@ -3,11 +3,13 @@ using System; using System.IO.Pipelines; +using System.IO.Pipelines.Text.Primitives; +using System.Text; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Sockets.Formatters; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Transports @@ -33,33 +35,27 @@ namespace Microsoft.AspNetCore.Sockets.Transports await context.Response.Body.FlushAsync(); var pipe = context.Response.Body.AsPipelineWriter(); + var output = new PipelineTextOutput(pipe, TextEncoder.Utf8); // We don't need the Encoder, but it's harmless to set. try { while (await _application.WaitToReadAsync(token)) { - var buffer = pipe.Alloc(); while (_application.TryRead(out var message)) { - if (!ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer.Memory.Span, out var written)) + if (!ServerSentEventsMessageFormatter.TryWriteMessage(message, output)) { - // We need to expand the buffer - // REVIEW: I'm not sure I fully understand the "right" pattern here... - buffer.Ensure(LongPollingTransport.MaxBufferSize); + // We ran out of space to write, even after trying to enlarge. + // This should only happen in a significant lack-of-memory scenario. - // Try one more time - if (!ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer.Memory.Span, out written)) - { - // Message too large - throw new InvalidOperationException($"Message is too large to write. Maximum allowed message size is: {LongPollingTransport.MaxBufferSize}"); - } + // IOutput doesn't really have a way to write incremental + + // Throwing InvalidOperationException here, but it's not quite an invalid operation... + throw new InvalidOperationException("Ran out of space to format messages!"); } - buffer.Advance(written); - buffer.Commit(); - buffer = pipe.Alloc(); - } - await buffer.FlushAsync(); + await output.FlushAsync(); + } } } catch (OperationCanceledException) diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/ByteArrayExtensions.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/ByteArrayExtensions.cs new file mode 100644 index 0000000000..5e5ad459f4 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/ByteArrayExtensions.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Buffers; +using System.Collections.Generic; +using System.Collections.Sequences; + +namespace System +{ + internal static class ByteArrayExtensions + { + public static ReadOnlyBytes ToChunkedReadOnlyBytes(this byte[] data, int chunkSize) + { + var chunks = new List(); + for (var i = 0; i < data.Length; i += chunkSize) + { + var thisChunkSize = Math.Min(chunkSize, data.Length - i); + var chunk = new byte[thisChunkSize]; + for (var j = 0; j < thisChunkSize; j++) + { + chunk[j] = data[i + j]; + } + chunks.Add(chunk); + } + + chunks.Reverse(); + + ReadOnlyBytes? bytes = null; + foreach (var chunk in chunks) + { + if (bytes == null) + { + bytes = new ReadOnlyBytes(chunk); + } + else + { + bytes = new ReadOnlyBytes(chunk, bytes); + } + } + return bytes.Value; + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/BinaryMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/BinaryMessageFormatterTests.cs deleted file mode 100644 index 07830a33c0..0000000000 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/BinaryMessageFormatterTests.cs +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.IO.Pipelines; -using Microsoft.AspNetCore.Sockets.Tests; -using Xunit; - -namespace Microsoft.AspNetCore.Sockets.Formatters.Tests -{ - public partial class BinaryMessageFormatterTests - { - [Fact] - public void WriteMultipleMessages() - { - var expectedEncoding = new byte[] - { - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - /* type: */ 0x01, // Binary - /* body: */ - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, - /* type: */ 0x00, // Text - /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - /* type: */ 0x03, // Close - /* body: */ 0x41, - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, - /* type: */ 0x02, // Error - /* body: */ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 - }; - - var messages = new[] - { - MessageTestUtils.CreateMessage(new byte[0]), - MessageTestUtils.CreateMessage("Hello,\r\nWorld!",MessageType.Text), - MessageTestUtils.CreateMessage("A", MessageType.Close), - MessageTestUtils.CreateMessage("Server Error", MessageType.Error) - }; - - var array = new byte[256]; - var buffer = array.Slice(); - var totalConsumed = 0; - foreach (var message in messages) - { - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Binary, out var consumed)); - buffer = buffer.Slice(consumed); - totalConsumed += consumed; - } - - Assert.Equal(expectedEncoding, array.Slice(0, totalConsumed).ToArray()); - } - - [Theory] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - public void WriteBinaryMessage(byte[] encoded, byte[] payload) - { - var message = MessageTestUtils.CreateMessage(payload); - var buffer = new byte[256]; - - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Binary, out var bytesWritten)); - - var encodedSpan = buffer.Slice(0, bytesWritten); - Assert.Equal(encoded, encodedSpan.ToArray()); - } - - [Theory] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x41, 0x42, 0x43 }, MessageType.Text, "ABC")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0B, 0x00, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, MessageType.Text, "A\nR\rC\r\n;DEF")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 }, MessageType.Close, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x03, 0x43, 0x6F, 0x6E, 0x6E, 0x65, 0x63, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x64 }, MessageType.Close, "Connection Closed")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02 }, MessageType.Error, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x02, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 }, MessageType.Error, "Server Error")] - public void WriteTextMessage(byte[] encoded, MessageType messageType, string payload) - { - var message = MessageTestUtils.CreateMessage(payload, messageType); - var buffer = new byte[256]; - - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Binary, out var bytesWritten)); - - var encodedSpan = buffer.Slice(0, bytesWritten); - Assert.Equal(encoded, encodedSpan.ToArray()); - } - - [Fact] - public void WriteInvalidMessages() - { - var message = new Message(new byte[0], MessageType.Binary, endOfMessage: false); - var ex = Assert.Throws(() => - MessageFormatter.TryFormatMessage(message, Span.Empty, MessageFormat.Binary, out var written)); - Assert.Equal($"Cannot format message where endOfMessage is false using this format{Environment.NewLine}Parameter name: message", ex.Message); - Assert.Equal("message", ex.ParamName); - } - - [Theory] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x41, 0x42, 0x43 }, MessageType.Text, "ABC")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0B, 0x00, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, MessageType.Text, "A\nR\rC\r\n;DEF")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 }, MessageType.Close, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x03, 0x43, 0x6F, 0x6E, 0x6E, 0x65, 0x63, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x64 }, MessageType.Close, "Connection Closed")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02 }, MessageType.Error, "")] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x02, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 }, MessageType.Error, "Server Error")] - public void ReadTextMessage(byte[] encoded, MessageType messageType, string payload) - { - Assert.True(MessageFormatter.TryParseMessage(encoded, MessageFormat.Binary, out var message, out var consumed)); - Assert.Equal(consumed, encoded.Length); - - MessageTestUtils.AssertMessage(message, messageType, payload); - } - - [Theory] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - public void ReadBinaryMessage(byte[] encoded, byte[] payload) - { - Assert.True(MessageFormatter.TryParseMessage(encoded, MessageFormat.Binary, out var message, out var consumed)); - Assert.Equal(consumed, encoded.Length); - - MessageTestUtils.AssertMessage(message, MessageType.Binary, payload); - } - - [Fact] - public void ReadMultipleMessages() - { - var encoded = new byte[] - { - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - /* type: */ 0x01, // Binary - /* body: */ - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, - /* type: */ 0x00, // Text - /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - /* type: */ 0x03, // Close - /* body: */ 0x41, - /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, - /* type: */ 0x02, // Error - /* body: */ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 - }; - var buffer = encoded.Slice(); - - var messages = new List(); - var consumedTotal = 0; - while (MessageFormatter.TryParseMessage(buffer, MessageFormat.Binary, out var message, out var consumed)) - { - messages.Add(message); - consumedTotal += consumed; - buffer = buffer.Slice(consumed); - } - - Assert.Equal(consumedTotal, encoded.Length); - - Assert.Equal(4, messages.Count); - MessageTestUtils.AssertMessage(messages[0], MessageType.Binary, new byte[0]); - MessageTestUtils.AssertMessage(messages[1], MessageType.Text, "Hello,\r\nWorld!"); - MessageTestUtils.AssertMessage(messages[2], MessageType.Close, "A"); - MessageTestUtils.AssertMessage(messages[3], MessageType.Error, "Server Error"); - } - - [Theory] - [InlineData(new byte[0])] // Empty - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 })] // Just length - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00 })] // Not enough data for payload - [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04 })] // Invalid Type - public void ReadInvalidMessages(byte[] encoded) - { - Assert.False(MessageFormatter.TryParseMessage(encoded, MessageFormat.Binary, out var message, out var consumed)); - Assert.Equal(0, consumed); - } - - [Fact] - public void InsufficientWriteBufferSpace() - { - const int ExpectedSize = 13; - var message = MessageTestUtils.CreateMessage("Test", MessageType.Text); - - byte[] buffer; - int bufferSize; - int written; - for (bufferSize = 0; bufferSize < 13; bufferSize++) - { - buffer = new byte[bufferSize]; - Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Binary, out written)); - Assert.Equal(0, written); - } - - buffer = new byte[bufferSize]; - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Binary, out written)); - Assert.Equal(ExpectedSize, written); - } - } -} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs deleted file mode 100644 index becc3479d6..0000000000 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs +++ /dev/null @@ -1,206 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Text; -using Microsoft.AspNetCore.Sockets.Tests; -using Xunit; - -namespace Microsoft.AspNetCore.Sockets.Formatters.Tests -{ - public class TextMessageFormatterTests - { - [Fact] - public void WriteMultipleMessages() - { - const string expectedEncoding = "0:B:;14:T:Hello,\r\nWorld!;1:C:A;12:E:Server Error;"; - var messages = new[] - { - MessageTestUtils.CreateMessage(new byte[0]), - MessageTestUtils.CreateMessage("Hello,\r\nWorld!",MessageType.Text), - MessageTestUtils.CreateMessage("A", MessageType.Close), - MessageTestUtils.CreateMessage("Server Error", MessageType.Error) - }; - - var array = new byte[256]; - var buffer = array.Slice(); - var totalConsumed = 0; - foreach (var message in messages) - { - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out var consumed)); - buffer = buffer.Slice(consumed); - totalConsumed += consumed; - } - - Assert.Equal(expectedEncoding, Encoding.UTF8.GetString(array, 0, totalConsumed)); - } - - [Theory] - [InlineData("0:B:;", new byte[0])] - [InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - [InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] - [InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] - public void WriteBinaryMessage(string encoded, byte[] payload) - { - var message = MessageTestUtils.CreateMessage(payload); - var buffer = new byte[256]; - - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out var bytesWritten)); - - var encodedSpan = buffer.Slice(0, bytesWritten); - Assert.Equal(encoded, Encoding.UTF8.GetString(encodedSpan.ToArray())); - } - - [Theory] - [InlineData("0:T:;", MessageType.Text, "")] - [InlineData("3:T:ABC;", MessageType.Text, "ABC")] - [InlineData("11:T:A\nR\rC\r\n;DEF;", MessageType.Text, "A\nR\rC\r\n;DEF")] - [InlineData("0:C:;", MessageType.Close, "")] - [InlineData("17:C:Connection Closed;", MessageType.Close, "Connection Closed")] - [InlineData("0:E:;", MessageType.Error, "")] - [InlineData("12:E:Server Error;", MessageType.Error, "Server Error")] - public void WriteTextMessage(string encoded, MessageType messageType, string payload) - { - var message = MessageTestUtils.CreateMessage(payload, messageType); - var buffer = new byte[256]; - - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out var bytesWritten)); - - var encodedSpan = buffer.Slice(0, bytesWritten); - Assert.Equal(encoded, Encoding.UTF8.GetString(encodedSpan.ToArray())); - } - - [Fact] - public void WriteInvalidMessages() - { - var message = new Message(new byte[0], MessageType.Binary, endOfMessage: false); - var ex = Assert.Throws(() => - MessageFormatter.TryFormatMessage(message, Span.Empty, MessageFormat.Text, out var written)); - Assert.Equal($"Cannot format message where endOfMessage is false using this format{Environment.NewLine}Parameter name: message", ex.Message); - Assert.Equal("message", ex.ParamName); - } - - [Theory] - [InlineData("0:T:;", MessageType.Text, "")] - [InlineData("3:T:ABC;", MessageType.Text, "ABC")] - [InlineData("11:T:A\nR\rC\r\n;DEF;", MessageType.Text, "A\nR\rC\r\n;DEF")] - [InlineData("0:C:;", MessageType.Close, "")] - [InlineData("17:C:Connection Closed;", MessageType.Close, "Connection Closed")] - [InlineData("0:E:;", MessageType.Error, "")] - [InlineData("12:E:Server Error;", MessageType.Error, "Server Error")] - public void ReadTextMessage(string encoded, MessageType messageType, string payload) - { - var buffer = Encoding.UTF8.GetBytes(encoded); - - Assert.True(MessageFormatter.TryParseMessage(buffer, MessageFormat.Text, out var message, out var consumed)); - Assert.Equal(consumed, buffer.Length); - - MessageTestUtils.AssertMessage(message, messageType, payload); - } - - [Theory] - [InlineData("0:B:;", new byte[0])] - [InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - [InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] - [InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] - public void ReadBinaryMessage(string encoded, byte[] payload) - { - var buffer = Encoding.UTF8.GetBytes(encoded); - - Assert.True(MessageFormatter.TryParseMessage(buffer, MessageFormat.Text, out var message, out var consumed)); - Assert.Equal(consumed, buffer.Length); - - MessageTestUtils.AssertMessage(message, MessageType.Binary, payload); - } - - [Fact] - public void ReadMultipleMessages() - { - const string encoded = "0:B:;14:T:Hello,\r\nWorld!;1:C:A;12:E:Server Error;"; - var buffer = (Span)Encoding.UTF8.GetBytes(encoded); - - var messages = new List(); - var consumedTotal = 0; - while (MessageFormatter.TryParseMessage(buffer, MessageFormat.Text, out var message, out var consumed)) - { - messages.Add(message); - consumedTotal += consumed; - buffer = buffer.Slice(consumed); - } - - Assert.Equal(consumedTotal, Encoding.UTF8.GetByteCount(encoded)); - - Assert.Equal(4, messages.Count); - MessageTestUtils.AssertMessage(messages[0], MessageType.Binary, new byte[0]); - MessageTestUtils.AssertMessage(messages[1], MessageType.Text, "Hello,\r\nWorld!"); - MessageTestUtils.AssertMessage(messages[2], MessageType.Close, "A"); - MessageTestUtils.AssertMessage(messages[3], MessageType.Error, "Server Error"); - } - - [Theory] - [InlineData("")] - [InlineData("ABC")] - [InlineData("1230450945")] - [InlineData("12ab34:")] - [InlineData("1:asdf")] - [InlineData("1::")] - [InlineData("1:AB:")] - [InlineData("5:T:A")] - [InlineData("5:T:ABCDE")] - [InlineData("5:T:ABCDEF")] - [InlineData("5:X:ABCDEF")] - [InlineData("1029348109238412903849023841290834901283409128349018239048102394:X:ABCDEF")] - public void ReadInvalidMessages(string encoded) - { - var buffer = Encoding.UTF8.GetBytes(encoded); - Assert.False(MessageFormatter.TryParseMessage(buffer, MessageFormat.Text, out var message, out var consumed)); - Assert.Equal(0, consumed); - } - - [Theory] - [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] - [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] - public void InsufficientWriteBufferSpaceBinary(byte[] payload) - { - const int ExpectedSize = 13; - var message = MessageTestUtils.CreateMessage(payload); - - byte[] buffer; - int bufferSize; - int written; - for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++) - { - buffer = new byte[bufferSize]; - Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); - Assert.Equal(0, written); - } - - buffer = new byte[bufferSize]; - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); - Assert.Equal(ExpectedSize, written); - } - - [Fact] - public void InsufficientWriteBufferSpaceText() - { - const int ExpectedSize = 9; - var message = MessageTestUtils.CreateMessage("Test", MessageType.Text); - - byte[] buffer; - int bufferSize; - int written; - for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++) - { - buffer = new byte[bufferSize]; - Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); - Assert.Equal(0, written); - } - - buffer = new byte[bufferSize]; - Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); - Assert.Equal(ExpectedSize, written); - } - } -} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs new file mode 100644 index 0000000000..8b851a4486 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ArrayOutput.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters +{ + internal class ArrayOutput : IOutput + { + private IList> _buffers = new List>(); + + private int _chunkSize; + private byte[] _activeBuffer; + private int _offset; + + public Span Buffer => _activeBuffer.Slice(_offset); + + public ArrayOutput(int chunkSize) + { + _chunkSize = chunkSize; + AdvanceChunk(); + } + + public void Advance(int bytes) + { + // Determine the new location + _offset += bytes; + Debug.Assert(_offset <= _activeBuffer.Length, "How did we write more data than we had space?"); + } + + public void Enlarge(int desiredBufferLength = 0) + { + if (desiredBufferLength == 0 || _activeBuffer.Length - _offset < desiredBufferLength) + { + AdvanceChunk(); + } + } + + public byte[] ToArray() + { + var totalLength = _buffers.Sum(b => b.Count) + _offset; + + var arr = new byte[totalLength]; + + int offset = 0; + foreach (var buffer in _buffers) + { + System.Buffer.BlockCopy(buffer.Array, 0, arr, offset, buffer.Count); + offset += buffer.Count; + } + + if (_offset > 0) + { + System.Buffer.BlockCopy(_activeBuffer, 0, arr, offset, _offset); + } + + return arr; + } + + private void AdvanceChunk() + { + if (_activeBuffer != null) + { + _buffers.Add(new ArraySegment(_activeBuffer, 0, _offset)); + } + + _activeBuffer = new byte[_chunkSize]; + _offset = 0; + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs new file mode 100644 index 0000000000..8e833c388f --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters +{ + public partial class BinaryMessageFormatterTests + { + [Fact] + public void WriteMultipleMessages() + { + var expectedEncoding = new byte[] + { + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + /* type: */ 0x01, // Binary + /* body: */ + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, + /* type: */ 0x00, // Text + /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + /* type: */ 0x03, // Close + /* body: */ 0x41, + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, + /* type: */ 0x02, // Error + /* body: */ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 + }; + + var messages = new[] + { + MessageTestUtils.CreateMessage(new byte[0]), + MessageTestUtils.CreateMessage("Hello,\r\nWorld!",MessageType.Text), + MessageTestUtils.CreateMessage("A", MessageType.Close), + MessageTestUtils.CreateMessage("Server Error", MessageType.Error) + }; + + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + foreach (var message in messages) + { + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Binary)); + } + + Assert.Equal(expectedEncoding, output.ToArray()); + } + + [Theory] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData(4, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] + [InlineData(4, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData(0, 256, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] + [InlineData(0, 256, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + public void WriteBinaryMessage(int offset, int chunkSize, byte[] encoded, byte[] payload) + { + var message = MessageTestUtils.CreateMessage(payload); + var output = new ArrayOutput(chunkSize); + + if (offset > 0) + { + output.Advance(offset); + } + + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Binary)); + + Assert.Equal(encoded, output.ToArray().Slice(offset).ToArray()); + } + + [Theory] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x41, 0x42, 0x43 }, MessageType.Text, "ABC")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0B, 0x00, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, MessageType.Text, "A\nR\rC\r\n;DEF")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 }, MessageType.Close, "")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x03, 0x43, 0x6F, 0x6E, 0x6E, 0x65, 0x63, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x64 }, MessageType.Close, "Connection Closed")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02 }, MessageType.Error, "")] + [InlineData(0, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x02, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 }, MessageType.Error, "Server Error")] + [InlineData(4, 8, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] + [InlineData(0, 256, new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] + public void WriteTextMessage(int offset, int chunkSize, byte[] encoded, MessageType messageType, string payload) + { + var message = MessageTestUtils.CreateMessage(payload, messageType); + var output = new ArrayOutput(chunkSize); + + if (offset > 0) + { + output.Advance(offset); + } + + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Binary)); + + Assert.Equal(encoded, output.ToArray().Slice(offset).ToArray()); + } + + [Fact] + public void WriteInvalidMessages() + { + var message = new Message(new byte[0], MessageType.Binary, endOfMessage: false); + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + var ex = Assert.Throws(() => + MessageFormatter.TryWriteMessage(message, output, MessageFormat.Binary)); + Assert.Equal($"Cannot format message where endOfMessage is false using this format{Environment.NewLine}Parameter name: message", ex.Message); + Assert.Equal("message", ex.ParamName); + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs new file mode 100644 index 0000000000..8909729571 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/BinaryMessageParserTests.cs @@ -0,0 +1,111 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.AspNetCore.Sockets.Tests; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters +{ + public class BinaryMessageParserTests + { + [Theory] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, MessageType.Text, "")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x41, 0x42, 0x43 }, MessageType.Text, "ABC")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0B, 0x00, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, MessageType.Text, "A\nR\rC\r\n;DEF")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03 }, MessageType.Close, "")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x03, 0x43, 0x6F, 0x6E, 0x6E, 0x65, 0x63, 0x74, 0x69, 0x6F, 0x6E, 0x20, 0x43, 0x6C, 0x6F, 0x73, 0x65, 0x64 }, MessageType.Close, "Connection Closed")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02 }, MessageType.Error, "")] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, 0x02, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 }, MessageType.Error, "Server Error")] + public void ReadTextMessage(byte[] encoded, MessageType messageType, string payload) + { + var parser = new MessageParser(); + var reader = new BytesReader(encoded); + Assert.True(parser.TryParseMessage(ref reader, MessageFormat.Binary, out var message)); + Assert.Equal(reader.Index, encoded.Length); + + MessageTestUtils.AssertMessage(message, messageType, payload); + } + + [Theory] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01 }, new byte[0])] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x01, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + public void ReadBinaryMessage(byte[] encoded, byte[] payload) + { + var parser = new MessageParser(); + var reader = new BytesReader(encoded); + Assert.True(parser.TryParseMessage(ref reader, MessageFormat.Binary, out var message)); + Assert.Equal(reader.Index, encoded.Length); + + MessageTestUtils.AssertMessage(message, MessageType.Binary, payload); + } + + [Theory] + [InlineData(0)] // No chunking + [InlineData(4)] + [InlineData(8)] + [InlineData(256)] + public void ReadMultipleMessages(int chunkSize) + { + var encoded = new byte[] + { + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + /* type: */ 0x01, // Binary + /* body: */ + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0E, + /* type: */ 0x00, // Text + /* body: */ 0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x2C, 0x0D, 0x0A, 0x57, 0x6F, 0x72, 0x6C, 0x64, 0x21, + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + /* type: */ 0x03, // Close + /* body: */ 0x41, + /* length: */ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0C, + /* type: */ 0x02, // Error + /* body: */ 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6F, 0x72 + }; + var parser = new MessageParser(); + var buffer = chunkSize > 0 ? + encoded.ToChunkedReadOnlyBytes(chunkSize) : + new ReadOnlyBytes(encoded); + var reader = new BytesReader(buffer); + + var messages = new List(); + while (parser.TryParseMessage(ref reader, MessageFormat.Binary, out var message)) + { + messages.Add(message); + } + + Assert.Equal(encoded.Length, reader.Index); + + Assert.Equal(4, messages.Count); + MessageTestUtils.AssertMessage(messages[0], MessageType.Binary, new byte[0]); + MessageTestUtils.AssertMessage(messages[1], MessageType.Text, "Hello,\r\nWorld!"); + MessageTestUtils.AssertMessage(messages[2], MessageType.Close, "A"); + MessageTestUtils.AssertMessage(messages[3], MessageType.Error, "Server Error"); + } + + [Theory] + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04 }, "Unknown type value: 0x4")] // Invalid Type + public void ReadInvalidMessages(byte[] encoded, string message) + { + var parser = new MessageParser(); + var reader = new BytesReader(new ReadOnlyBytes(encoded)); + var ex = Assert.Throws(() => parser.TryParseMessage(ref reader, MessageFormat.Binary, out _)); + Assert.Equal(message, ex.Message); + } + + [Theory] + [InlineData(new byte[0])] // Empty + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 })] // Just length + [InlineData(new byte[] { 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x00, 0x00, 0x00 })] // Not enough data for payload + public void ReadIncompleteMessages(byte[] encoded) + { + var parser = new MessageParser(); + var reader = new BytesReader(new ReadOnlyBytes(encoded)); + Assert.False(parser.TryParseMessage(ref reader, MessageFormat.Binary, out var message)); + Assert.Equal(encoded.Length, reader.Index); + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsMessageFormatterTests.cs similarity index 66% rename from test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs rename to test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsMessageFormatterTests.cs index 8bc24921da..3333c3f833 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsMessageFormatterTests.cs @@ -3,40 +3,20 @@ using System; using System.Text; -using Microsoft.AspNetCore.Sockets.Tests; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Xunit; -namespace Microsoft.AspNetCore.Sockets.Formatters.Tests +namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters { public class ServerSentEventsMessageFormatterTests { - [Fact] - public void InsufficientWriteBufferSpace() - { - const int ExpectedSize = 23; - var message = MessageTestUtils.CreateMessage("Test", MessageType.Text); - - byte[] buffer; - int bufferSize; - int written; - for (bufferSize = 0; bufferSize < 23; bufferSize++) - { - buffer = new byte[bufferSize]; - Assert.False(ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer, out written)); - Assert.Equal(0, written); - } - - buffer = new byte[bufferSize]; - Assert.True(ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer, out written)); - Assert.Equal(ExpectedSize, written); - } - [Fact] public void WriteInvalidMessages() { var message = new Message(new byte[0], MessageType.Binary, endOfMessage: false); + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing var ex = Assert.Throws(() => - ServerSentEventsMessageFormatter.TryFormatMessage(message, Span.Empty, out var written)); + ServerSentEventsMessageFormatter.TryWriteMessage(message, output)); Assert.Equal("Cannot format message where endOfMessage is false using this format", ex.Message); } @@ -63,10 +43,10 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests { var message = MessageTestUtils.CreateMessage(payload, messageType); - var buffer = new byte[256]; - Assert.True(ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer, out var written)); + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + Assert.True(ServerSentEventsMessageFormatter.TryWriteMessage(message, output)); - Assert.Equal(encoded, Encoding.UTF8.GetString(buffer, 0, written)); + Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); } [Theory] @@ -76,10 +56,10 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests { var message = MessageTestUtils.CreateMessage(payload); - var buffer = new byte[256]; - Assert.True(ServerSentEventsMessageFormatter.TryFormatMessage(message, buffer, out var written)); + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + Assert.True(ServerSentEventsMessageFormatter.TryWriteMessage(message, output)); - Assert.Equal(encoded, Encoding.UTF8.GetString(buffer, 0, written)); + Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs new file mode 100644 index 0000000000..dc5c1b0e05 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageFormatterTests.cs @@ -0,0 +1,81 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Tests.Internal.Formatters +{ + public class TextMessageFormatterTests + { + [Fact] + public void WriteMultipleMessages() + { + const string expectedEncoding = "0:B:;14:T:Hello,\r\nWorld!;1:C:A;12:E:Server Error;"; + var messages = new[] + { + MessageTestUtils.CreateMessage(new byte[0]), + MessageTestUtils.CreateMessage("Hello,\r\nWorld!",MessageType.Text), + MessageTestUtils.CreateMessage("A", MessageType.Close), + MessageTestUtils.CreateMessage("Server Error", MessageType.Error) + }; + + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + foreach (var message in messages) + { + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Text)); + } + + Assert.Equal(expectedEncoding, Encoding.UTF8.GetString(output.ToArray())); + } + + [Theory] + [InlineData(8, "0:B:;", new byte[0])] + [InlineData(8, "8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData(8, "8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] + [InlineData(8, "8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] + [InlineData(256, "8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] + public void WriteBinaryMessage(int chunkSize, string encoded, byte[] payload) + { + var message = MessageTestUtils.CreateMessage(payload); + var output = new ArrayOutput(chunkSize); + + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Text)); + + Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); + } + + [Theory] + [InlineData(8, "0:T:;", MessageType.Text, "")] + [InlineData(8, "3:T:ABC;", MessageType.Text, "ABC")] + [InlineData(8, "11:T:A\nR\rC\r\n;DEF;", MessageType.Text, "A\nR\rC\r\n;DEF")] + [InlineData(8, "0:C:;", MessageType.Close, "")] + [InlineData(8, "17:C:Connection Closed;", MessageType.Close, "Connection Closed")] + [InlineData(8, "0:E:;", MessageType.Error, "")] + [InlineData(8, "12:E:Server Error;", MessageType.Error, "Server Error")] + [InlineData(256, "11:T:A\nR\rC\r\n;DEF;", MessageType.Text, "A\nR\rC\r\n;DEF")] + public void WriteTextMessage(int chunkSize, string encoded, MessageType messageType, string payload) + { + var message = MessageTestUtils.CreateMessage(payload, messageType); + var output = new ArrayOutput(chunkSize); // Use small chunks to test Advance/Enlarge and partial payload writing + + Assert.True(MessageFormatter.TryWriteMessage(message, output, MessageFormat.Text)); + + Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); + } + + [Fact] + public void WriteInvalidMessages() + { + var message = new Message(new byte[0], MessageType.Binary, endOfMessage: false); + var output = new ArrayOutput(chunkSize: 8); // Use small chunks to test Advance/Enlarge and partial payload writing + var ex = Assert.Throws(() => + MessageFormatter.TryWriteMessage(message, output, MessageFormat.Text)); + Assert.Equal($"Cannot format message where endOfMessage is false using this format{Environment.NewLine}Parameter name: message", ex.Message); + Assert.Equal("message", ex.ParamName); + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageParserTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageParserTests.cs new file mode 100644 index 0000000000..330bafec91 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/TextMessageParserTests.cs @@ -0,0 +1,131 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.Sockets.Internal.Formatters; +using Microsoft.AspNetCore.Sockets.Tests; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters +{ + public class TextMessageParserTests + { + [Theory] + [InlineData("0:T:;", MessageType.Text, "")] + [InlineData("3:T:ABC;", MessageType.Text, "ABC")] + [InlineData("11:T:A\nR\rC\r\n;DEF;", MessageType.Text, "A\nR\rC\r\n;DEF")] + [InlineData("0:C:;", MessageType.Close, "")] + [InlineData("17:C:Connection Closed;", MessageType.Close, "Connection Closed")] + [InlineData("0:E:;", MessageType.Error, "")] + [InlineData("12:E:Server Error;", MessageType.Error, "Server Error")] + public void ReadTextMessage(string encoded, MessageType messageType, string payload) + { + var parser = new MessageParser(); + var buffer = Encoding.UTF8.GetBytes(encoded); + var reader = new BytesReader(buffer); + + Assert.True(parser.TryParseMessage(ref reader, MessageFormat.Text, out var message)); + Assert.Equal(reader.Index, buffer.Length); + + MessageTestUtils.AssertMessage(message, messageType, payload); + } + + [Theory] + [InlineData("0:B:;", new byte[0])] + [InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] + [InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] + public void ReadBinaryMessage(string encoded, byte[] payload) + { + var parser = new MessageParser(); + var buffer = Encoding.UTF8.GetBytes(encoded); + var reader = new BytesReader(buffer); + + Assert.True(parser.TryParseMessage(ref reader, MessageFormat.Text, out var message)); + Assert.Equal(reader.Index, buffer.Length); + + MessageTestUtils.AssertMessage(message, MessageType.Binary, payload); + } + + [Theory] + [InlineData(0)] // Not chunked + [InlineData(4)] + [InlineData(8)] + public void ReadMultipleMessages(int chunkSize) + { + const string encoded = "0:B:;14:T:Hello,\r\nWorld!;1:C:A;12:E:Server Error;"; + var parser = new MessageParser(); + var data = Encoding.UTF8.GetBytes(encoded); + var buffer = chunkSize > 0 ? + data.ToChunkedReadOnlyBytes(chunkSize) : + new ReadOnlyBytes(data); + + var reader = new BytesReader(buffer); + + var messages = new List(); + while (parser.TryParseMessage(ref reader, MessageFormat.Text, out var message)) + { + messages.Add(message); + } + + Assert.Equal(reader.Index, Encoding.UTF8.GetByteCount(encoded)); + + Assert.Equal(4, messages.Count); + MessageTestUtils.AssertMessage(messages[0], MessageType.Binary, new byte[0]); + MessageTestUtils.AssertMessage(messages[1], MessageType.Text, "Hello,\r\nWorld!"); + MessageTestUtils.AssertMessage(messages[2], MessageType.Close, "A"); + MessageTestUtils.AssertMessage(messages[3], MessageType.Error, "Server Error"); + } + + [Theory] + [InlineData("")] + [InlineData("ABC")] + [InlineData("1230450945")] + [InlineData("1:")] + [InlineData("10")] + [InlineData("5:T:A")] + [InlineData("5:T:ABCDE")] + public void ReadIncompleteMessages(string encoded) + { + var parser = new MessageParser(); + var buffer = Encoding.UTF8.GetBytes(encoded); + var reader = new BytesReader(buffer); + Assert.False(parser.TryParseMessage(ref reader, MessageFormat.Text, out _)); + } + + [Theory] + [InlineData("X:", "Invalid length: 'X'")] + [InlineData("5:X:ABCDEF", "Unknown message type: 'X'")] + [InlineData("1:asdf", "Unknown message type: 'a'")] + [InlineData("1::", "Unknown message type: ':'")] + [InlineData("1:AB:", "Unknown message type: 'A'")] + [InlineData("1:TA", "Missing delimiter ':' after type")] + [InlineData("1029348109238412903849023841290834901283409128349018239048102394:X:ABCDEF", "Invalid length: '1029348109238412903849023841290834901283409128349018239048102394'")] + [InlineData("12ab34:", "Invalid length: '12ab34'")] + [InlineData("5:T:ABCDEF", "Missing delimiter ';' after payload")] + public void ReadInvalidMessages(string encoded, string expectedMessage) + { + var parser = new MessageParser(); + var buffer = Encoding.UTF8.GetBytes(encoded); + var reader = new BytesReader(buffer); + var ex = Assert.Throws(() => parser.TryParseMessage(ref reader, MessageFormat.Text, out _)); + Assert.Equal(expectedMessage, ex.Message); + } + + [Fact] + public void ReadInvalidEncodedMessage() + { + var parser = new MessageParser(); + + // Invalid because first character is a UTF-8 "continuation" character + // We need to include the ':' so that + var buffer = new byte[] { 0x48, 0x65, 0x80, 0x6C, 0x6F, (byte)':' }; + var reader = new BytesReader(buffer); + var ex = Assert.Throws(() => parser.TryParseMessage(ref reader, MessageFormat.Text, out _)); + Assert.Equal("Invalid length", ex.Message); + } + } +}