Use IBufferWriter in IHubProtocol (#1791)

This commit is contained in:
BrennanConroy 2018-03-30 17:30:08 -07:00 committed by GitHub
parent eb7dc14c39
commit 903a9ea6a0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
22 changed files with 267 additions and 184 deletions

View File

@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
return false;
}
public void WriteMessage(HubMessage message, Stream output)
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
}
}

View File

@ -30,9 +30,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
using (var memoryBufferWriter = new MemoryBufferWriter())
{
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
}
_pipe = new TestDuplexPipe();

View File

@ -24,12 +24,14 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
var ms = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms);
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(ms.ToArray()), false, false);
_pipe = new TestDuplexPipe();
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult));
using (var writer = new MemoryBufferWriter())
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
_pipe = new TestDuplexPipe();
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult));
}
_tcs = new TaskCompletionSource<ReadResult>();
_pipe.AddReadResult(new ValueTask<ReadResult>(_tcs.Task));

View File

@ -26,9 +26,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
var ms = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(ms.ToArray()), false, false);
using (var writer = new MemoryBufferWriter())
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
}
_pipe = new TestDuplexPipe();

View File

@ -2,6 +2,7 @@ using System;
using System.Buffers;
using System.IO;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
@ -23,19 +24,22 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
var buffer = new byte[MessageLength];
Random.NextBytes(buffer);
var output = new MemoryStream();
BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, output);
output.Write(buffer, 0, buffer.Length);
_binaryInput = output.ToArray();
using (var writer = new MemoryBufferWriter())
{
BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer);
writer.Write(buffer);
_binaryInput = writer.ToArray();
}
buffer = new byte[MessageLength];
Random.NextBytes(buffer);
output = new MemoryStream();
output.Write(buffer, 0, buffer.Length);
TextMessageFormatter.WriteRecordSeparator(output);
using (var writer = new MemoryBufferWriter())
{
writer.Write(buffer);
TextMessageFormatter.WriteRecordSeparator(writer);
_textInput = output.ToArray();
_textInput = writer.ToArray();
}
}
[Benchmark]

View File

@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.Internal
internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter)
{
var writer = new JsonTextWriter(textWriter);
writer.ArrayPool = JsonArrayPool<char>.Shared;
// Don't close the output, leave closing to the caller
writer.CloseOutput = false;

View File

@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
AssertConnectionValid();
_protocol.WriteMessage(hubMessage, _connectionState.OutputStream);
_protocol.WriteMessage(hubMessage, _connectionState.Connection.Transport.Output);
Log.SendingMessage(_logger, hubMessage);
@ -826,7 +826,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
public IConnection Connection { get; }
public Task ReceiveTask { get; set; }
public Exception CloseException { get; set; }
public PipeWriterStream OutputStream { get; }
public bool Stopping
{
@ -838,7 +837,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
_hubConnection = hubConnection;
Connection = connection;
OutputStream = new PipeWriterStream(Connection.Transport.Output);
}
public string GetNextId() => Interlocked.Increment(ref _nextId).ToString();

View File

@ -2,22 +2,19 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Buffers;
namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
{
public static class BinaryMessageFormatter
{
public static void WriteLengthPrefix(long length, Stream output)
public static void WriteLengthPrefix(long length, IBufferWriter<byte> output)
{
// This code writes length prefix of the message as a VarInt. Read the comment in
// the BinaryMessageParser.TryParseMessage for details.
#if NETCOREAPP2_1
Span<byte> lenBuffer = stackalloc byte[5];
#else
var lenBuffer = new byte[5];
#endif
var lenNumBytes = 0;
do
{
@ -32,11 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
}
while (length > 0);
#if NETCOREAPP2_1
output.Write(lenBuffer.Slice(0, lenNumBytes));
#else
output.Write(lenBuffer, 0, lenNumBytes);
#endif
}
}
}

View File

@ -4,24 +4,26 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public sealed class MemoryBufferWriter : IBufferWriter<byte>
public sealed class MemoryBufferWriter : IBufferWriter<byte>, IDisposable
{
private readonly int _segmentSize;
private int _bytesWritten;
internal List<Memory<byte>> Segments { get; }
internal List<byte[]> Segments { get; }
internal int Position { get; private set; }
public MemoryBufferWriter(int segmentSize = 2048)
{
_segmentSize = segmentSize;
Segments = new List<Memory<byte>>();
Segments = new List<byte[]>();
}
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public void Advance(int count)
{
_bytesWritten += count;
Position += count;
}
@ -31,8 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (Segments.Count == 0 || Position == _segmentSize)
{
// TODO: Rent memory from a pool
Segments.Add(new Memory<byte>(new byte[_segmentSize]));
Segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
Position = 0;
}
@ -51,17 +52,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return Array.Empty<byte>();
}
var totalLength = (Segments.Count - 1) * _segmentSize;
totalLength += Position;
var result = new byte[totalLength];
var result = new byte[_bytesWritten];
var totalWritten = 0;
// Copy full segments
for (int i = 0; i < Segments.Count - 1; i++)
{
Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize));
Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize;
}
@ -71,5 +69,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return result;
}
public void Dispose()
{
for (int i = 0; i < Segments.Count; i++)
{
ArrayPool<byte>.Shared.Return(Segments[i]);
}
Segments.Clear();
}
}
}

View File

@ -1,18 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static class HubProtocolExtensions
{
public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message)
{
using (var ms = new LimitArrayPoolWriteStream())
using (var writer = new MemoryBufferWriter())
{
hubProtocol.WriteMessage(message, ms);
return ms.ToArray();
hubProtocol.WriteMessage(message, writer);
return writer.ToArray();
}
}
}

View File

@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message);
void WriteMessage(HubMessage message, Stream output);
void WriteMessage(HubMessage message, IBufferWriter<byte> output);
bool IsVersionSupported(int version);
}

View File

@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return message != null;
}
public void WriteMessage(HubMessage message, Stream output)
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
WriteMessageCore(message, output);
TextMessageFormatter.WriteRecordSeparator(output);
@ -340,50 +340,58 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
throw new JsonReaderException("Unexpected end when reading message headers");
}
private void WriteMessageCore(HubMessage message, Stream stream)
private void WriteMessageCore(HubMessage message, IBufferWriter<byte> stream)
{
using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true)))
var textWriter = Utf8BufferTextWriter.Get(stream);
try
{
writer.WriteStartObject();
switch (message)
using (var writer = JsonUtils.CreateJsonTextWriter(textWriter))
{
case InvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.InvocationMessageType);
WriteHeaders(writer, m);
WriteInvocationMessage(m, writer);
break;
case StreamInvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType);
WriteHeaders(writer, m);
WriteStreamInvocationMessage(m, writer);
break;
case StreamItemMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType);
WriteHeaders(writer, m);
WriteStreamItemMessage(m, writer);
break;
case CompletionMessage m:
WriteMessageType(writer, HubProtocolConstants.CompletionMessageType);
WriteHeaders(writer, m);
WriteCompletionMessage(m, writer);
break;
case CancelInvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType);
WriteHeaders(writer, m);
WriteCancelInvocationMessage(m, writer);
break;
case PingMessage _:
WriteMessageType(writer, HubProtocolConstants.PingMessageType);
break;
case CloseMessage m:
WriteMessageType(writer, HubProtocolConstants.CloseMessageType);
WriteCloseMessage(m, writer);
break;
default:
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
writer.WriteStartObject();
switch (message)
{
case InvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.InvocationMessageType);
WriteHeaders(writer, m);
WriteInvocationMessage(m, writer);
break;
case StreamInvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType);
WriteHeaders(writer, m);
WriteStreamInvocationMessage(m, writer);
break;
case StreamItemMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType);
WriteHeaders(writer, m);
WriteStreamItemMessage(m, writer);
break;
case CompletionMessage m:
WriteMessageType(writer, HubProtocolConstants.CompletionMessageType);
WriteHeaders(writer, m);
WriteCompletionMessage(m, writer);
break;
case CancelInvocationMessage m:
WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType);
WriteHeaders(writer, m);
WriteCancelInvocationMessage(m, writer);
break;
case PingMessage _:
WriteMessageType(writer, HubProtocolConstants.PingMessageType);
break;
case CloseMessage m:
WriteMessageType(writer, HubProtocolConstants.CloseMessageType);
WriteCloseMessage(m, writer);
break;
default:
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
}
writer.WriteEndObject();
writer.Flush();
}
writer.WriteEndObject();
writer.Flush();
}
finally
{
Utf8BufferTextWriter.Return(textWriter);
}
}

View File

@ -5,7 +5,6 @@ using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
@ -77,12 +76,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public override void Write(char[] buffer, int index, int count)
{
WriteInternal(buffer, index, count);
WriteInternal(buffer.AsSpan(index, count));
}
public override void Write(char[] buffer)
{
WriteInternal(buffer, 0, buffer.Length);
WriteInternal(buffer);
}
public override void Write(char value)
@ -120,6 +119,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
public override void Write(string value)
{
WriteInternal(value.AsSpan());
}
private Span<byte> GetBuffer()
{
EnsureBuffer();
@ -142,11 +146,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
private void WriteInternal(char[] buffer, int index, int count)
private void WriteInternal(ReadOnlySpan<char> buffer)
{
var currentIndex = index;
var charsRemaining = count;
while (charsRemaining > 0)
while (buffer.Length > 0)
{
// The destination byte array might not be large enough so multiple writes are sometimes required
var destination = GetBuffer();
@ -154,20 +156,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var bytesUsed = 0;
var charsUsed = 0;
#if NETCOREAPP2_1
_encoder.Convert(buffer.AsSpan(currentIndex, charsRemaining), destination, false, out charsUsed, out bytesUsed, out _);
_encoder.Convert(buffer, destination, false, out charsUsed, out bytesUsed, out _);
#else
unsafe
{
fixed (char* sourceChars = &buffer[currentIndex])
fixed (char* sourceChars = &MemoryMarshal.GetReference(buffer))
fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination))
{
_encoder.Convert(sourceChars, charsRemaining, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _);
_encoder.Convert(sourceChars, buffer.Length, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _);
}
}
#endif
charsRemaining -= charsUsed;
currentIndex += charsUsed;
buffer = buffer.Slice(charsUsed);
_memoryUsed += bytesUsed;
}
}

View File

@ -41,9 +41,11 @@ namespace Microsoft.AspNetCore.SignalR
static HubConnectionContext()
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
_successHandshakeResponseData = memoryBufferWriter.ToArray();
using (var memoryBufferWriter = new MemoryBufferWriter())
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
_successHandshakeResponseData = memoryBufferWriter.ToArray();
}
}
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)

View File

@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return destination;
}
public void WriteMessage(HubMessage message, Stream output)
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
using (var stream = new LimitArrayPoolWriteStream())
{
@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
// Write length then message to output
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
output.Write(buffer.Array, buffer.Offset, buffer.Count);
output.Write(buffer);
}
}

View File

@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(500);
await hubConnection.StartAsync().OrTimeout();
// Start an invocation (but we won't complete it)
var invokeTask = hubConnection.InvokeAsync("Method").OrTimeout();
@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
throw new InvalidOperationException("No Parsed Message provided");
}
public void WriteMessage(HubMessage message, Stream output)
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
if (_error != null)
{

View File

@ -75,9 +75,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var s = await ReadSentTextMessageAsync();
var output = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
await Application.Output.WriteAsync(output.ToArray());
using (var output = new MemoryBufferWriter())
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
await Application.Output.WriteAsync(output.ToArray());
}
return s;
}

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Xunit;
@ -31,20 +32,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Encoding.UTF8.GetBytes("Hello,\r\nWorld!")
};
var output = new MemoryStream(); // Use small chunks to test Advance/Enlarge and partial payload writing
foreach (var message in messages)
using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing
{
BinaryMessageFormatter.WriteLengthPrefix(message.Length, output);
output.Write(message, 0, message.Length);
}
foreach (var message in messages)
{
BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
writer.Write(message);
}
Assert.Equal(expectedEncoding, output.ToArray());
Assert.Equal(expectedEncoding, writer.ToArray());
}
}
[Theory]
[InlineData(0, new byte[] { 0x00 }, new byte[0])]
[InlineData(0, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
[InlineData(0, new byte[]
[InlineData(new byte[] { 0x00 }, new byte[0])]
[InlineData(new byte[]
{
0x80, 0x01, // Size - 128
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f,
@ -68,53 +70,45 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f
})]
[InlineData(4, new byte[] { 0x00 }, new byte[0])]
[InlineData(4, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
public void WriteBinaryMessage(int offset, byte[] encoded, byte[] payload)
[InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
public void WriteBinaryMessage(byte[] encoded, byte[] payload)
{
var output = new MemoryStream();
if (offset > 0)
using (var writer = new MemoryBufferWriter())
{
output.Seek(offset, SeekOrigin.Begin);
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
writer.Write(payload);
Assert.Equal(encoded, writer.ToArray());
}
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, output);
output.Write(payload, 0, payload.Length);
Assert.Equal(encoded, output.ToArray().Skip(offset));
}
[Theory]
[InlineData(0, new byte[] { 0x00 }, "")]
[InlineData(0, new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")]
[InlineData(0, new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")]
[InlineData(4, new byte[] { 0x00 }, "")]
public void WriteTextMessage(int offset, byte[] encoded, string payload)
[InlineData(new byte[] { 0x00 }, "")]
[InlineData(new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")]
[InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")]
public void WriteTextMessage(byte[] encoded, string payload)
{
var message = Encoding.UTF8.GetBytes(payload);
var output = new MemoryStream();
if (offset > 0)
using (var writer = new MemoryBufferWriter())
{
output.Seek(offset, SeekOrigin.Begin);
BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
writer.Write(message);
Assert.Equal(encoded, writer.ToArray());
}
BinaryMessageFormatter.WriteLengthPrefix(message.Length, output);
output.Write(message, 0, message.Length);
Assert.Equal(encoded, output.ToArray().Skip(offset));
}
[Theory]
[MemberData(nameof(RandomPayloads))]
public void RoundTrippingTest(byte[] payload)
{
using (var ms = new MemoryStream())
using (var writer = new MemoryBufferWriter())
{
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, ms);
ms.Write(payload, 0, payload.Length);
var buffer = new ReadOnlySequence<byte>(ms.ToArray());
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
writer.Write(payload);
var buffer = new ReadOnlySequence<byte>(writer.ToArray());
Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped));
Assert.Equal(payload, roundtripped.ToArray());
}

View File

@ -15,6 +15,7 @@ using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
using Microsoft.AspNetCore.SignalR.Internal;
using static HubMessageHelpers;
public class JsonHubProtocolTests
@ -108,10 +109,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
using (var ms = new MemoryStream())
using (var writer = new MemoryBufferWriter())
{
protocol.WriteMessage(message, ms);
var json = Encoding.UTF8.GetString(ms.ToArray());
protocol.WriteMessage(message, writer);
var json = Encoding.UTF8.GetString(writer.ToArray());
Assert.Equal(expectedOutput, json);
}

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using MsgPack;
@ -444,10 +445,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Frame(byte[] input)
{
using (var stream = new MemoryStream())
using (var stream = new MemoryBufferWriter())
{
BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream);
stream.Write(input, 0, input.Length);
stream.Write(input);
return stream.ToArray();
}
}
@ -486,11 +487,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Write(HubMessage message)
{
var protocol = new MessagePackHubProtocol();
using (var stream = new MemoryStream())
using (var writer = new MemoryBufferWriter())
{
protocol.WriteMessage(message, stream);
stream.Flush();
return stream.ToArray();
protocol.WriteMessage(message, writer);
return writer.ToArray();
}
}

View File

@ -4,9 +4,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit;
@ -17,8 +15,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void WriteChar_Unicode()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(4096);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write('[');
@ -57,8 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void WriteChar_UnicodeLastChar()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
using (Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter())
var bufferWriter = new TestMemoryBufferWriter(4096);
using (var textWriter = new Utf8BufferTextWriter())
{
textWriter.SetWriter(bufferWriter);
@ -73,8 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void WriteChar_UnicodeAndRunOutOfBufferSpace()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(4096);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write('[');
@ -124,8 +122,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(4096);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write(chars, 0, 1);
@ -153,8 +151,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(4096);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write(chars[0]);
@ -178,8 +176,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void WriteCharArray_NonZeroStart()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(4096);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray();
@ -194,8 +192,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void WriteCharArray_AcrossMultipleBuffers()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(2);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
var bufferWriter = new TestMemoryBufferWriter(2);
var textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray();
@ -222,7 +220,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Fact]
public void GetAndReturnCachedBufferTextWriter()
{
MemoryBufferWriter bufferWriter1 = new MemoryBufferWriter();
var bufferWriter1 = new TestMemoryBufferWriter();
var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1);
textWriter1.Write("Hello");
@ -231,7 +229,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray()));
MemoryBufferWriter bufferWriter2 = new MemoryBufferWriter();
TestMemoryBufferWriter bufferWriter2 = new TestMemoryBufferWriter();
var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2);
textWriter2.Write("World");
@ -242,5 +240,74 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
Assert.Same(textWriter1, textWriter2);
}
private sealed class TestMemoryBufferWriter : IBufferWriter<byte>
{
private readonly int _segmentSize;
internal List<Memory<byte>> Segments { get; }
internal int Position { get; private set; }
public TestMemoryBufferWriter(int segmentSize = 2048)
{
_segmentSize = segmentSize;
Segments = new List<Memory<byte>>();
}
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public void Advance(int count)
{
Position += count;
}
public Memory<byte> GetMemory(int sizeHint = 0)
{
// TODO: Use sizeHint
if (Segments.Count == 0 || Position == _segmentSize)
{
// TODO: Rent memory from a pool
Segments.Add(new Memory<byte>(new byte[_segmentSize]));
Position = 0;
}
return CurrentSegment.Slice(Position, CurrentSegment.Length - Position);
}
public Span<byte> GetSpan(int sizeHint = 0)
{
return GetMemory(sizeHint).Span;
}
public byte[] ToArray()
{
if (Segments.Count == 0)
{
return Array.Empty<byte>();
}
var totalLength = (Segments.Count - 1) * _segmentSize;
totalLength += Position;
var result = new byte[totalLength];
var totalWritten = 0;
// Copy full segments
for (int i = 0; i < Segments.Count - 1; i++)
{
Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize;
}
// Copy current incomplete segment
CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position));
return result;
}
}
}
}

View File

@ -66,9 +66,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
if (sendHandshakeRequestMessage)
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
using (var memoryBufferWriter = new MemoryBufferWriter())
{
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
}
}
var connection = handler.OnConnectedAsync(Connection);