diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index cc524f714b..3d982dd51a 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -69,6 +69,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks public void WriteMessage(HubMessage message, IBufferWriter output) { } + + public byte[] GetMessageBytes(HubMessage message) + { + return HubProtocolExtensions.GetMessageBytes(this, message); + } } public class NoErrorHubConnectionContext : HubConnectionContext diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubProtocolBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubProtocolBenchmark.cs index 8647f62fdc..0853cde5ab 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubProtocolBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubProtocolBenchmark.cs @@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks break; } - _binaryInput = _hubProtocol.WriteToArray(_hubMessage); + _binaryInput = _hubProtocol.GetMessageBytes(_hubMessage); _binder = new TestBinder(_hubMessage); } @@ -67,7 +67,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Benchmark] public void WriteSingleMessage() { - var bytes = _hubProtocol.WriteToArray(_hubMessage); + var bytes = _hubProtocol.GetMessageBytes(_hubMessage); if (bytes.Length != _binaryInput.Length) { throw new InvalidOperationException("Failed to write message"); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs index cb21cce57d..31a8f268a5 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisHubLifetimeManagerBenchmark.cs @@ -194,6 +194,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _innerProtocol.WriteMessage(message, output); } + public byte[] GetMessageBytes(HubMessage message) + { + return HubProtocolExtensions.GetMessageBytes(this, message); + } + public bool IsVersionSupported(int version) { return _innerProtocol.IsVersionSupported(version); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs index 5eec1d66bf..af14d3de5d 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/RedisProtocolBenchmark.cs @@ -143,6 +143,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { output.Write(_fixedOutput); } + + public byte[] GetMessageBytes(HubMessage message) + { + return HubProtocolExtensions.GetMessageBytes(this, message); + } } } } diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs index f17fa7ef40..1636c4c713 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } _parser = new ServerSentEventsMessageParser(); - _rawData = hubProtocol.WriteToArray(hubMessage); + _rawData = hubProtocol.GetMessageBytes(hubMessage); var ms = new MemoryStream(); ServerSentEventsMessageFormatter.WriteMessage(_rawData, ms); _sseFormattedData = ms.ToArray(); diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index d50b8966ac..adf5820429 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Runtime.CompilerServices; using System.Threading; @@ -217,6 +218,35 @@ namespace Microsoft.AspNetCore.Internal return result; } + public void CopyTo(Span span) + { + Debug.Assert(span.Length >= _bytesWritten); + + if (_currentSegment == null) + { + return; + } + + var totalWritten = 0; + + if (_fullSegments != null) + { + // Copy full segments + var count = _fullSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _fullSegments[i]; + segment.AsSpan().CopyTo(span.Slice(totalWritten)); + totalWritten += segment.Length; + } + } + + // Copy current incomplete segment + _currentSegment.AsSpan(0, _position).CopyTo(span.Slice(totalWritten)); + + Debug.Assert(_bytesWritten == totalWritten + _position); + } + public override void Flush() { } public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask; public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs index edfc0a1c0b..0ef7d0b5fe 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs @@ -10,15 +10,21 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters { public static void WriteLengthPrefix(long length, IBufferWriter output) { - // This code writes length prefix of the message as a VarInt. Read the comment in - // the BinaryMessageParser.TryParseMessage for details. - Span lenBuffer = stackalloc byte[5]; + var lenNumBytes = WriteLengthPrefix(length, lenBuffer); + + output.Write(lenBuffer.Slice(0, lenNumBytes)); + } + + public static int WriteLengthPrefix(long length, Span output) + { + // This code writes length prefix of the message as a VarInt. Read the comment in + // the BinaryMessageParser.TryParseMessage for details. var lenNumBytes = 0; do { - ref var current = ref lenBuffer[lenNumBytes]; + ref var current = ref output[lenNumBytes]; current = (byte)(length & 0x7f); length >>= 7; if (length > 0) @@ -29,7 +35,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters } while (length > 0); - output.Write(lenBuffer.Slice(0, lenNumBytes)); + return lenNumBytes; + } + + public static int LengthPrefixLength(long length) + { + var lenNumBytes = 0; + do + { + length >>= 7; + lenNumBytes++; + } + while (length > 0); + + return lenNumBytes; } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs index 2da713eb81..bf883730c1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs @@ -7,7 +7,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public static class HubProtocolExtensions { - public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message) + // Would work as default interface impl + public static byte[] GetMessageBytes(this IHubProtocol hubProtocol, HubMessage message) { var writer = MemoryBufferWriter.Get(); try diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index 3100717470..a2f341deb5 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Buffers; -using System.IO; using Microsoft.AspNetCore.Connections; namespace Microsoft.AspNetCore.SignalR.Internal.Protocol @@ -19,6 +18,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol void WriteMessage(HubMessage message, IBufferWriter output); + byte[] GetMessageBytes(HubMessage message); + bool IsVersionSupported(int version); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 62dab627ad..032e39c788 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR transferFormatFeature.ActiveFormat = Protocol.TransferFormat; } - _cachedPingMessage = Protocol.WriteToArray(PingMessage.Instance); + _cachedPingMessage = Protocol.GetMessageBytes(PingMessage.Instance); UserIdentifier = userIdProvider.GetUserId(this); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs index b54de1dade..dcff449701 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/SerializedHubMessage.cs @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal "This message was received from another server that did not have the requested protocol available."); } - serialized = protocol.WriteToArray(Message); + serialized = protocol.GetMessageBytes(Message); SetCache(protocol.Name, serialized); } @@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { writer.Write(protocol.Name); - var buffer = protocol.WriteToArray(message); + var buffer = protocol.GetMessageBytes(message); writer.Write(buffer.Length); writer.Write(buffer); } diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.Json/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.Json/Internal/Protocol/JsonHubProtocol.cs index c0b21f4cf7..4da52edd83 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.Json/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.Json/Internal/Protocol/JsonHubProtocol.cs @@ -82,6 +82,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol TextMessageFormatter.WriteRecordSeparator(output); } + public byte[] GetMessageBytes(HubMessage message) + { + return HubProtocolExtensions.GetMessageBytes(this, message); + } + private HubMessage ParseMessage(Utf8BufferTextReader textReader, IInvocationBinder binder) { try diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs index 7b98831bc8..e6e21625b5 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs @@ -303,6 +303,34 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } + public byte[] GetMessageBytes(HubMessage message) + { + var writer = MemoryBufferWriter.Get(); + + try + { + // Write message to a buffer so we can get its length + WriteMessageCore(message, writer); + + var dataLength = writer.Length; + var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length); + + var array = new byte[dataLength + prefixLength]; + var span = array.AsSpan(); + + // Write length then message to output + var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span); + Debug.Assert(written == prefixLength); + writer.CopyTo(span.Slice(prefixLength)); + + return array; + } + finally + { + MemoryBufferWriter.Return(writer); + } + } + private void WriteMessageCore(HubMessage message, Stream packer) { switch (message) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 7f8086f7f8..3d909c6f71 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -167,6 +167,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests throw _error; } } + + public byte[] GetMessageBytes(HubMessage message) + { + return HubProtocolExtensions.GetMessageBytes(this, message); + } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs index 55976b3095..10bc0b41fc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs @@ -34,6 +34,18 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void WritingNotingGivesEmptyData_CopyTo() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + Assert.Equal(0, bufferWriter.Length); + var data = new byte[bufferWriter.Length]; + bufferWriter.CopyTo(data); + Assert.Empty(data); + } + } + [Fact] public void WriteByteWorksAsFirstCall() { @@ -48,6 +60,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void WriteByteWorksAsFirstCall_CopyTo() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.WriteByte(234); + + Assert.Equal(1, bufferWriter.Length); + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(234, data[0]); + } + } + [Fact] public void WriteByteWorksIfFirstByteInNewSegment() { @@ -67,6 +94,27 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void WriteByteWorksIfFirstByteInNewSegment_CopyTo() + { + var inputSize = MinimumSegmentSize; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(16, bufferWriter.Length); + bufferWriter.WriteByte(16); + Assert.Equal(17, bufferWriter.Length); + + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(input, data.Take(16)); + Assert.Equal(16, data[16]); + } + } + [Fact] public void WriteByteWorksIfSegmentHasSpace() { @@ -88,6 +136,28 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void WriteByteWorksIfSegmentHasSpace_CopyTo() + { + var input = new byte[] { 11, 12, 13 }; + + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.Write(input, 0, input.Length); + bufferWriter.WriteByte(14); + + Assert.Equal(4, bufferWriter.Length); + + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(11, data[0]); + Assert.Equal(12, data[1]); + Assert.Equal(13, data[2]); + Assert.Equal(14, data[3]); + } + } + [Fact] public void ToArrayWithExactlyFullSegmentsWorks() { @@ -104,6 +174,24 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void ToArrayWithExactlyFullSegmentsWorks_CopyTo() + { + var inputSize = MinimumSegmentSize * 2; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(input, data); + } + } + [Fact] public void ToArrayWithSomeFullSegmentsWorks() { @@ -120,6 +208,23 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + [Fact] + public void ToArrayWithSomeFullSegmentsWorks_CopyTo() + { + var inputSize = (MinimumSegmentSize * 2) + 1; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(input, data); + } + } + [Fact] public async Task CopyToAsyncWithExactlyFullSegmentsWorks() { @@ -177,6 +282,34 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + + [Fact] + public void CopyToWithExactlyFullSegmentsWorks_CopyTo() + { + var inputSize = MinimumSegmentSize * 2; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + using (var destination = new MemoryBufferWriter()) + { + bufferWriter.CopyTo(destination); + var data = new byte[bufferWriter.Length]; + + bufferWriter.CopyTo(data); + Assert.Equal(input, data); + + Array.Clear(data, 0, data.Length); + + destination.CopyTo(data); + Assert.Equal(input, data); + } + } + } + [Fact] public void CopyToWithSomeFullSegmentsWorks() { @@ -197,6 +330,34 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } } + + [Fact] + public void CopyToWithSomeFullSegmentsWorks_CopyTo() + { + var inputSize = (MinimumSegmentSize * 2) + 1; + var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray(); + + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + bufferWriter.Write(input, 0, input.Length); + Assert.Equal(input.Length, bufferWriter.Length); + + using (var destination = new MemoryBufferWriter()) + { + bufferWriter.CopyTo(destination); + var data = new byte[bufferWriter.Length]; + bufferWriter.CopyTo(data); + + Assert.Equal(input, data); + + Array.Clear(data, 0, data.Length); + + destination.CopyTo(data); + Assert.Equal(input, data); + } + } + } + #if NETCOREAPP2_1 [Fact] public void WriteSpanWorksAtNonZeroOffset() @@ -216,6 +377,25 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Equal(4, data[3]); } } + + [Fact] + public void WriteSpanWorksAtNonZeroOffset_CopyTo() + { + using (var bufferWriter = new MemoryBufferWriter()) + { + bufferWriter.WriteByte(1); + bufferWriter.Write(new byte[] { 2, 3, 4 }.AsSpan()); + + Assert.Equal(4, bufferWriter.Length); + + var data = new byte[bufferWriter.Length]; + bufferWriter.CopyTo(data); + Assert.Equal(1, data[0]); + Assert.Equal(2, data[1]); + Assert.Equal(3, data[2]); + Assert.Equal(4, data[3]); + } + } #endif [Fact] diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 61b5b8df5c..66359ff755 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -179,7 +179,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task SendHubMessageAsync(HubMessage message) { - var payload = _protocol.WriteToArray(message); + var payload = _protocol.GetMessageBytes(message); await Connection.Application.Output.WriteAsync(payload); return message is HubInvocationMessage hubMessage ? hubMessage.InvocationId : null;