Merge pull request #1798 from aspnet/release/2.1

Use IBufferWriter in IHubProtocol (#1791)
This commit is contained in:
BrennanConroy 2018-03-30 17:30:30 -07:00 committed by GitHub
commit 90aa48c09f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 267 additions and 184 deletions

View File

@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
return false; return false;
} }
public void WriteMessage(HubMessage message, Stream output) public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{ {
} }
} }

View File

@ -30,9 +30,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup] [GlobalSetup]
public void GlobalSetup() public void GlobalSetup()
{ {
var memoryBufferWriter = new MemoryBufferWriter(); using (var memoryBufferWriter = new MemoryBufferWriter())
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); {
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false); HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
}
_pipe = new TestDuplexPipe(); _pipe = new TestDuplexPipe();

View File

@ -24,12 +24,14 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup] [GlobalSetup]
public void GlobalSetup() public void GlobalSetup()
{ {
var ms = new MemoryBufferWriter(); using (var writer = new MemoryBufferWriter())
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms); {
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(ms.ToArray()), false, false); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
_pipe = new TestDuplexPipe();
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult)); _pipe = new TestDuplexPipe();
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult));
}
_tcs = new TaskCompletionSource<ReadResult>(); _tcs = new TaskCompletionSource<ReadResult>();
_pipe.AddReadResult(new ValueTask<ReadResult>(_tcs.Task)); _pipe.AddReadResult(new ValueTask<ReadResult>(_tcs.Task));

View File

@ -26,9 +26,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup] [GlobalSetup]
public void GlobalSetup() public void GlobalSetup()
{ {
var ms = new MemoryBufferWriter(); using (var writer = new MemoryBufferWriter())
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms); {
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(ms.ToArray()), false, false); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
}
_pipe = new TestDuplexPipe(); _pipe = new TestDuplexPipe();

View File

@ -2,6 +2,7 @@ using System;
using System.Buffers; using System.Buffers;
using System.IO; using System.IO;
using BenchmarkDotNet.Attributes; using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
@ -23,19 +24,22 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{ {
var buffer = new byte[MessageLength]; var buffer = new byte[MessageLength];
Random.NextBytes(buffer); Random.NextBytes(buffer);
var output = new MemoryStream(); using (var writer = new MemoryBufferWriter())
BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, output); {
output.Write(buffer, 0, buffer.Length); BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer);
writer.Write(buffer);
_binaryInput = output.ToArray(); _binaryInput = writer.ToArray();
}
buffer = new byte[MessageLength]; buffer = new byte[MessageLength];
Random.NextBytes(buffer); Random.NextBytes(buffer);
output = new MemoryStream(); using (var writer = new MemoryBufferWriter())
output.Write(buffer, 0, buffer.Length); {
TextMessageFormatter.WriteRecordSeparator(output); writer.Write(buffer);
TextMessageFormatter.WriteRecordSeparator(writer);
_textInput = output.ToArray(); _textInput = writer.ToArray();
}
} }
[Benchmark] [Benchmark]

View File

@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.Internal
internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter) internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter)
{ {
var writer = new JsonTextWriter(textWriter); var writer = new JsonTextWriter(textWriter);
writer.ArrayPool = JsonArrayPool<char>.Shared;
// Don't close the output, leave closing to the caller // Don't close the output, leave closing to the caller
writer.CloseOutput = false; writer.CloseOutput = false;

View File

@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{ {
AssertConnectionValid(); AssertConnectionValid();
_protocol.WriteMessage(hubMessage, _connectionState.OutputStream); _protocol.WriteMessage(hubMessage, _connectionState.Connection.Transport.Output);
Log.SendingMessage(_logger, hubMessage); Log.SendingMessage(_logger, hubMessage);
@ -826,7 +826,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
public IConnection Connection { get; } public IConnection Connection { get; }
public Task ReceiveTask { get; set; } public Task ReceiveTask { get; set; }
public Exception CloseException { get; set; } public Exception CloseException { get; set; }
public PipeWriterStream OutputStream { get; }
public bool Stopping public bool Stopping
{ {
@ -838,7 +837,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
{ {
_hubConnection = hubConnection; _hubConnection = hubConnection;
Connection = connection; Connection = connection;
OutputStream = new PipeWriterStream(Connection.Transport.Output);
} }
public string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); public string GetNextId() => Interlocked.Increment(ref _nextId).ToString();

View File

@ -2,22 +2,19 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.IO; using System.Buffers;
namespace Microsoft.AspNetCore.SignalR.Internal.Formatters namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
{ {
public static class BinaryMessageFormatter public static class BinaryMessageFormatter
{ {
public static void WriteLengthPrefix(long length, Stream output) public static void WriteLengthPrefix(long length, IBufferWriter<byte> output)
{ {
// This code writes length prefix of the message as a VarInt. Read the comment in // This code writes length prefix of the message as a VarInt. Read the comment in
// the BinaryMessageParser.TryParseMessage for details. // the BinaryMessageParser.TryParseMessage for details.
#if NETCOREAPP2_1
Span<byte> lenBuffer = stackalloc byte[5]; Span<byte> lenBuffer = stackalloc byte[5];
#else
var lenBuffer = new byte[5];
#endif
var lenNumBytes = 0; var lenNumBytes = 0;
do do
{ {
@ -32,11 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
} }
while (length > 0); while (length > 0);
#if NETCOREAPP2_1
output.Write(lenBuffer.Slice(0, lenNumBytes)); output.Write(lenBuffer.Slice(0, lenNumBytes));
#else
output.Write(lenBuffer, 0, lenNumBytes);
#endif
} }
} }
} }

View File

@ -4,24 +4,26 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR.Internal namespace Microsoft.AspNetCore.SignalR.Internal
{ {
public sealed class MemoryBufferWriter : IBufferWriter<byte> public sealed class MemoryBufferWriter : IBufferWriter<byte>, IDisposable
{ {
private readonly int _segmentSize; private readonly int _segmentSize;
private int _bytesWritten;
internal List<Memory<byte>> Segments { get; } internal List<byte[]> Segments { get; }
internal int Position { get; private set; } internal int Position { get; private set; }
public MemoryBufferWriter(int segmentSize = 2048) public MemoryBufferWriter(int segmentSize = 2048)
{ {
_segmentSize = segmentSize; _segmentSize = segmentSize;
Segments = new List<Memory<byte>>(); Segments = new List<byte[]>();
} }
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public void Advance(int count) public void Advance(int count)
{ {
_bytesWritten += count;
Position += count; Position += count;
} }
@ -31,8 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (Segments.Count == 0 || Position == _segmentSize) if (Segments.Count == 0 || Position == _segmentSize)
{ {
// TODO: Rent memory from a pool Segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
Segments.Add(new Memory<byte>(new byte[_segmentSize]));
Position = 0; Position = 0;
} }
@ -51,17 +52,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return Array.Empty<byte>(); return Array.Empty<byte>();
} }
var totalLength = (Segments.Count - 1) * _segmentSize; var result = new byte[_bytesWritten];
totalLength += Position;
var result = new byte[totalLength];
var totalWritten = 0; var totalWritten = 0;
// Copy full segments // Copy full segments
for (int i = 0; i < Segments.Count - 1; i++) for (int i = 0; i < Segments.Count - 1; i++)
{ {
Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize)); Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize; totalWritten += _segmentSize;
} }
@ -71,5 +69,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return result; return result;
} }
public void Dispose()
{
for (int i = 0; i < Segments.Count; i++)
{
ArrayPool<byte>.Shared.Return(Segments[i]);
}
Segments.Clear();
}
} }
} }

View File

@ -1,18 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{ {
public static class HubProtocolExtensions public static class HubProtocolExtensions
{ {
public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message) public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message)
{ {
using (var ms = new LimitArrayPoolWriteStream()) using (var writer = new MemoryBufferWriter())
{ {
hubProtocol.WriteMessage(message, ms); hubProtocol.WriteMessage(message, writer);
return ms.ToArray(); return writer.ToArray();
} }
} }
} }

View File

@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message); bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message);
void WriteMessage(HubMessage message, Stream output); void WriteMessage(HubMessage message, IBufferWriter<byte> output);
bool IsVersionSupported(int version); bool IsVersionSupported(int version);
} }

View File

@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return message != null; return message != null;
} }
public void WriteMessage(HubMessage message, Stream output) public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{ {
WriteMessageCore(message, output); WriteMessageCore(message, output);
TextMessageFormatter.WriteRecordSeparator(output); TextMessageFormatter.WriteRecordSeparator(output);
@ -340,50 +340,58 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
throw new JsonReaderException("Unexpected end when reading message headers"); throw new JsonReaderException("Unexpected end when reading message headers");
} }
private void WriteMessageCore(HubMessage message, Stream stream) private void WriteMessageCore(HubMessage message, IBufferWriter<byte> stream)
{ {
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true))) var textWriter = Utf8BufferTextWriter.Get(stream);
try
{ {
writer.WriteStartObject(); using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
switch (message)
{ {
case InvocationMessage m: writer.WriteStartObject();
WriteMessageType(writer, HubProtocolConstants.InvocationMessageType); switch (message)
WriteHeaders(writer, m); {
WriteInvocationMessage(m, writer); case InvocationMessage m:
break; WriteMessageType(writer, HubProtocolConstants.InvocationMessageType);
case StreamInvocationMessage m: WriteHeaders(writer, m);
WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType); WriteInvocationMessage(m, writer);
WriteHeaders(writer, m); break;
WriteStreamInvocationMessage(m, writer); case StreamInvocationMessage m:
break; WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType);
case StreamItemMessage m: WriteHeaders(writer, m);
WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType); WriteStreamInvocationMessage(m, writer);
WriteHeaders(writer, m); break;
WriteStreamItemMessage(m, writer); case StreamItemMessage m:
break; WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType);
case CompletionMessage m: WriteHeaders(writer, m);
WriteMessageType(writer, HubProtocolConstants.CompletionMessageType); WriteStreamItemMessage(m, writer);
WriteHeaders(writer, m); break;
WriteCompletionMessage(m, writer); case CompletionMessage m:
break; WriteMessageType(writer, HubProtocolConstants.CompletionMessageType);
case CancelInvocationMessage m: WriteHeaders(writer, m);
WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType); WriteCompletionMessage(m, writer);
WriteHeaders(writer, m); break;
WriteCancelInvocationMessage(m, writer); case CancelInvocationMessage m:
break; WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType);
case PingMessage _: WriteHeaders(writer, m);
WriteMessageType(writer, HubProtocolConstants.PingMessageType); WriteCancelInvocationMessage(m, writer);
break; break;
case CloseMessage m: case PingMessage _:
WriteMessageType(writer, HubProtocolConstants.CloseMessageType); WriteMessageType(writer, HubProtocolConstants.PingMessageType);
WriteCloseMessage(m, writer); break;
break; case CloseMessage m:
default: WriteMessageType(writer, HubProtocolConstants.CloseMessageType);
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); WriteCloseMessage(m, writer);
break;
default:
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
}
writer.WriteEndObject();
writer.Flush();
} }
writer.WriteEndObject(); }
writer.Flush(); finally
{
Utf8BufferTextWriter.Return(textWriter);
} }
} }

View File

@ -5,7 +5,6 @@ using System;
using System.Buffers; using System.Buffers;
using System.Diagnostics; using System.Diagnostics;
using System.IO; using System.IO;
using System.Linq;
using System.Runtime.InteropServices; using System.Runtime.InteropServices;
using System.Text; using System.Text;
@ -77,12 +76,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public override void Write(char[] buffer, int index, int count) public override void Write(char[] buffer, int index, int count)
{ {
WriteInternal(buffer, index, count); WriteInternal(buffer.AsSpan(index, count));
} }
public override void Write(char[] buffer) public override void Write(char[] buffer)
{ {
WriteInternal(buffer, 0, buffer.Length); WriteInternal(buffer);
} }
public override void Write(char value) public override void Write(char value)
@ -120,6 +119,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
} }
} }
public override void Write(string value)
{
WriteInternal(value.AsSpan());
}
private Span<byte> GetBuffer() private Span<byte> GetBuffer()
{ {
EnsureBuffer(); EnsureBuffer();
@ -142,11 +146,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
} }
} }
private void WriteInternal(char[] buffer, int index, int count) private void WriteInternal(ReadOnlySpan<char> buffer)
{ {
var currentIndex = index; while (buffer.Length > 0)
var charsRemaining = count;
while (charsRemaining > 0)
{ {
// The destination byte array might not be large enough so multiple writes are sometimes required // The destination byte array might not be large enough so multiple writes are sometimes required
var destination = GetBuffer(); var destination = GetBuffer();
@ -154,20 +156,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var bytesUsed = 0; var bytesUsed = 0;
var charsUsed = 0; var charsUsed = 0;
#if NETCOREAPP2_1 #if NETCOREAPP2_1
_encoder.Convert(buffer.AsSpan(currentIndex, charsRemaining), destination, false, out charsUsed, out bytesUsed, out _); _encoder.Convert(buffer, destination, false, out charsUsed, out bytesUsed, out _);
#else #else
unsafe unsafe
{ {
fixed (char* sourceChars = &buffer[currentIndex]) fixed (char* sourceChars = &MemoryMarshal.GetReference(buffer))
fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination)) fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination))
{ {
_encoder.Convert(sourceChars, charsRemaining, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _); _encoder.Convert(sourceChars, buffer.Length, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _);
} }
} }
#endif #endif
charsRemaining -= charsUsed; buffer = buffer.Slice(charsUsed);
currentIndex += charsUsed;
_memoryUsed += bytesUsed; _memoryUsed += bytesUsed;
} }
} }

View File

@ -41,9 +41,11 @@ namespace Microsoft.AspNetCore.SignalR
static HubConnectionContext() static HubConnectionContext()
{ {
var memoryBufferWriter = new MemoryBufferWriter(); using (var memoryBufferWriter = new MemoryBufferWriter())
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter); {
_successHandshakeResponseData = memoryBufferWriter.ToArray(); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
_successHandshakeResponseData = memoryBufferWriter.ToArray();
}
} }
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)

View File

@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return destination; return destination;
} }
public void WriteMessage(HubMessage message, Stream output) public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{ {
using (var stream = new LimitArrayPoolWriteStream()) using (var stream = new LimitArrayPoolWriteStream())
{ {
@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
// Write length then message to output // Write length then message to output
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output); BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
output.Write(buffer.Array, buffer.Offset, buffer.Count); output.Write(buffer);
} }
} }

View File

@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(500); hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(500);
await hubConnection.StartAsync().OrTimeout(); await hubConnection.StartAsync().OrTimeout();
// Start an invocation (but we won't complete it) // Start an invocation (but we won't complete it)
var invokeTask = hubConnection.InvokeAsync("Method").OrTimeout(); var invokeTask = hubConnection.InvokeAsync("Method").OrTimeout();
@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
throw new InvalidOperationException("No Parsed Message provided"); throw new InvalidOperationException("No Parsed Message provided");
} }
public void WriteMessage(HubMessage message, Stream output) public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{ {
if (_error != null) if (_error != null)
{ {

View File

@ -75,9 +75,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{ {
var s = await ReadSentTextMessageAsync(); var s = await ReadSentTextMessageAsync();
var output = new MemoryBufferWriter(); using (var output = new MemoryBufferWriter())
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); {
await Application.Output.WriteAsync(output.ToArray()); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
await Application.Output.WriteAsync(output.ToArray());
}
return s; return s;
} }

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Xunit; using Xunit;
@ -31,20 +32,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Encoding.UTF8.GetBytes("Hello,\r\nWorld!") Encoding.UTF8.GetBytes("Hello,\r\nWorld!")
}; };
var output = new MemoryStream(); // Use small chunks to test Advance/Enlarge and partial payload writing using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing
foreach (var message in messages)
{ {
BinaryMessageFormatter.WriteLengthPrefix(message.Length, output); foreach (var message in messages)
output.Write(message, 0, message.Length); {
} BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
writer.Write(message);
}
Assert.Equal(expectedEncoding, output.ToArray()); Assert.Equal(expectedEncoding, writer.ToArray());
}
} }
[Theory] [Theory]
[InlineData(0, new byte[] { 0x00 }, new byte[0])] [InlineData(new byte[] { 0x00 }, new byte[0])]
[InlineData(0, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] [InlineData(new byte[]
[InlineData(0, new byte[]
{ {
0x80, 0x01, // Size - 128 0x80, 0x01, // Size - 128
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
@ -68,53 +70,45 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f
})] })]
[InlineData(4, new byte[] { 0x00 }, new byte[0])] [InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
[InlineData(4, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] public void WriteBinaryMessage(byte[] encoded, byte[] payload)
public void WriteBinaryMessage(int offset, byte[] encoded, byte[] payload)
{ {
var output = new MemoryStream(); using (var writer = new MemoryBufferWriter())
if (offset > 0)
{ {
output.Seek(offset, SeekOrigin.Begin);
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
writer.Write(payload);
Assert.Equal(encoded, writer.ToArray());
} }
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, output);
output.Write(payload, 0, payload.Length);
Assert.Equal(encoded, output.ToArray().Skip(offset));
} }
[Theory] [Theory]
[InlineData(0, new byte[] { 0x00 }, "")] [InlineData(new byte[] { 0x00 }, "")]
[InlineData(0, new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")] [InlineData(new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")]
[InlineData(0, new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")] [InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")]
[InlineData(4, new byte[] { 0x00 }, "")] public void WriteTextMessage(byte[] encoded, string payload)
public void WriteTextMessage(int offset, byte[] encoded, string payload)
{ {
var message = Encoding.UTF8.GetBytes(payload); var message = Encoding.UTF8.GetBytes(payload);
var output = new MemoryStream(); using (var writer = new MemoryBufferWriter())
if (offset > 0)
{ {
output.Seek(offset, SeekOrigin.Begin);
BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
writer.Write(message);
Assert.Equal(encoded, writer.ToArray());
} }
BinaryMessageFormatter.WriteLengthPrefix(message.Length, output);
output.Write(message, 0, message.Length);
Assert.Equal(encoded, output.ToArray().Skip(offset));
} }
[Theory] [Theory]
[MemberData(nameof(RandomPayloads))] [MemberData(nameof(RandomPayloads))]
public void RoundTrippingTest(byte[] payload) public void RoundTrippingTest(byte[] payload)
{ {
using (var ms = new MemoryStream()) using (var writer = new MemoryBufferWriter())
{ {
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, ms); BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
ms.Write(payload, 0, payload.Length); writer.Write(payload);
var buffer = new ReadOnlySequence<byte>(ms.ToArray()); var buffer = new ReadOnlySequence<byte>(writer.ToArray());
Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped)); Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped));
Assert.Equal(payload, roundtripped.ToArray()); Assert.Equal(payload, roundtripped.ToArray());
} }

View File

@ -15,6 +15,7 @@ using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{ {
using Microsoft.AspNetCore.SignalR.Internal;
using static HubMessageHelpers; using static HubMessageHelpers;
public class JsonHubProtocolTests public class JsonHubProtocolTests
@ -108,10 +109,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var protocol = new JsonHubProtocol(Options.Create(protocolOptions)); var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
using (var ms = new MemoryStream()) using (var writer = new MemoryBufferWriter())
{ {
protocol.WriteMessage(message, ms); protocol.WriteMessage(message, writer);
var json = Encoding.UTF8.GetString(ms.ToArray()); var json = Encoding.UTF8.GetString(writer.ToArray());
Assert.Equal(expectedOutput, json); Assert.Equal(expectedOutput, json);
} }

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
using System.Text; using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using MsgPack; using MsgPack;
@ -444,10 +445,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Frame(byte[] input) private static byte[] Frame(byte[] input)
{ {
using (var stream = new MemoryStream()) using (var stream = new MemoryBufferWriter())
{ {
BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream); BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream);
stream.Write(input, 0, input.Length); stream.Write(input);
return stream.ToArray(); return stream.ToArray();
} }
} }
@ -486,11 +487,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Write(HubMessage message) private static byte[] Write(HubMessage message)
{ {
var protocol = new MessagePackHubProtocol(); var protocol = new MessagePackHubProtocol();
using (var stream = new MemoryStream()) using (var writer = new MemoryBufferWriter())
{ {
protocol.WriteMessage(message, stream); protocol.WriteMessage(message, writer);
stream.Flush(); return writer.ToArray();
return stream.ToArray();
} }
} }

View File

@ -4,9 +4,7 @@
using System; using System;
using System.Buffers; using System.Buffers;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using System.Text; using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit; using Xunit;
@ -17,8 +15,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void WriteChar_Unicode() public void WriteChar_Unicode()
{ {
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
textWriter.Write('['); textWriter.Write('[');
@ -57,8 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void WriteChar_UnicodeLastChar() public void WriteChar_UnicodeLastChar()
{ {
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
using (Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter()) using (var textWriter = new Utf8BufferTextWriter())
{ {
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
@ -73,8 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void WriteChar_UnicodeAndRunOutOfBufferSpace() public void WriteChar_UnicodeAndRunOutOfBufferSpace()
{ {
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
textWriter.Write('['); textWriter.Write('[');
@ -124,8 +122,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
char[] chars = fourCircles.ToCharArray(); char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
textWriter.Write(chars, 0, 1); textWriter.Write(chars, 0, 1);
@ -153,8 +151,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
char[] chars = fourCircles.ToCharArray(); char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
textWriter.Write(chars[0]); textWriter.Write(chars[0]);
@ -178,8 +176,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void WriteCharArray_NonZeroStart() public void WriteCharArray_NonZeroStart()
{ {
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); var bufferWriter = new TestMemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray(); char[] chars = "Hello world".ToCharArray();
@ -194,8 +192,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void WriteCharArray_AcrossMultipleBuffers() public void WriteCharArray_AcrossMultipleBuffers()
{ {
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(2); var bufferWriter = new TestMemoryBufferWriter(2);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter); textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray(); char[] chars = "Hello world".ToCharArray();
@ -222,7 +220,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact] [Fact]
public void GetAndReturnCachedBufferTextWriter() public void GetAndReturnCachedBufferTextWriter()
{ {
MemoryBufferWriter bufferWriter1 = new MemoryBufferWriter(); var bufferWriter1 = new TestMemoryBufferWriter();
var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1); var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1);
textWriter1.Write("Hello"); textWriter1.Write("Hello");
@ -231,7 +229,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray())); Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray()));
MemoryBufferWriter bufferWriter2 = new MemoryBufferWriter(); TestMemoryBufferWriter bufferWriter2 = new TestMemoryBufferWriter();
var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2); var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2);
textWriter2.Write("World"); textWriter2.Write("World");
@ -242,5 +240,74 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Same(textWriter1, textWriter2); Assert.Same(textWriter1, textWriter2);
} }
private sealed class TestMemoryBufferWriter : IBufferWriter<byte>
{
private readonly int _segmentSize;
internal List<Memory<byte>> Segments { get; }
internal int Position { get; private set; }
public TestMemoryBufferWriter(int segmentSize = 2048)
{
_segmentSize = segmentSize;
Segments = new List<Memory<byte>>();
}
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public void Advance(int count)
{
Position += count;
}
public Memory<byte> GetMemory(int sizeHint = 0)
{
// TODO: Use sizeHint
if (Segments.Count == 0 || Position == _segmentSize)
{
// TODO: Rent memory from a pool
Segments.Add(new Memory<byte>(new byte[_segmentSize]));
Position = 0;
}
return CurrentSegment.Slice(Position, CurrentSegment.Length - Position);
}
public Span<byte> GetSpan(int sizeHint = 0)
{
return GetMemory(sizeHint).Span;
}
public byte[] ToArray()
{
if (Segments.Count == 0)
{
return Array.Empty<byte>();
}
var totalLength = (Segments.Count - 1) * _segmentSize;
totalLength += Position;
var result = new byte[totalLength];
var totalWritten = 0;
// Copy full segments
for (int i = 0; i < Segments.Count - 1; i++)
{
Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize;
}
// Copy current incomplete segment
CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position));
return result;
}
}
} }
} }

View File

@ -66,9 +66,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{ {
if (sendHandshakeRequestMessage) if (sendHandshakeRequestMessage)
{ {
var memoryBufferWriter = new MemoryBufferWriter(); using (var memoryBufferWriter = new MemoryBufferWriter())
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); {
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
}
} }
var connection = handler.OnConnectedAsync(Connection); var connection = handler.OnConnectedAsync(Connection);