From 8c84518ecc6d93bacd7b8160572e2cab8b13b741 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Thu, 29 Mar 2018 11:03:40 +1300 Subject: [PATCH] Message writing optimization (#1683) --- .../BroadcastBenchmark.cs | 4 +- .../DefaultHubActivatorBenchmark.cs | 5 +- .../HubConnectionBenchmark.cs | 88 +++++++ .../HubConnectionContextBenchmark.cs | 93 +++++++ ....AspNetCore.SignalR.Microbenchmarks.csproj | 1 + .../Shared/TestDuplexPipe.cs | 27 ++ .../Shared/TestPipeReader.cs | 62 +++++ .../Shared/TestPipeWriter.cs | 50 ++++ .../HubConnection.cs | 46 ++-- .../Formatters/TextMessageFormatter.cs | 10 +- .../Internal/MemoryBufferWriter.cs | 75 ++++++ .../Internal/Protocol/HandshakeProtocol.cs | 72 +++-- .../Protocol/HubProtocolExtensions.cs | 2 +- .../Protocol/LimitArrayPoolWriteStream.cs | 163 ++++++++++++ .../Internal/Protocol/Utf8BufferTextReader.cs | 2 +- .../Internal/Protocol/Utf8BufferTextWriter.cs | 195 ++++++++++++++ .../Properties/AssemblyInfo.cs | 3 +- .../HubConnectionContext.cs | 24 +- .../Protocol/MessagePackHubProtocol.cs | 23 +- .../RedisHubLifetimeManager.cs | 5 +- .../HttpConnectionContext.cs | 4 +- .../TestConnection.cs | 3 +- .../Protocol/Utf8BufferTextWriterTests.cs | 246 ++++++++++++++++++ .../TestClient.cs | 8 +- 24 files changed, 1129 insertions(+), 82 deletions(-) create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionBenchmark.cs create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeReader.cs create mode 100644 benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs create mode 100644 src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 61f570f4eb..4e6f1afaf4 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using System.IO.Pipelines; using System.Threading; @@ -29,7 +32,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { _hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); - IHubProtocol protocol; if (Protocol == "json") diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubActivatorBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubActivatorBenchmark.cs index f20abf1c72..84a9b10b7e 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubActivatorBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubActivatorBenchmark.cs @@ -1,4 +1,7 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Collections.Generic; using System.Text; using BenchmarkDotNet.Attributes; diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionBenchmark.cs new file mode 100644 index 0000000000..730142aef7 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionBenchmark.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Connections.Features; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared; +using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks +{ + public class HubConnectionBenchmark + { + private HubConnection _hubConnection; + private TestDuplexPipe _pipe; + private ReadResult _handshakeResponseResult; + + [GlobalSetup] + public void GlobalSetup() + { + var ms = new MemoryBufferWriter(); + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms); + _handshakeResponseResult = new ReadResult(new ReadOnlySequence(ms.ToArray()), false, false); + + _pipe = new TestDuplexPipe(); + + var connection = new TestConnection(); + // prevents keep alive time being activated + connection.Features.Set(new TestConnectionInherentKeepAliveFeature()); + connection.Transport = _pipe; + + _hubConnection = new HubConnection(() => connection, new JsonHubProtocol(), new NullLoggerFactory()); + } + + private void AddHandshakeResponse() + { + _pipe.AddReadResult(_handshakeResponseResult); + } + + [Benchmark] + public async Task StartAsync() + { + AddHandshakeResponse(); + + await _hubConnection.StartAsync(); + await _hubConnection.StopAsync(); + } + } + + public class TestConnectionInherentKeepAliveFeature : IConnectionInherentKeepAliveFeature + { + public TimeSpan KeepAliveInterval { get; } = TimeSpan.Zero; + } + + public class TestConnection : IConnection + { + public Task StartAsync() + { + throw new NotImplementedException(); + } + + public Task StartAsync(TransferFormat transferFormat) + { + return Task.CompletedTask; + } + + public Task DisposeAsync() + { + return Task.CompletedTask; + } + + public IDuplexPipe Transport { get; set; } + + public IFeatureCollection Features { get; } = new FeatureCollection(); + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs new file mode 100644 index 0000000000..0108e3557c --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs @@ -0,0 +1,93 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.SignalR.Core; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Formatters; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared; +using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks +{ + public class HubConnectionContextBenchmark + { + private HubConnectionContext _hubConnectionContext; + private TestHubProtocolResolver _successHubProtocolResolver; + private TestHubProtocolResolver _failureHubProtocolResolver; + private TestUserIdProvider _userIdProvider; + private List _supportedProtocols; + private TestDuplexPipe _pipe; + private ReadResult _handshakeResponseResult; + + [GlobalSetup] + public void GlobalSetup() + { + var memoryBufferWriter = new MemoryBufferWriter(); + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); + _handshakeResponseResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); + + _pipe = new TestDuplexPipe(); + + var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _pipe, _pipe); + _hubConnectionContext = new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance); + + _successHubProtocolResolver = new TestHubProtocolResolver(new JsonHubProtocol()); + _failureHubProtocolResolver = new TestHubProtocolResolver(null); + _userIdProvider = new TestUserIdProvider(); + _supportedProtocols = new List {"json"}; + } + + [Benchmark] + public async Task SuccessHandshakeAsync() + { + _pipe.AddReadResult(_handshakeResponseResult); + + await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _successHubProtocolResolver, _userIdProvider); + } + + [Benchmark] + public async Task ErrorHandshakeAsync() + { + _pipe.AddReadResult(_handshakeResponseResult); + + await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _failureHubProtocolResolver, _userIdProvider); + } + } + + public class TestUserIdProvider : IUserIdProvider + { + public string GetUserId(HubConnectionContext connection) + { + return "UserId!"; + } + } + + public class TestHubProtocolResolver : IHubProtocolResolver + { + private readonly IHubProtocol _instance; + + public TestHubProtocolResolver(IHubProtocol instance) + { + _instance = instance; + } + + public IHubProtocol GetProtocol(string protocolName, IList supportedProtocols) + { + return _instance; + } + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj index 33d46900d9..a742b96718 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj @@ -12,6 +12,7 @@ + diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs new file mode 100644 index 0000000000..805bd9c170 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared +{ + public class TestDuplexPipe : IDuplexPipe + { + private readonly TestPipeReader _input; + + public PipeReader Input => _input; + + public PipeWriter Output { get; } + + public TestDuplexPipe() + { + _input = new TestPipeReader(); + Output = new TestPipeWriter(); + } + + public void AddReadResult(ReadResult readResult) + { + _input.ReadResults.Add(readResult); + } + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeReader.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeReader.cs new file mode 100644 index 0000000000..eaa38767b9 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeReader.cs @@ -0,0 +1,62 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared +{ + public class TestPipeReader : PipeReader + { + public List ReadResults { get; } + + public TestPipeReader() + { + ReadResults = new List(); + } + + public override void AdvanceTo(SequencePosition consumed) + { + } + + public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) + { + } + + public override void CancelPendingRead() + { + throw new NotImplementedException(); + } + + public override void Complete(Exception exception = null) + { + throw new NotImplementedException(); + } + + public override void OnWriterCompleted(Action callback, object state) + { + throw new NotImplementedException(); + } + + public override ValueTask ReadAsync(CancellationToken cancellationToken = new CancellationToken()) + { + if (ReadResults.Count == 0) + { + return new ValueTask(new ReadResult(default, false, true)); + } + + var result = ReadResults[0]; + ReadResults.RemoveAt(0); + + return new ValueTask(result); + } + + public override bool TryRead(out ReadResult result) + { + throw new NotImplementedException(); + } + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs new file mode 100644 index 0000000000..6bc98def01 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs @@ -0,0 +1,50 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared +{ + public class TestPipeWriter : PipeWriter + { + // huge buffer that should be large enough for writing any content + private readonly byte[] _buffer = new byte[10000]; + + public override void Advance(int bytes) + { + } + + public override Memory GetMemory(int sizeHint = 0) + { + return _buffer; + } + + public override Span GetSpan(int sizeHint = 0) + { + return _buffer; + } + + public override void OnReaderCompleted(Action callback, object state) + { + throw new NotImplementedException(); + } + + public override void CancelPendingFlush() + { + throw new NotImplementedException(); + } + + public override void Complete(Exception exception = null) + { + throw new NotImplementedException(); + } + + public override ValueTask FlushAsync(CancellationToken cancellationToken = new CancellationToken()) + { + return default; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index a0d4c23f2e..def78d57d4 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -496,17 +496,17 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task HandshakeAsync() { // Send the Handshake request - using (var memoryStream = new MemoryStream()) - { - Log.SendingHubHandshake(_logger); - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream); - var result = await WriteAsync(memoryStream.ToArray(), CancellationToken.None); + Log.SendingHubHandshake(_logger); - if (result.IsCompleted) - { - // The other side disconnected - throw new InvalidOperationException("The server disconnected before the handshake was completed"); - } + var handshakeRequest = new HandshakeRequestMessage(_protocol.Name, _protocol.Version); + HandshakeProtocol.WriteRequestMessage(handshakeRequest, _connectionState.Connection.Transport.Output); + + var sendHandshakeResult = await _connectionState.Connection.Transport.Output.FlushAsync(CancellationToken.None); + + if (sendHandshakeResult.IsCompleted) + { + // The other side disconnected + throw new InvalidOperationException("The server disconnected before the handshake was completed"); } try @@ -667,18 +667,24 @@ namespace Microsoft.AspNetCore.SignalR.Client private void RunClosedEvent(Exception closeException) { - _ = Task.Run(() => + var closed = Closed; + + // There is no need to start a new task if there is no Closed event registered + if (closed != null) { - try + _ = Task.Run(() => { - Log.InvokingClosedEventHandler(_logger); - Closed?.Invoke(closeException); - } - catch (Exception ex) - { - Log.ErrorDuringClosedEvent(_logger, ex); - } - }); + try + { + Log.InvokingClosedEventHandler(_logger); + closed.Invoke(closeException); + } + catch (Exception ex) + { + Log.ErrorDuringClosedEvent(_logger, ex); + } + }); + } } private void ResetTimeoutTimer(Timer timeoutTimer) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs index 50c4cebcc7..223fd59d6a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageFormatter.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Buffers; using System.IO; namespace Microsoft.AspNetCore.SignalR.Internal.Formatters @@ -9,11 +10,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters { // This record separator is supposed to be used only for JSON payloads where 0x1e character // will not occur (is not a valid character) and therefore it is safe to not escape it - internal static readonly byte RecordSeparator = 0x1e; + public static readonly byte RecordSeparator = 0x1e; public static void WriteRecordSeparator(Stream output) { output.WriteByte(RecordSeparator); } + + public static void WriteRecordSeparator(IBufferWriter output) + { + var buffer = output.GetSpan(1); + buffer[0] = RecordSeparator; + output.Advance(1); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs new file mode 100644 index 0000000000..db667efd67 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs @@ -0,0 +1,75 @@ +using System; +using System.Buffers; +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public sealed class MemoryBufferWriter : IBufferWriter + { + private readonly int _segmentSize; + + internal List> Segments { get; } + internal int Position { get; private set; } + + public MemoryBufferWriter(int segmentSize = 2048) + { + _segmentSize = segmentSize; + + Segments = new List>(); + } + + public Memory CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; + + public void Advance(int count) + { + Position += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + // TODO: Use sizeHint + + if (Segments.Count == 0 || Position == _segmentSize) + { + // TODO: Rent memory from a pool + Segments.Add(new Memory(new byte[_segmentSize])); + Position = 0; + } + + return CurrentSegment.Slice(Position, CurrentSegment.Length - Position); + } + + public Span GetSpan(int sizeHint = 0) + { + return GetMemory(sizeHint).Span; + } + + public byte[] ToArray() + { + if (Segments.Count == 0) + { + return Array.Empty(); + } + + var totalLength = (Segments.Count - 1) * _segmentSize; + totalLength += Position; + + var result = new byte[totalLength]; + + var totalWritten = 0; + + // Copy full segments + for (int i = 0; i < Segments.Count - 1; i++) + { + Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize)); + + totalWritten += _segmentSize; + } + + // Copy current incomplete segment + CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position)); + + return result; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs index a2c662775c..cdd61a3aec 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs @@ -13,47 +13,67 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public static class HandshakeProtocol { - private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); - private const string ProtocolPropertyName = "protocol"; private const string ProtocolVersionName = "version"; private const string ErrorPropertyName = "error"; private const string TypePropertyName = "type"; - public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, Stream output) + public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, IBufferWriter output) { - using (var writer = CreateJsonTextWriter(output)) + var textWriter = Utf8BufferTextWriter.Get(output); + try { - writer.WriteStartObject(); - writer.WritePropertyName(ProtocolPropertyName); - writer.WriteValue(requestMessage.Protocol); - writer.WritePropertyName(ProtocolVersionName); - writer.WriteValue(requestMessage.Version); - writer.WriteEndObject(); - } - - TextMessageFormatter.WriteRecordSeparator(output); - } - - public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, Stream output) - { - using (var writer = CreateJsonTextWriter(output)) - { - writer.WriteStartObject(); - if (!string.IsNullOrEmpty(responseMessage.Error)) + using (var writer = CreateJsonTextWriter(textWriter)) { - writer.WritePropertyName(ErrorPropertyName); - writer.WriteValue(responseMessage.Error); + writer.WriteStartObject(); + writer.WritePropertyName(ProtocolPropertyName); + writer.WriteValue(requestMessage.Protocol); + writer.WritePropertyName(ProtocolVersionName); + writer.WriteValue(requestMessage.Version); + writer.WriteEndObject(); + writer.Flush(); } - writer.WriteEndObject(); + } + finally + { + Utf8BufferTextWriter.Return(textWriter); } TextMessageFormatter.WriteRecordSeparator(output); } - private static JsonTextWriter CreateJsonTextWriter(Stream output) + public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, IBufferWriter output) { - return new JsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true)); + var textWriter = Utf8BufferTextWriter.Get(output); + try + { + using (var writer = CreateJsonTextWriter(textWriter)) + { + writer.WriteStartObject(); + if (!string.IsNullOrEmpty(responseMessage.Error)) + { + writer.WritePropertyName(ErrorPropertyName); + writer.WriteValue(responseMessage.Error); + } + + writer.WriteEndObject(); + writer.Flush(); + } + } + finally + { + Utf8BufferTextWriter.Return(textWriter); + } + + TextMessageFormatter.WriteRecordSeparator(output); + } + + private static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter) + { + var writer = new JsonTextWriter(textWriter); + writer.CloseOutput = false; + + return writer; } public static bool TryParseResponseMessage(ref ReadOnlySequence buffer, out HandshakeResponseMessage responseMessage) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs index 7710af3a80..73ae855889 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs @@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message) { - using (var ms = new MemoryStream()) + using (var ms = new LimitArrayPoolWriteStream()) { hubProtocol.WriteMessage(message, ms); return ms.ToArray(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs new file mode 100644 index 0000000000..7a114f0ef6 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/LimitArrayPoolWriteStream.cs @@ -0,0 +1,163 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public sealed class LimitArrayPoolWriteStream : Stream + { + private const int MaxByteArrayLength = 0x7FFFFFC7; + private const int InitialLength = 256; + + private readonly int _maxBufferSize; + private byte[] _buffer; + private int _length; + + public LimitArrayPoolWriteStream() : this(MaxByteArrayLength) { } + + public LimitArrayPoolWriteStream(int maxBufferSize) : this(maxBufferSize, InitialLength) { } + + public LimitArrayPoolWriteStream(int maxBufferSize, long capacity) + { + if (capacity < InitialLength) + { + capacity = InitialLength; + } + else if (capacity > maxBufferSize) + { + throw CreateOverCapacityException(maxBufferSize); + } + + _maxBufferSize = maxBufferSize; + _buffer = ArrayPool.Shared.Rent((int)capacity); + } + + protected override void Dispose(bool disposing) + { + if (_buffer != null) + { + ArrayPool.Shared.Return(_buffer); + _buffer = null; + } + + base.Dispose(disposing); + } + + public ArraySegment GetBuffer() => new ArraySegment(_buffer, 0, _length); + + public byte[] ToArray() + { + var arr = new byte[_length]; + Buffer.BlockCopy(_buffer, 0, arr, 0, _length); + return arr; + } + + private void EnsureCapacity(int value) + { + if ((uint)value > (uint)_maxBufferSize) // value cast handles overflow to negative as well + { + throw CreateOverCapacityException(_maxBufferSize); + } + else if (value > _buffer.Length) + { + Grow(value); + } + } + + private void Grow(int value) + { + Debug.Assert(value > _buffer.Length); + + // Extract the current buffer to be replaced. + byte[] currentBuffer = _buffer; + _buffer = null; + + // Determine the capacity to request for the new buffer. It should be + // at least twice as long as the current one, if not more if the requested + // value is more than that. If the new value would put it longer than the max + // allowed byte array, than shrink to that (and if the required length is actually + // longer than that, we'll let the runtime throw). + uint twiceLength = 2 * (uint)currentBuffer.Length; + int newCapacity = twiceLength > MaxByteArrayLength ? + (value > MaxByteArrayLength ? value : MaxByteArrayLength) : + Math.Max(value, (int)twiceLength); + + // Get a new buffer, copy the current one to it, return the current one, and + // set the new buffer as current. + byte[] newBuffer = ArrayPool.Shared.Rent(newCapacity); + Buffer.BlockCopy(currentBuffer, 0, newBuffer, 0, _length); + ArrayPool.Shared.Return(currentBuffer); + _buffer = newBuffer; + } + + public override void Write(byte[] buffer, int offset, int count) + { + Debug.Assert(buffer != null); + Debug.Assert(offset >= 0); + Debug.Assert(count >= 0); + + EnsureCapacity(_length + count); + Buffer.BlockCopy(buffer, offset, _buffer, _length, count); + _length += count; + } + +#if NETCOREAPP2_1 + public override void Write(ReadOnlySpan source) + { + EnsureCapacity(_length + source.Length); + source.CopyTo(new Span(_buffer, _length, source.Length)); + _length += source.Length; + } +#endif + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + Write(buffer, offset, count); + return Task.CompletedTask; + } + +#if NETCOREAPP2_1 + public override ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + Write(source.Span); + return default; + } +#endif + + public override void WriteByte(byte value) + { + int newLength = _length + 1; + EnsureCapacity(newLength); + _buffer[_length] = value; + _length = newLength; + } + + public override void Flush() { } + public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; + + public override long Length => _length; + public override bool CanWrite => true; + public override bool CanRead => false; + public override bool CanSeek => false; + + public override long Position + { + get => throw new NotSupportedException(); + set => throw new NotSupportedException(); + } + public override int Read(byte[] buffer, int offset, int count) { throw new NotSupportedException(); } + public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); } + public override void SetLength(long value) { throw new NotSupportedException(); } + + private static Exception CreateOverCapacityException(int maxBufferSize) + { + return new InvalidOperationException($"Buffer size of {maxBufferSize} exceeded."); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextReader.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextReader.cs index 5a6cbb99bd..1dba50f2c6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextReader.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextReader.cs @@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { internal class Utf8BufferTextReader : TextReader { + private readonly Decoder _decoder; private ReadOnlySequence _utf8Buffer; - private Decoder _decoder; [ThreadStatic] private static Utf8BufferTextReader _cachedInstance; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs new file mode 100644 index 0000000000..8ed80f4c27 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs @@ -0,0 +1,195 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Runtime.InteropServices; +using System.Text; + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + internal sealed class Utf8BufferTextWriter : TextWriter + { + private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + + [ThreadStatic] + private static Utf8BufferTextWriter _cachedInstance; + + private readonly Encoder _encoder; + private IBufferWriter _bufferWriter; + private Memory _memory; + private int _memoryUsed; + +#if DEBUG + private bool _inUse; +#endif + + public override Encoding Encoding => _utf8NoBom; + + public Utf8BufferTextWriter() + { + _encoder = _utf8NoBom.GetEncoder(); + } + + public static Utf8BufferTextWriter Get(IBufferWriter bufferWriter) + { + var writer = _cachedInstance; + if (writer == null) + { + writer = new Utf8BufferTextWriter(); + } + + // Taken off the the thread static + _cachedInstance = null; +#if DEBUG + if (writer._inUse) + { + throw new InvalidOperationException("The writer wasn't returned!"); + } + + writer._inUse = true; +#endif + writer.SetWriter(bufferWriter); + return writer; + } + + public static void Return(Utf8BufferTextWriter writer) + { + _cachedInstance = writer; + + writer._encoder.Reset(); + writer._memory = Memory.Empty; + writer._memoryUsed = 0; + writer._bufferWriter = null; + +#if DEBUG + writer._inUse = false; +#endif + } + + public void SetWriter(IBufferWriter bufferWriter) + { + _bufferWriter = bufferWriter; + } + + public override void Write(char[] buffer, int index, int count) + { + WriteInternal(buffer, index, count); + } + + public override void Write(char[] buffer) + { + WriteInternal(buffer, 0, buffer.Length); + } + + public override void Write(char value) + { + var destination = GetBuffer(); + + if (value <= 127) + { + destination[0] = (byte)value; + _memoryUsed++; + } + else + { + // Json.NET only writes ASCII characters by themselves, e.g. {}[], etc + // this should be an exceptional case + var bytesUsed = 0; + var charsUsed = 0; + unsafe + { +#if NETCOREAPP2_1 + _encoder.Convert(new Span(&value, 1), destination, false, out charsUsed, out bytesUsed, out _); +#else + fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination)) + { + _encoder.Convert(&value, 1, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _); + } +#endif + } + Debug.Assert(charsUsed == 1); + + if (bytesUsed > 0) + { + _memoryUsed += bytesUsed; + } + } + } + + private Span GetBuffer() + { + EnsureBuffer(); + + return _memory.Span.Slice(_memoryUsed, _memory.Length - _memoryUsed); + } + + private void EnsureBuffer() + { + if (_memoryUsed == _memory.Length) + { + // Used up the memory from the buffer writer so advance and get more + if (_memoryUsed > 0) + { + _bufferWriter.Advance(_memoryUsed); + } + + _memory = _bufferWriter.GetMemory(); + _memoryUsed = 0; + } + } + + private void WriteInternal(char[] buffer, int index, int count) + { + var currentIndex = index; + var charsRemaining = count; + while (charsRemaining > 0) + { + // The destination byte array might not be large enough so multiple writes are sometimes required + var destination = GetBuffer(); + + var bytesUsed = 0; + var charsUsed = 0; +#if NETCOREAPP2_1 + _encoder.Convert(buffer.AsSpan(currentIndex, charsRemaining), destination, false, out charsUsed, out bytesUsed, out _); +#else + unsafe + { + fixed (char* sourceChars = &buffer[currentIndex]) + fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination)) + { + _encoder.Convert(sourceChars, charsRemaining, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _); + } + } +#endif + + charsRemaining -= charsUsed; + currentIndex += charsUsed; + _memoryUsed += bytesUsed; + } + } + + public override void Flush() + { + if (_memoryUsed > 0) + { + _bufferWriter.Advance(_memoryUsed); + _memory = _memory.Slice(_memoryUsed, _memory.Length - _memoryUsed); + _memoryUsed = 0; + } + } + + protected override void Dispose(bool disposing) + { + base.Dispose(disposing); + + if (disposing) + { + Flush(); + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.SignalR.Common/Properties/AssemblyInfo.cs index 92e52dceeb..4a210459c0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Properties/AssemblyInfo.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Properties/AssemblyInfo.cs @@ -3,5 +3,6 @@ using System.Runtime.CompilerServices; +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Common.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests.Utils, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] -[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Common.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Microbenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 525cf55136..d9395b76f5 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; @@ -25,7 +26,8 @@ namespace Microsoft.AspNetCore.SignalR { public class HubConnectionContext { - private static Action _abortedCallback = AbortConnection; + private static readonly Action _abortedCallback = AbortConnection; + private static readonly byte[] _successHandshakeResponseData; private readonly ConnectionContext _connectionContext; private readonly ILogger _logger; @@ -37,6 +39,13 @@ namespace Microsoft.AspNetCore.SignalR private long _lastSendTimestamp = Stopwatch.GetTimestamp(); private byte[] _cachedPingMessage; + static HubConnectionContext() + { + var memoryBufferWriter = new MemoryBufferWriter(); + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter); + _successHandshakeResponseData = memoryBufferWriter.ToArray(); + } + public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { _connectionContext = connectionContext; @@ -185,10 +194,17 @@ namespace Microsoft.AspNetCore.SignalR try { - var ms = new MemoryStream(); - HandshakeProtocol.WriteResponseMessage(message, ms); + if (message == HandshakeResponseMessage.Empty) + { + // success response is always an empty object so send cached data + _connectionContext.Transport.Output.Write(_successHandshakeResponseData); + } + else + { + HandshakeProtocol.WriteResponseMessage(message, _connectionContext.Transport.Output); + } - await _connectionContext.Transport.Output.WriteAsync(ms.ToArray()); + await _connectionContext.Transport.Output.FlushAsync(); } finally { diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs index 2b9f130e08..05202de878 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs @@ -263,22 +263,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public void WriteMessage(HubMessage message, Stream output) { - // We're writing data into the memoryStream so that we can get the length prefix - using (var memoryStream = new MemoryStream()) + using (var stream = new LimitArrayPoolWriteStream()) { - WriteMessageCore(message, memoryStream); - if (memoryStream.TryGetBuffer(out var buffer)) - { - // Write the buffer directly - BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output); - output.Write(buffer.Array, buffer.Offset, buffer.Count); - } - else - { - BinaryMessageFormatter.WriteLengthPrefix(memoryStream.Length, output); - memoryStream.Position = 0; - memoryStream.CopyTo(output); - } + // Write message to a buffer so we can get its length + WriteMessageCore(message, stream); + var buffer = stream.GetBuffer(); + + // Write length then message to output + BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output); + output.Write(buffer.Array, buffer.Offset, buffer.Count); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 9fd3d2ca12..f3e5434b7c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -213,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis private async Task PublishAsync(string channel, IRedisMessage message) { byte[] payload; - using (var stream = new MemoryStream()) + using (var stream = new LimitArrayPoolWriteStream()) using (var writer = new JsonTextWriter(new StreamWriter(stream))) { _serializer.Serialize(writer, message); @@ -316,8 +316,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { var groupChannel = _channelNamePrefix + ".group." + groupName; - GroupData group; - if (!_groups.TryGetValue(groupChannel, out group)) + if (!_groups.TryGetValue(groupChannel, out var group)) { return; } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContext.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContext.cs index 56be6a94f5..982c9c4dd0 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContext.cs @@ -26,12 +26,12 @@ namespace Microsoft.AspNetCore.Sockets ITransferFormatFeature, IHttpContextFeature { - private object _heartbeatLock = new object(); + private readonly object _heartbeatLock = new object(); private List<(Action handler, object state)> _heartbeatHandlers; // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task - private TaskCompletionSource _disposeTcs = new TaskCompletionSource(); + private readonly TaskCompletionSource _disposeTcs = new TaskCompletionSource(); /// /// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations. diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index f090512621..56883a367e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -10,6 +10,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets.Client; @@ -72,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var s = await ReadSentTextMessageAsync(); - var output = new MemoryStream(); + var output = new MemoryBufferWriter(); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); await Application.Output.WriteAsync(output.ToArray()); diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs new file mode 100644 index 0000000000..fa3e56f7b8 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs @@ -0,0 +1,246 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol +{ + public class Utf8BufferTextWriterTests + { + [Fact] + public void WriteChar_Unicode() + { + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + textWriter.Write('['); + textWriter.Flush(); + Assert.Equal(1, bufferWriter.Position); + Assert.Equal((byte)'[', bufferWriter.CurrentSegment.Span[0]); + + textWriter.Write('"'); + textWriter.Flush(); + Assert.Equal(2, bufferWriter.Position); + Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[1]); + + textWriter.Write('\u00A3'); + textWriter.Flush(); + Assert.Equal(4, bufferWriter.Position); + + textWriter.Write('\u00A3'); + textWriter.Flush(); + Assert.Equal(6, bufferWriter.Position); + + textWriter.Write('"'); + textWriter.Flush(); + Assert.Equal(7, bufferWriter.Position); + Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[2]); + Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[3]); + Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[4]); + Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[5]); + Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[6]); + + textWriter.Write(']'); + textWriter.Flush(); + Assert.Equal(8, bufferWriter.Position); + Assert.Equal((byte)']', bufferWriter.CurrentSegment.Span[7]); + } + + [Fact] + public void WriteChar_UnicodeLastChar() + { + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + using (Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter()) + { + textWriter.SetWriter(bufferWriter); + + textWriter.Write('\u00A3'); + } + + Assert.Equal(2, bufferWriter.Position); + Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[0]); + Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[1]); + } + + [Fact] + public void WriteChar_UnicodeAndRunOutOfBufferSpace() + { + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + textWriter.Write('['); + textWriter.Flush(); + Assert.Equal(1, bufferWriter.Position); + Assert.Equal((byte)'[', bufferWriter.CurrentSegment.Span[0]); + + textWriter.Write('"'); + textWriter.Flush(); + Assert.Equal(2, bufferWriter.Position); + Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[1]); + + for (int i = 0; i < 2000; i++) + { + textWriter.Write('\u00A3'); + } + textWriter.Flush(); + + textWriter.Write('"'); + textWriter.Flush(); + Assert.Equal(4003, bufferWriter.Position); + Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[4002]); + + textWriter.Write(']'); + textWriter.Flush(); + Assert.Equal(4004, bufferWriter.Position); + + string result = Encoding.UTF8.GetString(bufferWriter.CurrentSegment.Slice(0, bufferWriter.Position).ToArray()); + Assert.Equal(2004, result.Length); + + Assert.Equal('[', result[0]); + Assert.Equal('"', result[1]); + + for (int i = 0; i < 2000; i++) + { + Assert.Equal('\u00A3', result[i + 2]); + } + + Assert.Equal('"', result[2002]); + Assert.Equal(']', result[2003]); + } + + [Fact] + public void WriteCharArray_SurrogatePairInMultipleCalls() + { + string fourCircles = char.ConvertFromUtf32(0x1F01C); + + char[] chars = fourCircles.ToCharArray(); + + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + textWriter.Write(chars, 0, 1); + textWriter.Flush(); + + // Surrogate buffered + Assert.Equal(0, bufferWriter.Position); + + textWriter.Write(chars, 1, 1); + textWriter.Flush(); + + Assert.Equal(4, bufferWriter.Position); + + byte[] expectedData = Encoding.UTF8.GetBytes(fourCircles); + + byte[] actualData = bufferWriter.CurrentSegment.Slice(0, 4).ToArray(); + + Assert.Equal(expectedData, actualData); + } + + [Fact] + public void WriteChar_SurrogatePairInMultipleCalls() + { + string fourCircles = char.ConvertFromUtf32(0x1F01C); + + char[] chars = fourCircles.ToCharArray(); + + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + textWriter.Write(chars[0]); + textWriter.Flush(); + + // Surrogate buffered + Assert.Equal(0, bufferWriter.Position); + + textWriter.Write(chars[1]); + textWriter.Flush(); + + Assert.Equal(4, bufferWriter.Position); + + byte[] expectedData = Encoding.UTF8.GetBytes(fourCircles); + + byte[] actualData = bufferWriter.CurrentSegment.Slice(0, 4).ToArray(); + + Assert.Equal(expectedData, actualData); + } + + [Fact] + public void WriteCharArray_NonZeroStart() + { + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + char[] chars = "Hello world".ToCharArray(); + + textWriter.Write(chars, 6, 1); + textWriter.Flush(); + + Assert.Equal(1, bufferWriter.Position); + Assert.Equal((byte)'w', bufferWriter.CurrentSegment.Span[0]); + } + + [Fact] + public void WriteCharArray_AcrossMultipleBuffers() + { + MemoryBufferWriter bufferWriter = new MemoryBufferWriter(2); + Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + textWriter.SetWriter(bufferWriter); + + char[] chars = "Hello world".ToCharArray(); + + textWriter.Write(chars); + textWriter.Flush(); + + Assert.Equal(6, bufferWriter.Segments.Count); + Assert.Equal(1, bufferWriter.Position); + + Assert.Equal((byte)'H', bufferWriter.Segments[0].Span[0]); + Assert.Equal((byte)'e', bufferWriter.Segments[0].Span[1]); + Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[0]); + Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[1]); + Assert.Equal((byte)'o', bufferWriter.Segments[2].Span[0]); + Assert.Equal((byte)' ', bufferWriter.Segments[2].Span[1]); + Assert.Equal((byte)'w', bufferWriter.Segments[3].Span[0]); + Assert.Equal((byte)'o', bufferWriter.Segments[3].Span[1]); + Assert.Equal((byte)'r', bufferWriter.Segments[4].Span[0]); + Assert.Equal((byte)'l', bufferWriter.Segments[4].Span[1]); + Assert.Equal((byte)'d', bufferWriter.Segments[5].Span[0]); + } + + [Fact] + public void GetAndReturnCachedBufferTextWriter() + { + MemoryBufferWriter bufferWriter1 = new MemoryBufferWriter(); + + var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1); + textWriter1.Write("Hello"); + textWriter1.Flush(); + Utf8BufferTextWriter.Return(textWriter1); + + Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray())); + + MemoryBufferWriter bufferWriter2 = new MemoryBufferWriter(); + + var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2); + textWriter2.Write("World"); + textWriter2.Flush(); + Utf8BufferTextWriter.Return(textWriter2); + + Assert.Equal("World", Encoding.UTF8.GetString(bufferWriter2.ToArray())); + + Assert.Same(textWriter1, textWriter2); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 580eb44f81..8dacbe9b65 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -69,11 +69,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { if (sendHandshakeRequestMessage) { - using (var memoryStream = new MemoryStream()) - { - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream); - await Connection.Application.Output.WriteAsync(memoryStream.ToArray()); - } + var memoryBufferWriter = new MemoryBufferWriter(); + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); + await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); } var connection = handler.OnConnectedAsync(Connection);