Support binary messages in SSE parser (#418)

This commit is contained in:
Mikael Mengistu 2017-04-26 15:04:48 -07:00 committed by GitHub
parent 3006d315cc
commit 74b318b3e4
4 changed files with 110 additions and 88 deletions

View File

@ -0,0 +1,41 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Binary;
namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
{
internal static class MessageFormatUtils
{
public static byte[] DecodePayload(byte[] inputPayload)
{
if (inputPayload.Length > 0)
{
// Determine the output size
// Every 4 Base64 characters represents 3 bytes
var decodedLength = (inputPayload.Length / 4) * 3;
// Subtract padding bytes
if (inputPayload[inputPayload.Length - 1] == '=')
{
decodedLength -= 1;
}
if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=')
{
decodedLength -= 1;
}
// Allocate a new buffer to decode to
var decodeBuffer = new byte[decodedLength];
if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength)
{
throw new FormatException("Invalid Base64 payload");
}
return decodeBuffer;
}
return inputPayload;
}
}
}

View File

@ -32,7 +32,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
examined = buffer.End;
message = new Message();
var reader = new ReadableBufferReader(buffer);
_messageType = MessageType.Text;
var start = consumed;
var end = examined;
@ -83,6 +82,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
EnsureStartsWithDataPrefix(line);
}
var payload = Array.Empty<byte>();
switch (_internalParserState)
{
case InternalParseState.ReadMessageType:
@ -94,7 +94,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
consumed = lineEnd;
break;
case InternalParseState.ReadMessagePayload:
// Slice away the 'data: '
var payloadLength = line.Length - (_dataPrefix.Length + _sseLineEnding.Length);
var newData = line.Slice(_dataPrefix.Length, payloadLength).ToArray();
@ -104,42 +103,56 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
consumed = lineEnd;
break;
case InternalParseState.ReadEndOfMessage:
if (_data.Count > 0)
if (_data.Count == 1)
{
payload = _data[0];
}
else if (_data.Count > 1)
{
// Find the final size of the payload
var payloadSize = 0;
foreach (var dataLine in _data)
{
payloadSize += dataLine.Length + _newLine.Length;
payloadSize += dataLine.Length;
}
// Allocate space in the paylod buffer for the data and the new lines.
// Subtract newLine length because we don't want a trailing newline.
var payload = new byte[payloadSize - _newLine.Length];
if (_messageType != MessageType.Binary)
{
payloadSize += _newLine.Length*_data.Count;
// Allocate space in the paylod buffer for the data and the new lines.
// Subtract newLine length because we don't want a trailing newline.
payload = new byte[payloadSize - _newLine.Length];
}
else
{
payload = new byte[payloadSize];
}
var offset = 0;
foreach (var dataLine in _data)
{
dataLine.CopyTo(payload, offset);
offset += dataLine.Length;
if (offset < payload.Length)
if (offset < payload.Length && _messageType != MessageType.Binary)
{
_newLine.CopyTo(payload, offset);
offset += _newLine.Length;
}
}
message = new Message(payload, _messageType);
}
else
{
// Empty message
message = new Message(Array.Empty<byte>(), _messageType);
}
if (_messageType == MessageType.Binary)
{
payload = MessageFormatUtils.DecodePayload(payload);
}
message = new Message(payload, _messageType);
consumed = lineEnd;
examined = consumed;
return ParseResult.Completed;
}
if (reader.Peek() == ByteCR)
{
_internalParserState = InternalParseState.ReadEndOfMessage;
@ -188,7 +201,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
case ByteT:
return MessageType.Text;
case ByteB:
throw new NotSupportedException("Support for binary messages has not been implemented yet");
return MessageType.Binary;
case ByteC:
return MessageType.Close;
case ByteE:

View File

@ -155,7 +155,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
if (_state.Read == _state.Length)
{
_state.Payload = DecodePayload(_state.Payload);
if (_state.MessageType == MessageType.Binary)
{
_state.Payload = MessageFormatUtils.DecodePayload(_state.Payload);
}
_state.Phase = ParsePhase.PayloadComplete;
}
@ -169,36 +172,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
}
}
private byte[] DecodePayload(byte[] inputPayload)
{
if (_state.MessageType == MessageType.Binary && inputPayload.Length > 0)
{
// Determine the output size
// Every 4 Base64 characters represents 3 bytes
var decodedLength = (inputPayload.Length / 4) * 3;
// Subtract padding bytes
if (inputPayload[inputPayload.Length - 1] == '=')
{
decodedLength -= 1;
}
if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=')
{
decodedLength -= 1;
}
// Allocate a new buffer to decode to
var decodeBuffer = new byte[decodedLength];
if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength)
{
throw new FormatException("Invalid Base64 payload");
}
return decodeBuffer;
}
return inputPayload;
}
private static bool TryParseType(byte type, out MessageType messageType)
{
switch ((char)type)

View File

@ -14,13 +14,16 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
public class ServerSentEventsParserTests
{
[Theory]
[InlineData("data: T\r\n\r\n", "")]
[InlineData("data: T\r\ndata: \r\r\n\r\n", "\r")]
[InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB")]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World")]
public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage)
[InlineData("data: T\r\n\r\n", "", MessageType.Text)]
[InlineData("data: B\r\n\r\n", "", MessageType.Binary)]
[InlineData("data: T\r\ndata: \r\r\n\r\n", "\r", MessageType.Text)]
[InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB", MessageType.Text)]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World", MessageType.Text)]
[InlineData("data: B\r\ndata: SGVsbG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)]
[InlineData("data: B\r\ndata: SGVsbG8g\r\ndata: V29ybGQ=\r\n\r\n", "Hello World", MessageType.Binary)]
public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage, MessageType messageType)
{
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
var readableBuffer = ReadableBuffer.Create(buffer);
@ -28,7 +31,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message);
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
Assert.Equal(MessageType.Text, message.Type);
Assert.Equal(messageType, message.Type);
Assert.Equal(consumed, examined);
var result = Encoding.UTF8.GetString(message.Payload);
@ -52,6 +55,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
[InlineData("data: T\r\ndata: Hello, World\r\n\r\\", "Expected a \\r\\n frame ending")]
[InlineData("data: T\r\ndata: Major\r\ndata: Key\rndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
[InlineData("data: T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
[InlineData("data: B\r\n SGVsbG8sIFdvcmxk\r\n\r\n", "Expected the message prefix 'data: '")]
public void ParseSSEMessageFailureCases(string encodedMessage, string expectedExceptionMessage)
{
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
@ -72,6 +76,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
[InlineData("data: T\r\ndata: Hello, World\r")]
[InlineData("data: T\r\ndata: Hello, World\r\n")]
[InlineData("data: T\r\ndata: Hello, World\r\n\r")]
[InlineData("data: B\r\ndata: SGVsbG8sIFd")]
public void ParseSSEMessageIncompleteParseResult(string encodedMessage)
{
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
@ -84,17 +89,18 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
}
[Theory]
[InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World")]
[InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World")]
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage)
[InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
[InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)]
public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage, MessageType expectedMessageType)
{
using (var pipeFactory = new PipeFactory())
{
@ -117,7 +123,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message);
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
Assert.Equal(MessageType.Text, message.Type);
Assert.Equal(expectedMessageType, message.Type);
Assert.Equal(consumed, examined);
var resultMessage = Encoding.UTF8.GetString(message.Payload);
@ -140,6 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
[InlineData("data: T\r\nda", "ta: Hello\n, World\r\n\r\n", "Unexpected '\n' in message. A '\n' character can only be used as part of the newline sequence '\r\n'")]
[InlineData("data:", " data: \r\n", "Unknown message type: 'd'")]
[InlineData("data: ", "T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
[InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\n\n", "There was an error in the frame format")]
public async Task ParseMessageAcrossMultipleReadsFailure(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage)
{
using (var pipeFactory = new PipeFactory())
@ -167,15 +174,15 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
}
}
[Fact]
public async Task ParseMultipleMessages()
[Theory]
[InlineData("data: T\r\ndata: foo\r\n\r\n", "data: T\r\ndata: bar\r\n\r\n", MessageType.Text)]
[InlineData("data: B\r\ndata: Zm9v\r\n\r\n", "data: B\r\ndata: YmFy\r\n\r\n", MessageType.Binary)]
public async Task ParseMultipleMessagesText(string message1, string message2, MessageType expectedMessageType)
{
using (var pipeFactory = new PipeFactory())
{
var pipe = pipeFactory.Create();
var message1 = "data: T\r\ndata: foo\r\n\r\n";
var message2 = "data: T\r\ndata: bar\r\n\r\n";
// Read the first part of the message
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes(message1 + message2));
@ -184,7 +191,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
var parseResult = parser.ParseMessage(result.Buffer, out var consumed, out var examined, out var message);
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
Assert.Equal(MessageType.Text, message.Type);
Assert.Equal(expectedMessageType, message.Type);
Assert.Equal("foo", Encoding.UTF8.GetString(message.Payload));
Assert.Equal(consumed, result.Buffer.Move(result.Buffer.Start, message1.Length));
pipe.Reader.Advance(consumed, examined);
@ -195,7 +202,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
result = await pipe.Reader.ReadAsync();
parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message);
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
Assert.Equal(MessageType.Text, message.Type);
Assert.Equal(expectedMessageType, message.Type);
Assert.Equal("bar", Encoding.UTF8.GetString(message.Payload));
pipe.Reader.Advance(consumed, examined);
}
@ -205,14 +212,14 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
{
get
{
yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic" };
yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down" };
yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic", MessageType.Text };
yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down", MessageType.Text };
}
}
[Theory]
[MemberData(nameof(MultilineMessages))]
public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage)
public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage, MessageType expectedMessageType)
{
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
var readableBuffer = ReadableBuffer.Create(buffer);
@ -220,23 +227,11 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message);
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
Assert.Equal(MessageType.Text, message.Type);
Assert.Equal(expectedMessageType, message.Type);
Assert.Equal(consumed, examined);
var result = Encoding.UTF8.GetString(message.Payload);
Assert.Equal(expectedMessage, result);
}
[Fact]
public void ParseSSEMessageBinaryNotSupported()
{
var encodedMessage = "data: B\r\ndata: \r\n\r\n";
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
var readableBuffer = ReadableBuffer.Create(buffer);
var parser = new ServerSentEventsMessageParser();
var ex = Assert.Throws<NotSupportedException>(() => { parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message); });
Assert.Equal("Support for binary messages has not been implemented yet", ex.Message);
}
}
}