diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs index 5fd40034a0..1ab081b281 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/ServerSentEventsMessageFormatter.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Text; +using System.Binary; namespace Microsoft.AspNetCore.Sockets.Formatters { @@ -76,9 +76,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters var writtenSoFar = 0; if (type == MessageType.Binary) { - // TODO: We're going to need to fix this as part of https://github.com/aspnet/SignalR/issues/192 - var message = Convert.ToBase64String(payload.ToArray()); - var encodedSize = DataPrefix.Length + Encoding.UTF8.GetByteCount(message) + Newline.Length; + var encodedSize = DataPrefix.Length + Base64.ComputeEncodedLength(payload.Length) + Newline.Length; if (buffer.Length < encodedSize) { bytesWritten = 0; @@ -87,9 +85,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters DataPrefix.CopyTo(buffer); buffer = buffer.Slice(DataPrefix.Length); - var array = Encoding.UTF8.GetBytes(message); - array.CopyTo(buffer); - buffer = buffer.Slice(array.Length); + var encodedLength = Base64.Encode(payload, buffer); + buffer = buffer.Slice(encodedLength); Newline.CopyTo(buffer); writtenSoFar += encodedSize; diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs index fc5d9b2599..b2eab314a3 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.Sockets.Common/Formatters/TextMessageFormatter.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Binary; using System.IO.Pipelines; using System.Text; @@ -38,7 +39,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters // We need at least 4 more characters of space (':', type flag, ':', and eventually the terminating ';') // We'll still need to double-check that we have space for the terminator after we write the payload, // but this way we can exit early if the buffer is way too small. - if (buffer.Length < 4) + if (buffer.Length < 4 + length) { bytesWritten = 0; return false; @@ -56,24 +57,22 @@ namespace Microsoft.AspNetCore.Sockets.Formatters // Payload if (message.Type == MessageType.Binary) { - // Encode the payload. For now, we make it an array and use the old-fashioned types because we need to mirror packages - // I've filed https://github.com/aspnet/SignalR/issues/192 to update this. -anurse - var payload = Convert.ToBase64String(message.Payload); - if (!TextEncoder.Utf8.TryEncode(payload, buffer, out int payloadWritten)) + // Encode the payload directly into the buffer + var writtenByPayload = Base64.Encode(message.Payload, buffer); + + // Check that we wrote enough. Length was already set (above) to the expected length in base64-encoded bytes + if (writtenByPayload < length) { bytesWritten = 0; return false; } - written += payloadWritten; - buffer = buffer.Slice(payloadWritten); + + // We did, advance the buffers and continue + buffer = buffer.Slice(writtenByPayload); + written += writtenByPayload; } else { - if (buffer.Length < message.Payload.Length) - { - bytesWritten = 0; - return false; - } message.Payload.CopyTo(buffer.Slice(0, message.Payload.Length)); written += message.Payload.Length; buffer = buffer.Slice(message.Payload.Length); @@ -114,8 +113,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters } // Check if there's enough space in the buffer to even bother continuing - // There are at least 4 characters we still expect to see: ':', type flag, ':', ';'. - if (buffer.Length < 4) + // There are at least 4 characters we still expect to see: ':', type flag, ':', ';', plus the (encoded) payload length. + if (buffer.Length < 4 + length) { message = default(Message); bytesConsumed = 0; @@ -149,39 +148,38 @@ namespace Microsoft.AspNetCore.Sockets.Formatters buffer = buffer.Slice(3); consumedSoFar += 3; - // We expect to see +1 more characters. Since is the exact number of bytes in the text (even if base64-encoded) - // and we expect to see the ';' - if (buffer.Length < length + 1) - { - message = default(Message); - bytesConsumed = 0; - return false; - } - // Grab the payload buffer - var payloadBuffer = buffer.Slice(0, length); + var payload = buffer.Slice(0, length); buffer = buffer.Slice(length); consumedSoFar += length; // Parse the payload. For now, we make it an array and use the old-fashioned types. // I've filed https://github.com/aspnet/SignalR/issues/192 to update this. -anurse - var payload = payloadBuffer.ToArray(); - if (messageType == MessageType.Binary) + if (messageType == MessageType.Binary && payload.Length > 0) { - byte[] decoded; - try + // Determine the output size + // Every 4 Base64 characters represents 3 bytes + var decodedLength = (int)((payload.Length / 4) * 3); + + // Subtract padding bytes + if (payload[payload.Length - 1] == '=') { - var str = Encoding.UTF8.GetString(payload); - decoded = Convert.FromBase64String(str); + decodedLength -= 1; } - catch + if (payload.Length > 1 && payload[payload.Length - 2] == '=') + { + decodedLength -= 1; + } + + // Allocate a new buffer to decode to + var decodeBuffer = new byte[decodedLength]; + if (Base64.Decode(payload, decodeBuffer) != decodedLength) { - // Decoding failure message = default(Message); bytesConsumed = 0; return false; } - payload = decoded; + payload = decodeBuffer; } // Verify the trailer @@ -192,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters return false; } - message = new Message(payload, messageType); + message = new Message(payload.ToArray(), messageType); bytesConsumed = consumedSoFar + 1; return true; } diff --git a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj index 9da735ef4f..496aa61bbd 100644 --- a/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Common/Microsoft.AspNetCore.Sockets.Common.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs index 2d01660d4d..8bc24921da 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/ServerSentEventsMessageFormatterTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Text; using Microsoft.AspNetCore.Sockets.Tests; using Xunit; diff --git a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs index 06a1e4b5b6..0f67c4c89f 100644 --- a/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Common.Tests/Formatters/TextMessageFormatterTests.cs @@ -39,6 +39,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests [Theory] [InlineData("0:B:;", new byte[0])] [InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] + [InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] public void WriteBinaryMessage(string encoded, byte[] payload) { var message = MessageTestUtils.CreateMessage(payload); @@ -99,6 +101,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests [Theory] [InlineData("0:B:;", new byte[0])] [InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] + [InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] public void ReadBinaryMessage(string encoded, byte[] payload) { var buffer = Encoding.UTF8.GetBytes(encoded); @@ -153,8 +157,32 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests Assert.Equal(0, consumed); } + [Theory] + [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })] + [InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })] + public void InsufficientWriteBufferSpaceBinary(byte[] payload) + { + const int ExpectedSize = 13; + var message = MessageTestUtils.CreateMessage(payload); + + byte[] buffer; + int bufferSize; + int written; + for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++) + { + buffer = new byte[bufferSize]; + Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); + Assert.Equal(0, written); + } + + buffer = new byte[bufferSize]; + Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written)); + Assert.Equal(ExpectedSize, written); + } + [Fact] - public void InsufficientWriteBufferSpace() + public void InsufficientWriteBufferSpaceText() { const int ExpectedSize = 9; var message = MessageTestUtils.CreateMessage("Test", MessageType.Text); @@ -162,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests byte[] buffer; int bufferSize; int written; - for (bufferSize = 0; bufferSize < 9; bufferSize++) + for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++) { buffer = new byte[bufferSize]; Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written));