diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 5905722050..2a3eb0f1a0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -430,7 +430,7 @@ namespace Microsoft.AspNetCore.SignalR.Client return; } - //TODO: Optimize this! + // TODO: Optimize this! // Copying the callbacks to avoid concurrency issues InvocationHandler[] copiedHandlers; lock (handlers) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs index e74fe6e9d9..050f746d0a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs @@ -10,16 +10,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders { public class Base64Encoder : IDataEncoder { - public ReadOnlySpan Decode(byte[] payload) + public bool TryDecode(ref ReadOnlySpan buffer, out ReadOnlySpan data) { - ReadOnlySpan buffer = payload; - LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message); - - Span decoded = new byte[Base64.GetMaxDecodedFromUtf8Length(message.Length)]; - var status = Base64.DecodeFromUtf8(message, decoded, out _, out var written); - Debug.Assert(status == OperationStatus.Done); - - return decoded.Slice(0, written); + if (LengthPrefixedTextMessageParser.TryParseMessage(ref buffer, out var message)) + { + Span decoded = new byte[Base64.GetMaxDecodedFromUtf8Length(message.Length)]; + var status = Base64.DecodeFromUtf8(message, decoded, out _, out var written); + Debug.Assert(status == OperationStatus.Done); + data = decoded.Slice(0, written); + return true; + } + return false; } private const int Int32OverflowLength = 10; diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs index 8ca5f05ad6..f7f7146076 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs @@ -8,6 +8,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders public interface IDataEncoder { byte[] Encode(byte[] payload); - ReadOnlySpan Decode(byte[] payload); + bool TryDecode(ref ReadOnlySpan buffer, out ReadOnlySpan data); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs index 06621470e7..e66970f9cf 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs @@ -7,9 +7,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders { public class PassThroughEncoder : IDataEncoder { - public ReadOnlySpan Decode(byte[] payload) + public bool TryDecode(ref ReadOnlySpan buffer, out ReadOnlySpan data) { - return payload; + data = buffer; + buffer = Array.Empty(); + return true; } public byte[] Encode(byte[] payload) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs index 36283170db..9bb8ef6cd5 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/HubProtocolReaderWriter.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Buffers; using System.Collections; using System.Collections.Generic; @@ -32,8 +33,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal public bool ReadMessages(byte[] input, IInvocationBinder binder, out IList messages) { - var buffer = _dataEncoder.Decode(input); - return _hubProtocol.TryParseMessages(buffer, binder, out messages); + messages = new List(); + ReadOnlySpan span = input; + while (span.Length > 0 && _dataEncoder.TryDecode(ref span, out var data)) + { + _hubProtocol.TryParseMessages(data, binder, messages); + } + return messages.Count > 0; } public byte[] WriteMessage(HubMessage hubMessage) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index 1dc73512b8..f00c76247f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -13,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol ProtocolType Type { get; } - bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages); + bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, IList messages); void WriteMessage(HubMessage message, Stream output); } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index fd3fb4ac6e..8eeaaec772 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -43,10 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public ProtocolType Type => ProtocolType.Text; - public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, IList messages) { - messages = new List(); - while (TextMessageParser.TryParseMessage(ref input, out var payload)) { // TODO: Need a span-native JSON parser! diff --git a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs index c91b6bdc01..641df71fc1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Protocols.MsgPack/Internal/Protocol/MessagePackHubProtocol.cs @@ -35,10 +35,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol SerializationContext = options.Value.SerializationContext; } - public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, IList messages) { - messages = new List(); - while (BinaryMessageParser.TryParseMessage(ref input, out var payload)) { using (var memoryStream = new MemoryStream(payload.ToArray())) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 5a58227945..fb52860572 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -3,12 +3,11 @@ using System; using System.Collections.Generic; -using System.Diagnostics; using System.IO; +using System.IO.Pipelines; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using System.Threading.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; @@ -31,7 +30,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private volatile ConnectionState _connectionState = ConnectionState.Disconnected; private readonly object _stateChangeLock = new object(); - private volatile ChannelConnection _transportChannel; + private volatile IDuplexPipe _transportChannel; private readonly HttpClient _httpClient; private readonly HttpOptions _httpOptions; private volatile ITransport _transport; @@ -43,8 +42,8 @@ namespace Microsoft.AspNetCore.Sockets.Client private string _connectionId; private Exception _abortException; private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); - private ChannelReader Input => _transportChannel.Input; - private ChannelWriter Output => _transportChannel.Output; + private PipeReader Input => _transportChannel.Input; + private PipeWriter Output => _transportChannel.Output; private readonly List _callbacks = new List(); private readonly TransportType _requestedTransportType = TransportType.All; private readonly ConnectionLogScope _logScope; @@ -187,7 +186,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { _closeTcs = new TaskCompletionSource(); - _ = Input.Completion.ContinueWith(async t => + Input.OnWriterCompleted(async (exception, state) => { // Grab the exception and then clear it. // See comment at AbortAsync for more discussion on the thread-safety @@ -221,9 +220,9 @@ namespace Microsoft.AspNetCore.Sockets.Client try { - if (t.IsFaulted) + if (exception != null) { - Closed?.Invoke(t.Exception.InnerException); + Closed?.Invoke(exception); } else { @@ -237,7 +236,8 @@ namespace Microsoft.AspNetCore.Sockets.Client // Suppress (but log) the exception, this is user code _logger.ErrorDuringClosedEvent(ex); } - }); + + }, null); _receiveLoopTask = ReceiveAsync(); } @@ -325,15 +325,14 @@ namespace Microsoft.AspNetCore.Sockets.Client private async Task StartTransport(Uri connectUrl) { - var applicationToTransport = Channel.CreateUnbounded(); - var transportToApplication = Channel.CreateUnbounded(); - var applicationSide = ChannelConnection.Create(applicationToTransport, transportToApplication); - _transportChannel = ChannelConnection.Create(transportToApplication, applicationToTransport); + var options = new PipeOptions(readerScheduler: PipeScheduler.ThreadPool); + var pair = DuplexPipe.CreateConnectionPair(options, options); + _transportChannel = pair.Transport; // Start the transport, giving it one end of the pipeline try { - await _transport.StartAsync(connectUrl, applicationSide, GetTransferMode(), this); + await _transport.StartAsync(connectUrl, pair.Application, GetTransferMode(), this); // actual transfer mode can differ from the one that was requested so set it on the feature if (!_transport.Mode.HasValue) @@ -379,57 +378,72 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.HttpReceiveStarted(); - while (await Input.WaitToReadAsync()) + while (true) { if (_connectionState != ConnectionState.Connected) { _logger.SkipRaisingReceiveEvent(); - // drain - Input.TryRead(out _); - continue; + + break; } - if (Input.TryRead(out var buffer)) + var result = await Input.ReadAsync(); + var buffer = result.Buffer; + + try { - _logger.ScheduleReceiveEvent(); - _ = _eventQueue.Enqueue(async () => + if (!buffer.IsEmpty) { - _logger.RaiseReceiveEvent(); + _logger.ScheduleReceiveEvent(); + var data = buffer.ToArray(); - // Copying the callbacks to avoid concurrency issues - ReceiveCallback[] callbackCopies; - lock (_callbacks) + _ = _eventQueue.Enqueue(async () => { - callbackCopies = new ReceiveCallback[_callbacks.Count]; - _callbacks.CopyTo(callbackCopies); - } + _logger.RaiseReceiveEvent(); - foreach (var callbackObject in callbackCopies) - { - try + // Copying the callbacks to avoid concurrency issues + ReceiveCallback[] callbackCopies; + lock (_callbacks) { - await callbackObject.InvokeAsync(buffer); + callbackCopies = new ReceiveCallback[_callbacks.Count]; + _callbacks.CopyTo(callbackCopies); } - catch (Exception ex) + + foreach (var callbackObject in callbackCopies) { - _logger.ExceptionThrownFromCallback(nameof(OnReceived), ex); + try + { + await callbackObject.InvokeAsync(data); + } + catch (Exception ex) + { + _logger.ExceptionThrownFromCallback(nameof(OnReceived), ex); + } } - } - }); + }); + + } + else if (result.IsCompleted) + { + break; + } } - else + finally { - _logger.FailedReadingMessage(); + Input.AdvanceTo(buffer.End); } } - - await Input.Completion; } catch (Exception ex) { - Output.TryComplete(ex); + Input.Complete(ex); + _logger.ErrorReceiving(ex); } + finally + { + Input.Complete(); + } _logger.EndReceive(); } @@ -450,23 +464,11 @@ namespace Microsoft.AspNetCore.Sockets.Client "Cannot send messages when the connection is not in the Connected state."); } - // TaskCreationOptions.RunContinuationsAsynchronously ensures that continuations awaiting - // SendAsync (i.e. user's code) are not running on the same thread as the code that sets - // TaskCompletionSource result. This way we prevent from user's code blocking our channel - // send loop. - var sendTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var message = new SendMessage(data, sendTcs); - _logger.SendingMessage(); - while (await Output.WaitToWriteAsync(cancellationToken)) - { - if (Output.TryWrite(message)) - { - await sendTcs.Task; - break; - } - } + cancellationToken.ThrowIfCancellationRequested(); + + await Output.WriteAsync(data); } // AbortAsync creates a few thread-safety races that we are OK with. @@ -539,7 +541,7 @@ namespace Microsoft.AspNetCore.Sockets.Client if (_transportChannel != null) { - Output.TryComplete(); + Output.Complete(); } if (transport != null) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs index 0587ef0684..a01da0c378 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs @@ -3,13 +3,13 @@ using System; using System.Threading.Tasks; -using System.Threading.Channels; +using System.IO.Pipelines; namespace Microsoft.AspNetCore.Sockets.Client { public interface ITransport { - Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, IConnection connection); + Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection); Task StopAsync(); TransferMode? Mode { get; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransportFactory.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransportFactory.cs index 0d7fe168cf..f701951f67 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransportFactory.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransportFactory.cs @@ -1,8 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Net.Http; - namespace Microsoft.AspNetCore.Sockets.Client { public interface ITransportFactory diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index 870175a819..d44264fa59 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -47,8 +47,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal private static readonly Action _messageToApp = LoggerMessage.Define(LogLevel.Debug, new EventId(12, nameof(MessageToApp)), "Passing message to application. Payload size: {count}."); - private static readonly Action _receivedFromApp = - LoggerMessage.Define(LogLevel.Debug, new EventId(13, nameof(ReceivedFromApp)), "Received message from application. Payload size: {count}."); + private static readonly Action _receivedFromApp = + LoggerMessage.Define(LogLevel.Debug, new EventId(13, nameof(ReceivedFromApp)), "Received message from application. Payload size: {count}."); private static readonly Action _sendMessageCanceled = LoggerMessage.Define(LogLevel.Information, new EventId(14, nameof(SendMessageCanceled)), "Sending a message canceled."); @@ -66,8 +66,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal LoggerMessage.Define(LogLevel.Debug, new EventId(18, nameof(CancelMessage)), "Canceled passing message to application."); // Category: ServerSentEventsTransport and LongPollingTransport - private static readonly Action _sendingMessages = - LoggerMessage.Define(LogLevel.Debug, new EventId(10, nameof(SendingMessages)), "Sending {count} message(s) to the server using url: {url}."); + private static readonly Action _sendingMessages = + LoggerMessage.Define(LogLevel.Debug, new EventId(10, nameof(SendingMessages)), "Sending {count} bytes to the server using url: {url}."); private static readonly Action _sentSuccessfully = LoggerMessage.Define(LogLevel.Debug, new EventId(11, nameof(SentSuccessfully)), "Message(s) sent successfully."); @@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal _sendStarted(logger, null); } - public static void ReceivedFromApp(this ILogger logger, int count) + public static void ReceivedFromApp(this ILogger logger, long count) { _receivedFromApp(logger, count, null); } @@ -261,7 +261,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal _cancelMessage(logger, null); } - public static void SendingMessages(this ILogger logger, int count, Uri url) + public static void SendingMessages(this ILogger logger, long count, Uri url) { _sendingMessages(logger, count, url, null); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index 0dfc8135e9..67178be7fc 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -2,10 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private readonly HttpClient _httpClient; private readonly HttpOptions _httpOptions; private readonly ILogger _logger; - private Channel _application; + private IDuplexPipe _application; private Task _sender; private Task _poller; @@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, IConnection connection) + public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection) { if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) { @@ -62,7 +62,8 @@ namespace Microsoft.AspNetCore.Sockets.Client Running = Task.WhenAll(_sender, _poller).ContinueWith(t => { _logger.TransportStopped(t.Exception?.InnerException); - _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + _application.Output.Complete(t.Exception?.InnerException); + _application.Input.Complete(); return t; }).Unwrap(); @@ -122,17 +123,11 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.ReceivedMessages(); - // Until Pipeline starts natively supporting BytesReader, this is the easiest way to do this. + // TODO: Use CopyToAsync here var payload = await response.Content.ReadAsByteArrayAsync(); if (payload.Length > 0) { - while (!_application.Writer.TryWrite(payload)) - { - if (cancellationToken.IsCancellationRequested || !await _application.Writer.WaitToWriteAsync(cancellationToken)) - { - return; - } - } + await _application.Output.WriteAsync(payload); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index 0e1771b52c..4a4b7631d2 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -2,12 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.IO; +using System.IO.Pipelines; using System.Net.Http; using System.Threading; using System.Threading.Tasks; -using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; @@ -16,87 +14,61 @@ namespace Microsoft.AspNetCore.Sockets.Client { internal static class SendUtils { - public static async Task SendMessages(Uri sendUrl, Channel application, HttpClient httpClient, + public static async Task SendMessages(Uri sendUrl, IDuplexPipe application, HttpClient httpClient, HttpOptions httpOptions, CancellationTokenSource transportCts, ILogger logger) { logger.SendStarted(); - IList messages = null; + try { - while (await application.Reader.WaitToReadAsync(transportCts.Token)) + while (true) { - // Grab as many messages as we can from the channel - messages = new List(); - while (!transportCts.IsCancellationRequested && application.Reader.TryRead(out SendMessage message)) + var result = await application.Input.ReadAsync(transportCts.Token); + var buffer = result.Buffer; + + try { - messages.Add(message); - } + // Grab as many messages as we can from the channel - transportCts.Token.ThrowIfCancellationRequested(); - if (messages.Count > 0) - { - logger.SendingMessages(messages.Count, sendUrl); - - // Send them in a single post - var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); - PrepareHttpRequest(request, httpOptions); - - // TODO: We can probably use a pipeline here or some kind of pooled memory. - // But where do we get the pool from? ArrayBufferPool.Instance? - var memoryStream = new MemoryStream(); - - foreach (var message in messages) + transportCts.Token.ThrowIfCancellationRequested(); + if (!buffer.IsEmpty) { - if (message.Payload != null) - { - memoryStream.Write(message.Payload, 0, message.Payload.Length); - } + logger.SendingMessages(buffer.Length, sendUrl); + + // Send them in a single post + var request = new HttpRequestMessage(HttpMethod.Post, sendUrl); + PrepareHttpRequest(request, httpOptions); + + // TODO: Use a custom stream implementation over the ReadOnlyBuffer + request.Content = new ByteArrayContent(buffer.ToArray()); + + var response = await httpClient.SendAsync(request, transportCts.Token); + response.EnsureSuccessStatusCode(); + + logger.SentSuccessfully(); } - - memoryStream.Position = 0; - - // Set the, now filled, stream as the content - request.Content = new StreamContent(memoryStream); - - var response = await httpClient.SendAsync(request, transportCts.Token); - response.EnsureSuccessStatusCode(); - - logger.SentSuccessfully(); - foreach (var message in messages) + else if (result.IsCompleted) { - message.SendResult?.TrySetResult(null); + break; + } + else + { + logger.NoMessages(); } } - else + finally { - logger.NoMessages(); + application.Input.AdvanceTo(buffer.End); } } } catch (OperationCanceledException) { - // transport is being closed - if (messages != null) - { - foreach (var message in messages) - { - // This will no-op for any messages that were already marked as completed. - message.SendResult?.TrySetCanceled(); - } - } logger.SendCanceled(); } catch (Exception ex) { logger.ErrorSending(sendUrl, ex); - if (messages != null) - { - foreach (var message in messages) - { - // This will no-op for any messages that were already marked as completed. - message.SendResult?.TrySetException(ex); - } - } throw; } finally diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index 864c98247c..f0b6c0be7c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -2,13 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.IO.Pipelines; using System.Net.Http; using System.Net.Http.Headers; +using System.Text; using System.Threading; using System.Threading.Tasks; -using System.Threading.Channels; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Internal.Formatters; @@ -19,14 +18,13 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class ServerSentEventsTransport : ITransport { - private static readonly MemoryPool _memoryPool = new MemoryPool(); private readonly HttpClient _httpClient; private readonly HttpOptions _httpOptions; private readonly ILogger _logger; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly ServerSentEventsMessageParser _parser = new ServerSentEventsMessageParser(); - private Channel _application; + private IDuplexPipe _application; public Task Running { get; private set; } = Task.CompletedTask; @@ -48,7 +46,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, IConnection connection) + public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection) { if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) { @@ -66,15 +64,16 @@ namespace Microsoft.AspNetCore.Sockets.Client Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => { _logger.TransportStopped(t.Exception?.InnerException); + _application.Output.Complete(t.Exception?.InnerException); + _application.Input.Complete(); - _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); return t; }).Unwrap(); return Task.CompletedTask; } - private async Task OpenConnection(Channel application, Uri url, CancellationToken cancellationToken) + private async Task OpenConnection(IDuplexPipe application, Uri url, CancellationToken cancellationToken) { _logger.StartReceive(); @@ -83,59 +82,60 @@ namespace Microsoft.AspNetCore.Sockets.Client request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); - var stream = await response.Content.ReadAsStreamAsync(); - var pipelineReader = StreamPipeConnection.CreateReader(new PipeOptions(_memoryPool), stream); - var readCancellationRegistration = cancellationToken.Register( - reader => ((PipeReader)reader).CancelPendingRead(), pipelineReader); - try + using (var stream = await response.Content.ReadAsStreamAsync()) { - while (true) + var pipelineReader = StreamPipeConnection.CreateReader(PipeOptions.Default, stream); + var readCancellationRegistration = cancellationToken.Register( + reader => ((PipeReader)reader).CancelPendingRead(), pipelineReader); + try { - var result = await pipelineReader.ReadAsync(); - var input = result.Buffer; - if (result.IsCancelled || (input.IsEmpty && result.IsCompleted)) + while (true) { - _logger.EventStreamEnded(); - break; - } - - var consumed = input.Start; - var examined = input.End; - - try - { - var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer); - - switch (parseResult) + var result = await pipelineReader.ReadAsync(); + var input = result.Buffer; + if (result.IsCancelled || (input.IsEmpty && result.IsCompleted)) { - case ServerSentEventsMessageParser.ParseResult.Completed: - _application.Writer.TryWrite(buffer); - _parser.Reset(); - break; - case ServerSentEventsMessageParser.ParseResult.Incomplete: - if (result.IsCompleted) - { - throw new FormatException("Incomplete message."); - } - break; + _logger.EventStreamEnded(); + break; + } + + var consumed = input.Start; + var examined = input.End; + + try + { + var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer); + + switch (parseResult) + { + case ServerSentEventsMessageParser.ParseResult.Completed: + await _application.Output.WriteAsync(buffer); + _parser.Reset(); + break; + case ServerSentEventsMessageParser.ParseResult.Incomplete: + if (result.IsCompleted) + { + throw new FormatException("Incomplete message."); + } + break; + } + } + finally + { + pipelineReader.AdvanceTo(consumed, examined); } } - finally - { - pipelineReader.AdvanceTo(consumed, examined); - } } - } - catch (OperationCanceledException) - { - _logger.ReceiveCanceled(); - } - finally - { - readCancellationRegistration.Dispose(); - _transportCts.Cancel(); - stream.Dispose(); - _logger.ReceiveStopped(); + catch (OperationCanceledException) + { + _logger.ReceiveCanceled(); + } + finally + { + readCancellationRegistration.Dispose(); + _transportCts.Cancel(); + _logger.ReceiveStopped(); + } } } @@ -143,7 +143,6 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.TransportStopping(); _transportCts.Cancel(); - _application.Writer.TryComplete(); try { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index d1eadab32d..d338b5eb5b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -2,11 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Diagnostics; +using System.IO.Pipelines; using System.Net.WebSockets; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Sockets.Client.Internal; @@ -18,7 +16,7 @@ namespace Microsoft.AspNetCore.Sockets.Client public class WebSocketsTransport : ITransport { private readonly ClientWebSocket _webSocket; - private Channel _application; + private IDuplexPipe _application; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly CancellationTokenSource _receiveCts = new CancellationTokenSource(); private readonly ILogger _logger; @@ -54,7 +52,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, IConnection connection) + public async Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection) { if (url == null) { @@ -77,8 +75,8 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.StartTransport(Mode.Value); await Connect(url); - var sendTask = SendMessages(url); - var receiveTask = ReceiveMessages(url); + var sendTask = SendMessages(); + var receiveTask = ReceiveMessages(); // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 @@ -86,84 +84,48 @@ namespace Microsoft.AspNetCore.Sockets.Client { _webSocket.Dispose(); _logger.TransportStopped(t.Exception?.InnerException); - _application.Writer.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + + _application.Output.Complete(t.Exception?.InnerException); + _application.Input.Complete(); + return t; }).Unwrap(); } - private async Task ReceiveMessages(Uri pollUrl) + private async Task ReceiveMessages() { _logger.StartReceive(); try { - while (!_receiveCts.Token.IsCancellationRequested) + while (true) { - const int bufferSize = 4096; - var totalBytes = 0; - var incomingMessage = new List>(); - WebSocketReceiveResult receiveResult; - do + var memory = _application.Output.GetMemory(); + + // REVIEW: Use new Memory websocket APIs on .NET Core 2.1 + memory.TryGetArray(out var arraySegment); + + // Exceptions are handled above where the send and receive tasks are being run. + var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _receiveCts.Token); + if (receiveResult.MessageType == WebSocketMessageType.Close) { - var buffer = new ArraySegment(new byte[bufferSize]); + _logger.WebSocketClosed(receiveResult.CloseStatus); - //Exceptions are handled above where the send and receive tasks are being run. - receiveResult = await _webSocket.ReceiveAsync(buffer, _receiveCts.Token); - if (receiveResult.MessageType == WebSocketMessageType.Close) + if (receiveResult.CloseStatus != WebSocketCloseStatus.NormalClosure) { - _logger.WebSocketClosed(receiveResult.CloseStatus); - - _application.Writer.Complete( - receiveResult.CloseStatus == WebSocketCloseStatus.NormalClosure - ? null - : new InvalidOperationException( - $"Websocket closed with error: {receiveResult.CloseStatus}.")); - return; + throw new InvalidOperationException($"Websocket closed with error: {receiveResult.CloseStatus}."); } - _logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); - - var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); - incomingMessage.Add(truncBuffer); - totalBytes += receiveResult.Count; - } while (!receiveResult.EndOfMessage); - - //Making sure the message type is either text or binary - Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type"); - - var messageBuffer = new byte[totalBytes]; - if (incomingMessage.Count > 1) - { - var offset = 0; - for (var i = 0; i < incomingMessage.Count; i++) - { - Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count); - offset += incomingMessage[i].Count; - } - } - else - { - Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, messageBuffer, 0, incomingMessage[0].Count); + return; } - try + _logger.MessageReceived(receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); + + _application.Output.Advance(receiveResult.Count); + + if (receiveResult.EndOfMessage) { - if (!_transportCts.Token.IsCancellationRequested) - { - _logger.MessageToApp(messageBuffer.Length); - while (await _application.Writer.WaitToWriteAsync(_transportCts.Token)) - { - if (_application.Writer.TryWrite(messageBuffer)) - { - incomingMessage.Clear(); - break; - } - } - } - } - catch (OperationCanceledException) - { - _logger.CancelMessage(); + await _application.Output.FlushAsync(_transportCts.Token); } } } @@ -173,12 +135,13 @@ namespace Microsoft.AspNetCore.Sockets.Client } finally { + // We're done writing _logger.ReceiveStopped(); _transportCts.Cancel(); } } - private async Task SendMessages(Uri sendUrl) + private async Task SendMessages() { _logger.SendStarted(); @@ -189,32 +152,38 @@ namespace Microsoft.AspNetCore.Sockets.Client try { - while (await _application.Reader.WaitToReadAsync(_transportCts.Token)) + while (true) { - while (_application.Reader.TryRead(out SendMessage message)) + var result = await _application.Input.ReadAsync(_transportCts.Token); + var buffer = result.Buffer; + try { - try + if (!buffer.IsEmpty) { - _logger.ReceivedFromApp(message.Payload.Length); + _logger.ReceivedFromApp(buffer.Length); - await _webSocket.SendAsync(new ArraySegment(message.Payload), webSocketMessageType, true, _transportCts.Token); - - message.SendResult.SetResult(null); + await _webSocket.SendAsync(new ArraySegment(buffer.ToArray()), webSocketMessageType, true, _transportCts.Token); } - catch (OperationCanceledException) + else if (result.IsCompleted) { - _logger.SendMessageCanceled(); - message.SendResult.SetCanceled(); - await CloseWebSocket(); break; } - catch (Exception ex) - { - _logger.ErrorSendingMessage(ex); - message.SendResult.SetException(ex); - await CloseWebSocket(); - throw; - } + } + catch (OperationCanceledException) + { + _logger.SendMessageCanceled(); + await CloseWebSocket(); + break; + } + catch (Exception ex) + { + _logger.ErrorSendingMessage(ex); + await CloseWebSocket(); + throw; + } + finally + { + _application.Input.AdvanceTo(buffer.End); } } } diff --git a/test/Common/TaskExtensions.cs b/test/Common/TaskExtensions.cs index b26dd2b1b0..4621def836 100644 --- a/test/Common/TaskExtensions.cs +++ b/test/Common/TaskExtensions.cs @@ -20,11 +20,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + var cts = new CancellationTokenSource(); + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } + cts.Cancel(); await task; } @@ -36,11 +38,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests.Common public static async Task OrTimeout(this Task task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { - var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout)); + var cts = new CancellationTokenSource(); + var completed = await Task.WhenAny(task, Task.Delay(Debugger.IsAttached ? Timeout.InfiniteTimeSpan : timeout, cts.Token)); if (completed != task) { throw new TimeoutException(GetMessage(memberName, filePath, lineNumber)); } + cts.Cancel(); return await task; } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 3f624232e9..aaf4c72f20 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -219,7 +219,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task CanInvokeClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) { - using (StartLog(out var loggerFactory, $"{nameof(CanInvokeClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) + using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanInvokeClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { const string originalMessage = "SignalR"; @@ -252,7 +252,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) { - using (StartLog(out var loggerFactory, $"{nameof(InvokeNonExistantClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) + using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(InvokeNonExistantClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); var connection = new HubConnection(httpConnection, protocol, loggerFactory); @@ -292,7 +292,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task CanStreamClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) { - using (StartLog(out var loggerFactory, $"{nameof(CanStreamClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) + using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanStreamClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); var connection = new HubConnection(httpConnection, protocol, loggerFactory); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs index a05c0eff67..dcf796195f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs @@ -214,7 +214,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var httpHandler = new TestHttpMessageHandler(); var longPollResult = new TaskCompletionSource(); - httpHandler.OnLongPoll(cancellationToken => longPollResult.Task.OrTimeout()); + httpHandler.OnLongPoll(cancellationToken => + { + cancellationToken.Register(() => + { + longPollResult.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + return longPollResult.Task; + }); httpHandler.OnSocketSend((data, _) => { @@ -227,9 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests async (connection, closed) => { await connection.StartAsync().OrTimeout(); - await Assert.ThrowsAsync(() => connection.SendAsync(new byte[] { 0x42 }).OrTimeout()); - - longPollResult.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + await connection.SendAsync(new byte[] { 0x42 }).OrTimeout(); // Wait for the connection to close, because the send failed. await Assert.ThrowsAsync(() => closed.OrTimeout()); @@ -318,7 +323,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests async (connection, closed) => { await connection.StartAsync().OrTimeout(); - testTransport.Application.Writer.TryComplete(expected); + testTransport.Application.Output.Complete(expected); var actual = await Assert.ThrowsAsync(() => closed.OrTimeout()); Assert.Same(expected, actual); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs index f922fb2716..91b7a94cb0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs @@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }, receiveTcs); await connection.StartAsync().OrTimeout(); - Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + Assert.Contains("42", await receiveTcs.Task.OrTimeout()); }); } @@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }, receiveTcs); await connection.StartAsync().OrTimeout(); - Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + Assert.Contains("42", await receiveTcs.Task.OrTimeout()); Assert.True(receivedRaised); }); } @@ -98,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }, receiveTcs); await connection.StartAsync().OrTimeout(); - Assert.Equal("42", await receiveTcs.Task.OrTimeout()); + Assert.Contains("42", await receiveTcs.Task.OrTimeout()); Assert.True(receivedRaised); }); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs index beee62e7ed..7494a962f0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs @@ -92,13 +92,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task CallerReceivesExceptionsFromSendAsync() + public async Task ExceptionOnSendAsyncClosesWithError() { var testHttpHandler = new TestHttpMessageHandler(); - var longPollTcs = new TaskCompletionSource(); + var longPollTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - testHttpHandler.OnLongPoll(cancellationToken => longPollTcs.Task); + testHttpHandler.OnLongPoll(cancellationToken => + { + cancellationToken.Register(() => longPollTcs.TrySetResult(null)); + + return longPollTcs.Task; + }); testHttpHandler.OnSocketSend((buf, cancellationToken) => { @@ -111,10 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await connection.StartAsync().OrTimeout(); - var exception = await Assert.ThrowsAsync( - async () => await connection.SendAsync(new byte[0]).OrTimeout()); + await connection.SendAsync(new byte[] { 0 }).OrTimeout(); - longPollTcs.TrySetResult(null); + var exception = await Assert.ThrowsAsync(() => closed.OrTimeout()); }); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 8f2b70df44..f8cf843588 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -62,20 +63,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.StartAsync().OrTimeout(); // This will trigger the received callback - testTransport.Application.Writer.TryWrite(Array.Empty()); + await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); // Wait to hit the sync point. We are now blocking up the TaskQueue await onReceived.WaitForSyncPoint().OrTimeout(); // Now we write something else and we want to test that the HttpConnection receive loop is still // removing items from the channel even though OnReceived is blocked up. - testTransport.Application.Writer.TryWrite(Array.Empty()); + await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); // Now that we've written, we wait for WaitToReadAsync to return an INCOMPLETE task. It will do so // once HttpConnection reads the message. We also use a CTS to timeout in case the loop is indeed blocked var cts = new CancellationTokenSource(); cts.CancelAfter(TimeSpan.FromSeconds(5)); - while (testTransport.Application.Reader.WaitToReadAsync().IsCompleted && !cts.IsCancellationRequested) + while (testTransport.Application.Input.WaitToReadAsync().IsCompleted && !cts.IsCancellationRequested) { // Yield to allow the HttpConnection to dequeue the message await Task.Yield(); @@ -109,7 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.StartAsync().OrTimeout(); logger.LogInformation("Started connection"); - testTransport.Application.Writer.TryWrite(Array.Empty()); + await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); await onReceived.WaitForSyncPoint().OrTimeout(); // Dispose should complete, even though the receive callbacks are completely blocked up. diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 72bc1e37b6..b876176338 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -251,10 +251,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public ProtocolType Type => ProtocolType.Binary; - public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, out IList messages) + public bool TryParseMessages(ReadOnlySpan input, IInvocationBinder binder, IList messages) { - messages = new List(); - ParseCalls += 1; if (_error != null) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 918db74227..2b319b3f74 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Text; @@ -12,7 +13,6 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client.Tests; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; -using Microsoft.AspNetCore.Sockets.Internal; using Moq; using Moq.Protected; using Xunit; @@ -41,10 +41,8 @@ namespace Microsoft.AspNetCore.Client.Tests try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); transportActiveTask = longPollingTransport.Running; @@ -77,13 +75,14 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = ChannelConnection.Create(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); await longPollingTransport.Running.OrTimeout(); - Assert.True(transportToConnection.Reader.Completion.IsCompleted); + + Assert.True(pair.Transport.Input.TryRead(out var result)); + Assert.True(result.IsCompleted); + pair.Transport.Input.AdvanceTo(result.Buffer.End); } finally { @@ -130,17 +129,12 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); - var data = await transportToConnection.Reader.ReadAllAsync().OrTimeout(); + var data = await pair.Transport.Input.ReadAllAsync().OrTimeout(); await longPollingTransport.Running.OrTimeout(); - Assert.True(transportToConnection.Reader.Completion.IsCompleted); - Assert.Equal(2, data.Count); - Assert.Equal(Encoding.UTF8.GetBytes("Hello"), data[0]); - Assert.Equal(Encoding.UTF8.GetBytes("World"), data[1]); + Assert.Equal(Encoding.UTF8.GetBytes("HelloWorld"), data); } finally { @@ -166,13 +160,19 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); var exception = - await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion.OrTimeout()); + await Assert.ThrowsAsync(async () => + { + async Task ReadAsync() + { + await pair.Transport.Input.ReadAsync(); + } + + await ReadAsync().OrTimeout(); + }); Assert.Contains(" 500 ", exception.Message); } finally @@ -202,21 +202,14 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); - await connectionToTransport.Writer.WriteAsync(new SendMessage()); + await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); await Assert.ThrowsAsync(async () => await longPollingTransport.Running.OrTimeout()); - // The channel needs to be drained for the Completion task to be completed - while (transportToConnection.Reader.TryRead(out var message)) - { - } - - var exception = await Assert.ThrowsAsync(async () => await transportToConnection.Reader.Completion); + var exception = await Assert.ThrowsAsync(async () => await pair.Transport.Input.ReadAllAsync().OrTimeout()); Assert.Contains(" 500 ", exception.Message); } finally @@ -243,17 +236,14 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); - connectionToTransport.Writer.Complete(); + pair.Transport.Output.Complete(); await longPollingTransport.Running.OrTimeout(); - await longPollingTransport.Running.OrTimeout(); - await connectionToTransport.Reader.Completion.OrTimeout(); + await pair.Transport.Input.ReadAllAsync().OrTimeout(); } finally { @@ -292,32 +282,22 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); // Wait for the transport to finish await longPollingTransport.Running.OrTimeout(); // Pull Messages out of the channel - var messages = new List(); - while (await transportToConnection.Reader.WaitToReadAsync()) - { - while (transportToConnection.Reader.TryRead(out var message)) - { - messages.Add(message); - } - } + var message = await pair.Transport.Input.ReadAllAsync(); // Check the provided request Assert.Equal(2, sentRequests.Count); // Check the messages received - Assert.Single(messages); - Assert.Equal(message1Payload, messages[0]); + Assert.Equal(message1Payload, message); } finally { @@ -350,24 +330,19 @@ namespace Microsoft.AspNetCore.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - - var tcs1 = new TaskCompletionSource(); - var tcs2 = new TaskCompletionSource(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); // Pre-queue some messages - await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("Hello"), tcs1)).OrTimeout(); - await connectionToTransport.Writer.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout(); + await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + await pair.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("World")); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); - connectionToTransport.Writer.Complete(); + pair.Transport.Output.Complete(); await longPollingTransport.Running.OrTimeout(); - await connectionToTransport.Reader.Completion.OrTimeout(); + await pair.Transport.Input.ReadAllAsync(); Assert.Single(sentRequests); Assert.Equal(new byte[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o', (byte)'W', (byte)'o', (byte)'r', (byte)'l', (byte)'d' @@ -400,12 +375,10 @@ namespace Microsoft.AspNetCore.Client.Tests try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); Assert.Null(longPollingTransport.Mode); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connection: new TestConnection()); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, transferMode, connection: new TestConnection()); Assert.Equal(transferMode, longPollingTransport.Mode); } finally @@ -466,10 +439,8 @@ namespace Microsoft.AspNetCore.Client.Tests try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connection: new TestConnection()); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Binary, connection: new TestConnection()); var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout(); Assert.Equal(completionTcs.Task, completedTask); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 97de64ec80..22a89d5e3e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.IO.Pipelines; using System.Net.Http; using System.Net.Http.Headers; using System.Text; @@ -10,10 +11,8 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; -using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; -using Microsoft.AspNetCore.Sockets.Internal; using Moq; using Moq.Protected; using Xunit; @@ -51,11 +50,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); await sseTransport.StopAsync().OrTimeout(); @@ -102,15 +99,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); transportActiveTask = sseTransport.Running; Assert.False(transportActiveTask.IsCompleted); - var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout(); + var message = await pair.Transport.Input.ReadSingleAsync().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); } finally @@ -150,11 +146,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); var exception = await Assert.ThrowsAsync(() => sseTransport.Running.OrTimeout()); Assert.Equal("Incomplete message.", exception.Message); @@ -195,18 +189,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task; - var sendTcs = new TaskCompletionSource(); - Assert.True(connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs))); + await pair.Transport.Output.WriteAsync(new byte[] { 0x42 }); - var exception = await Assert.ThrowsAsync(() => sendTcs.Task.OrTimeout()); + var exception = await Assert.ThrowsAsync(() => pair.Transport.Input.ReadAllAsync().OrTimeout()); Assert.Contains("500", exception.Message); Assert.Same(exception, await Assert.ThrowsAsync(() => sseTransport.Running.OrTimeout())); @@ -242,15 +233,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); - connectionToTransport.Writer.TryComplete(null); + pair.Transport.Output.Complete(); await sseTransport.Running.OrTimeout(); } @@ -272,13 +261,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); - var message = await transportToConnection.Reader.ReadAsync().AsTask().OrTimeout(); + await sseTransport.StartAsync( + new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text, connection: Mock.Of()).OrTimeout(); + + var message = await pair.Transport.Input.ReadSingleAsync().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); await sseTransport.Running.OrTimeout(); @@ -302,11 +290,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); Assert.Null(sseTransport.Mode); - await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connection: Mock.Of()).OrTimeout(); + + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + await sseTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, transferMode, connection: Mock.Of()).OrTimeout(); Assert.Equal(TransferMode.Text, sseTransport.Mode); await sseTransport.StopAsync().OrTimeout(); } @@ -327,9 +314,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var sseTransport = new ServerSentEventsTransport(httpClient); - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var exception = await Assert.ThrowsAsync(() => sseTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connection: Mock.Of())); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs index a5ae73ecbf..ddbb37bf82 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs @@ -1,4 +1,5 @@ using System; +using System.IO.Pipelines; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -12,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private readonly Func _startHandler; public TransferMode? Mode { get; } - public Channel Application { get; private set; } + public IDuplexPipe Application { get; private set; } public TestTransport(Func onTransportStop = null, Func onTransportStart = null, TransferMode transferMode = TransferMode.Text) { @@ -21,7 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Mode = transferMode; } - public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, IConnection connection) + public Task StartAsync(Uri url, IDuplexPipe application, TransferMode requestedTransferMode, IConnection connection) { Application = application; return _startHandler(); @@ -30,7 +31,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task StopAsync() { await _stopHandler(); - Application.Writer.TryComplete(); + Application.Output.Complete(); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/Base64EncoderTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/Base64EncoderTests.cs index 1d6e1b6d6d..0db0079ae5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/Base64EncoderTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Encoders/Base64EncoderTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Text; using Xunit; @@ -22,11 +23,31 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Encoders [MemberData(nameof(Payloads))] public void VerifyEncode(string payload, string encoded) { - var encodedMessage = Encoding.UTF8.GetBytes(encoded); - var decodedMessage = Encoding.UTF8.GetString(new Base64Encoder().Decode(encodedMessage).ToArray()); + ReadOnlySpan encodedMessage = Encoding.UTF8.GetBytes(encoded); + var encoder = new Base64Encoder(); + encoder.TryDecode(ref encodedMessage, out var data); + var decodedMessage = Encoding.UTF8.GetString(data.ToArray()); Assert.Equal(payload, decodedMessage); } + [Fact] + public void CanParseMultipleMessages() + { + ReadOnlySpan data = Encoding.UTF8.GetBytes("28:QQpSDUMNCjtERUYxMjM0NTY3ODkw;4:QUJD;4:QUJD;"); + var encoder = new Base64Encoder(); + Assert.True(encoder.TryDecode(ref data, out var payload1)); + Assert.True(encoder.TryDecode(ref data, out var payload2)); + Assert.True(encoder.TryDecode(ref data, out var payload3)); + Assert.False(encoder.TryDecode(ref data, out var payload4)); + Assert.Equal(0, data.Length); + var payload1Value = Encoding.UTF8.GetString(payload1.ToArray()); + var payload2Value = Encoding.UTF8.GetString(payload2.ToArray()); + var payload3Value = Encoding.UTF8.GetString(payload3.ToArray()); + Assert.Equal("A\nR\rC\r\n;DEF1234567890", payload1Value); + Assert.Equal("ABC", payload2Value); + Assert.Equal("ABC", payload3Value); + } + public static IEnumerable Payloads => new object[][] { diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index e9039a9fb0..21535851b9 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -125,7 +125,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol var binder = new TestBinder(expectedMessage); var protocol = new JsonHubProtocol(Options.Create(protocolOptions)); - protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages); + var messages = new List(); + protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages); Assert.Equal(expectedMessage, messages[0], TestHubMessageEqualityComparer.Instance); } @@ -174,7 +175,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol var binder = new TestBinder(Array.Empty(), typeof(object)); var protocol = new JsonHubProtocol(); - var ex = Assert.Throws(() => protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages)); + var messages = new List(); + var ex = Assert.Throws(() => protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages)); Assert.Equal(expectedMessage, ex.Message); } @@ -189,7 +191,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool)); var protocol = new JsonHubProtocol(); - protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages); + var messages = new List(); + protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, messages); var ex = Assert.Throws(() => ((HubMethodInvocationMessage)messages[0]).Arguments); Assert.Equal(expectedMessage, ex.Message); } diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index 315bf73fc7..762a79df88 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -349,9 +349,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol // Parse the input fully now. bytes = Frame(bytes); var protocol = new MessagePackHubProtocol(); - Assert.True(protocol.TryParseMessages(bytes, new TestBinder(testData.Message), out var messages)); + var messages = new List(); + Assert.True(protocol.TryParseMessages(bytes, new TestBinder(testData.Message), messages)); - Assert.Equal(1, messages.Count); + Assert.Single(messages); Assert.Equal(testData.Message, messages[0], TestHubMessageEqualityComparer.Instance); } @@ -419,7 +420,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol { var buffer = Frame(Pack(testData.Encoded)); var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); - var exception = Assert.Throws(() => _hubProtocol.TryParseMessages(buffer, binder, out var messages)); + var messages = new List(); + var exception = Assert.Throws(() => _hubProtocol.TryParseMessages(buffer, binder, messages)); Assert.Equal(testData.ErrorMessage, exception.Message); } @@ -447,7 +449,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol { var buffer = Frame(Pack(testData.Encoded)); var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); - _hubProtocol.TryParseMessages(buffer, binder, out var messages); + var messages = new List(); + _hubProtocol.TryParseMessages(buffer, binder, messages); var exception = Assert.Throws(() => ((HubMethodInvocationMessage)messages[0]).Arguments); Assert.Equal(testData.ErrorMessage, exception.Message); @@ -458,7 +461,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol public void ParserDoesNotConsumePartialData(byte[] payload, int expectedMessagesCount) { var binder = new TestBinder(new[] { typeof(string) }, typeof(string)); - var result = _hubProtocol.TryParseMessages(payload, binder, out var messages); + var messages = new List(); + var result = _hubProtocol.TryParseMessages(payload, binder, messages); Assert.True(result || messages.Count == 0); Assert.Equal(expectedMessagesCount, messages.Count); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs new file mode 100644 index 0000000000..5098d4b5df --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs @@ -0,0 +1,75 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace System.IO.Pipelines +{ + public static class PipeReaderExtensions + { + public static async Task WaitToReadAsync(this PipeReader pipeReader) + { + while (true) + { + var result = await pipeReader.ReadAsync(); + + try + { + if (!result.Buffer.IsEmpty) + { + return true; + } + + if (result.IsCompleted) + { + return false; + } + } + finally + { + // Consume nothing, just wait for everything + pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); + } + } + } + + public static async Task ReadSingleAsync(this PipeReader pipeReader) + { + while (true) + { + var result = await pipeReader.ReadAsync(); + + try + { + return result.Buffer.ToArray(); + } + finally + { + pipeReader.AdvanceTo(result.Buffer.End); + } + } + } + + public static async Task ReadAllAsync(this PipeReader pipeReader) + { + while (true) + { + var result = await pipeReader.ReadAsync(); + + try + { + if (result.IsCompleted) + { + return result.Buffer.ToArray(); + } + } + finally + { + // Consume nothing, just wait for everything + pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 80d7ad05fd..3fa1dd8307 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -260,7 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [MemberData(nameof(MessageSizesData))] public async Task ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport(string message) { - using (StartLog(out var loggerFactory, testName: $"ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport_{message.Length}")) + using (StartLog(out var loggerFactory, LogLevel.Trace, testName: $"ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport_{message.Length}")) { var logger = loggerFactory.CreateLogger(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 0e250e26cf..86a884cc33 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -2,11 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Threading.Channels; +using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; -using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging.Testing; using Moq; @@ -36,12 +35,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartLog(out var loggerFactory)) { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, TransferMode.Binary, connection: Mock.Of()).OrTimeout(); await webSocketsTransport.StopAsync().OrTimeout(); await webSocketsTransport.Running.OrTimeout(); @@ -54,14 +50,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartLog(out var loggerFactory)) { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, TransferMode.Binary, connection: Mock.Of()); - connectionToTransport.Writer.TryComplete(); + pair.Transport.Output.Complete(); await webSocketsTransport.Running.OrTimeout(TimeSpan.FromSeconds(10)); } } @@ -74,33 +67,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartLog(out var loggerFactory)) { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connection: Mock.Of()); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, transferMode, connection: Mock.Of()); - var sendTcs = new TaskCompletionSource(); - connectionToTransport.Writer.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); - try - { - await sendTcs.Task; - } - catch (OperationCanceledException) - { - // Because the server and client are run in the same process there is a race where websocket.SendAsync - // can send a message but before returning be suspended allowing the server to run the EchoEndpoint and - // send a close frame which triggers a cancellation token on the client and cancels the websocket.SendAsync. - // Our solution to this is to just catch OperationCanceledException from the sent message if the race happens - // because we know the send went through, and its safe to check the response. - } + await pair.Transport.Output.WriteAsync(new byte[] { 0x42 }); // The echo endpoint closes the connection immediately after sending response which should stop the transport await webSocketsTransport.Running.OrTimeout(); - Assert.True(transportToConnection.Reader.TryRead(out var buffer)); - Assert.Equal(new byte[] { 0x42 }, buffer); + Assert.True(pair.Transport.Input.TryRead(out var result)); + Assert.Equal(new byte[] { 0x42 }, result.Buffer.ToArray()); + pair.Transport.Input.AdvanceTo(result.Buffer.End); } } @@ -112,14 +90,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartLog(out var loggerFactory)) { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); Assert.Null(webSocketsTransport.Mode); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), pair.Application, transferMode, connection: Mock.Of()).OrTimeout(); Assert.Equal(transferMode, webSocketsTransport.Mode); @@ -134,13 +109,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartLog(out var loggerFactory)) { - var connectionToTransport = Channel.CreateUnbounded(); - var transportToConnection = Channel.CreateUnbounded(); - var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); var webSocketsTransport = new WebSocketsTransport(httpOptions: null, loggerFactory: loggerFactory); var exception = await Assert.ThrowsAsync(() => - webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connection: Mock.Of())); + webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), pair.Application, TransferMode.Text | TransferMode.Binary, connection: Mock.Of())); Assert.Contains("Invalid transfer mode.", exception.Message); Assert.Equal("requestedTransferMode", exception.ParamName);