ThreadStatic MemoryBufferWriter (#1821)
This commit is contained in:
parent
6640f14e35
commit
5ce672dfe6
|
|
@ -30,11 +30,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
[GlobalSetup]
|
||||
public void GlobalSetup()
|
||||
{
|
||||
using (var memoryBufferWriter = new MemoryBufferWriter())
|
||||
var memoryBufferWriter = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
|
||||
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(memoryBufferWriter);
|
||||
}
|
||||
|
||||
_pipe = new TestDuplexPipe();
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
[GlobalSetup]
|
||||
public void GlobalSetup()
|
||||
{
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
|
||||
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
|
||||
|
|
@ -32,6 +33,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
_pipe = new TestDuplexPipe();
|
||||
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult));
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
|
||||
_tcs = new TaskCompletionSource<ReadResult>();
|
||||
_pipe.AddReadResult(new ValueTask<ReadResult>(_tcs.Task));
|
||||
|
|
|
|||
|
|
@ -26,11 +26,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
[GlobalSetup]
|
||||
public void GlobalSetup()
|
||||
{
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
|
||||
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
|
||||
_pipe = new TestDuplexPipe();
|
||||
|
||||
|
|
|
|||
|
|
@ -24,22 +24,32 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
|||
{
|
||||
var buffer = new byte[MessageLength];
|
||||
Random.NextBytes(buffer);
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer);
|
||||
writer.Write(buffer);
|
||||
_binaryInput = writer.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
|
||||
buffer = new byte[MessageLength];
|
||||
Random.NextBytes(buffer);
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
writer.Write(buffer);
|
||||
TextMessageFormatter.WriteRecordSeparator(writer);
|
||||
|
||||
_textInput = writer.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
|
|
|
|||
|
|
@ -1,43 +1,92 @@
|
|||
using System;
|
||||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public sealed class MemoryBufferWriter : IBufferWriter<byte>, IDisposable
|
||||
public sealed class MemoryBufferWriter : IBufferWriter<byte>
|
||||
{
|
||||
[ThreadStatic]
|
||||
private static MemoryBufferWriter _cachedInstance;
|
||||
|
||||
#if DEBUG
|
||||
private bool _inUse;
|
||||
#endif
|
||||
|
||||
private readonly int _segmentSize;
|
||||
private int _bytesWritten;
|
||||
|
||||
internal List<byte[]> Segments { get; }
|
||||
internal int Position { get; private set; }
|
||||
private List<byte[]> _segments;
|
||||
private int _position;
|
||||
|
||||
public MemoryBufferWriter(int segmentSize = 2048)
|
||||
private MemoryBufferWriter(int segmentSize = 2048)
|
||||
{
|
||||
_segmentSize = segmentSize;
|
||||
|
||||
Segments = new List<byte[]>();
|
||||
_segments = new List<byte[]>();
|
||||
}
|
||||
|
||||
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
|
||||
public static MemoryBufferWriter Get()
|
||||
{
|
||||
var writer = _cachedInstance;
|
||||
if (writer == null)
|
||||
{
|
||||
writer = new MemoryBufferWriter();
|
||||
}
|
||||
|
||||
// Taken off the thread static
|
||||
_cachedInstance = null;
|
||||
#if DEBUG
|
||||
if (writer._inUse)
|
||||
{
|
||||
throw new InvalidOperationException("The reader wasn't returned!");
|
||||
}
|
||||
|
||||
writer._inUse = true;
|
||||
#endif
|
||||
|
||||
return writer;
|
||||
}
|
||||
|
||||
public static void Return(MemoryBufferWriter writer)
|
||||
{
|
||||
_cachedInstance = writer;
|
||||
#if DEBUG
|
||||
writer._inUse = false;
|
||||
#endif
|
||||
for (int i = 0; i < writer._segments.Count; i++)
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(writer._segments[i]);
|
||||
}
|
||||
writer._segments.Clear();
|
||||
writer._bytesWritten = 0;
|
||||
writer._position = 0;
|
||||
}
|
||||
|
||||
public Memory<byte> CurrentSegment => _segments[_segments.Count - 1];
|
||||
|
||||
public void Advance(int count)
|
||||
{
|
||||
_bytesWritten += count;
|
||||
Position += count;
|
||||
_position += count;
|
||||
}
|
||||
|
||||
public Memory<byte> GetMemory(int sizeHint = 0)
|
||||
{
|
||||
// TODO: Use sizeHint
|
||||
|
||||
if (Segments.Count == 0 || Position == _segmentSize)
|
||||
if (_segments.Count == 0 || _position == _segmentSize)
|
||||
{
|
||||
Segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
|
||||
Position = 0;
|
||||
_segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
|
||||
_position = 0;
|
||||
}
|
||||
|
||||
return CurrentSegment.Slice(Position, CurrentSegment.Length - Position);
|
||||
// Cache property access
|
||||
var currentSegment = CurrentSegment;
|
||||
return currentSegment.Slice(_position, currentSegment.Length - _position);
|
||||
}
|
||||
|
||||
public Span<byte> GetSpan(int sizeHint = 0)
|
||||
|
|
@ -47,7 +96,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public byte[] ToArray()
|
||||
{
|
||||
if (Segments.Count == 0)
|
||||
if (_segments.Count == 0)
|
||||
{
|
||||
return Array.Empty<byte>();
|
||||
}
|
||||
|
|
@ -57,26 +106,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
var totalWritten = 0;
|
||||
|
||||
// Copy full segments
|
||||
for (int i = 0; i < Segments.Count - 1; i++)
|
||||
for (int i = 0; i < _segments.Count - 1; i++)
|
||||
{
|
||||
Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
|
||||
_segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
|
||||
|
||||
totalWritten += _segmentSize;
|
||||
}
|
||||
|
||||
// Copy current incomplete segment
|
||||
CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position));
|
||||
CurrentSegment.Slice(0, _position).CopyTo(result.AsMemory(totalWritten, _position));
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
for (int i = 0; i < Segments.Count; i++)
|
||||
{
|
||||
ArrayPool<byte>.Shared.Return(Segments[i]);
|
||||
}
|
||||
Segments.Clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
{
|
||||
|
|
@ -14,6 +13,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
|
||||
private object _lock = new object();
|
||||
private List<SerializedMessage> _serializedMessages;
|
||||
private SerializedMessage _message1;
|
||||
private SerializedMessage _message2;
|
||||
|
||||
public byte[] WriteMessage(IHubProtocol protocol)
|
||||
{
|
||||
|
|
@ -24,9 +25,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
|
||||
lock (_lock)
|
||||
{
|
||||
if (ReferenceEquals(_message1.Protocol, protocol))
|
||||
{
|
||||
return _message1.Message;
|
||||
}
|
||||
|
||||
if (ReferenceEquals(_message2.Protocol, protocol))
|
||||
{
|
||||
return _message2.Message;
|
||||
}
|
||||
|
||||
for (var i = 0; i < _serializedMessages?.Count; i++)
|
||||
{
|
||||
if (_serializedMessages[i].Protocol.Equals(protocol))
|
||||
if (ReferenceEquals(_serializedMessages[i].Protocol, protocol))
|
||||
{
|
||||
return _serializedMessages[i].Message;
|
||||
}
|
||||
|
|
@ -34,17 +45,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
|
||||
var bytes = protocol.WriteToArray(this);
|
||||
|
||||
if (_serializedMessages == null)
|
||||
if (_message1.Protocol == null)
|
||||
{
|
||||
// Initialize with capacity 2 for the 2 built in protocols
|
||||
_serializedMessages = new List<SerializedMessage>(2);
|
||||
_message1 = new SerializedMessage(protocol, bytes);
|
||||
}
|
||||
|
||||
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
|
||||
// So we cap how many caches we store and worst case just serialize the message for every connection
|
||||
if (_serializedMessages.Count < 10)
|
||||
else if (_message2.Protocol == null)
|
||||
{
|
||||
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
|
||||
_message2 = new SerializedMessage(protocol, bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (_serializedMessages == null)
|
||||
{
|
||||
_serializedMessages = new List<SerializedMessage>();
|
||||
}
|
||||
|
||||
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
|
||||
// So we cap how many caches we store and worst case just serialize the message for every connection
|
||||
if (_serializedMessages.Count < 10)
|
||||
{
|
||||
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
|
||||
}
|
||||
}
|
||||
|
||||
return bytes;
|
||||
|
|
|
|||
|
|
@ -7,11 +7,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
|||
{
|
||||
public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message)
|
||||
{
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
hubProtocol.WriteMessage(message, writer);
|
||||
return writer.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,11 +41,16 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
static HubConnectionContext()
|
||||
{
|
||||
using (var memoryBufferWriter = new MemoryBufferWriter())
|
||||
var memoryBufferWriter = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
|
||||
_successHandshakeResponseData = memoryBufferWriter.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(memoryBufferWriter);
|
||||
}
|
||||
}
|
||||
|
||||
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)
|
||||
|
|
@ -108,9 +113,10 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
// So that we don't serialize the HubMessage for every single connection
|
||||
var buffer = message.WriteMessage(Protocol);
|
||||
|
||||
_connectionContext.Transport.Output.Write(buffer);
|
||||
var output = _connectionContext.Transport.Output;
|
||||
output.Write(buffer);
|
||||
|
||||
return _connectionContext.Transport.Output.FlushAsync();
|
||||
return output.FlushAsync();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -75,11 +75,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
{
|
||||
var s = await ReadSentTextMessageAsync();
|
||||
|
||||
using (var output = new MemoryBufferWriter())
|
||||
var output = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
|
||||
await Application.Output.WriteAsync(output.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(output);
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
Encoding.UTF8.GetBytes("Hello,\r\nWorld!")
|
||||
};
|
||||
|
||||
using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing
|
||||
var writer = MemoryBufferWriter.Get(); // Use small chunks to test Advance/Enlarge and partial payload writing
|
||||
try
|
||||
{
|
||||
foreach (var message in messages)
|
||||
{
|
||||
|
|
@ -42,6 +43,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
|
||||
Assert.Equal(expectedEncoding, writer.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
|
|
@ -73,7 +78,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
[InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
|
||||
public void WriteBinaryMessage(byte[] encoded, byte[] payload)
|
||||
{
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
|
||||
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
|
||||
|
|
@ -81,6 +87,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
|
||||
Assert.Equal(encoded, writer.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
|
|
@ -90,7 +100,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
public void WriteTextMessage(byte[] encoded, string payload)
|
||||
{
|
||||
var message = Encoding.UTF8.GetBytes(payload);
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
|
||||
BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
|
||||
|
|
@ -98,13 +109,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
|
||||
Assert.Equal(encoded, writer.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(RandomPayloads))]
|
||||
public void RoundTrippingTest(byte[] payload)
|
||||
{
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
|
||||
writer.Write(payload);
|
||||
|
|
@ -112,6 +128,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
|
|||
Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped));
|
||||
Assert.Equal(payload, roundtripped.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
public static IEnumerable<object[]> RandomPayloads()
|
||||
|
|
|
|||
|
|
@ -114,13 +114,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
|
||||
var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
|
||||
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
protocol.WriteMessage(message, writer);
|
||||
var json = Encoding.UTF8.GetString(writer.ToArray());
|
||||
|
||||
Assert.Equal(expectedOutput, json);
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
|
|
|
|||
|
|
@ -445,12 +445,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
|
||||
private static byte[] Frame(byte[] input)
|
||||
{
|
||||
using (var stream = new MemoryBufferWriter())
|
||||
var stream = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream);
|
||||
stream.Write(input);
|
||||
return stream.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(stream);
|
||||
}
|
||||
}
|
||||
|
||||
private static MessagePackObject Unpack(byte[] input)
|
||||
|
|
@ -487,11 +492,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
private static byte[] Write(HubMessage message)
|
||||
{
|
||||
var protocol = new MessagePackHubProtocol();
|
||||
using (var writer = new MemoryBufferWriter())
|
||||
var writer = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
protocol.WriteMessage(message, writer);
|
||||
return writer.ToArray();
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(writer);
|
||||
}
|
||||
}
|
||||
|
||||
public class InvalidMessageData
|
||||
|
|
|
|||
|
|
@ -66,11 +66,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
if (sendHandshakeRequestMessage)
|
||||
{
|
||||
using (var memoryBufferWriter = new MemoryBufferWriter())
|
||||
var memoryBufferWriter = MemoryBufferWriter.Get();
|
||||
try
|
||||
{
|
||||
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
|
||||
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
MemoryBufferWriter.Return(memoryBufferWriter);
|
||||
}
|
||||
}
|
||||
|
||||
var connection = handler.OnConnectedAsync(Connection);
|
||||
|
|
|
|||
Loading…
Reference in New Issue