diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 237b86771d..ab7bcc2a54 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return false; } - public void WriteMessage(HubMessage message, Stream output) + public void WriteMessage(HubMessage message, IBufferWriter output) { } } diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs index ba161e8096..1ebf21ff8c 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionContextBenchmark.cs @@ -30,9 +30,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - var memoryBufferWriter = new MemoryBufferWriter(); - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); - _handshakeRequestResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); + using (var memoryBufferWriter = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage("json", 1), memoryBufferWriter); + _handshakeRequestResult = new ReadResult(new ReadOnlySequence(memoryBufferWriter.ToArray()), false, false); + } _pipe = new TestDuplexPipe(); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs index fa0ff6a205..ea1ed7cb97 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionSendBenchmark.cs @@ -24,12 +24,14 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - var ms = new MemoryBufferWriter(); - HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms); - var handshakeResponseResult = new ReadResult(new ReadOnlySequence(ms.ToArray()), false, false); - - _pipe = new TestDuplexPipe(); - _pipe.AddReadResult(new ValueTask(handshakeResponseResult)); + using (var writer = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer); + var handshakeResponseResult = new ReadResult(new ReadOnlySequence(writer.ToArray()), false, false); + + _pipe = new TestDuplexPipe(); + _pipe.AddReadResult(new ValueTask(handshakeResponseResult)); + } _tcs = new TaskCompletionSource(); _pipe.AddReadResult(new ValueTask(_tcs.Task)); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs index 7a29f3635c..f237383b15 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/HubConnectionStartBenchmark.cs @@ -26,9 +26,11 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - var ms = new MemoryBufferWriter(); - HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, ms); - _handshakeResponseResult = new ReadResult(new ReadOnlySequence(ms.ToArray()), false, false); + using (var writer = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, writer); + _handshakeResponseResult = new ReadResult(new ReadOnlySequence(writer.ToArray()), false, false); + } _pipe = new TestDuplexPipe(); diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs index 9d6ccc0c4d..4a0f5be890 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/MessageParserBenchmark.cs @@ -2,6 +2,7 @@ using System; using System.Buffers; using System.IO; using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks @@ -23,19 +24,22 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { var buffer = new byte[MessageLength]; Random.NextBytes(buffer); - var output = new MemoryStream(); - BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, output); - output.Write(buffer, 0, buffer.Length); - - _binaryInput = output.ToArray(); + using (var writer = new MemoryBufferWriter()) + { + BinaryMessageFormatter.WriteLengthPrefix(buffer.Length, writer); + writer.Write(buffer); + _binaryInput = writer.ToArray(); + } buffer = new byte[MessageLength]; Random.NextBytes(buffer); - output = new MemoryStream(); - output.Write(buffer, 0, buffer.Length); - TextMessageFormatter.WriteRecordSeparator(output); + using (var writer = new MemoryBufferWriter()) + { + writer.Write(buffer); + TextMessageFormatter.WriteRecordSeparator(writer); - _textInput = output.ToArray(); + _textInput = writer.ToArray(); + } } [Benchmark] diff --git a/src/Common/JsonUtils.cs b/src/Common/JsonUtils.cs index 97851fd8b1..bd37090763 100644 --- a/src/Common/JsonUtils.cs +++ b/src/Common/JsonUtils.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.Internal internal static JsonTextWriter CreateJsonTextWriter(TextWriter textWriter) { var writer = new JsonTextWriter(textWriter); - + writer.ArrayPool = JsonArrayPool.Shared; // Don't close the output, leave closing to the caller writer.CloseOutput = false; diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 90f16c24af..8d1c76b415 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { AssertConnectionValid(); - _protocol.WriteMessage(hubMessage, _connectionState.OutputStream); + _protocol.WriteMessage(hubMessage, _connectionState.Connection.Transport.Output); Log.SendingMessage(_logger, hubMessage); @@ -826,7 +826,6 @@ namespace Microsoft.AspNetCore.SignalR.Client public IConnection Connection { get; } public Task ReceiveTask { get; set; } public Exception CloseException { get; set; } - public PipeWriterStream OutputStream { get; } public bool Stopping { @@ -838,7 +837,6 @@ namespace Microsoft.AspNetCore.SignalR.Client { _hubConnection = hubConnection; Connection = connection; - OutputStream = new PipeWriterStream(Connection.Transport.Output); } public string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs index 8c89ec9fce..edfc0a1c0b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/BinaryMessageFormatter.cs @@ -2,22 +2,19 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO; +using System.Buffers; namespace Microsoft.AspNetCore.SignalR.Internal.Formatters { public static class BinaryMessageFormatter { - public static void WriteLengthPrefix(long length, Stream output) + public static void WriteLengthPrefix(long length, IBufferWriter output) { // This code writes length prefix of the message as a VarInt. Read the comment in // the BinaryMessageParser.TryParseMessage for details. -#if NETCOREAPP2_1 Span lenBuffer = stackalloc byte[5]; -#else - var lenBuffer = new byte[5]; -#endif + var lenNumBytes = 0; do { @@ -32,11 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters } while (length > 0); -#if NETCOREAPP2_1 output.Write(lenBuffer.Slice(0, lenNumBytes)); -#else - output.Write(lenBuffer, 0, lenNumBytes); -#endif } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs index db667efd67..ae41eea276 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/MemoryBufferWriter.cs @@ -4,24 +4,26 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.SignalR.Internal { - public sealed class MemoryBufferWriter : IBufferWriter + public sealed class MemoryBufferWriter : IBufferWriter, IDisposable { private readonly int _segmentSize; + private int _bytesWritten; - internal List> Segments { get; } + internal List Segments { get; } internal int Position { get; private set; } public MemoryBufferWriter(int segmentSize = 2048) { _segmentSize = segmentSize; - Segments = new List>(); + Segments = new List(); } public Memory CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; public void Advance(int count) { + _bytesWritten += count; Position += count; } @@ -31,8 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal if (Segments.Count == 0 || Position == _segmentSize) { - // TODO: Rent memory from a pool - Segments.Add(new Memory(new byte[_segmentSize])); + Segments.Add(ArrayPool.Shared.Rent(_segmentSize)); Position = 0; } @@ -51,17 +52,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal return Array.Empty(); } - var totalLength = (Segments.Count - 1) * _segmentSize; - totalLength += Position; - - var result = new byte[totalLength]; + var result = new byte[_bytesWritten]; var totalWritten = 0; // Copy full segments for (int i = 0; i < Segments.Count - 1; i++) { - Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize)); + Segments[i].AsMemory().CopyTo(result.AsMemory(totalWritten, _segmentSize)); totalWritten += _segmentSize; } @@ -71,5 +69,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal return result; } + + public void Dispose() + { + for (int i = 0; i < Segments.Count; i++) + { + ArrayPool.Shared.Return(Segments[i]); + } + Segments.Clear(); + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs index 73ae855889..7b38156f1c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubProtocolExtensions.cs @@ -1,18 +1,16 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.IO; - namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { public static class HubProtocolExtensions { public static byte[] WriteToArray(this IHubProtocol hubProtocol, HubMessage message) { - using (var ms = new LimitArrayPoolWriteStream()) + using (var writer = new MemoryBufferWriter()) { - hubProtocol.WriteMessage(message, ms); - return ms.ToArray(); + hubProtocol.WriteMessage(message, writer); + return writer.ToArray(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index f05481b20c..3100717470 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message); - void WriteMessage(HubMessage message, Stream output); + void WriteMessage(HubMessage message, IBufferWriter output); bool IsVersionSupported(int version); } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index b16f1644c5..a95bad8d61 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return message != null; } - public void WriteMessage(HubMessage message, Stream output) + public void WriteMessage(HubMessage message, IBufferWriter output) { WriteMessageCore(message, output); TextMessageFormatter.WriteRecordSeparator(output); @@ -340,50 +340,58 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol throw new JsonReaderException("Unexpected end when reading message headers"); } - private void WriteMessageCore(HubMessage message, Stream stream) + private void WriteMessageCore(HubMessage message, IBufferWriter stream) { - using (var writer = JsonUtils.CreateJsonTextWriter(new StreamWriter(stream, _utf8NoBom, 1024, leaveOpen: true))) + var textWriter = Utf8BufferTextWriter.Get(stream); + try { - writer.WriteStartObject(); - switch (message) + using (var writer = JsonUtils.CreateJsonTextWriter(textWriter)) { - case InvocationMessage m: - WriteMessageType(writer, HubProtocolConstants.InvocationMessageType); - WriteHeaders(writer, m); - WriteInvocationMessage(m, writer); - break; - case StreamInvocationMessage m: - WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType); - WriteHeaders(writer, m); - WriteStreamInvocationMessage(m, writer); - break; - case StreamItemMessage m: - WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType); - WriteHeaders(writer, m); - WriteStreamItemMessage(m, writer); - break; - case CompletionMessage m: - WriteMessageType(writer, HubProtocolConstants.CompletionMessageType); - WriteHeaders(writer, m); - WriteCompletionMessage(m, writer); - break; - case CancelInvocationMessage m: - WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType); - WriteHeaders(writer, m); - WriteCancelInvocationMessage(m, writer); - break; - case PingMessage _: - WriteMessageType(writer, HubProtocolConstants.PingMessageType); - break; - case CloseMessage m: - WriteMessageType(writer, HubProtocolConstants.CloseMessageType); - WriteCloseMessage(m, writer); - break; - default: - throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); + writer.WriteStartObject(); + switch (message) + { + case InvocationMessage m: + WriteMessageType(writer, HubProtocolConstants.InvocationMessageType); + WriteHeaders(writer, m); + WriteInvocationMessage(m, writer); + break; + case StreamInvocationMessage m: + WriteMessageType(writer, HubProtocolConstants.StreamInvocationMessageType); + WriteHeaders(writer, m); + WriteStreamInvocationMessage(m, writer); + break; + case StreamItemMessage m: + WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType); + WriteHeaders(writer, m); + WriteStreamItemMessage(m, writer); + break; + case CompletionMessage m: + WriteMessageType(writer, HubProtocolConstants.CompletionMessageType); + WriteHeaders(writer, m); + WriteCompletionMessage(m, writer); + break; + case CancelInvocationMessage m: + WriteMessageType(writer, HubProtocolConstants.CancelInvocationMessageType); + WriteHeaders(writer, m); + WriteCancelInvocationMessage(m, writer); + break; + case PingMessage _: + WriteMessageType(writer, HubProtocolConstants.PingMessageType); + break; + case CloseMessage m: + WriteMessageType(writer, HubProtocolConstants.CloseMessageType); + WriteCloseMessage(m, writer); + break; + default: + throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); + } + writer.WriteEndObject(); + writer.Flush(); } - writer.WriteEndObject(); - writer.Flush(); + } + finally + { + Utf8BufferTextWriter.Return(textWriter); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs index 8ed80f4c27..1d98f99c89 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/Utf8BufferTextWriter.cs @@ -5,7 +5,6 @@ using System; using System.Buffers; using System.Diagnostics; using System.IO; -using System.Linq; using System.Runtime.InteropServices; using System.Text; @@ -77,12 +76,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public override void Write(char[] buffer, int index, int count) { - WriteInternal(buffer, index, count); + WriteInternal(buffer.AsSpan(index, count)); } public override void Write(char[] buffer) { - WriteInternal(buffer, 0, buffer.Length); + WriteInternal(buffer); } public override void Write(char value) @@ -120,6 +119,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } + public override void Write(string value) + { + WriteInternal(value.AsSpan()); + } + private Span GetBuffer() { EnsureBuffer(); @@ -142,11 +146,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol } } - private void WriteInternal(char[] buffer, int index, int count) + private void WriteInternal(ReadOnlySpan buffer) { - var currentIndex = index; - var charsRemaining = count; - while (charsRemaining > 0) + while (buffer.Length > 0) { // The destination byte array might not be large enough so multiple writes are sometimes required var destination = GetBuffer(); @@ -154,20 +156,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol var bytesUsed = 0; var charsUsed = 0; #if NETCOREAPP2_1 - _encoder.Convert(buffer.AsSpan(currentIndex, charsRemaining), destination, false, out charsUsed, out bytesUsed, out _); + _encoder.Convert(buffer, destination, false, out charsUsed, out bytesUsed, out _); #else unsafe { - fixed (char* sourceChars = &buffer[currentIndex]) + fixed (char* sourceChars = &MemoryMarshal.GetReference(buffer)) fixed (byte* destinationBytes = &MemoryMarshal.GetReference(destination)) { - _encoder.Convert(sourceChars, charsRemaining, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _); + _encoder.Convert(sourceChars, buffer.Length, destinationBytes, destination.Length, false, out charsUsed, out bytesUsed, out _); } } #endif - charsRemaining -= charsUsed; - currentIndex += charsUsed; + buffer = buffer.Slice(charsUsed); _memoryUsed += bytesUsed; } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 930e8406f7..ce4b770ec0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -41,9 +41,11 @@ namespace Microsoft.AspNetCore.SignalR static HubConnectionContext() { - var memoryBufferWriter = new MemoryBufferWriter(); - HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter); - _successHandshakeResponseData = memoryBufferWriter.ToArray(); + using (var memoryBufferWriter = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, memoryBufferWriter); + _successHandshakeResponseData = memoryBufferWriter.ToArray(); + } } public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs index 05202de878..e0be88c850 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs @@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol return destination; } - public void WriteMessage(HubMessage message, Stream output) + public void WriteMessage(HubMessage message, IBufferWriter output) { using (var stream = new LimitArrayPoolWriteStream()) { @@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol // Write length then message to output BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output); - output.Write(buffer.Array, buffer.Offset, buffer.Count); + output.Write(buffer); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 7fe884500f..d7701603b9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(500); await hubConnection.StartAsync().OrTimeout(); - + // Start an invocation (but we won't complete it) var invokeTask = hubConnection.InvokeAsync("Method").OrTimeout(); @@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests throw new InvalidOperationException("No Parsed Message provided"); } - public void WriteMessage(HubMessage message, Stream output) + public void WriteMessage(HubMessage message, IBufferWriter output) { if (_error != null) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index a97a8f0d1a..8bc98d4d22 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -75,9 +75,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var s = await ReadSentTextMessageAsync(); - var output = new MemoryBufferWriter(); - HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); - await Application.Output.WriteAsync(output.ToArray()); + using (var output = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); + await Application.Output.WriteAsync(output.ToArray()); + } return s; } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs index 3d14e94458..9951bdcd00 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Formatters/BinaryMessageFormatterTests.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Xunit; @@ -31,20 +32,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters Encoding.UTF8.GetBytes("Hello,\r\nWorld!") }; - var output = new MemoryStream(); // Use small chunks to test Advance/Enlarge and partial payload writing - foreach (var message in messages) + using (var writer = new MemoryBufferWriter()) // Use small chunks to test Advance/Enlarge and partial payload writing { - BinaryMessageFormatter.WriteLengthPrefix(message.Length, output); - output.Write(message, 0, message.Length); - } + foreach (var message in messages) + { + BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer); + writer.Write(message); + } - Assert.Equal(expectedEncoding, output.ToArray()); + Assert.Equal(expectedEncoding, writer.ToArray()); + } } [Theory] - [InlineData(0, new byte[] { 0x00 }, new byte[0])] - [InlineData(0, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - [InlineData(0, new byte[] + [InlineData(new byte[] { 0x00 }, new byte[0])] + [InlineData(new byte[] { 0x80, 0x01, // Size - 128 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f, @@ -68,53 +70,45 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f })] - [InlineData(4, new byte[] { 0x00 }, new byte[0])] - [InlineData(4, new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] - public void WriteBinaryMessage(int offset, byte[] encoded, byte[] payload) + [InlineData(new byte[] { 0x04, 0xAB, 0xCD, 0xEF, 0x12 }, new byte[] { 0xAB, 0xCD, 0xEF, 0x12 })] + public void WriteBinaryMessage(byte[] encoded, byte[] payload) { - var output = new MemoryStream(); - - if (offset > 0) + using (var writer = new MemoryBufferWriter()) { - output.Seek(offset, SeekOrigin.Begin); + + BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer); + writer.Write(payload); + + Assert.Equal(encoded, writer.ToArray()); } - - BinaryMessageFormatter.WriteLengthPrefix(payload.Length, output); - output.Write(payload, 0, payload.Length); - - Assert.Equal(encoded, output.ToArray().Skip(offset)); } [Theory] - [InlineData(0, new byte[] { 0x00 }, "")] - [InlineData(0, new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")] - [InlineData(0, new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")] - [InlineData(4, new byte[] { 0x00 }, "")] - public void WriteTextMessage(int offset, byte[] encoded, string payload) + [InlineData(new byte[] { 0x00 }, "")] + [InlineData(new byte[] { 0x03, 0x41, 0x42, 0x43 }, "ABC")] + [InlineData(new byte[] { 0x0B, 0x41, 0x0A, 0x52, 0x0D, 0x43, 0x0D, 0x0A, 0x3B, 0x44, 0x45, 0x46 }, "A\nR\rC\r\n;DEF")] + public void WriteTextMessage(byte[] encoded, string payload) { var message = Encoding.UTF8.GetBytes(payload); - var output = new MemoryStream(); - - if (offset > 0) + using (var writer = new MemoryBufferWriter()) { - output.Seek(offset, SeekOrigin.Begin); + + BinaryMessageFormatter.WriteLengthPrefix(message.Length, writer); + writer.Write(message); + + Assert.Equal(encoded, writer.ToArray()); } - - BinaryMessageFormatter.WriteLengthPrefix(message.Length, output); - output.Write(message, 0, message.Length); - - Assert.Equal(encoded, output.ToArray().Skip(offset)); } [Theory] [MemberData(nameof(RandomPayloads))] public void RoundTrippingTest(byte[] payload) { - using (var ms = new MemoryStream()) + using (var writer = new MemoryBufferWriter()) { - BinaryMessageFormatter.WriteLengthPrefix(payload.Length, ms); - ms.Write(payload, 0, payload.Length); - var buffer = new ReadOnlySequence(ms.ToArray()); + BinaryMessageFormatter.WriteLengthPrefix(payload.Length, writer); + writer.Write(payload); + var buffer = new ReadOnlySequence(writer.ToArray()); Assert.True(BinaryMessageParser.TryParseMessage(ref buffer, out var roundtripped)); Assert.Equal(payload, roundtripped.ToArray()); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 8fd29e4c3b..f6b81cf398 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -15,6 +15,7 @@ using Xunit; namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol { + using Microsoft.AspNetCore.SignalR.Internal; using static HubMessageHelpers; public class JsonHubProtocolTests @@ -108,10 +109,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol var protocol = new JsonHubProtocol(Options.Create(protocolOptions)); - using (var ms = new MemoryStream()) + using (var writer = new MemoryBufferWriter()) { - protocol.WriteMessage(message, ms); - var json = Encoding.UTF8.GetString(ms.ToArray()); + protocol.WriteMessage(message, writer); + var json = Encoding.UTF8.GetString(writer.ToArray()); Assert.Equal(expectedOutput, json); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index b1fccbcc2a..2693c74908 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Formatters; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using MsgPack; @@ -444,10 +445,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol private static byte[] Frame(byte[] input) { - using (var stream = new MemoryStream()) + using (var stream = new MemoryBufferWriter()) { BinaryMessageFormatter.WriteLengthPrefix(input.Length, stream); - stream.Write(input, 0, input.Length); + stream.Write(input); return stream.ToArray(); } } @@ -486,11 +487,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol private static byte[] Write(HubMessage message) { var protocol = new MessagePackHubProtocol(); - using (var stream = new MemoryStream()) + using (var writer = new MemoryBufferWriter()) { - protocol.WriteMessage(message, stream); - stream.Flush(); - return stream.ToArray(); + protocol.WriteMessage(message, writer); + return writer.ToArray(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs index fa3e56f7b8..b47b71c723 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs @@ -4,9 +4,7 @@ using System; using System.Buffers; using System.Collections.Generic; -using System.Linq; using System.Text; -using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Xunit; @@ -17,8 +15,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void WriteChar_Unicode() { - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(4096); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); textWriter.Write('['); @@ -57,8 +55,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void WriteChar_UnicodeLastChar() { - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - using (Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter()) + var bufferWriter = new TestMemoryBufferWriter(4096); + using (var textWriter = new Utf8BufferTextWriter()) { textWriter.SetWriter(bufferWriter); @@ -73,8 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void WriteChar_UnicodeAndRunOutOfBufferSpace() { - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(4096); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); textWriter.Write('['); @@ -124,8 +122,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol char[] chars = fourCircles.ToCharArray(); - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(4096); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); textWriter.Write(chars, 0, 1); @@ -153,8 +151,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol char[] chars = fourCircles.ToCharArray(); - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(4096); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); textWriter.Write(chars[0]); @@ -178,8 +176,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void WriteCharArray_NonZeroStart() { - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(4096); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(4096); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); char[] chars = "Hello world".ToCharArray(); @@ -194,8 +192,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void WriteCharArray_AcrossMultipleBuffers() { - MemoryBufferWriter bufferWriter = new MemoryBufferWriter(2); - Utf8BufferTextWriter textWriter = new Utf8BufferTextWriter(); + var bufferWriter = new TestMemoryBufferWriter(2); + var textWriter = new Utf8BufferTextWriter(); textWriter.SetWriter(bufferWriter); char[] chars = "Hello world".ToCharArray(); @@ -222,7 +220,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol [Fact] public void GetAndReturnCachedBufferTextWriter() { - MemoryBufferWriter bufferWriter1 = new MemoryBufferWriter(); + var bufferWriter1 = new TestMemoryBufferWriter(); var textWriter1 = Utf8BufferTextWriter.Get(bufferWriter1); textWriter1.Write("Hello"); @@ -231,7 +229,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Equal("Hello", Encoding.UTF8.GetString(bufferWriter1.ToArray())); - MemoryBufferWriter bufferWriter2 = new MemoryBufferWriter(); + TestMemoryBufferWriter bufferWriter2 = new TestMemoryBufferWriter(); var textWriter2 = Utf8BufferTextWriter.Get(bufferWriter2); textWriter2.Write("World"); @@ -242,5 +240,74 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Same(textWriter1, textWriter2); } + + private sealed class TestMemoryBufferWriter : IBufferWriter + { + private readonly int _segmentSize; + + internal List> Segments { get; } + internal int Position { get; private set; } + + public TestMemoryBufferWriter(int segmentSize = 2048) + { + _segmentSize = segmentSize; + + Segments = new List>(); + } + + public Memory CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; + + public void Advance(int count) + { + Position += count; + } + + public Memory GetMemory(int sizeHint = 0) + { + // TODO: Use sizeHint + + if (Segments.Count == 0 || Position == _segmentSize) + { + // TODO: Rent memory from a pool + Segments.Add(new Memory(new byte[_segmentSize])); + Position = 0; + } + + return CurrentSegment.Slice(Position, CurrentSegment.Length - Position); + } + + public Span GetSpan(int sizeHint = 0) + { + return GetMemory(sizeHint).Span; + } + + public byte[] ToArray() + { + if (Segments.Count == 0) + { + return Array.Empty(); + } + + var totalLength = (Segments.Count - 1) * _segmentSize; + totalLength += Position; + + var result = new byte[totalLength]; + + var totalWritten = 0; + + // Copy full segments + for (int i = 0; i < Segments.Count - 1; i++) + { + Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize)); + + totalWritten += _segmentSize; + } + + // Copy current incomplete segment + CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position)); + + return result; + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 331d6ec7ce..2a7d4a6a18 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -66,9 +66,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { if (sendHandshakeRequestMessage) { - var memoryBufferWriter = new MemoryBufferWriter(); - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); - await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); + using (var memoryBufferWriter = new MemoryBufferWriter()) + { + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); + await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); + } } var connection = handler.OnConnectedAsync(Connection);