use new Base64 codec in Text and SSE Formatters (#231)

fix #192
This commit is contained in:
Andrew Stanton-Nurse 2017-02-22 16:26:58 -08:00 committed by GitHub
parent 9767dbd5c1
commit 08c550655a
5 changed files with 67 additions and 44 deletions

View File

@ -2,7 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Text;
using System.Binary;
namespace Microsoft.AspNetCore.Sockets.Formatters
{
@ -76,9 +76,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
var writtenSoFar = 0;
if (type == MessageType.Binary)
{
// TODO: We're going to need to fix this as part of https://github.com/aspnet/SignalR/issues/192
var message = Convert.ToBase64String(payload.ToArray());
var encodedSize = DataPrefix.Length + Encoding.UTF8.GetByteCount(message) + Newline.Length;
var encodedSize = DataPrefix.Length + Base64.ComputeEncodedLength(payload.Length) + Newline.Length;
if (buffer.Length < encodedSize)
{
bytesWritten = 0;
@ -87,9 +85,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
DataPrefix.CopyTo(buffer);
buffer = buffer.Slice(DataPrefix.Length);
var array = Encoding.UTF8.GetBytes(message);
array.CopyTo(buffer);
buffer = buffer.Slice(array.Length);
var encodedLength = Base64.Encode(payload, buffer);
buffer = buffer.Slice(encodedLength);
Newline.CopyTo(buffer);
writtenSoFar += encodedSize;

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Binary;
using System.IO.Pipelines;
using System.Text;
@ -38,7 +39,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
// We need at least 4 more characters of space (':', type flag, ':', and eventually the terminating ';')
// We'll still need to double-check that we have space for the terminator after we write the payload,
// but this way we can exit early if the buffer is way too small.
if (buffer.Length < 4)
if (buffer.Length < 4 + length)
{
bytesWritten = 0;
return false;
@ -56,24 +57,22 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
// Payload
if (message.Type == MessageType.Binary)
{
// Encode the payload. For now, we make it an array and use the old-fashioned types because we need to mirror packages
// I've filed https://github.com/aspnet/SignalR/issues/192 to update this. -anurse
var payload = Convert.ToBase64String(message.Payload);
if (!TextEncoder.Utf8.TryEncode(payload, buffer, out int payloadWritten))
// Encode the payload directly into the buffer
var writtenByPayload = Base64.Encode(message.Payload, buffer);
// Check that we wrote enough. Length was already set (above) to the expected length in base64-encoded bytes
if (writtenByPayload < length)
{
bytesWritten = 0;
return false;
}
written += payloadWritten;
buffer = buffer.Slice(payloadWritten);
// We did, advance the buffers and continue
buffer = buffer.Slice(writtenByPayload);
written += writtenByPayload;
}
else
{
if (buffer.Length < message.Payload.Length)
{
bytesWritten = 0;
return false;
}
message.Payload.CopyTo(buffer.Slice(0, message.Payload.Length));
written += message.Payload.Length;
buffer = buffer.Slice(message.Payload.Length);
@ -114,8 +113,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
}
// Check if there's enough space in the buffer to even bother continuing
// There are at least 4 characters we still expect to see: ':', type flag, ':', ';'.
if (buffer.Length < 4)
// There are at least 4 characters we still expect to see: ':', type flag, ':', ';', plus the (encoded) payload length.
if (buffer.Length < 4 + length)
{
message = default(Message);
bytesConsumed = 0;
@ -149,39 +148,38 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
buffer = buffer.Slice(3);
consumedSoFar += 3;
// We expect to see <length>+1 more characters. Since <length> is the exact number of bytes in the text (even if base64-encoded)
// and we expect to see the ';'
if (buffer.Length < length + 1)
{
message = default(Message);
bytesConsumed = 0;
return false;
}
// Grab the payload buffer
var payloadBuffer = buffer.Slice(0, length);
var payload = buffer.Slice(0, length);
buffer = buffer.Slice(length);
consumedSoFar += length;
// Parse the payload. For now, we make it an array and use the old-fashioned types.
// I've filed https://github.com/aspnet/SignalR/issues/192 to update this. -anurse
var payload = payloadBuffer.ToArray();
if (messageType == MessageType.Binary)
if (messageType == MessageType.Binary && payload.Length > 0)
{
byte[] decoded;
try
// Determine the output size
// Every 4 Base64 characters represents 3 bytes
var decodedLength = (int)((payload.Length / 4) * 3);
// Subtract padding bytes
if (payload[payload.Length - 1] == '=')
{
var str = Encoding.UTF8.GetString(payload);
decoded = Convert.FromBase64String(str);
decodedLength -= 1;
}
catch
if (payload.Length > 1 && payload[payload.Length - 2] == '=')
{
decodedLength -= 1;
}
// Allocate a new buffer to decode to
var decodeBuffer = new byte[decodedLength];
if (Base64.Decode(payload, decodeBuffer) != decodedLength)
{
// Decoding failure
message = default(Message);
bytesConsumed = 0;
return false;
}
payload = decoded;
payload = decodeBuffer;
}
// Verify the trailer
@ -192,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters
return false;
}
message = new Message(payload, messageType);
message = new Message(payload.ToArray(), messageType);
bytesConsumed = consumedSoFar + 1;
return true;
}

View File

@ -12,6 +12,7 @@
</PropertyGroup>
<ItemGroup>
<PackageReference Include="System.Binary.Base64" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.Text.Primitives" Version="$(CoreFxLabsVersion)" />
<PackageReference Include="System.Threading.Tasks.Channels" Version="$(CoreFxLabsVersion)" />

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Text;
using Microsoft.AspNetCore.Sockets.Tests;
using Xunit;

View File

@ -39,6 +39,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests
[Theory]
[InlineData("0:B:;", new byte[0])]
[InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
[InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })]
[InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })]
public void WriteBinaryMessage(string encoded, byte[] payload)
{
var message = MessageTestUtils.CreateMessage(payload);
@ -99,6 +101,8 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests
[Theory]
[InlineData("0:B:;", new byte[0])]
[InlineData("8:B:q83vEg==;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
[InlineData("8:B:q83vEjQ=;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })]
[InlineData("8:B:q83vEjRW;", new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })]
public void ReadBinaryMessage(string encoded, byte[] payload)
{
var buffer = Encoding.UTF8.GetBytes(encoded);
@ -153,8 +157,32 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests
Assert.Equal(0, consumed);
}
[Theory]
[InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
[InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34 })]
[InlineData(new byte[] { 0xAB, 0xCD, 0xEF, 0x12, 0x34, 0x56 })]
public void InsufficientWriteBufferSpaceBinary(byte[] payload)
{
const int ExpectedSize = 13;
var message = MessageTestUtils.CreateMessage(payload);
byte[] buffer;
int bufferSize;
int written;
for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++)
{
buffer = new byte[bufferSize];
Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written));
Assert.Equal(0, written);
}
buffer = new byte[bufferSize];
Assert.True(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written));
Assert.Equal(ExpectedSize, written);
}
[Fact]
public void InsufficientWriteBufferSpace()
public void InsufficientWriteBufferSpaceText()
{
const int ExpectedSize = 9;
var message = MessageTestUtils.CreateMessage("Test", MessageType.Text);
@ -162,7 +190,7 @@ namespace Microsoft.AspNetCore.Sockets.Formatters.Tests
byte[] buffer;
int bufferSize;
int written;
for (bufferSize = 0; bufferSize < 9; bufferSize++)
for (bufferSize = 0; bufferSize < ExpectedSize; bufferSize++)
{
buffer = new byte[bufferSize];
Assert.False(MessageFormatter.TryFormatMessage(message, buffer, MessageFormat.Text, out written));