ThreadStatic MemoryBufferWriter (#1821)

This commit is contained in:
BrennanConroy 2018-04-02 11:25:04 -07:00 committed by GitHub
parent 6640f14e35
commit 5ce672dfe6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 195 additions and 53 deletions

View File

@ -30,11 +30,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
using (var memoryBufferWriter = new MemoryBufferWriter())
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeRequestResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
}
finally
{
MemoryBufferWriter.Return(memoryBufferWriter);
}
_pipe = new TestDuplexPipe();

View File

@ -24,7 +24,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
var handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
@ -32,6 +33,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
_pipe = new TestDuplexPipe();
_pipe.AddReadResult(new ValueTask<ReadResult>(handshakeResponseResult));
}
finally
{
MemoryBufferWriter.Return(writer);
}
_tcs = new TaskCompletionSource<ReadResult>();
_pipe.AddReadResult(new ValueTask<ReadResult>(_tcs.Task));

View File

@ -26,11 +26,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
[GlobalSetup]
public void GlobalSetup()
{
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(writer.ToArray()), false, false);
}
finally
{
MemoryBufferWriter.Return(writer);
}
_pipe = new TestDuplexPipe();

View File

@ -24,22 +24,32 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
var buffer = new byte[MessageLength];
Random.NextBytes(buffer);
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer);
writer.Write(buffer);
_binaryInput = writer.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
}
buffer = new byte[MessageLength];
Random.NextBytes(buffer);
using (var writer = new MemoryBufferWriter())
writer = MemoryBufferWriter.Get();
try
{
writer.Write(buffer);
TextMessageFormatter.WriteRecordSeparator(writer);
_textInput = writer.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
[Benchmark]

View File

@ -1,43 +1,92 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public sealed class MemoryBufferWriter : IBufferWriter<byte>, IDisposable
public sealed class MemoryBufferWriter : IBufferWriter<byte>
{
[ThreadStatic]
private static MemoryBufferWriter _cachedInstance;
#if DEBUG
private bool _inUse;
#endif
private readonly int _segmentSize;
private int _bytesWritten;
internal List<byte[]> Segments { get; }
internal int Position { get; private set; }
private List<byte[]> _segments;
private int _position;
public MemoryBufferWriter(int segmentSize = 2048)
private MemoryBufferWriter(int segmentSize = 2048)
{
_segmentSize = segmentSize;
Segments = new List<byte[]>();
_segments = new List<byte[]>();
}
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public static MemoryBufferWriter Get()
{
var writer = _cachedInstance;
if (writer == null)
{
writer = new MemoryBufferWriter();
}
// Taken off the thread static
_cachedInstance = null;
#if DEBUG
if (writer._inUse)
{
throw new InvalidOperationException("The reader wasn't returned!");
}
writer._inUse = true;
#endif
return writer;
}
public static void Return(MemoryBufferWriter writer)
{
_cachedInstance = writer;
#if DEBUG
writer._inUse = false;
#endif
for (int i = 0; i < writer._segments.Count; i++)
{
ArrayPool<byte>.Shared.Return(writer._segments[i]);
}
writer._segments.Clear();
writer._bytesWritten = 0;
writer._position = 0;
}
public Memory<byte> CurrentSegment => _segments[_segments.Count - 1];
public void Advance(int count)
{
_bytesWritten += count;
Position += count;
_position += count;
}
public Memory<byte> GetMemory(int sizeHint = 0)
{
// TODO: Use sizeHint
if (Segments.Count == 0 || Position == _segmentSize)
if (_segments.Count == 0 || _position == _segmentSize)
{
Segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
Position = 0;
_segments.Add(ArrayPool<byte>.Shared.Rent(_segmentSize));
_position = 0;
}
return CurrentSegment.Slice(Position, CurrentSegment.Length - Position);
// Cache property access
var currentSegment = CurrentSegment;
return currentSegment.Slice(_position, currentSegment.Length - _position);
}
public Span<byte> GetSpan(int sizeHint = 0)
@ -47,7 +96,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public byte[] ToArray()
{
if (Segments.Count == 0)
if (_segments.Count == 0)
{
return Array.Empty<byte>();
}
@ -57,26 +106,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
var totalWritten = 0;
// Copy full segments
for (int i = 0; i < Segments.Count - 1; i++)
for (int i = 0; i < _segments.Count - 1; i++)
{
Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
_segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize;
}
// Copy current incomplete segment
CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position));
CurrentSegment.Slice(0, _position).CopyTo(result.AsMemory(totalWritten, _position));
return result;
}
public void Dispose()
{
for (int i = 0; i < Segments.Count; i++)
{
ArrayPool<byte>.Shared.Return(Segments[i]);
}
Segments.Clear();
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
@ -14,6 +13,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private object _lock = new object();
private List<SerializedMessage> _serializedMessages;
private SerializedMessage _message1;
private SerializedMessage _message2;
public byte[] WriteMessage(IHubProtocol protocol)
{
@ -24,9 +25,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
lock (_lock)
{
if (ReferenceEquals(_message1.Protocol, protocol))
{
return _message1.Message;
}
if (ReferenceEquals(_message2.Protocol, protocol))
{
return _message2.Message;
}
for (var i = 0; i < _serializedMessages?.Count; i++)
{
if (_serializedMessages[i].Protocol.Equals(protocol))
if (ReferenceEquals(_serializedMessages[i].Protocol, protocol))
{
return _serializedMessages[i].Message;
}
@ -34,17 +45,27 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var bytes = protocol.WriteToArray(this);
if (_serializedMessages == null)
if (_message1.Protocol == null)
{
// Initialize with capacity 2 for the 2 built in protocols
_serializedMessages = new List<SerializedMessage>(2);
_message1 = new SerializedMessage(protocol, bytes);
}
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
// So we cap how many caches we store and worst case just serialize the message for every connection
if (_serializedMessages.Count < 10)
else if (_message2.Protocol == null)
{
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
_message2 = new SerializedMessage(protocol, bytes);
}
else
{
if (_serializedMessages == null)
{
_serializedMessages = new List<SerializedMessage>();
}
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
// So we cap how many caches we store and worst case just serialize the message for every connection
if (_serializedMessages.Count < 10)
{
_serializedMessages.Add(new SerializedMessage(protocol, bytes));
}
}
return bytes;

View File

@ -7,11 +7,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message)
{
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
hubProtocol.WriteMessage(message, writer);
return writer.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
}
}

View File

@ -41,11 +41,16 @@ namespace Microsoft.AspNetCore.SignalR
static HubConnectionContext()
{
using (var memoryBufferWriter = new MemoryBufferWriter())
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
_successHandshakeResponseData = memoryBufferWriter.ToArray();
}
finally
{
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)
@ -108,9 +113,10 @@ namespace Microsoft.AspNetCore.SignalR
// So that we don't serialize the HubMessage for every single connection
var buffer = message.WriteMessage(Protocol);
_connectionContext.Transport.Output.Write(buffer);
var output = _connectionContext.Transport.Output;
output.Write(buffer);
return _connectionContext.Transport.Output.FlushAsync();
return output.FlushAsync();
}
catch (Exception ex)
{

View File

@ -75,11 +75,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var s = await ReadSentTextMessageAsync();
using (var output = new MemoryBufferWriter())
var output = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
await Application.Output.WriteAsync(output.ToArray());
}
finally
{
MemoryBufferWriter.Return(output);
}
return s;
}

View File

@ -32,7 +32,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Encoding.UTF8.GetBytes("Hello,\r\nWorld!")
};
using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing
var writer = MemoryBufferWriter.Get(); // Use small chunks to test Advance/Enlarge and partial payload writing
try
{
foreach (var message in messages)
{
@ -42,6 +43,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Assert.Equal(expectedEncoding, writer.ToArray());
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
[Theory]
@ -73,7 +78,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
[InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })]
public void WriteBinaryMessage(byte[] encoded, byte[] payload)
{
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
@ -81,6 +87,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Assert.Equal(encoded, writer.ToArray());
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
[Theory]
@ -90,7 +100,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
public void WriteTextMessage(byte[] encoded, string payload)
{
var message = Encoding.UTF8.GetBytes(payload);
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer);
@ -98,13 +109,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Assert.Equal(encoded, writer.ToArray());
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
[Theory]
[MemberData(nameof(RandomPayloads))]
public void RoundTrippingTest(byte[] payload)
{
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer);
writer.Write(payload);
@ -112,6 +128,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped));
Assert.Equal(payload, roundtripped.ToArray());
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
public static IEnumerable<object[]> RandomPayloads()

View File

@ -114,13 +114,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var protocol = new JsonHubProtocol(Options.Create(protocolOptions));
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
protocol.WriteMessage(message, writer);
var json = Encoding.UTF8.GetString(writer.ToArray());
Assert.Equal(expectedOutput, json);
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
[Theory]

View File

@ -445,12 +445,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Frame(byte[] input)
{
using (var stream = new MemoryBufferWriter())
var stream = MemoryBufferWriter.Get();
try
{
BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream);
stream.Write(input);
return stream.ToArray();
}
finally
{
MemoryBufferWriter.Return(stream);
}
}
private static MessagePackObject Unpack(byte[] input)
@ -487,11 +492,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private static byte[] Write(HubMessage message)
{
var protocol = new MessagePackHubProtocol();
using (var writer = new MemoryBufferWriter())
var writer = MemoryBufferWriter.Get();
try
{
protocol.WriteMessage(message, writer);
return writer.ToArray();
}
finally
{
MemoryBufferWriter.Return(writer);
}
}
public class InvalidMessageData

View File

@ -66,11 +66,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
if (sendHandshakeRequestMessage)
{
using (var memoryBufferWriter = new MemoryBufferWriter())
var memoryBufferWriter = MemoryBufferWriter.Get();
try
{
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
}
finally
{
MemoryBufferWriter.Return(memoryBufferWriter);
}
}
var connection = handler.OnConnectedAsync(Connection);