Support binary messages in SSE parser (#418)
This commit is contained in:
parent
3006d315cc
commit
74b318b3e4
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Binary;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
||||
{
|
||||
internal static class MessageFormatUtils
|
||||
{
|
||||
public static byte[] DecodePayload(byte[] inputPayload)
|
||||
{
|
||||
if (inputPayload.Length > 0)
|
||||
{
|
||||
// Determine the output size
|
||||
// Every 4 Base64 characters represents 3 bytes
|
||||
var decodedLength = (inputPayload.Length / 4) * 3;
|
||||
|
||||
// Subtract padding bytes
|
||||
if (inputPayload[inputPayload.Length - 1] == '=')
|
||||
{
|
||||
decodedLength -= 1;
|
||||
}
|
||||
if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=')
|
||||
{
|
||||
decodedLength -= 1;
|
||||
}
|
||||
|
||||
// Allocate a new buffer to decode to
|
||||
var decodeBuffer = new byte[decodedLength];
|
||||
if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength)
|
||||
{
|
||||
throw new FormatException("Invalid Base64 payload");
|
||||
}
|
||||
return decodeBuffer;
|
||||
}
|
||||
|
||||
return inputPayload;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -32,7 +32,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
examined = buffer.End;
|
||||
message = new Message();
|
||||
var reader = new ReadableBufferReader(buffer);
|
||||
_messageType = MessageType.Text;
|
||||
|
||||
var start = consumed;
|
||||
var end = examined;
|
||||
|
|
@ -83,6 +82,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
EnsureStartsWithDataPrefix(line);
|
||||
}
|
||||
|
||||
var payload = Array.Empty<byte>();
|
||||
switch (_internalParserState)
|
||||
{
|
||||
case InternalParseState.ReadMessageType:
|
||||
|
|
@ -94,7 +94,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
consumed = lineEnd;
|
||||
break;
|
||||
case InternalParseState.ReadMessagePayload:
|
||||
|
||||
// Slice away the 'data: '
|
||||
var payloadLength = line.Length - (_dataPrefix.Length + _sseLineEnding.Length);
|
||||
var newData = line.Slice(_dataPrefix.Length, payloadLength).ToArray();
|
||||
|
|
@ -104,42 +103,56 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
consumed = lineEnd;
|
||||
break;
|
||||
case InternalParseState.ReadEndOfMessage:
|
||||
if (_data.Count > 0)
|
||||
if (_data.Count == 1)
|
||||
{
|
||||
payload = _data[0];
|
||||
}
|
||||
else if (_data.Count > 1)
|
||||
{
|
||||
// Find the final size of the payload
|
||||
var payloadSize = 0;
|
||||
foreach (var dataLine in _data)
|
||||
{
|
||||
payloadSize += dataLine.Length + _newLine.Length;
|
||||
payloadSize += dataLine.Length;
|
||||
}
|
||||
|
||||
// Allocate space in the paylod buffer for the data and the new lines.
|
||||
// Subtract newLine length because we don't want a trailing newline.
|
||||
var payload = new byte[payloadSize - _newLine.Length];
|
||||
if (_messageType != MessageType.Binary)
|
||||
{
|
||||
payloadSize += _newLine.Length*_data.Count;
|
||||
|
||||
// Allocate space in the paylod buffer for the data and the new lines.
|
||||
// Subtract newLine length because we don't want a trailing newline.
|
||||
payload = new byte[payloadSize - _newLine.Length];
|
||||
}
|
||||
else
|
||||
{
|
||||
payload = new byte[payloadSize];
|
||||
}
|
||||
|
||||
var offset = 0;
|
||||
foreach (var dataLine in _data)
|
||||
{
|
||||
dataLine.CopyTo(payload, offset);
|
||||
offset += dataLine.Length;
|
||||
if (offset < payload.Length)
|
||||
if (offset < payload.Length && _messageType != MessageType.Binary)
|
||||
{
|
||||
_newLine.CopyTo(payload, offset);
|
||||
offset += _newLine.Length;
|
||||
}
|
||||
}
|
||||
message = new Message(payload, _messageType);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Empty message
|
||||
message = new Message(Array.Empty<byte>(), _messageType);
|
||||
}
|
||||
|
||||
if (_messageType == MessageType.Binary)
|
||||
{
|
||||
payload = MessageFormatUtils.DecodePayload(payload);
|
||||
}
|
||||
|
||||
message = new Message(payload, _messageType);
|
||||
consumed = lineEnd;
|
||||
examined = consumed;
|
||||
return ParseResult.Completed;
|
||||
}
|
||||
|
||||
if (reader.Peek() == ByteCR)
|
||||
{
|
||||
_internalParserState = InternalParseState.ReadEndOfMessage;
|
||||
|
|
@ -188,7 +201,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
case ByteT:
|
||||
return MessageType.Text;
|
||||
case ByteB:
|
||||
throw new NotSupportedException("Support for binary messages has not been implemented yet");
|
||||
return MessageType.Binary;
|
||||
case ByteC:
|
||||
return MessageType.Close;
|
||||
case ByteE:
|
||||
|
|
|
|||
|
|
@ -155,7 +155,10 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
|
||||
if (_state.Read == _state.Length)
|
||||
{
|
||||
_state.Payload = DecodePayload(_state.Payload);
|
||||
if (_state.MessageType == MessageType.Binary)
|
||||
{
|
||||
_state.Payload = MessageFormatUtils.DecodePayload(_state.Payload);
|
||||
}
|
||||
|
||||
_state.Phase = ParsePhase.PayloadComplete;
|
||||
}
|
||||
|
|
@ -169,36 +172,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Formatters
|
|||
}
|
||||
}
|
||||
|
||||
private byte[] DecodePayload(byte[] inputPayload)
|
||||
{
|
||||
if (_state.MessageType == MessageType.Binary && inputPayload.Length > 0)
|
||||
{
|
||||
// Determine the output size
|
||||
// Every 4 Base64 characters represents 3 bytes
|
||||
var decodedLength = (inputPayload.Length / 4) * 3;
|
||||
|
||||
// Subtract padding bytes
|
||||
if (inputPayload[inputPayload.Length - 1] == '=')
|
||||
{
|
||||
decodedLength -= 1;
|
||||
}
|
||||
if (inputPayload.Length > 1 && inputPayload[inputPayload.Length - 2] == '=')
|
||||
{
|
||||
decodedLength -= 1;
|
||||
}
|
||||
|
||||
// Allocate a new buffer to decode to
|
||||
var decodeBuffer = new byte[decodedLength];
|
||||
if (Base64.Decode(inputPayload, decodeBuffer) != decodedLength)
|
||||
{
|
||||
throw new FormatException("Invalid Base64 payload");
|
||||
}
|
||||
return decodeBuffer;
|
||||
}
|
||||
|
||||
return inputPayload;
|
||||
}
|
||||
|
||||
private static bool TryParseType(byte type, out MessageType messageType)
|
||||
{
|
||||
switch ((char)type)
|
||||
|
|
|
|||
|
|
@ -14,13 +14,16 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
public class ServerSentEventsParserTests
|
||||
{
|
||||
[Theory]
|
||||
[InlineData("data: T\r\n\r\n", "")]
|
||||
[InlineData("data: T\r\ndata: \r\r\n\r\n", "\r")]
|
||||
[InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World")]
|
||||
public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage)
|
||||
[InlineData("data: T\r\n\r\n", "", MessageType.Text)]
|
||||
[InlineData("data: B\r\n\r\n", "", MessageType.Binary)]
|
||||
[InlineData("data: T\r\ndata: \r\r\n\r\n", "\r", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: A\rB\r\n\r\n", "A\rB", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\ndata: ", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: B\r\ndata: SGVsbG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)]
|
||||
[InlineData("data: B\r\ndata: SGVsbG8g\r\ndata: V29ybGQ=\r\n\r\n", "Hello World", MessageType.Binary)]
|
||||
public void ParseSSEMessageSuccessCases(string encodedMessage, string expectedMessage, MessageType messageType)
|
||||
{
|
||||
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
|
||||
var readableBuffer = ReadableBuffer.Create(buffer);
|
||||
|
|
@ -28,7 +31,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
|
||||
var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message);
|
||||
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
|
||||
Assert.Equal(MessageType.Text, message.Type);
|
||||
Assert.Equal(messageType, message.Type);
|
||||
Assert.Equal(consumed, examined);
|
||||
|
||||
var result = Encoding.UTF8.GetString(message.Payload);
|
||||
|
|
@ -52,6 +55,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
[InlineData("data: T\r\ndata: Hello, World\r\n\r\\", "Expected a \\r\\n frame ending")]
|
||||
[InlineData("data: T\r\ndata: Major\r\ndata: Key\rndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
|
||||
[InlineData("data: T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
|
||||
[InlineData("data: B\r\n SGVsbG8sIFdvcmxk\r\n\r\n", "Expected the message prefix 'data: '")]
|
||||
public void ParseSSEMessageFailureCases(string encodedMessage, string expectedExceptionMessage)
|
||||
{
|
||||
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
|
||||
|
|
@ -72,6 +76,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
[InlineData("data: T\r\ndata: Hello, World\r")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n\r")]
|
||||
[InlineData("data: B\r\ndata: SGVsbG8sIFd")]
|
||||
public void ParseSSEMessageIncompleteParseResult(string encodedMessage)
|
||||
{
|
||||
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
|
||||
|
|
@ -84,17 +89,18 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World")]
|
||||
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
[InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World")]
|
||||
public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage)
|
||||
[InlineData("d", "ata: T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r", "\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\n", "data: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\nd", "ata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: ", "Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: Hello, World", "\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T\r\ndata: Hello, World\r\n", "\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: T", "\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: ", "T\r\ndata: Hello, World\r\n\r\n", "Hello, World", MessageType.Text)]
|
||||
[InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\r\n", "Hello, World", MessageType.Binary)]
|
||||
public async Task ParseMessageAcrossMultipleReadsSuccess(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage, MessageType expectedMessageType)
|
||||
{
|
||||
using (var pipeFactory = new PipeFactory())
|
||||
{
|
||||
|
|
@ -117,7 +123,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
|
||||
parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message);
|
||||
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
|
||||
Assert.Equal(MessageType.Text, message.Type);
|
||||
Assert.Equal(expectedMessageType, message.Type);
|
||||
Assert.Equal(consumed, examined);
|
||||
|
||||
var resultMessage = Encoding.UTF8.GetString(message.Payload);
|
||||
|
|
@ -140,6 +146,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
[InlineData("data: T\r\nda", "ta: Hello\n, World\r\n\r\n", "Unexpected '\n' in message. A '\n' character can only be used as part of the newline sequence '\r\n'")]
|
||||
[InlineData("data:", " data: \r\n", "Unknown message type: 'd'")]
|
||||
[InlineData("data: ", "T\r\ndata: Major\r\ndata: Key\r\ndata: Alert\r\n\r\\", "Expected a \\r\\n frame ending")]
|
||||
[InlineData("data: B\r\ndata: SGVs", "bG8sIFdvcmxk\r\n\n\n", "There was an error in the frame format")]
|
||||
public async Task ParseMessageAcrossMultipleReadsFailure(string encodedMessagePart1, string encodedMessagePart2, string expectedMessage)
|
||||
{
|
||||
using (var pipeFactory = new PipeFactory())
|
||||
|
|
@ -167,15 +174,15 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ParseMultipleMessages()
|
||||
[Theory]
|
||||
[InlineData("data: T\r\ndata: foo\r\n\r\n", "data: T\r\ndata: bar\r\n\r\n", MessageType.Text)]
|
||||
[InlineData("data: B\r\ndata: Zm9v\r\n\r\n", "data: B\r\ndata: YmFy\r\n\r\n", MessageType.Binary)]
|
||||
public async Task ParseMultipleMessagesText(string message1, string message2, MessageType expectedMessageType)
|
||||
{
|
||||
using (var pipeFactory = new PipeFactory())
|
||||
{
|
||||
var pipe = pipeFactory.Create();
|
||||
|
||||
var message1 = "data: T\r\ndata: foo\r\n\r\n";
|
||||
var message2 = "data: T\r\ndata: bar\r\n\r\n";
|
||||
// Read the first part of the message
|
||||
await pipe.Writer.WriteAsync(Encoding.UTF8.GetBytes(message1 + message2));
|
||||
|
||||
|
|
@ -184,7 +191,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
|
||||
var parseResult = parser.ParseMessage(result.Buffer, out var consumed, out var examined, out var message);
|
||||
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
|
||||
Assert.Equal(MessageType.Text, message.Type);
|
||||
Assert.Equal(expectedMessageType, message.Type);
|
||||
Assert.Equal("foo", Encoding.UTF8.GetString(message.Payload));
|
||||
Assert.Equal(consumed, result.Buffer.Move(result.Buffer.Start, message1.Length));
|
||||
pipe.Reader.Advance(consumed, examined);
|
||||
|
|
@ -195,7 +202,7 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
result = await pipe.Reader.ReadAsync();
|
||||
parseResult = parser.ParseMessage(result.Buffer, out consumed, out examined, out message);
|
||||
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
|
||||
Assert.Equal(MessageType.Text, message.Type);
|
||||
Assert.Equal(expectedMessageType, message.Type);
|
||||
Assert.Equal("bar", Encoding.UTF8.GetString(message.Payload));
|
||||
pipe.Reader.Advance(consumed, examined);
|
||||
}
|
||||
|
|
@ -205,14 +212,14 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
{
|
||||
get
|
||||
{
|
||||
yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic" };
|
||||
yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down" };
|
||||
yield return new object[] { "data: T\r\ndata: Shaolin\r\ndata: Fantastic\r\n\r\n", "Shaolin" + Environment.NewLine + " Fantastic", MessageType.Text };
|
||||
yield return new object[] { "data: T\r\ndata: The\r\ndata: Get\r\ndata: Down\r\n\r\n", "The" + Environment.NewLine + "Get" + Environment.NewLine + "Down", MessageType.Text };
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(MultilineMessages))]
|
||||
public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage)
|
||||
public void ParseMessagesWithMultipleDataLines(string encodedMessage, string expectedMessage, MessageType expectedMessageType)
|
||||
{
|
||||
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
|
||||
var readableBuffer = ReadableBuffer.Create(buffer);
|
||||
|
|
@ -220,23 +227,11 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters
|
|||
|
||||
var parseResult = parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message);
|
||||
Assert.Equal(ServerSentEventsMessageParser.ParseResult.Completed, parseResult);
|
||||
Assert.Equal(MessageType.Text, message.Type);
|
||||
Assert.Equal(expectedMessageType, message.Type);
|
||||
Assert.Equal(consumed, examined);
|
||||
|
||||
var result = Encoding.UTF8.GetString(message.Payload);
|
||||
Assert.Equal(expectedMessage, result);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ParseSSEMessageBinaryNotSupported()
|
||||
{
|
||||
var encodedMessage = "data: B\r\ndata: \r\n\r\n";
|
||||
var buffer = Encoding.UTF8.GetBytes(encodedMessage);
|
||||
var readableBuffer = ReadableBuffer.Create(buffer);
|
||||
var parser = new ServerSentEventsMessageParser();
|
||||
|
||||
var ex = Assert.Throws<NotSupportedException>(() => { parser.ParseMessage(readableBuffer, out var consumed, out var examined, out Message message); });
|
||||
Assert.Equal("Support for binary messages has not been implemented yet", ex.Message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue