From 5ce672dfe66584303d374d4a76077ea8cb73f785 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Mon, 2 Apr 2018 11:25:04 -0700 Subject: [PATCH] ThreadStatic MemoryBufferWriter (#1821) --- .../HubConnectionContextBenchmark.cs | 7 +- .../HubConnectionSendBenchmark.cs | 7 +- .../HubConnectionStartBenchmark.cs | 7 +- .../MessageParserBenchmark.cs | 14 ++- .../Internal/MemoryBufferWriter.cs | 90 +++++++++++++------ .../Internal/Protocol/HubMessage.cs | 41 ++++++--- .../Protocol/HubProtocolExtensions.cs | 7 +- .../HubConnectionContext.cs | 12 ++- .../TestConnection.cs | 7 +- .../Formatters/BinaryMessageFormatterTests.cs | 28 +++++- .../Internal/Protocol/JsonHubProtocolTests.cs | 7 +- .../Protocol/MessagePackHubProtocolTests.cs | 14 ++- .../TestClient.cs | 7 +- 13 files changed, 195 insertions(+), 53 deletions(-) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs index 1ebf21ff8c..d8d70a608f 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs @@ -30,11 +30,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - using (var memoryBufferWriter = new MemoryBufferWriter()) + var memoryBufferWriter = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); _handshakeRequestResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } _pipe = new TestDuplexPipe(); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs index ea1ed7cb97..6ed4d551fd 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs @@ -24,7 +24,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer); var handshakeResponseResult = new ReadResult(new ReadOnlySequence(writer.ToArray()), false, false); @@ -32,6 +33,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _pipe = new TestDuplexPipe(); _pipe.AddReadResult(new ValueTask(handshakeResponseResult)); } + finally + { + MemoryBufferWriter.Return(writer); + } _tcs = new TaskCompletionSource(); _pipe.AddReadResult(new ValueTask(_tcs.Task)); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs index f237383b15..d4a91e4f44 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs @@ -26,11 +26,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer); _handshakeResponseResult = new ReadResult(new ReadOnlySequence(writer.ToArray()), false, false); } + finally + { + MemoryBufferWriter.Return(writer); + } _pipe = new TestDuplexPipe(); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs index 4a0f5be890..8d6e67f351 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs @@ -24,22 +24,32 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { var buffer = new byte[MessageLength]; Random.NextBytes(buffer); - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer); writer.Write(buffer); _binaryInput = writer.ToArray(); } + finally + { + MemoryBufferWriter.Return(writer); + } buffer = new byte[MessageLength]; Random.NextBytes(buffer); - using (var writer = new MemoryBufferWriter()) + writer = MemoryBufferWriter.Get(); + try { writer.Write(buffer); TextMessageFormatter.WriteRecordSeparator(writer); _textInput = writer.ToArray(); } + finally + { + MemoryBufferWriter.Return(writer); + } } [Benchmark] diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs index ae41eea276..60ecf21e7f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs @@ -1,43 +1,92 @@ -using System; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; using System.Buffers; using System.Collections.Generic; namespace Microsoft.AspNetCore.SignalR.Internal { - public sealed class MemoryBufferWriter : IBufferWriter, IDisposable + public sealed class MemoryBufferWriter : IBufferWriter { + [ThreadStatic] + private static MemoryBufferWriter _cachedInstance; + +#if DEBUG + private bool _inUse; +#endif + private readonly int _segmentSize; private int _bytesWritten; - internal List Segments { get; } - internal int Position { get; private set; } + private List _segments; + private int _position; - public MemoryBufferWriter(int segmentSize = 2048) + private MemoryBufferWriter(int segmentSize = 2048) { _segmentSize = segmentSize; - Segments = new List(); + _segments = new List(); } - public Memory CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; + public static MemoryBufferWriter Get() + { + var writer = _cachedInstance; + if (writer == null) + { + writer = new MemoryBufferWriter(); + } + + // Taken off the thread static + _cachedInstance = null; +#if DEBUG + if (writer._inUse) + { + throw new InvalidOperationException("The reader wasn't returned!"); + } + + writer._inUse = true; +#endif + + return writer; + } + + public static void Return(MemoryBufferWriter writer) + { + _cachedInstance = writer; +#if DEBUG + writer._inUse = false; +#endif + for (int i = 0; i < writer._segments.Count; i++) + { + ArrayPool.Shared.Return(writer._segments[i]); + } + writer._segments.Clear(); + writer._bytesWritten = 0; + writer._position = 0; + } + + public Memory CurrentSegment => _segments[_segments.Count - 1]; public void Advance(int count) { _bytesWritten += count; - Position += count; + _position += count; } public Memory GetMemory(int sizeHint = 0) { // TODO: Use sizeHint - if (Segments.Count == 0 || Position == _segmentSize) + if (_segments.Count == 0 || _position == _segmentSize) { - Segments.Add(ArrayPool.Shared.Rent(_segmentSize)); - Position = 0; + _segments.Add(ArrayPool.Shared.Rent(_segmentSize)); + _position = 0; } - return CurrentSegment.Slice(Position, CurrentSegment.Length - Position); + // Cache property access + var currentSegment = CurrentSegment; + return currentSegment.Slice(_position, currentSegment.Length - _position); } public Span GetSpan(int sizeHint = 0) @@ -47,7 +96,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal public byte[] ToArray() { - if (Segments.Count == 0) + if (_segments.Count == 0) { return Array.Empty(); } @@ -57,26 +106,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal var totalWritten = 0; // Copy full segments - for (int i = 0; i < Segments.Count - 1; i++) + for (int i = 0; i < _segments.Count - 1; i++) { - Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize)); + _segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize)); totalWritten += _segmentSize; } // Copy current incomplete segment - CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position)); + CurrentSegment.Slice(0, _position).CopyTo(result.AsMemory(totalWritten, _position)); return result; } - - public void Dispose() - { - for (int i = 0; i < Segments.Count; i++) - { - ArrayPool.Shared.Return(Segments[i]); - } - Segments.Clear(); - } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs index c1ebe09cf8..9cb4aabf5a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.IO; namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { @@ -14,6 +13,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol private object _lock = new object(); private List _serializedMessages; + private SerializedMessage _message1; + private SerializedMessage _message2; public byte[] WriteMessage(IHubProtocol protocol) { @@ -24,9 +25,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol lock (_lock) { + if (ReferenceEquals(_message1.Protocol, protocol)) + { + return _message1.Message; + } + + if (ReferenceEquals(_message2.Protocol, protocol)) + { + return _message2.Message; + } + for (var i = 0; i < _serializedMessages?.Count; i++) { - if (_serializedMessages[i].Protocol.Equals(protocol)) + if (ReferenceEquals(_serializedMessages[i].Protocol, protocol)) { return _serializedMessages[i].Message; } @@ -34,17 +45,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol var bytes = protocol.WriteToArray(this); - if (_serializedMessages == null) + if (_message1.Protocol == null) { - // Initialize with capacity 2 for the 2 built in protocols - _serializedMessages = new List(2); + _message1 = new SerializedMessage(protocol, bytes); } - - // We don't want to balloon memory if someone writes a poor IHubProtocolResolver - // So we cap how many caches we store and worst case just serialize the message for every connection - if (_serializedMessages.Count < 10) + else if (_message2.Protocol == null) { - _serializedMessages.Add(new SerializedMessage(protocol, bytes)); + _message2 = new SerializedMessage(protocol, bytes); + } + else + { + if (_serializedMessages == null) + { + _serializedMessages = new List(); + } + + // We don't want to balloon memory if someone writes a poor IHubProtocolResolver + // So we cap how many caches we store and worst case just serialize the message for every connection + if (_serializedMessages.Count < 10) + { + _serializedMessages.Add(new SerializedMessage(protocol, bytes)); + } } return bytes; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs index 7b38156f1c..c3e4fba9bd 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs @@ -7,11 +7,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message) { - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { hubProtocol.WriteMessage(message, writer); return writer.ToArray(); } + finally + { + MemoryBufferWriter.Return(writer); + } } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index ce4b770ec0..7fb8d6a4d4 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -41,11 +41,16 @@ namespace Microsoft.AspNetCore.SignalR static HubConnectionContext() { - using (var memoryBufferWriter = new MemoryBufferWriter()) + var memoryBufferWriter = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter); _successHandshakeResponseData = memoryBufferWriter.ToArray(); } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } } public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) @@ -108,9 +113,10 @@ namespace Microsoft.AspNetCore.SignalR // So that we don't serialize the HubMessage for every single connection var buffer = message.WriteMessage(Protocol); - _connectionContext.Transport.Output.Write(buffer); + var output = _connectionContext.Transport.Output; + output.Write(buffer); - return _connectionContext.Transport.Output.FlushAsync(); + return output.FlushAsync(); } catch (Exception ex) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 8bc98d4d22..b62fc3aaec 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -75,11 +75,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var s = await ReadSentTextMessageAsync(); - using (var output = new MemoryBufferWriter()) + var output = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); await Application.Output.WriteAsync(output.ToArray()); } + finally + { + MemoryBufferWriter.Return(output); + } return s; } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs index 9951bdcd00..f355d772ec 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs @@ -32,7 +32,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Encoding.UTF8.GetBytes("Hello,\r\nWorld!") }; - using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing + var writer = MemoryBufferWriter.Get(); // Use small chunks to test Advance/Enlarge and partial payload writing + try { foreach (var message in messages) { @@ -42,6 +43,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Assert.Equal(expectedEncoding, writer.ToArray()); } + finally + { + MemoryBufferWriter.Return(writer); + } } [Theory] @@ -73,7 +78,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters [InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] public void WriteBinaryMessage(byte[] encoded, byte[] payload) { - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer); @@ -81,6 +87,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Assert.Equal(encoded, writer.ToArray()); } + finally + { + MemoryBufferWriter.Return(writer); + } } [Theory] @@ -90,7 +100,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters public void WriteTextMessage(byte[] encoded, string payload) { var message = Encoding.UTF8.GetBytes(payload); - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer); @@ -98,13 +109,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Assert.Equal(encoded, writer.ToArray()); } + finally + { + MemoryBufferWriter.Return(writer); + } } [Theory] [MemberData(nameof(RandomPayloads))] public void RoundTrippingTest(byte[] payload) { - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer); writer.Write(payload); @@ -112,6 +128,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped)); Assert.Equal(payload, roundtripped.ToArray()); } + finally + { + MemoryBufferWriter.Return(writer); + } } public static IEnumerable RandomPayloads() diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 4de3f5b482..ff321edff2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -114,13 +114,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol var protocol = new JsonHubProtocol(Options.Create(protocolOptions)); - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { protocol.WriteMessage(message, writer); var json = Encoding.UTF8.GetString(writer.ToArray()); Assert.Equal(expectedOutput, json); } + finally + { + MemoryBufferWriter.Return(writer); + } } [Theory] diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index 34b6d81570..51a115f94f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -445,12 +445,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol private static byte[] Frame(byte[] input) { - using (var stream = new MemoryBufferWriter()) + var stream = MemoryBufferWriter.Get(); + try { BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream); stream.Write(input); return stream.ToArray(); } + finally + { + MemoryBufferWriter.Return(stream); + } } private static MessagePackObject Unpack(byte[] input) @@ -487,11 +492,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol private static byte[] Write(HubMessage message) { var protocol = new MessagePackHubProtocol(); - using (var writer = new MemoryBufferWriter()) + var writer = MemoryBufferWriter.Get(); + try { protocol.WriteMessage(message, writer); return writer.ToArray(); } + finally + { + MemoryBufferWriter.Return(writer); + } } public class InvalidMessageData diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 2a7d4a6a18..8a321c5886 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -66,11 +66,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests { if (sendHandshakeRequestMessage) { - using (var memoryBufferWriter = new MemoryBufferWriter()) + var memoryBufferWriter = MemoryBufferWriter.Get(); + try { HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } } var connection = handler.OnConnectedAsync(Connection);