Message writing optimization (#1683)

This commit is contained in:
James Newton-King 2018-03-29 11:03:40 +13:00 committed by GitHub
parent 890c8566d6
commit 8c84518ecc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 1129 additions and 82 deletions

View File

@ -1,3 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
@ -29,7 +32,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
_hubLifetimeManager = new DefaultHubLifetimeManager<Hub>(NullLogger<DefaultHubLifetimeManager<Hub>>.Instance);
IHubProtocol protocol;
if (Protocol == "json")

View File

@ -1,4 +1,7 @@
using System;
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using BenchmarkDotNet.Attributes;

View File

@ -0,0 +1,88 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
public class HubConnectionBenchmark
{
private HubConnection _hubConnection;
private TestDuplexPipe _pipe;
private ReadResult _handshakeResponseResult;
[GlobalSetup]
public void GlobalSetup()
{
var ms = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(ms.ToArray()), false, false);
_pipe = new TestDuplexPipe();
var connection = new TestConnection();
// prevents keep alive time being activated
connection.Features.Set<IConnectionInherentKeepAliveFeature>(new TestConnectionInherentKeepAliveFeature());
connection.Transport = _pipe;
_hubConnection = new HubConnection(() => connection, new JsonHubProtocol(), new NullLoggerFactory());
}
private void AddHandshakeResponse()
{
_pipe.AddReadResult(_handshakeResponseResult);
}
[Benchmark]
public async Task StartAsync()
{
AddHandshakeResponse();
await _hubConnection.StartAsync();
await _hubConnection.StopAsync();
}
}
public class TestConnectionInherentKeepAliveFeature : IConnectionInherentKeepAliveFeature
{
public TimeSpan KeepAliveInterval { get; } = TimeSpan.Zero;
}
public class TestConnection : IConnection
{
public Task StartAsync()
{
throw new NotImplementedException();
}
public Task StartAsync(TransferFormat transferFormat)
{
return Task.CompletedTask;
}
public Task DisposeAsync()
{
return Task.CompletedTask;
}
public IDuplexPipe Transport { get; set; }
public IFeatureCollection Features { get; } = new FeatureCollection();
}
}

View File

@ -0,0 +1,93 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
public class HubConnectionContextBenchmark
{
private HubConnectionContext _hubConnectionContext;
private TestHubProtocolResolver _successHubProtocolResolver;
private TestHubProtocolResolver _failureHubProtocolResolver;
private TestUserIdProvider _userIdProvider;
private List<string> _supportedProtocols;
private TestDuplexPipe _pipe;
private ReadResult _handshakeResponseResult;
[GlobalSetup]
public void GlobalSetup()
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter);
_handshakeResponseResult = new ReadResult(new ReadOnlySequence<byte>(memoryBufferWriter.ToArray()), false, false);
_pipe = new TestDuplexPipe();
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _pipe, _pipe);
_hubConnectionContext = new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance);
_successHubProtocolResolver = new TestHubProtocolResolver(new JsonHubProtocol());
_failureHubProtocolResolver = new TestHubProtocolResolver(null);
_userIdProvider = new TestUserIdProvider();
_supportedProtocols = new List<string> {"json"};
}
[Benchmark]
public async Task SuccessHandshakeAsync()
{
_pipe.AddReadResult(_handshakeResponseResult);
await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _successHubProtocolResolver, _userIdProvider);
}
[Benchmark]
public async Task ErrorHandshakeAsync()
{
_pipe.AddReadResult(_handshakeResponseResult);
await _hubConnectionContext.HandshakeAsync(TimeSpan.FromSeconds(5), _supportedProtocols, _failureHubProtocolResolver, _userIdProvider);
}
}
public class TestUserIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
return "UserId!";
}
}
public class TestHubProtocolResolver : IHubProtocolResolver
{
private readonly IHubProtocol _instance;
public TestHubProtocolResolver(IHubProtocol instance)
{
_instance = instance;
}
public IHubProtocol GetProtocol(string protocolName, IList<string> supportedProtocols)
{
return _instance;
}
}
}

View File

@ -12,6 +12,7 @@
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Core\Microsoft.AspNetCore.SignalR.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Client.Core\Microsoft.AspNetCore.SignalR.Client.Core.csproj" />
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.SignalR.Protocols.MsgPack\Microsoft.AspNetCore.SignalR.Protocols.MsgPack.csproj" />
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.BenchmarkRunner.Sources" Version="$(MicrosoftAspNetCoreBenchmarkRunnerSourcesPackageVersion)" />

View File

@ -0,0 +1,27 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared
{
public class TestDuplexPipe : IDuplexPipe
{
private readonly TestPipeReader _input;
public PipeReader Input => _input;
public PipeWriter Output { get; }
public TestDuplexPipe()
{
_input = new TestPipeReader();
Output = new TestPipeWriter();
}
public void AddReadResult(ReadResult readResult)
{
_input.ReadResults.Add(readResult);
}
}
}

View File

@ -0,0 +1,62 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared
{
public class TestPipeReader : PipeReader
{
public List<ReadResult> ReadResults { get; }
public TestPipeReader()
{
ReadResults = new List<ReadResult>();
}
public override void AdvanceTo(SequencePosition consumed)
{
}
public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
{
}
public override void CancelPendingRead()
{
throw new NotImplementedException();
}
public override void Complete(Exception exception = null)
{
throw new NotImplementedException();
}
public override void OnWriterCompleted(Action<Exception, object> callback, object state)
{
throw new NotImplementedException();
}
public override ValueTask<ReadResult> ReadAsync(CancellationToken cancellationToken = new CancellationToken())
{
if (ReadResults.Count == 0)
{
return new ValueTask<ReadResult>(new ReadResult(default, false, true));
}
var result = ReadResults[0];
ReadResults.RemoveAt(0);
return new ValueTask<ReadResult>(result);
}
public override bool TryRead(out ReadResult result)
{
throw new NotImplementedException();
}
}
}

View File

@ -0,0 +1,50 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared
{
public class TestPipeWriter : PipeWriter
{
// huge buffer that should be large enough for writing any content
private readonly byte[] _buffer = new byte[10000];
public override void Advance(int bytes)
{
}
public override Memory<byte> GetMemory(int sizeHint = 0)
{
return _buffer;
}
public override Span<byte> GetSpan(int sizeHint = 0)
{
return _buffer;
}
public override void OnReaderCompleted(Action<Exception, object> callback, object state)
{
throw new NotImplementedException();
}
public override void CancelPendingFlush()
{
throw new NotImplementedException();
}
public override void Complete(Exception exception = null)
{
throw new NotImplementedException();
}
public override ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = new CancellationToken())
{
return default;
}
}
}

View File

@ -496,17 +496,17 @@ namespace Microsoft.AspNetCore.SignalR.Client
private async Task HandshakeAsync()
{
// Send the Handshake request
using (var memoryStream = new MemoryStream())
{
Log.SendingHubHandshake(_logger);
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream);
var result = await WriteAsync(memoryStream.ToArray(), CancellationToken.None);
Log.SendingHubHandshake(_logger);
if (result.IsCompleted)
{
// The other side disconnected
throw new InvalidOperationException("The server disconnected before the handshake was completed");
}
var handshakeRequest = new HandshakeRequestMessage(_protocol.Name, _protocol.Version);
HandshakeProtocol.WriteRequestMessage(handshakeRequest, _connectionState.Connection.Transport.Output);
var sendHandshakeResult = await _connectionState.Connection.Transport.Output.FlushAsync(CancellationToken.None);
if (sendHandshakeResult.IsCompleted)
{
// The other side disconnected
throw new InvalidOperationException("The server disconnected before the handshake was completed");
}
try
@ -667,18 +667,24 @@ namespace Microsoft.AspNetCore.SignalR.Client
private void RunClosedEvent(Exception closeException)
{
_ = Task.Run(() =>
var closed = Closed;
// There is no need to start a new task if there is no Closed event registered
if (closed != null)
{
try
_ = Task.Run(() =>
{
Log.InvokingClosedEventHandler(_logger);
Closed?.Invoke(closeException);
}
catch (Exception ex)
{
Log.ErrorDuringClosedEvent(_logger, ex);
}
});
try
{
Log.InvokingClosedEventHandler(_logger);
closed.Invoke(closeException);
}
catch (Exception ex)
{
Log.ErrorDuringClosedEvent(_logger, ex);
}
});
}
}
private void ResetTimeoutTimer(Timer timeoutTimer)

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Buffers;
using System.IO;
namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
@ -9,11 +10,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
{
// This record separator is supposed to be used only for JSON payloads where 0x1e character
// will not occur (is not a valid character) and therefore it is safe to not escape it
internal static readonly byte RecordSeparator = 0x1e;
public static readonly byte RecordSeparator = 0x1e;
public static void WriteRecordSeparator(Stream output)
{
output.WriteByte(RecordSeparator);
}
public static void WriteRecordSeparator(IBufferWriter<byte> output)
{
var buffer = output.GetSpan(1);
buffer[0] = RecordSeparator;
output.Advance(1);
}
}
}

View File

@ -0,0 +1,75 @@
using System;
using System.Buffers;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public sealed class MemoryBufferWriter : IBufferWriter<byte>
{
private readonly int _segmentSize;
internal List<Memory<byte>> Segments { get; }
internal int Position { get; private set; }
public MemoryBufferWriter(int segmentSize = 2048)
{
_segmentSize = segmentSize;
Segments = new List<Memory<byte>>();
}
public Memory<byte> CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null;
public void Advance(int count)
{
Position += count;
}
public Memory<byte> GetMemory(int sizeHint = 0)
{
// TODO: Use sizeHint
if (Segments.Count == 0 || Position == _segmentSize)
{
// TODO: Rent memory from a pool
Segments.Add(new Memory<byte>(new byte[_segmentSize]));
Position = 0;
}
return CurrentSegment.Slice(Position, CurrentSegment.Length - Position);
}
public Span<byte> GetSpan(int sizeHint = 0)
{
return GetMemory(sizeHint).Span;
}
public byte[] ToArray()
{
if (Segments.Count == 0)
{
return Array.Empty<byte>();
}
var totalLength = (Segments.Count - 1) * _segmentSize;
totalLength += Position;
var result = new byte[totalLength];
var totalWritten = 0;
// Copy full segments
for (int i = 0; i < Segments.Count - 1; i++)
{
Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize));
totalWritten += _segmentSize;
}
// Copy current incomplete segment
CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position));
return result;
}
}
}

View File

@ -13,47 +13,67 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static class HandshakeProtocol
{
private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
private const string ProtocolPropertyName = "protocol";
private const string ProtocolVersionName = "version";
private const string ErrorPropertyName = "error";
private const string TypePropertyName = "type";
public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, Stream output)
public static void WriteRequestMessage(HandshakeRequestMessage requestMessage, IBufferWriter<byte> output)
{
using (var writer = CreateJsonTextWriter(output))
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
writer.WriteStartObject();
writer.WritePropertyName(ProtocolPropertyName);
writer.WriteValue(requestMessage.Protocol);
writer.WritePropertyName(ProtocolVersionName);
writer.WriteValue(requestMessage.Version);
writer.WriteEndObject();
}
TextMessageFormatter.WriteRecordSeparator(output);
}
public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, Stream output)
{
using (var writer = CreateJsonTextWriter(output))
{
writer.WriteStartObject();
if (!string.IsNullOrEmpty(responseMessage.Error))
using (var writer = CreateJsonTextWriter(textWriter))
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(responseMessage.Error);
writer.WriteStartObject();
writer.WritePropertyName(ProtocolPropertyName);
writer.WriteValue(requestMessage.Protocol);
writer.WritePropertyName(ProtocolVersionName);
writer.WriteValue(requestMessage.Version);
writer.WriteEndObject();
writer.Flush();
}
writer.WriteEndObject();
}
finally
{
Utf8BufferTextWriter.Return(textWriter);
}
TextMessageFormatter.WriteRecordSeparator(output);
}
private static JsonTextWriter CreateJsonTextWriter(Stream output)
public static void WriteResponseMessage(HandshakeResponseMessage responseMessage, IBufferWriter<byte> output)
{
return new JsonTextWriter(new StreamWriter(output, _utf8NoBom, 1024, leaveOpen: true));
var textWriter = Utf8BufferTextWriter.Get(output);
try
{
using (var writer = CreateJsonTextWriter(textWriter))
{
writer.WriteStartObject();
if (!string.IsNullOrEmpty(responseMessage.Error))
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(responseMessage.Error);
}
writer.WriteEndObject();
writer.Flush();
}
}
finally
{
Utf8BufferTextWriter.Return(textWriter);
}
TextMessageFormatter.WriteRecordSeparator(output);
}
private static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter)
{
var writer = new JsonTextWriter(textWriter);
writer.CloseOutput = false;
return writer;
}
public static bool TryParseResponseMessage(ref ReadOnlySequence<byte> buffer, out HandshakeResponseMessage responseMessage)

View File

@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message)
{
using (var ms = new MemoryStream())
using (var ms = new LimitArrayPoolWriteStream())
{
hubProtocol.WriteMessage(message, ms);
return ms.ToArray();

View File

@ -0,0 +1,163 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public sealed class LimitArrayPoolWriteStream : Stream
{
private const int MaxByteArrayLength = 0x7FFFFFC7;
private const int InitialLength = 256;
private readonly int _maxBufferSize;
private byte[] _buffer;
private int _length;
public LimitArrayPoolWriteStream() : this(MaxByteArrayLength) { }
public LimitArrayPoolWriteStream(int maxBufferSize) : this(maxBufferSize, InitialLength) { }
public LimitArrayPoolWriteStream(int maxBufferSize, long capacity)
{
if (capacity < InitialLength)
{
capacity = InitialLength;
}
else if (capacity > maxBufferSize)
{
throw CreateOverCapacityException(maxBufferSize);
}
_maxBufferSize = maxBufferSize;
_buffer = ArrayPool<byte>.Shared.Rent((int)capacity);
}
protected override void Dispose(bool disposing)
{
if (_buffer != null)
{
ArrayPool<byte>.Shared.Return(_buffer);
_buffer = null;
}
base.Dispose(disposing);
}
public ArraySegment<byte> GetBuffer() => new ArraySegment<byte>(_buffer, 0, _length);
public byte[] ToArray()
{
var arr = new byte[_length];
Buffer.BlockCopy(_buffer, 0, arr, 0, _length);
return arr;
}
private void EnsureCapacity(int value)
{
if ((uint)value > (uint)_maxBufferSize) // value cast handles overflow to negative as well
{
throw CreateOverCapacityException(_maxBufferSize);
}
else if (value > _buffer.Length)
{
Grow(value);
}
}
private void Grow(int value)
{
Debug.Assert(value > _buffer.Length);
// Extract the current buffer to be replaced.
byte[] currentBuffer = _buffer;
_buffer = null;
// Determine the capacity to request for the new buffer. It should be
// at least twice as long as the current one, if not more if the requested
// value is more than that. If the new value would put it longer than the max
// allowed byte array, than shrink to that (and if the required length is actually
// longer than that, we'll let the runtime throw).
uint twiceLength = 2 * (uint)currentBuffer.Length;
int newCapacity = twiceLength > MaxByteArrayLength ?
(value > MaxByteArrayLength ? value : MaxByteArrayLength) :
Math.Max(value, (int)twiceLength);
// Get a new buffer, copy the current one to it, return the current one, and
// set the new buffer as current.
byte[] newBuffer = ArrayPool<byte>.Shared.Rent(newCapacity);
Buffer.BlockCopy(currentBuffer, 0, newBuffer, 0, _length);
ArrayPool<byte>.Shared.Return(currentBuffer);
_buffer = newBuffer;
}
public override void Write(byte[] buffer, int offset, int count)
{
Debug.Assert(buffer != null);
Debug.Assert(offset >= 0);
Debug.Assert(count >= 0);
EnsureCapacity(_length + count);
Buffer.BlockCopy(buffer, offset, _buffer, _length, count);
_length += count;
}
#if NETCOREAPP2_1
public override void Write(ReadOnlySpan<byte> source)
{
EnsureCapacity(_length + source.Length);
source.CopyTo(new Span<byte>(_buffer, _length, source.Length));
_length += source.Length;
}
#endif
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Write(buffer, offset, count);
return Task.CompletedTask;
}
#if NETCOREAPP2_1
public override ValueTask WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
{
Write(source.Span);
return default;
}
#endif
public override void WriteByte(byte value)
{
int newLength = _length + 1;
EnsureCapacity(newLength);
_buffer[_length] = value;
_length = newLength;
}
public override void Flush() { }
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
public override long Length => _length;
public override bool CanWrite => true;
public override bool CanRead => false;
public override bool CanSeek => false;
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count) { throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); }
public override void SetLength(long value) { throw new NotSupportedException(); }
private static Exception CreateOverCapacityException(int maxBufferSize)
{
return new InvalidOperationException($"Buffer size of {maxBufferSize} exceeded.");
}
}
}

View File

@ -11,8 +11,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
internal class Utf8BufferTextReader : TextReader
{
private readonly Decoder _decoder;
private ReadOnlySequence<byte> _utf8Buffer;
private Decoder _decoder;
[ThreadStatic]
private static Utf8BufferTextReader _cachedInstance;

View File

@ -0,0 +1,195 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
internal sealed class Utf8BufferTextWriter : TextWriter
{
private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false);
[ThreadStatic]
private static Utf8BufferTextWriter _cachedInstance;
private readonly Encoder _encoder;
private IBufferWriter<byte> _bufferWriter;
private Memory<byte> _memory;
private int _memoryUsed;
#if DEBUG
private bool _inUse;
#endif
public override Encoding Encoding => _utf8NoBom;
public Utf8BufferTextWriter()
{
_encoder = _utf8NoBom.GetEncoder();
}
public static Utf8BufferTextWriter Get(IBufferWriter<byte> bufferWriter)
{
var writer = _cachedInstance;
if (writer == null)
{
writer = new Utf8BufferTextWriter();
}
// Taken off the the thread static
_cachedInstance = null;
#if DEBUG
if (writer._inUse)
{
throw new InvalidOperationException("The writer wasn't returned!");
}
writer._inUse = true;
#endif
writer.SetWriter(bufferWriter);
return writer;
}
public static void Return(Utf8BufferTextWriter writer)
{
_cachedInstance = writer;
writer._encoder.Reset();
writer._memory = Memory<byte>.Empty;
writer._memoryUsed = 0;
writer._bufferWriter = null;
#if DEBUG
writer._inUse = false;
#endif
}
public void SetWriter(IBufferWriter<byte> bufferWriter)
{
_bufferWriter = bufferWriter;
}
public override void Write(char[] buffer, int index, int count)
{
WriteInternal(buffer, index, count);
}
public override void Write(char[] buffer)
{
WriteInternal(buffer, 0, buffer.Length);
}
public override void Write(char value)
{
var destination = GetBuffer();
if (value <= 127)
{
destination[0] = (byte)value;
_memoryUsed++;
}
else
{
// Json.NET only writes ASCII characters by themselves, e.g. {}[], etc
// this should be an exceptional case
var bytesUsed = 0;
var charsUsed = 0;
unsafe
{
#if NETCOREAPP2_1
_encoder.Convert(new Span<char>(&value, 1), destination, false, out charsUsed, out bytesUsed, out _);
#else
fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination))
{
_encoder.Convert(&value, 1, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _);
}
#endif
}
Debug.Assert(charsUsed == 1);
if (bytesUsed > 0)
{
_memoryUsed += bytesUsed;
}
}
}
private Span<byte> GetBuffer()
{
EnsureBuffer();
return _memory.Span.Slice(_memoryUsed, _memory.Length - _memoryUsed);
}
private void EnsureBuffer()
{
if (_memoryUsed == _memory.Length)
{
// Used up the memory from the buffer writer so advance and get more
if (_memoryUsed > 0)
{
_bufferWriter.Advance(_memoryUsed);
}
_memory = _bufferWriter.GetMemory();
_memoryUsed = 0;
}
}
private void WriteInternal(char[] buffer, int index, int count)
{
var currentIndex = index;
var charsRemaining = count;
while (charsRemaining > 0)
{
// The destination byte array might not be large enough so multiple writes are sometimes required
var destination = GetBuffer();
var bytesUsed = 0;
var charsUsed = 0;
#if NETCOREAPP2_1
_encoder.Convert(buffer.AsSpan(currentIndex, charsRemaining), destination, false, out charsUsed, out bytesUsed, out _);
#else
unsafe
{
fixed (char* sourceChars = &buffer[currentIndex])
fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination))
{
_encoder.Convert(sourceChars, charsRemaining, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _);
}
}
#endif
charsRemaining -= charsUsed;
currentIndex += charsUsed;
_memoryUsed += bytesUsed;
}
}
public override void Flush()
{
if (_memoryUsed > 0)
{
_bufferWriter.Advance(_memoryUsed);
_memory = _memory.Slice(_memoryUsed, _memory.Length - _memoryUsed);
_memoryUsed = 0;
}
}
protected override void Dispose(bool disposing)
{
base.Dispose(disposing);
if (disposing)
{
Flush();
}
}
}
}

View File

@ -3,5 +3,6 @@
using System.Runtime.CompilerServices;
[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Common.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")]
[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests.Utils, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")]
[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Common.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")]
[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Microbenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")]

View File

@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
@ -25,7 +26,8 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionContext
{
private static Action<object> _abortedCallback = AbortConnection;
private static readonly Action<object> _abortedCallback = AbortConnection;
private static readonly byte[] _successHandshakeResponseData;
private readonly ConnectionContext _connectionContext;
private readonly ILogger _logger;
@ -37,6 +39,13 @@ namespace Microsoft.AspNetCore.SignalR
private long _lastSendTimestamp = Stopwatch.GetTimestamp();
private byte[] _cachedPingMessage;
static HubConnectionContext()
{
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter);
_successHandshakeResponseData = memoryBufferWriter.ToArray();
}
public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory)
{
_connectionContext = connectionContext;
@ -185,10 +194,17 @@ namespace Microsoft.AspNetCore.SignalR
try
{
var ms = new MemoryStream();
HandshakeProtocol.WriteResponseMessage(message, ms);
if (message == HandshakeResponseMessage.Empty)
{
// success response is always an empty object so send cached data
_connectionContext.Transport.Output.Write(_successHandshakeResponseData);
}
else
{
HandshakeProtocol.WriteResponseMessage(message, _connectionContext.Transport.Output);
}
await _connectionContext.Transport.Output.WriteAsync(ms.ToArray());
await _connectionContext.Transport.Output.FlushAsync();
}
finally
{

View File

@ -263,22 +263,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public void WriteMessage(HubMessage message, Stream output)
{
// We're writing data into the memoryStream so that we can get the length prefix
using (var memoryStream = new MemoryStream())
using (var stream = new LimitArrayPoolWriteStream())
{
WriteMessageCore(message, memoryStream);
if (memoryStream.TryGetBuffer(out var buffer))
{
// Write the buffer directly
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
output.Write(buffer.Array, buffer.Offset, buffer.Count);
}
else
{
BinaryMessageFormatter.WriteLengthPrefix(memoryStream.Length, output);
memoryStream.Position = 0;
memoryStream.CopyTo(output);
}
// Write message to a buffer so we can get its length
WriteMessageCore(message, stream);
var buffer = stream.GetBuffer();
// Write length then message to output
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
output.Write(buffer.Array, buffer.Offset, buffer.Count);
}
}

View File

@ -213,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private async Task PublishAsync(string channel, IRedisMessage message)
{
byte[] payload;
using (var stream = new MemoryStream())
using (var stream = new LimitArrayPoolWriteStream())
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
{
_serializer.Serialize(writer, message);
@ -316,8 +316,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
var groupChannel = _channelNamePrefix + ".group." + groupName;
GroupData group;
if (!_groups.TryGetValue(groupChannel, out group))
if (!_groups.TryGetValue(groupChannel, out var group))
{
return;
}

View File

@ -26,12 +26,12 @@ namespace Microsoft.AspNetCore.Sockets
ITransferFormatFeature,
IHttpContextFeature
{
private object _heartbeatLock = new object();
private readonly object _heartbeatLock = new object();
private List<(Action<object> handler, object state)> _heartbeatHandlers;
// This tcs exists so that multiple calls to DisposeAsync all wait asynchronously
// on the same task
private TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _disposeTcs = new TaskCompletionSource<object>();
/// <summary>
/// Creates the DefaultConnectionContext without Pipes to avoid upfront allocations.

View File

@ -10,6 +10,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets.Client;
@ -72,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var s = await ReadSentTextMessageAsync();
var output = new MemoryStream();
var output = new MemoryBufferWriter();
HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output);
await Application.Output.WriteAsync(output.ToArray());

View File

@ -0,0 +1,246 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
public class Utf8BufferTextWriterTests
{
[Fact]
public void WriteChar_Unicode()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write('[');
textWriter.Flush();
Assert.Equal(1, bufferWriter.Position);
Assert.Equal((byte)'[', bufferWriter.CurrentSegment.Span[0]);
textWriter.Write('"');
textWriter.Flush();
Assert.Equal(2, bufferWriter.Position);
Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[1]);
textWriter.Write('\u00A3');
textWriter.Flush();
Assert.Equal(4, bufferWriter.Position);
textWriter.Write('\u00A3');
textWriter.Flush();
Assert.Equal(6, bufferWriter.Position);
textWriter.Write('"');
textWriter.Flush();
Assert.Equal(7, bufferWriter.Position);
Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[2]);
Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[3]);
Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[4]);
Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[5]);
Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[6]);
textWriter.Write(']');
textWriter.Flush();
Assert.Equal(8, bufferWriter.Position);
Assert.Equal((byte)']', bufferWriter.CurrentSegment.Span[7]);
}
[Fact]
public void WriteChar_UnicodeLastChar()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
using (Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter())
{
textWriter.SetWriter(bufferWriter);
textWriter.Write('\u00A3');
}
Assert.Equal(2, bufferWriter.Position);
Assert.Equal((byte)0xC2, bufferWriter.CurrentSegment.Span[0]);
Assert.Equal((byte)0xA3, bufferWriter.CurrentSegment.Span[1]);
}
[Fact]
public void WriteChar_UnicodeAndRunOutOfBufferSpace()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write('[');
textWriter.Flush();
Assert.Equal(1, bufferWriter.Position);
Assert.Equal((byte)'[', bufferWriter.CurrentSegment.Span[0]);
textWriter.Write('"');
textWriter.Flush();
Assert.Equal(2, bufferWriter.Position);
Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[1]);
for (int i = 0; i < 2000; i++)
{
textWriter.Write('\u00A3');
}
textWriter.Flush();
textWriter.Write('"');
textWriter.Flush();
Assert.Equal(4003, bufferWriter.Position);
Assert.Equal((byte)'"', bufferWriter.CurrentSegment.Span[4002]);
textWriter.Write(']');
textWriter.Flush();
Assert.Equal(4004, bufferWriter.Position);
string result = Encoding.UTF8.GetString(bufferWriter.CurrentSegment.Slice(0, bufferWriter.Position).ToArray());
Assert.Equal(2004, result.Length);
Assert.Equal('[', result[0]);
Assert.Equal('"', result[1]);
for (int i = 0; i < 2000; i++)
{
Assert.Equal('\u00A3', result[i + 2]);
}
Assert.Equal('"', result[2002]);
Assert.Equal(']', result[2003]);
}
[Fact]
public void WriteCharArray_SurrogatePairInMultipleCalls()
{
string fourCircles = char.ConvertFromUtf32(0x1F01C);
char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write(chars, 0, 1);
textWriter.Flush();
// Surrogate buffered
Assert.Equal(0, bufferWriter.Position);
textWriter.Write(chars, 1, 1);
textWriter.Flush();
Assert.Equal(4, bufferWriter.Position);
byte[] expectedData = Encoding.UTF8.GetBytes(fourCircles);
byte[] actualData = bufferWriter.CurrentSegment.Slice(0, 4).ToArray();
Assert.Equal(expectedData, actualData);
}
[Fact]
public void WriteChar_SurrogatePairInMultipleCalls()
{
string fourCircles = char.ConvertFromUtf32(0x1F01C);
char[] chars = fourCircles.ToCharArray();
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
textWriter.Write(chars[0]);
textWriter.Flush();
// Surrogate buffered
Assert.Equal(0, bufferWriter.Position);
textWriter.Write(chars[1]);
textWriter.Flush();
Assert.Equal(4, bufferWriter.Position);
byte[] expectedData = Encoding.UTF8.GetBytes(fourCircles);
byte[] actualData = bufferWriter.CurrentSegment.Slice(0, 4).ToArray();
Assert.Equal(expectedData, actualData);
}
[Fact]
public void WriteCharArray_NonZeroStart()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray();
textWriter.Write(chars, 6, 1);
textWriter.Flush();
Assert.Equal(1, bufferWriter.Position);
Assert.Equal((byte)'w', bufferWriter.CurrentSegment.Span[0]);
}
[Fact]
public void WriteCharArray_AcrossMultipleBuffers()
{
MemoryBufferWriter bufferWriter = new MemoryBufferWriter(2);
Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter();
textWriter.SetWriter(bufferWriter);
char[] chars = "Hello world".ToCharArray();
textWriter.Write(chars);
textWriter.Flush();
Assert.Equal(6, bufferWriter.Segments.Count);
Assert.Equal(1, bufferWriter.Position);
Assert.Equal((byte)'H', bufferWriter.Segments[0].Span[0]);
Assert.Equal((byte)'e', bufferWriter.Segments[0].Span[1]);
Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[0]);
Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[1]);
Assert.Equal((byte)'o', bufferWriter.Segments[2].Span[0]);
Assert.Equal((byte)' ', bufferWriter.Segments[2].Span[1]);
Assert.Equal((byte)'w', bufferWriter.Segments[3].Span[0]);
Assert.Equal((byte)'o', bufferWriter.Segments[3].Span[1]);
Assert.Equal((byte)'r', bufferWriter.Segments[4].Span[0]);
Assert.Equal((byte)'l', bufferWriter.Segments[4].Span[1]);
Assert.Equal((byte)'d', bufferWriter.Segments[5].Span[0]);
}
[Fact]
public void GetAndReturnCachedBufferTextWriter()
{
MemoryBufferWriter bufferWriter1 = new MemoryBufferWriter();
var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1);
textWriter1.Write("Hello");
textWriter1.Flush();
Utf8BufferTextWriter.Return(textWriter1);
Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray()));
MemoryBufferWriter bufferWriter2 = new MemoryBufferWriter();
var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2);
textWriter2.Write("World");
textWriter2.Flush();
Utf8BufferTextWriter.Return(textWriter2);
Assert.Equal("World", Encoding.UTF8.GetString(bufferWriter2.ToArray()));
Assert.Same(textWriter1, textWriter2);
}
}
}

View File

@ -69,11 +69,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
if (sendHandshakeRequestMessage)
{
using (var memoryStream = new MemoryStream())
{
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream);
await Connection.Application.Output.WriteAsync(memoryStream.ToArray());
}
var memoryBufferWriter = new MemoryBufferWriter();
HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter);
await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray());
}
var connection = handler.OnConnectedAsync(Connection);