From 74b318b3e4a8c0239d86c4071dbe272e63ead206 Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Wed, 26 Apr 2017 15:04:48 -0700 Subject: [PATCH] Support binary messages in SSE parser (#418) --- .../Internal/Formatters/MessageFormatUtils.cs | 41 ++++++++++ .../ServerSentEventsMessageParser.cs | 43 ++++++---- .../Internal/Formatters/TextMessageParser.cs | 35 +------- .../Formatters/ServerSentEventsParserTests.cs | 79 +++++++++---------- 4 files changed, 110 insertions(+), 88 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatUtils.cs diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatUtils.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatUtils.cs new file mode 100644 index 0000000000..a217680fe3 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/MessageFormatUtils.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Binary; + +namespace Microsoft.AspNetCore.Sockets.Internal.Formatters +{ + internal static class MessageFormatUtils + { + public static byte[] DecodePayload(byte[] inputPayload) + { + if (inputPayload.Length > 0) + { + // Determine the output size + // Every 4 Base64 characters represents 3 bytes + var decodedLength = (inputPayload.Length / 4) * 3; + + // Subtract padding bytes + if (inputPayload[inputPayload.Length - 1] == '=') + { + decodedLength -= 1; + } + if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=') + { + decodedLength -= 1; + } + + // Allocate a new buffer to decode to + var decodeBuffer = new byte[decodedLength]; + if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength) + { + throw new FormatException("Invalid Base64 payload"); + } + return decodeBuffer; + } + + return inputPayload; + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageParser.cs index e8e014812a..3d22d6a5df 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageParser.cs +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/ServerSentEventsMessageParser.cs @@ -32,7 +32,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters examined = buffer.End; message = new Message(); var reader = new ReadableBufferReader(buffer); - _messageType = MessageType.Text; var start = consumed; var end = examined; @@ -83,6 +82,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters EnsureStartsWithDataPrefix(line); } + var payload = Array.Empty(); switch (_internalParserState) { case InternalParseState.ReadMessageType: @@ -94,7 +94,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters consumed = lineEnd; break; case InternalParseState.ReadMessagePayload: - // Slice away the 'data: ' var payloadLength = line.Length - (_dataPrefix.Length + _sseLineEnding.Length); var newData = line.Slice(_dataPrefix.Length, payloadLength).ToArray(); @@ -104,42 +103,56 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters consumed = lineEnd; break; case InternalParseState.ReadEndOfMessage: - if (_data.Count > 0) + if (_data.Count == 1) + { + payload = _data[0]; + } + else if (_data.Count > 1) { // Find the final size of the payload var payloadSize = 0; foreach (var dataLine in _data) { - payloadSize += dataLine.Length + _newLine.Length; + payloadSize += dataLine.Length; } - // Allocate space in the paylod buffer for the data and the new lines. - // Subtract newLine length because we don't want a trailing newline. - var payload = new byte[payloadSize - _newLine.Length]; + if (_messageType != MessageType.Binary) + { + payloadSize += _newLine.Length*_data.Count; + + // Allocate space in the paylod buffer for the data and the new lines. + // Subtract newLine length because we don't want a trailing newline. + payload = new byte[payloadSize - _newLine.Length]; + } + else + { + payload = new byte[payloadSize]; + } var offset = 0; foreach (var dataLine in _data) { dataLine.CopyTo(payload, offset); offset += dataLine.Length; - if (offset < payload.Length) + if (offset < payload.Length && _messageType != MessageType.Binary) { _newLine.CopyTo(payload, offset); offset += _newLine.Length; } } - message = new Message(payload, _messageType); - } - else - { - // Empty message - message = new Message(Array.Empty(), _messageType); } + if (_messageType == MessageType.Binary) + { + payload = MessageFormatUtils.DecodePayload(payload); + } + + message = new Message(payload, _messageType); consumed = lineEnd; examined = consumed; return ParseResult.Completed; } + if (reader.Peek() == ByteCR) { _internalParserState = InternalParseState.ReadEndOfMessage; @@ -188,7 +201,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters case ByteT: return MessageType.Text; case ByteB: - throw new NotSupportedException("Support for binary messages has not been implemented yet"); + return MessageType.Binary; case ByteC: return MessageType.Close; case ByteE: diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs index bd48403f67..1bd033aac4 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs +++ b/src/Microsoft.AspNetCore.Sockets.Common/Internal/Formatters/TextMessageParser.cs @@ -155,7 +155,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters if (_state.Read == _state.Length) { - _state.Payload = DecodePayload(_state.Payload); + if (_state.MessageType == MessageType.Binary) + { + _state.Payload = MessageFormatUtils.DecodePayload(_state.Payload); + } _state.Phase = ParsePhase.PayloadComplete; } @@ -169,36 +172,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters } } - private byte[] DecodePayload(byte[] inputPayload) - { - if (_state.MessageType == MessageType.Binary && inputPayload.Length > 0) - { - // Determine the output size - // Every 4 Base64 characters represents 3 bytes - var decodedLength = (inputPayload.Length / 4) * 3; - - // Subtract padding bytes - if (inputPayload[inputPayload.Length - 1] == '=') - { - decodedLength -= 1; - } - if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=') - { - decodedLength -= 1; - } - - // Allocate a new buffer to decode to - var decodeBuffer = new byte[decodedLength]; - if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength) - { - throw new FormatException("Invalid Base64 payload"); - } - return decodeBuffer; - } - - return inputPayload; - } - private static bool TryParseType(byte type, out MessageType messageType) { switch ((char)type) diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsParserTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsParserTests.cs index 7a0b4dd65f..2a6f3e1788 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsParserTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Internal/Formatters/ServerSentEventsParserTests.cs @@ -14,13 +14,16 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters public class ServerSentEventsParserTests { [Theory] - [InlineData("data: T\r\n\r\n", "")] - [InlineData("data: T\r\ndata: \r\r\n\r\n", "\r")] - [InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB")] - [InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World")] - public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage) + [InlineData("data: T\r\n\r\n", "", MessageType.Text)] + [InlineData("data: B\r\n\r\n", "", MessageType.Binary)] + [InlineData("data: T\r\ndata: \r\r\n\r\n", "\r", MessageType.Text)] + [InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB", MessageType.Text)] + [InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World", MessageType.Text)] + [InlineData("data: B\r\ndata: SGVsbG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)] + [InlineData("data: B\r\ndata: SGVsbG8g\r\ndata: V29ybGQ=\r\n\r\n", "Hello World", MessageType.Binary)] + public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage, MessageType messageType) { var buffer = Encoding.UTF8.GetBytes(encodedMessage); var readableBuffer = ReadableBuffer.Create(buffer); @@ -28,7 +31,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message); Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult); - Assert.Equal(MessageType.Text, message.Type); + Assert.Equal(messageType, message.Type); Assert.Equal(consumed, examined); var result = Encoding.UTF8.GetString(message.Payload); @@ -52,6 +55,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters [InlineData("data: T\r\ndata: Hello, World\r\n\r\\", "Expected a \\r\\n frame ending")] [InlineData("data: T\r\ndata: Major\r\ndata: Key\rndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")] [InlineData("data: T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")] + [InlineData("data: B\r\n SGVsbG8sIFdvcmxk\r\n\r\n", "Expected the message prefix 'data: '")] public void ParseSSEMessageFailureCases(string encodedMessage, string expectedExceptionMessage) { var buffer = Encoding.UTF8.GetBytes(encodedMessage); @@ -72,6 +76,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters [InlineData("data: T\r\ndata: Hello, World\r")] [InlineData("data: T\r\ndata: Hello, World\r\n")] [InlineData("data: T\r\ndata: Hello, World\r\n\r")] + [InlineData("data: B\r\ndata: SGVsbG8sIFd")] public void ParseSSEMessageIncompleteParseResult(string encodedMessage) { var buffer = Encoding.UTF8.GetBytes(encodedMessage); @@ -84,17 +89,18 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters } [Theory] - [InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World")] - [InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World")] - [InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World")] - public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage) + [InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)] + [InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)] + public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage, MessageType expectedMessageType) { using (var pipeFactory = new PipeFactory()) { @@ -117,7 +123,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message); Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult); - Assert.Equal(MessageType.Text, message.Type); + Assert.Equal(expectedMessageType, message.Type); Assert.Equal(consumed, examined); var resultMessage = Encoding.UTF8.GetString(message.Payload); @@ -140,6 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters [InlineData("data: T\r\nda", "ta: Hello\n, World\r\n\r\n", "Unexpected '\n' in message. A '\n' character can only be used as part of the newline sequence '\r\n'")] [InlineData("data:", " data: \r\n", "Unknown message type: 'd'")] [InlineData("data: ", "T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")] + [InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\n\n", "There was an error in the frame format")] public async Task ParseMessageAcrossMultipleReadsFailure(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage) { using (var pipeFactory = new PipeFactory()) @@ -167,15 +174,15 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters } } - [Fact] - public async Task ParseMultipleMessages() + [Theory] + [InlineData("data: T\r\ndata: foo\r\n\r\n", "data: T\r\ndata: bar\r\n\r\n", MessageType.Text)] + [InlineData("data: B\r\ndata: Zm9v\r\n\r\n", "data: B\r\ndata: YmFy\r\n\r\n", MessageType.Binary)] + public async Task ParseMultipleMessagesText(string message1, string message2, MessageType expectedMessageType) { using (var pipeFactory = new PipeFactory()) { var pipe = pipeFactory.Create(); - var message1 = "data: T\r\ndata: foo\r\n\r\n"; - var message2 = "data: T\r\ndata: bar\r\n\r\n"; // Read the first part of the message await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes(message1 + message2)); @@ -184,7 +191,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters var parseResult = parser.ParseMessage(result.Buffer, out var consumed, out var examined, out var message); Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult); - Assert.Equal(MessageType.Text, message.Type); + Assert.Equal(expectedMessageType, message.Type); Assert.Equal("foo", Encoding.UTF8.GetString(message.Payload)); Assert.Equal(consumed, result.Buffer.Move(result.Buffer.Start, message1.Length)); pipe.Reader.Advance(consumed, examined); @@ -195,7 +202,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters result = await pipe.Reader.ReadAsync(); parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message); Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult); - Assert.Equal(MessageType.Text, message.Type); + Assert.Equal(expectedMessageType, message.Type); Assert.Equal("bar", Encoding.UTF8.GetString(message.Payload)); pipe.Reader.Advance(consumed, examined); } @@ -205,14 +212,14 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters { get { - yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic" }; - yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down" }; + yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic", MessageType.Text }; + yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down", MessageType.Text }; } } [Theory] [MemberData(nameof(MultilineMessages))] - public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage) + public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage, MessageType expectedMessageType) { var buffer = Encoding.UTF8.GetBytes(encodedMessage); var readableBuffer = ReadableBuffer.Create(buffer); @@ -220,23 +227,11 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message); Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult); - Assert.Equal(MessageType.Text, message.Type); + Assert.Equal(expectedMessageType, message.Type); Assert.Equal(consumed, examined); var result = Encoding.UTF8.GetString(message.Payload); Assert.Equal(expectedMessage, result); } - - [Fact] - public void ParseSSEMessageBinaryNotSupported() - { - var encodedMessage = "data: B\r\ndata: \r\n\r\n"; - var buffer = Encoding.UTF8.GetBytes(encodedMessage); - var readableBuffer = ReadableBuffer.Create(buffer); - var parser = new ServerSentEventsMessageParser(); - - var ex = Assert.Throws(() => { parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message); }); - Assert.Equal("Support for binary messages has not been implemented yet", ex.Message); - } } }