From cfaa123eb88954053de58bc3c98cd2ab01272915 Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Tue, 27 Mar 2018 23:02:07 -0700 Subject: [PATCH] IConnection refactoring (#1718) - IConnection is now single-use and HubConnection creates a new instance for reconnecting - IConnection is just a Pipe now, no more events --- .../ts/FunctionalTests/FunctionalTests.csproj | 24 + clients/ts/FunctionalTests/tsconfig.json | 1 + samples/ClientSample/RawSample.cs | 65 +- samples/SocketsSample/web.config | 4 +- samples/SocketsSample/wwwroot/sockets.html | 2 +- .../HubConnection.Log.cs | 241 +++- .../HubConnection.cs | 1009 +++++++++++------ .../HubConnectionBuilder.cs | 5 +- .../InvocationRequest.cs | 2 +- .../Properties/AssemblyInfo.cs | 6 + .../Internal/Formatters/TextMessageParser.cs | 20 +- .../IConnection.cs | 16 +- .../DefaultTransportFactory.cs | 1 + .../HttpConnection.Log.cs | 245 ++-- .../HttpConnection.cs | 905 +++++---------- .../HttpConnectionExtensions.cs | 20 - .../HttpOptions.cs | 2 +- .../LongPollingTransport.Log.cs | 19 +- .../{ => Internal}/LongPollingTransport.cs | 4 +- .../ServerSentEventsTransport.Log.cs | 4 +- .../ServerSentEventsTransport.cs | 2 +- .../Internal/TaskQueue.cs | 71 -- .../{ => Internal}/WebSocketsTransport.Log.cs | 2 +- .../{ => Internal}/WebSocketsTransport.cs | 4 +- .../HubConnectionTests.cs | 204 ++-- .../HttpConnectionTests.AbortAsync.cs | 130 --- ...HttpConnectionTests.ConnectionLifecycle.cs | 343 ++---- .../HttpConnectionTests.Helpers.cs | 90 +- .../HttpConnectionTests.Negotiate.cs | 20 +- .../HttpConnectionTests.OnReceived.cs | 109 -- ...nc.cs => HttpConnectionTests.Transport.cs} | 103 +- .../HttpConnectionTests.cs | 89 +- .../HubConnectionExtensionsTests.cs | 204 ---- .../HubConnectionProtocolTests.cs | 429 ------- .../HubConnectionTests.ConnectionLifecycle.cs | 351 ++++++ .../HubConnectionTests.Extensions.cs | 202 ++++ .../HubConnectionTests.Helpers.cs | 16 + .../HubConnectionTests.Protocol.cs | 408 +++++++ .../HubConnectionTests.cs | 179 +-- .../LongPollingTransportTests.cs | 8 +- .../ServerSentEventsTransportTests.cs | 2 +- .../SyncPoint.cs | 80 ++ .../TaskQueueTests.cs | 56 - .../TestConnection.cs | 232 ++-- .../TestHttpMessageHandler.cs | 2 +- .../TestTransport.cs | 42 +- .../ChannelExtensions.cs | 25 +- ...soft.AspNetCore.SignalR.Tests.Utils.csproj | 4 + .../PipeCompletionExtensions.cs | 44 + .../PipeReaderExtensions.cs | 18 +- .../ServerFixture.cs | 20 +- .../TaskExtensions.cs | 8 +- .../DefaultTransportFactoryTests.cs | 1 + .../EndToEndTests.cs | 81 +- .../WebSocketsTransportTests.cs | 1 + 55 files changed, 2992 insertions(+), 3183 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR.Client.Core/Properties/AssemblyInfo.cs delete mode 100644 src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/LongPollingTransport.Log.cs (80%) rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/LongPollingTransport.cs (97%) rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/ServerSentEventsTransport.Log.cs (97%) rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/ServerSentEventsTransport.cs (99%) delete mode 100644 src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/WebSocketsTransport.Log.cs (99%) rename src/Microsoft.AspNetCore.Sockets.Client.Http/{ => Internal}/WebSocketsTransport.cs (98%) delete mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs delete mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs rename test/Microsoft.AspNetCore.SignalR.Client.Tests/{HttpConnectionTests.SendAsync.cs => HttpConnectionTests.Transport.cs} (57%) delete mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs delete mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Extensions.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Helpers.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs delete mode 100644 test/Microsoft.AspNetCore.SignalR.Client.Tests/TaskQueueTests.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeCompletionExtensions.cs diff --git a/clients/ts/FunctionalTests/FunctionalTests.csproj b/clients/ts/FunctionalTests/FunctionalTests.csproj index 4a51804075..250728200b 100644 --- a/clients/ts/FunctionalTests/FunctionalTests.csproj +++ b/clients/ts/FunctionalTests/FunctionalTests.csproj @@ -5,6 +5,18 @@ True + + + + + + + + + + + + @@ -25,6 +37,18 @@ + + + + + + + + + + + + diff --git a/clients/ts/FunctionalTests/tsconfig.json b/clients/ts/FunctionalTests/tsconfig.json index a9c8306f13..b82243eeea 100644 --- a/clients/ts/FunctionalTests/tsconfig.json +++ b/clients/ts/FunctionalTests/tsconfig.json @@ -1,4 +1,5 @@ { + "compileOnSave": false, "compilerOptions": { "noImplicitAny": false, "noEmitOnError": true, diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index 5ac8c1a6cb..b72804e7a9 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -2,15 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Diagnostics; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; using System.Linq; -using System.Net.Http; using System.Text; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.Extensions.CommandLineUtils; using Microsoft.Extensions.Logging; @@ -35,36 +36,25 @@ namespace ClientSample baseUrl = string.IsNullOrEmpty(baseUrl) ? "http://localhost:5000/chat" : baseUrl; var loggerFactory = new LoggerFactory(); - var logger = loggerFactory.CreateLogger(); Console.WriteLine($"Connecting to {baseUrl}..."); var connection = new HttpConnection(new Uri(baseUrl), loggerFactory); try { - var closeTcs = new TaskCompletionSource(); - connection.Closed += e => closeTcs.SetResult(null); - connection.OnReceived(data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}")); await connection.StartAsync(TransferFormat.Text); Console.WriteLine($"Connected to {baseUrl}"); - var cts = new CancellationTokenSource(); - Console.CancelKeyPress += async (sender, a) => + var shutdown = new TaskCompletionSource(); + Console.CancelKeyPress += (sender, a) => { a.Cancel = true; - await connection.DisposeAsync(); + shutdown.TrySetResult(null); }; - while (!closeTcs.Task.IsCompleted) - { - var line = await Task.Run(() => Console.ReadLine(), cts.Token); + _ = ReceiveLoop(Console.Out, connection.Transport.Input); + _ = SendLoop(Console.In, connection.Transport.Output); - if (line == null) - { - break; - } - - await connection.SendAsync(Encoding.UTF8.GetBytes(line), cts.Token); - } + await shutdown.Task; } catch (AggregateException aex) when (aex.InnerExceptions.All(e => e is OperationCanceledException)) { @@ -78,5 +68,40 @@ namespace ClientSample } return 0; } + + private static async Task ReceiveLoop(TextWriter output, PipeReader input) + { + while (true) + { + var result = await input.ReadAsync(); + var buffer = result.Buffer; + + try + { + if (!buffer.IsEmpty) + { + await output.WriteLineAsync(Encoding.UTF8.GetString(buffer.ToArray())); + } + else if (result.IsCompleted) + { + // No more data, and the pipe is complete + break; + } + } + finally + { + input.AdvanceTo(buffer.End); + } + } + } + + private static async Task SendLoop(TextReader input, PipeWriter output) + { + while (true) + { + var result = await input.ReadLineAsync(); + await output.WriteAsync(Encoding.UTF8.GetBytes(result)); + } + } } } diff --git a/samples/SocketsSample/web.config b/samples/SocketsSample/web.config index 8700b60c05..5defc1eb37 100644 --- a/samples/SocketsSample/web.config +++ b/samples/SocketsSample/web.config @@ -7,6 +7,8 @@ - + + + \ No newline at end of file diff --git a/samples/SocketsSample/wwwroot/sockets.html b/samples/SocketsSample/wwwroot/sockets.html index 3adc944092..0cbd7f6264 100644 --- a/samples/SocketsSample/wwwroot/sockets.html +++ b/samples/SocketsSample/wwwroot/sockets.html @@ -38,7 +38,7 @@ event.preventDefault(); }); - connection.start().then(function () { + connection.start(signalR.TransferFormat.Text).then(function () { console.log("Opened"); }, function () { console.log("Error opening connection"); diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.Log.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.Log.cs index 1b8c7bb35d..4edad81a03 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.Log.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.Log.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR.Client @@ -17,29 +18,29 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _preparingBlockingInvocation = LoggerMessage.Define(LogLevel.Trace, new EventId(2, "PreparingBlockingInvocation"), "Preparing blocking invocation '{InvocationId}' of '{Target}', with return type '{ReturnType}' and {ArgumentCount} argument(s)."); - private static readonly Action _registerInvocation = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, "RegisterInvocation"), "Registering Invocation ID '{InvocationId}' for tracking."); + private static readonly Action _registeringInvocation = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, "RegisteringInvocation"), "Registering Invocation ID '{InvocationId}' for tracking."); - private static readonly Action _issueInvocation = - LoggerMessage.Define(LogLevel.Trace, new EventId(4, "IssueInvocation"), "Issuing Invocation '{InvocationId}': {ReturnType} {MethodName}({Args})."); + private static readonly Action _issuingInvocation = + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "IssuingInvocation"), "Issuing Invocation '{InvocationId}': {ReturnType} {MethodName}({Args})."); - private static readonly Action _sendInvocation = - LoggerMessage.Define(LogLevel.Debug, new EventId(5, "SendInvocation"), "Sending Invocation '{InvocationId}'."); + private static readonly Action _sendingMessage = + LoggerMessage.Define(LogLevel.Debug, new EventId(5, "SendingMessage"), "Sending {MessageType} message '{InvocationId}'."); - private static readonly Action _sendInvocationCompleted = - LoggerMessage.Define(LogLevel.Debug, new EventId(6, "SendInvocationCompleted"), "Sending Invocation '{InvocationId}' completed."); + private static readonly Action _messageSent = + LoggerMessage.Define(LogLevel.Debug, new EventId(6, "MessageSent"), "Sending {MessageType} message '{InvocationId}' completed."); - private static readonly Action _sendInvocationFailed = - LoggerMessage.Define(LogLevel.Error, new EventId(7, "SendInvocationFailed"), "Sending Invocation '{InvocationId}' failed."); + private static readonly Action _failedToSendInvocation = + LoggerMessage.Define(LogLevel.Error, new EventId(7, "FailedToSendInvocation"), "Sending Invocation '{InvocationId}' failed."); private static readonly Action _receivedInvocation = LoggerMessage.Define(LogLevel.Trace, new EventId(8, "ReceivedInvocation"), "Received Invocation '{InvocationId}': {MethodName}({Args})."); - private static readonly Action _dropCompletionMessage = - LoggerMessage.Define(LogLevel.Warning, new EventId(9, "DropCompletionMessage"), "Dropped unsolicited Completion message for invocation '{InvocationId}'."); + private static readonly Action _droppedCompletionMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(9, "DroppedCompletionMessage"), "Dropped unsolicited Completion message for invocation '{InvocationId}'."); - private static readonly Action _dropStreamMessage = - LoggerMessage.Define(LogLevel.Warning, new EventId(10, "DropStreamMessage"), "Dropped unsolicited StreamItem message for invocation '{InvocationId}'."); + private static readonly Action _droppedStreamMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(10, "DroppedStreamMessage"), "Dropped unsolicited StreamItem message for invocation '{InvocationId}'."); private static readonly Action _shutdownConnection = LoggerMessage.Define(LogLevel.Trace, new EventId(11, "ShutdownConnection"), "Shutting down connection."); @@ -47,8 +48,8 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _shutdownWithError = LoggerMessage.Define(LogLevel.Error, new EventId(12, "ShutdownWithError"), "Connection is shutting down due to an error."); - private static readonly Action _removeInvocation = - LoggerMessage.Define(LogLevel.Trace, new EventId(13, "RemoveInvocation"), "Removing pending invocation {InvocationId}."); + private static readonly Action _removingInvocation = + LoggerMessage.Define(LogLevel.Trace, new EventId(13, "RemovingInvocation"), "Removing pending invocation {InvocationId}."); private static readonly Action _missingHandler = LoggerMessage.Define(LogLevel.Warning, new EventId(14, "MissingHandler"), "Failed to find handler for '{Target}' method."); @@ -68,11 +69,11 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _cancelingInvocationCompletion = LoggerMessage.Define(LogLevel.Trace, new EventId(19, "CancelingInvocationCompletion"), "Canceling dispatch of Completion message for Invocation {InvocationId}. The invocation was canceled."); - private static readonly Action _cancelingCompletion = - LoggerMessage.Define(LogLevel.Trace, new EventId(20, "CancelingCompletion"), "Canceling dispatch of Completion message for Invocation {InvocationId}. The invocation was canceled."); + private static readonly Action _releasingConnectionLock = + LoggerMessage.Define(LogLevel.Trace, new EventId(20, "ReleasingConnectionLock"), "Releasing Connection Lock in {MethodName} ({FilePath}:{LineNumber})."); - private static readonly Action _invokeAfterTermination = - LoggerMessage.Define(LogLevel.Error, new EventId(21, "InvokeAfterTermination"), "Invoke for Invocation '{InvocationId}' was called after the connection was terminated."); + private static readonly Action _stopped = + LoggerMessage.Define(LogLevel.Debug, new EventId(21, "Stopped"), "HubConnection stopped."); private static readonly Action _invocationAlreadyInUse = LoggerMessage.Define(LogLevel.Critical, new EventId(22, "InvocationAlreadyInUse"), "Invocation ID '{InvocationId}' is already in use."); @@ -125,6 +126,60 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _receivedCloseWithError = LoggerMessage.Define(LogLevel.Error, new EventId(38, "ReceivedCloseWithError"), "Received close message with an error: {Error}"); + private static readonly Action _handshakeComplete = + LoggerMessage.Define(LogLevel.Debug, new EventId(39, "HandshakeComplete"), "Handshake with server complete."); + + private static readonly Action _registeringHandler = + LoggerMessage.Define(LogLevel.Debug, new EventId(40, "RegisteringHandler"), "Registering handler for client method '{MethodName}'."); + + private static readonly Action _starting = + LoggerMessage.Define(LogLevel.Debug, new EventId(41, "Starting"), "Starting HubConnection."); + + private static readonly Action _waitingOnConnectionLock = + LoggerMessage.Define(LogLevel.Trace, new EventId(42, "WaitingOnConnectionLock"), "Waiting on Connection Lock in {MethodName} ({FilePath}:{LineNumber})."); + + private static readonly Action _errorStartingConnection = + LoggerMessage.Define(LogLevel.Error, new EventId(43, "ErrorStartingConnection"), "Error starting connection."); + + private static readonly Action _started = + LoggerMessage.Define(LogLevel.Information, new EventId(44, "Started"), "HubConnection started."); + + private static readonly Action _sendingCancellation = + LoggerMessage.Define(LogLevel.Debug, new EventId(45, "SendingCancellation"), "Sending Cancellation for Invocation '{InvocationId}'."); + + private static readonly Action _cancelingOutstandingInvocations = + LoggerMessage.Define(LogLevel.Debug, new EventId(46, "CancelingOutstandingInvocations"), "Canceling all outstanding invocations."); + + private static readonly Action _receiveLoopStarting = + LoggerMessage.Define(LogLevel.Debug, new EventId(47, "ReceiveLoopStarting"), "Receive loop starting."); + + private static readonly Action _startingServerTimeoutTimer = + LoggerMessage.Define(LogLevel.Debug, new EventId(48, "StartingServerTimeoutTimer"), "Starting server timeout timer. Duration: {ServerTimeout:0.00}ms"); + + private static readonly Action _notUsingServerTimeout = + LoggerMessage.Define(LogLevel.Debug, new EventId(49, "NotUsingServerTimeout"), "Not using server timeout because the transport inherently tracks server availability."); + + private static readonly Action _serverDisconnectedWithError = + LoggerMessage.Define(LogLevel.Error, new EventId(50, "ServerDisconnectedWithError"), "The server connection was terminated with an error."); + + private static readonly Action _invokingClosedEventHandler = + LoggerMessage.Define(LogLevel.Debug, new EventId(51, "InvokingClosedEventHandler"), "Invoking the Closed event handler."); + + private static readonly Action _stopping = + LoggerMessage.Define(LogLevel.Debug, new EventId(52, "Stopping"), "Stopping HubConnection."); + + private static readonly Action _terminatingReceiveLoop = + LoggerMessage.Define(LogLevel.Debug, new EventId(53, "TerminatingReceiveLoop"), "Terminating receive loop."); + + private static readonly Action _waitingForReceiveLoopToTerminate = + LoggerMessage.Define(LogLevel.Debug, new EventId(54, "WaitingForReceiveLoopToTerminate"), "Waiting for the receive loop to terminate."); + + private static readonly Action _unableToSendCancellation = + LoggerMessage.Define(LogLevel.Trace, new EventId(55, "UnableToSendCancellation"), "Unable to send cancellation for invocation '{InvocationId}'. The connection is inactive."); + + private static readonly Action _processingMessage = + LoggerMessage.Define(LogLevel.Debug, new EventId(56, "ProcessingMessage"), "Processing {MessageLength} byte message from server."); + public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count) { _preparingNonBlockingInvocation(logger, target, count, null); @@ -140,33 +195,39 @@ namespace Microsoft.AspNetCore.SignalR.Client _preparingStreamingInvocation(logger, invocationId, target, returnType, count, null); } - public static void RegisterInvocation(ILogger logger, string invocationId) + public static void RegisteringInvocation(ILogger logger, string invocationId) { - _registerInvocation(logger, invocationId, null); + _registeringInvocation(logger, invocationId, null); } - public static void IssueInvocation(ILogger logger, string invocationId, string returnType, string methodName, object[] args) + public static void IssuingInvocation(ILogger logger, string invocationId, string returnType, string methodName, object[] args) { if (logger.IsEnabled(LogLevel.Trace)) { var argsList = args == null ? string.Empty : string.Join(", ", args.Select(a => a?.GetType().FullName ?? "(null)")); - _issueInvocation(logger, invocationId, returnType, methodName, argsList, null); + _issuingInvocation(logger, invocationId, returnType, methodName, argsList, null); } } - public static void SendInvocation(ILogger logger, string invocationId) + public static void SendingMessage(ILogger logger, HubInvocationMessage message) { - _sendInvocation(logger, invocationId, null); + if (logger.IsEnabled(LogLevel.Debug)) + { + _sendingMessage(logger, message.GetType().Name, message.InvocationId, null); + } } - public static void SendInvocationCompleted(ILogger logger, string invocationId) + public static void MessageSent(ILogger logger, HubInvocationMessage message) { - _sendInvocationCompleted(logger, invocationId, null); + if (logger.IsEnabled(LogLevel.Debug)) + { + _messageSent(logger, message.GetType().Name, message.InvocationId, null); + } } - public static void SendInvocationFailed(ILogger logger, string invocationId, Exception exception) + public static void FailedToSendInvocation(ILogger logger, string invocationId, Exception exception) { - _sendInvocationFailed(logger, invocationId, exception); + _failedToSendInvocation(logger, invocationId, exception); } public static void ReceivedInvocation(ILogger logger, string invocationId, string methodName, object[] args) @@ -178,14 +239,14 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - public static void DropCompletionMessage(ILogger logger, string invocationId) + public static void DroppedCompletionMessage(ILogger logger, string invocationId) { - _dropCompletionMessage(logger, invocationId, null); + _droppedCompletionMessage(logger, invocationId, null); } - public static void DropStreamMessage(ILogger logger, string invocationId) + public static void DroppedStreamMessage(ILogger logger, string invocationId) { - _dropStreamMessage(logger, invocationId, null); + _droppedStreamMessage(logger, invocationId, null); } public static void ShutdownConnection(ILogger logger) @@ -198,9 +259,9 @@ namespace Microsoft.AspNetCore.SignalR.Client _shutdownWithError(logger, exception); } - public static void RemoveInvocation(ILogger logger, string invocationId) + public static void RemovingInvocation(ILogger logger, string invocationId) { - _removeInvocation(logger, invocationId, null); + _removingInvocation(logger, invocationId, null); } public static void MissingHandler(ILogger logger, string target) @@ -233,14 +294,9 @@ namespace Microsoft.AspNetCore.SignalR.Client _cancelingInvocationCompletion(logger, invocationId, null); } - public static void CancelingCompletion(ILogger logger, string invocationId) + public static void Stopped(ILogger logger) { - _cancelingCompletion(logger, invocationId, null); - } - - public static void InvokeAfterTermination(ILogger logger, string invocationId) - { - _invokeAfterTermination(logger, invocationId, null); + _stopped(logger, null); } public static void InvocationAlreadyInUse(ILogger logger, string invocationId) @@ -322,6 +378,105 @@ namespace Microsoft.AspNetCore.SignalR.Client { _receivedCloseWithError(logger, error, null); } + + public static void HandshakeComplete(ILogger logger) + { + _handshakeComplete(logger, null); + } + + public static void RegisteringHandler(ILogger logger, string methodName) + { + _registeringHandler(logger, methodName, null); + } + + public static void Starting(ILogger logger) + { + _starting(logger, null); + } + + public static void ErrorStartingConnection(ILogger logger, Exception ex) + { + _errorStartingConnection(logger, ex); + } + + public static void Started(ILogger logger) + { + _started(logger, null); + } + + public static void SendingCancellation(ILogger logger, string invocationId) + { + _sendingCancellation(logger, invocationId, null); + } + + public static void CancelingOutstandingInvocations(ILogger logger) + { + _cancelingOutstandingInvocations(logger, null); + } + + public static void ReceiveLoopStarting(ILogger logger) + { + _receiveLoopStarting(logger, null); + } + + public static void StartingServerTimeoutTimer(ILogger logger, TimeSpan serverTimeout) + { + if (logger.IsEnabled(LogLevel.Debug)) + { + _startingServerTimeoutTimer(logger, serverTimeout.TotalMilliseconds, null); + } + } + + public static void NotUsingServerTimeout(ILogger logger) + { + _notUsingServerTimeout(logger, null); + } + + public static void ServerDisconnectedWithError(ILogger logger, Exception ex) + { + _serverDisconnectedWithError(logger, ex); + } + + public static void InvokingClosedEventHandler(ILogger logger) + { + _invokingClosedEventHandler(logger, null); + } + + public static void Stopping(ILogger logger) + { + _stopping(logger, null); + } + + public static void TerminatingReceiveLoop(ILogger logger) + { + _terminatingReceiveLoop(logger, null); + } + + public static void WaitingForReceiveLoopToTerminate(ILogger logger) + { + _waitingForReceiveLoopToTerminate(logger, null); + } + + public static void ProcessingMessage(ILogger logger, long length) + { + _processingMessage(logger, length, null); + } + + public static void WaitingOnConnectionLock(ILogger logger, string memberName, string filePath, int lineNumber) + { + _waitingOnConnectionLock(logger, memberName, filePath, lineNumber, null); + } + + public static void ReleasingConnectionLock(ILogger logger, string memberName, string filePath, int lineNumber) + { + _releasingConnectionLock(logger, memberName, filePath, lineNumber, null); + } + + public static void UnableToSendCancellation(ILogger logger, string invocationId) + { + _unableToSendCancellation(logger, invocationId, null); + } } } } + diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index af3d8e2e05..c588fa4a9f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -2,9 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.IO; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; @@ -23,22 +27,20 @@ namespace Microsoft.AspNetCore.SignalR.Client { public static readonly TimeSpan DefaultServerTimeout = TimeSpan.FromSeconds(30); // Server ping rate is 15 sec, this is 2 times that. + // This lock protects the connection state. + private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1); + + // Persistent across all connections private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - private readonly IConnection _connection; private readonly IHubProtocol _protocol; - private readonly HubBinder _binder; - - private readonly object _pendingCallsLock = new object(); - private readonly Dictionary _pendingCalls = new Dictionary(); + private readonly Func _connectionFactory; private readonly ConcurrentDictionary> _handlers = new ConcurrentDictionary>(); - private CancellationTokenSource _connectionActive; + private bool _disposed; - private int _nextId; - private volatile bool _startCalled; - private readonly Timer _timeoutTimer; - private bool _needKeepAlive; - private bool _receivedHandshakeResponse; + // Transient state to a connection + private readonly object _pendingCallsLock = new object(); + private ConnectionState _connectionState; public event Action Closed; @@ -48,103 +50,46 @@ namespace Microsoft.AspNetCore.SignalR.Client /// public TimeSpan ServerTimeout { get; set; } = DefaultServerTimeout; - public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) + public HubConnection(Func connectionFactory, IHubProtocol protocol) : this(connectionFactory, protocol, NullLoggerFactory.Instance) { - if (connection == null) - { - throw new ArgumentNullException(nameof(connection)); - } + } - if (protocol == null) - { - throw new ArgumentNullException(nameof(protocol)); - } + public HubConnection(Func connectionFactory, IHubProtocol protocol, ILoggerFactory loggerFactory) + { + _connectionFactory = connectionFactory ?? throw new ArgumentNullException(nameof(connectionFactory)); + _protocol = protocol ?? throw new ArgumentNullException(nameof(protocol)); - _connection = connection; - _binder = new HubBinder(this); - _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); - _connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this); - _connection.Closed += Shutdown; - - // Create the timer for timeout, but disabled by default (we enable it when started). - _timeoutTimer = new Timer(state => ((HubConnection)state).TimeoutElapsed(), this, Timeout.Infinite, Timeout.Infinite); } public async Task StartAsync() { - try + CheckDisposed(); + await StartAsyncCore().ForceAsync(); + } + + public async Task StopAsync() + { + CheckDisposed(); + await StopAsyncCore(disposing: false).ForceAsync(); + } + + public async Task DisposeAsync() + { + if (!_disposed) { - await StartAsyncCore().ForceAsync(); - } - finally - { - _startCalled = true; + await StopAsyncCore(disposing: true).ForceAsync(); } } - private void TimeoutElapsed() - { - _connection.AbortAsync(new TimeoutException($"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server.")); - } - - private void ResetTimeoutTimer() - { - if (_needKeepAlive) - { - Log.ResettingKeepAliveTimer(_logger); - - // If the connection is disposed while this callback is firing, or if the callback is fired after dispose - // (which can happen because of some races), this will throw ObjectDisposedException. That's OK, because - // we don't need the timer anyway. - try - { - _timeoutTimer.Change(ServerTimeout, Timeout.InfiniteTimeSpan); - } - catch (ObjectDisposedException) - { - // This is OK! - } - } - } - - private async Task StartAsyncCore() - { - await _connection.StartAsync(_protocol.TransferFormat); - _needKeepAlive = _connection.Features.Get() == null; - _receivedHandshakeResponse = false; - - Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); - - _connectionActive = new CancellationTokenSource(); - using (var memoryStream = new MemoryStream()) - { - Log.SendingHubHandshake(_logger); - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream); - await _connection.SendAsync(memoryStream.ToArray(), _connectionActive.Token); - } - - ResetTimeoutTimer(); - } - - public async Task StopAsync() => await StopAsyncCore().ForceAsync(); - - private Task StopAsyncCore() => _connection.StopAsync(); - - public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); - - private async Task DisposeAsyncCore() - { - await _connection.DisposeAsync(); - - // Dispose the timer AFTER shutting down the connection. - _timeoutTimer.Dispose(); - } - - // TODO: Client return values/tasks? public IDisposable On(string methodName, Type[] parameterTypes, Func handler, object state) { + Log.RegisteringHandler(_logger, methodName); + + CheckDisposed(); + + // It's OK to be disposed while registering a callback, we'll just never call the callback anyway (as with all the callbacks registered before disposal). var invocationHandler = new InvocationHandler(parameterTypes, handler, state); var invocationList = _handlers.AddOrUpdate(methodName, _ => new List { invocationHandler }, (_, invocations) => @@ -159,88 +104,215 @@ namespace Microsoft.AspNetCore.SignalR.Client return new Subscription(invocationHandler, invocationList); } - public async Task> StreamAsChannelAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) + public async Task> StreamAsChannelAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) => + await StreamAsChannelAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); + + public async Task InvokeAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) => + await InvokeAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); + + // REVIEW: We don't generally use cancellation tokens when writing to a pipe because the asynchrony is only the result of backpressure. + // However, this would be the only "invocation" method _without_ a cancellation token... which is odd. + public async Task SendAsync(string methodName, object[] args, CancellationToken cancellationToken = default) => + await SendAsyncCore(methodName, args, cancellationToken).ForceAsync(); + + private async Task StartAsyncCore() { - return await StreamAsChannelAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); + await WaitConnectionLockAsync(); + try + { + if (_connectionState != null) + { + // We're already connected + return; + } + + CheckDisposed(); + + Log.Starting(_logger); + + // Start the connection + var connection = _connectionFactory(); + await connection.StartAsync(_protocol.TransferFormat); + _connectionState = new ConnectionState(connection, this); + + // From here on, if an error occurs we need to shut down the connection because + // we still own it. + try + { + Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); + await HandshakeAsync(); + } + catch (Exception ex) + { + Log.ErrorStartingConnection(_logger, ex); + + // Can't have any invocations to cancel, we're in the lock. + await _connectionState.Connection.DisposeAsync(); + throw; + } + + _connectionState.ReceiveTask = ReceiveLoop(_connectionState); + Log.Started(_logger); + } + finally + { + ReleaseConnectionLock(); + } + } + + // This method does both Dispose and Start, the 'disposing' flag indicates which. + // The behaviors are nearly identical, except that the _disposed flag is set in the lock + // if we're disposing. + private async Task StopAsyncCore(bool disposing) + { + // Block a Start from happening until we've finished capturing the connection state. + ConnectionState connectionState; + await WaitConnectionLockAsync(); + try + { + if (disposing && _disposed) + { + // DisposeAsync should be idempotent. + return; + } + + CheckDisposed(); + connectionState = _connectionState; + + // Set the stopping flag so that any invocations after this get a useful error message instead of + // silently failing or throwing an error about the pipe being completed. + if (connectionState != null) + { + connectionState.Stopping = true; + } + + if (disposing) + { + _disposed = true; + } + } + finally + { + ReleaseConnectionLock(); + } + + // Now stop the connection we captured + if (connectionState != null) + { + await connectionState.StopAsync(ServerTimeout); + } } private async Task> StreamAsChannelAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - if (!_startCalled) + async Task OnStreamCancelled(InvocationRequest irq) { - throw new InvalidOperationException($"The '{nameof(StreamAsChannelAsync)}' method cannot be called before the connection has been started."); - } - - var invokeCts = new CancellationTokenSource(); - var irq = InvocationRequest.Stream(invokeCts.Token, returnType, GetNextId(), _loggerFactory, this, out var channel); - // After InvokeCore we don't want the irq cancellation token to be triggered. - // The stream invocation will be canceled by the CancelInvocationMessage, connection closing, or channel finishing. - using (cancellationToken.Register(token => ((CancellationTokenSource)token).Cancel(), invokeCts)) - { - await InvokeStreamCore(methodName, irq, args); - } - - if (cancellationToken.CanBeCanceled) - { - cancellationToken.Register(state => + // We need to take the connection lock in order to ensure we a) have a connection and b) are the only one accessing the write end of the pipe. + await WaitConnectionLockAsync(); + try { - var invocationReq = (InvocationRequest)state; - if (!invocationReq.HubConnection._connectionActive.IsCancellationRequested) + if (_connectionState != null) { + Log.SendingCancellation(_logger, irq.InvocationId); + // Fire and forget, if it fails that means we aren't connected anymore. - _ = invocationReq.HubConnection.SendHubMessage(new CancelInvocationMessage(invocationReq.InvocationId), invocationReq); - - if (invocationReq.HubConnection.TryRemoveInvocation(invocationReq.InvocationId, out _)) - { - invocationReq.Complete(CompletionMessage.Empty(irq.InvocationId)); - } - - invocationReq.Dispose(); + _ = SendHubMessage(new CancelInvocationMessage(irq.InvocationId), irq.CancellationToken); } - }, irq); + else + { + Log.UnableToSendCancellation(_logger, irq.InvocationId); + } + } + finally + { + ReleaseConnectionLock(); + } + + // Cancel the invocation + irq.Dispose(); + } + + CheckDisposed(); + await WaitConnectionLockAsync(); + + ChannelReader channel; + try + { + CheckDisposed(); + CheckConnectionActive(nameof(StreamAsChannelAsync)); + + var irq = InvocationRequest.Stream(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out channel); + await InvokeStreamCore(methodName, irq, args, cancellationToken); + + if (cancellationToken.CanBeCanceled) + { + cancellationToken.Register(state => _ = OnStreamCancelled((InvocationRequest)state), irq); + } + } + finally + { + ReleaseConnectionLock(); } return channel; } - public async Task InvokeAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) => - await InvokeAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); private async Task InvokeAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - if (!_startCalled) + CheckDisposed(); + await WaitConnectionLockAsync(); + + Task invocationTask; + try { - throw new InvalidOperationException($"The '{nameof(InvokeAsync)}' method cannot be called before the connection has been started."); + CheckDisposed(); + CheckConnectionActive(nameof(InvokeAsync)); + + var irq = InvocationRequest.Invoke(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out invocationTask); + await InvokeCore(methodName, irq, args, cancellationToken); + } + finally + { + ReleaseConnectionLock(); } - var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, this, out var task); - await InvokeCore(methodName, irq, args); - return await task; + // Wait for this outside the lock, because it won't complete until the server responds. + return await invocationTask; } - private Task InvokeCore(string methodName, InvocationRequest irq, object[] args) + private async Task InvokeCore(string methodName, InvocationRequest irq, object[] args, CancellationToken cancellationToken) { - ThrowIfConnectionTerminated(irq.InvocationId); + AssertConnectionValid(); + Log.PreparingBlockingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); // Client invocations are always blocking var invocationMessage = new InvocationMessage(irq.InvocationId, target: methodName, argumentBindingException: null, arguments: args); - Log.RegisterInvocation(_logger, invocationMessage.InvocationId); + Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); - AddInvocation(irq); + _connectionState.AddInvocation(irq); // Trace the full invocation - Log.IssueInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); + Log.IssuingInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); - // We don't need to wait for this to complete. It will signal back to the invocation request. - return SendHubMessage(invocationMessage, irq); + try + { + await SendHubMessage(invocationMessage, cancellationToken); + } + catch (Exception ex) + { + Log.FailedToSendInvocation(_logger, invocationMessage.InvocationId, ex); + _connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); + irq.Fail(ex); + } } - private Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args) + private async Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args, CancellationToken cancellationToken) { - ThrowIfConnectionTerminated(irq.InvocationId); + AssertConnectionValid(); Log.PreparingStreamingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); @@ -248,92 +320,72 @@ namespace Microsoft.AspNetCore.SignalR.Client argumentBindingException: null, arguments: args); // I just want an excuse to use 'irq' as a variable name... - Log.RegisterInvocation(_logger, invocationMessage.InvocationId); + Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); - AddInvocation(irq); + _connectionState.AddInvocation(irq); // Trace the full invocation - Log.IssueInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); + Log.IssuingInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); - // We don't need to wait for this to complete. It will signal back to the invocation request. - return SendHubMessage(invocationMessage, irq); - } - - private async Task SendHubMessage(HubInvocationMessage hubMessage, InvocationRequest irq) - { try { - var payload = _protocol.WriteToArray(hubMessage); - Log.SendInvocation(_logger, hubMessage.InvocationId); - - await _connection.SendAsync(payload, irq.CancellationToken); - Log.SendInvocationCompleted(_logger, hubMessage.InvocationId); + await SendHubMessage(invocationMessage, cancellationToken); } catch (Exception ex) { - Log.SendInvocationFailed(_logger, hubMessage.InvocationId, ex); + Log.FailedToSendInvocation(_logger, invocationMessage.InvocationId, ex); + _connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); irq.Fail(ex); - TryRemoveInvocation(hubMessage.InvocationId, out _); } } - public async Task SendAsync(string methodName, object[] args, CancellationToken cancellationToken = default) => - await SendAsyncCore(methodName, args, cancellationToken).ForceAsync(); + private async Task SendHubMessage(HubInvocationMessage hubMessage, CancellationToken cancellationToken = default) + { + AssertConnectionValid(); + + var payload = _protocol.WriteToArray(hubMessage); + + Log.SendingMessage(_logger, hubMessage); + // REVIEW: If a token is passed in and is cancelled during FlushAsync it seems to break .Complete()... + await WriteAsync(payload, CancellationToken.None); + Log.MessageSent(_logger, hubMessage); + } private async Task SendAsyncCore(string methodName, object[] args, CancellationToken cancellationToken) { - if (!_startCalled) - { - throw new InvalidOperationException($"The '{nameof(SendAsync)}' method cannot be called before the connection has been started."); - } - - var invocationMessage = new InvocationMessage(null, target: methodName, - argumentBindingException: null, arguments: args); - - ThrowIfConnectionTerminated(invocationMessage.InvocationId); + CheckDisposed(); + await WaitConnectionLockAsync(); try { + CheckDisposed(); + CheckConnectionActive(nameof(SendAsync)); + Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length); - var payload = _protocol.WriteToArray(invocationMessage); - Log.SendInvocation(_logger, invocationMessage.InvocationId); + var invocationMessage = new InvocationMessage(null, target: methodName, + argumentBindingException: null, arguments: args); - await _connection.SendAsync(payload, cancellationToken); - Log.SendInvocationCompleted(_logger, invocationMessage.InvocationId); + await SendHubMessage(invocationMessage, cancellationToken); } - catch (Exception ex) + finally { - Log.SendInvocationFailed(_logger, invocationMessage.InvocationId, ex); - throw; + ReleaseConnectionLock(); } } - private async Task OnDataReceivedAsync(byte[] data) + private async Task<(bool close, Exception exception)> ProcessMessagesAsync(ReadOnlySequence buffer, ConnectionState connectionState) { - ResetTimeoutTimer(); + Log.ProcessingMessage(_logger, buffer.Length); + + // TODO: Don't ToArray it :) + var data = buffer.ToArray(); var currentData = new ReadOnlyMemory(data); Log.ParsingMessages(_logger, currentData.Length); - // first message received must be handshake response - if (!_receivedHandshakeResponse) - { - // process handshake and return left over data to parse additional messages - if (!ProcessHandshakeResponse(ref currentData)) - { - return; - } - - _receivedHandshakeResponse = true; - if (currentData.IsEmpty) - { - return; - } - } - var messages = new List(); - if (_protocol.TryParseMessages(currentData, _binder, messages)) + if (_protocol.TryParseMessages(currentData, connectionState, messages)) { Log.ReceivingMessages(_logger, messages.Count); foreach (var message in messages) @@ -344,38 +396,39 @@ namespace Microsoft.AspNetCore.SignalR.Client case InvocationMessage invocation: Log.ReceivedInvocation(_logger, invocation.InvocationId, invocation.Target, invocation.ArgumentBindingException != null ? null : invocation.Arguments); - await DispatchInvocationAsync(invocation, _connectionActive.Token); + await DispatchInvocationAsync(invocation); break; case CompletionMessage completion: - if (!TryRemoveInvocation(completion.InvocationId, out irq)) + if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq)) { - Log.DropCompletionMessage(_logger, completion.InvocationId); - return; + Log.DroppedCompletionMessage(_logger, completion.InvocationId); + } + else + { + DispatchInvocationCompletion(completion, irq); + irq.Dispose(); } - DispatchInvocationCompletion(completion, irq); - irq.Dispose(); break; case StreamItemMessage streamItem: // Complete the invocation with an error, we don't support streaming (yet) - if (!TryGetInvocation(streamItem.InvocationId, out irq)) + if (!connectionState.TryGetInvocation(streamItem.InvocationId, out irq)) { - Log.DropStreamMessage(_logger, streamItem.InvocationId); - return; + Log.DroppedStreamMessage(_logger, streamItem.InvocationId); + return (close: false, exception: null); } - DispatchInvocationStreamItemAsync(streamItem, irq); + await DispatchInvocationStreamItemAsync(streamItem, irq); break; case CloseMessage close: if (string.IsNullOrEmpty(close.Error)) { Log.ReceivedClose(_logger); - Shutdown(); + return (close: true, exception: null); } else { Log.ReceivedCloseWithError(_logger, close.Error); - Shutdown(new InvalidOperationException(close.Error)); + return (close: true, exception: new HubException($"The server closed the connection with the following error: {close.Error}")); } - break; case PingMessage _: Log.ReceivedPing(_logger); // Nothing to do on receipt of a ping. @@ -390,85 +443,11 @@ namespace Microsoft.AspNetCore.SignalR.Client { Log.FailedParsing(_logger, data.Length); } + + return (close: false, exception: null); } - private bool ProcessHandshakeResponse(ref ReadOnlyMemory data) - { - HandshakeResponseMessage message; - - try - { - // read first message out of the incoming data - if (!TextMessageParser.TryParseMessage(ref data, out var payload)) - { - throw new InvalidDataException("Unable to parse payload as a handshake response message."); - } - - message = HandshakeProtocol.ParseResponseMessage(payload); - } - catch (Exception ex) - { - // shutdown if we're unable to read handshake - Log.ErrorReceivingHandshakeResponse(_logger, ex); - Shutdown(ex); - return false; - } - - if (!string.IsNullOrEmpty(message.Error)) - { - // shutdown if handshake returns an error - Log.HandshakeServerError(_logger, message.Error); - Shutdown(); - return false; - } - - return true; - } - - private void Shutdown(Exception exception = null) - { - // check if connection has already been shutdown - if (_connectionActive.IsCancellationRequested) - { - return; - } - - Log.ShutdownConnection(_logger); - if (exception != null) - { - Log.ShutdownWithError(_logger, exception); - } - - lock (_pendingCallsLock) - { - // We cancel inside the lock to make sure everyone who was part-way through registering an invocation - // completes. This also ensures that nobody will add things to _pendingCalls after we leave this block - // because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock) - _connectionActive.Cancel(); - - foreach (var outstandingCall in _pendingCalls.Values) - { - Log.RemoveInvocation(_logger, outstandingCall.InvocationId); - if (exception != null) - { - outstandingCall.Fail(exception); - } - outstandingCall.Dispose(); - } - _pendingCalls.Clear(); - } - - try - { - Closed?.Invoke(exception); - } - catch (Exception ex) - { - Log.ErrorDuringClosedEvent(_logger, ex); - } - } - - private async Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) + private async Task DispatchInvocationAsync(InvocationMessage invocation) { // Find the handler if (!_handlers.TryGetValue(invocation.Target, out var handlers)) @@ -499,9 +478,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel - // and there's nobody to actually wait for us to finish. - private async void DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) + private async Task DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) { Log.ReceivedStreamItem(_logger, streamItem.InvocationId); @@ -529,60 +506,272 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private void ThrowIfConnectionTerminated(string invocationId) + private void CheckDisposed() { - if (_connectionActive.Token.IsCancellationRequested) + if (_disposed) { - Log.InvokeAfterTermination(_logger, invocationId); - throw new InvalidOperationException("Connection has been terminated."); + throw new ObjectDisposedException(nameof(HubConnection)); } } - private string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); - - private void AddInvocation(InvocationRequest irq) + private async Task HandshakeAsync() { - lock (_pendingCallsLock) + // Send the Handshake request + using (var memoryStream = new MemoryStream()) { - ThrowIfConnectionTerminated(irq.InvocationId); - if (_pendingCalls.ContainsKey(irq.InvocationId)) + Log.SendingHubHandshake(_logger); + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryStream); + var result = await WriteAsync(memoryStream.ToArray(), CancellationToken.None); + + if (result.IsCompleted) { - Log.InvocationAlreadyInUse(_logger, irq.InvocationId); - throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); + // The other side disconnected + throw new InvalidOperationException("The server disconnected before the handshake was completed"); } - else + } + + try + { + while (true) { - _pendingCalls.Add(irq.InvocationId, irq); + var result = await _connectionState.Connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + + try + { + // Read first message out of the incoming data + if (!buffer.IsEmpty && TextMessageParser.TryParseMessage(ref buffer, out var payload)) + { + // Buffer was advanced to the end of the message by TryParseMessage + consumed = buffer.Start; + var message = HandshakeProtocol.ParseResponseMessage(payload.ToArray()); + + if (!string.IsNullOrEmpty(message.Error)) + { + Log.HandshakeServerError(_logger, message.Error); + throw new HubException( + $"Unable to complete handshake with the server due to an error: {message.Error}"); + } + + break; + } + else if (result.IsCompleted) + { + // Not enough data, and we won't be getting any more data. + throw new InvalidOperationException( + "The server disconnected before sending a handshake response"); + } + } + finally + { + _connectionState.Connection.Transport.Input.AdvanceTo(consumed); + } } } + catch (Exception ex) + { + // shutdown if we're unable to read handshake + Log.ErrorReceivingHandshakeResponse(_logger, ex); + throw; + } + + Log.HandshakeComplete(_logger); + } + + private async Task ReceiveLoop(ConnectionState connectionState) + { + // We hold a local capture of the connection state because StopAsync may dump out the current one. + // We'll be locking any time we want to check back in to the "active" connection state. + + Log.ReceiveLoopStarting(_logger); + + var timeoutTimer = StartTimeoutTimer(connectionState); + + try + { + while (true) + { + var result = await connectionState.Connection.Transport.Input.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.End; // TODO: Support partial messages + var examined = buffer.End; + + try + { + if (result.IsCanceled) + { + // We were cancelled. Possibly because we were stopped gracefully + break; + } + else if (!buffer.IsEmpty) + { + ResetTimeoutTimer(timeoutTimer); + + // We have data, process it + var (close, exception) = await ProcessMessagesAsync(buffer, connectionState); + if (close) + { + // Closing because we got a close frame, possibly with an error in it. + connectionState.CloseException = exception; + break; + } + } + else if (result.IsCompleted) + { + break; + } + } + finally + { + connectionState.Connection.Transport.Input.AdvanceTo(consumed, examined); + } + } + } + catch (Exception ex) + { + Log.ServerDisconnectedWithError(_logger, ex); + connectionState.CloseException = ex; + } + + // Clear the connectionState field + await WaitConnectionLockAsync(); + try + { + SafeAssert(ReferenceEquals(_connectionState, connectionState), + "Someone other than ReceiveLoop cleared the connection state!"); + _connectionState = null; + } + finally + { + ReleaseConnectionLock(); + } + + // Stop the timeout timer. + timeoutTimer?.Dispose(); + + // Dispose the connection + await connectionState.Connection.DisposeAsync(); + + // Cancel any outstanding invocations within the connection lock + connectionState.CancelOutstandingInvocations(connectionState.CloseException); + + if (connectionState.CloseException != null) + { + Log.ShutdownWithError(_logger, connectionState.CloseException); + } + else + { + Log.ShutdownConnection(_logger); + } + + // Fire-and-forget the closed event + RunClosedEvent(connectionState.CloseException); + } + + private void RunClosedEvent(Exception closeException) + { + _ = Task.Run(() => + { + try + { + Log.InvokingClosedEventHandler(_logger); + Closed?.Invoke(closeException); + } + catch (Exception ex) + { + Log.ErrorDuringClosedEvent(_logger, ex); + } + }); + } + + private void ResetTimeoutTimer(Timer timeoutTimer) + { + if (timeoutTimer != null) + { + Log.ResettingKeepAliveTimer(_logger); + timeoutTimer.Change(ServerTimeout, Timeout.InfiniteTimeSpan); + } } - private bool TryGetInvocation(string invocationId, out InvocationRequest irq) + private Timer StartTimeoutTimer(ConnectionState connectionState) { - lock (_pendingCallsLock) + // Check if we need keep-alive + Timer timeoutTimer = null; + if (connectionState.Connection.Features.Get() == null) { - ThrowIfConnectionTerminated(invocationId); - return _pendingCalls.TryGetValue(invocationId, out irq); + Log.StartingServerTimeoutTimer(_logger, ServerTimeout); + timeoutTimer = new Timer( + state => OnTimeout((ConnectionState)state), + connectionState, + dueTime: ServerTimeout, + period: Timeout.InfiniteTimeSpan); + } + else + { + Log.NotUsingServerTimeout(_logger); + } + + return timeoutTimer; + } + + private void OnTimeout(ConnectionState connectionState) + { + if (!Debugger.IsAttached) + { + connectionState.CloseException = new TimeoutException( + $"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server."); + connectionState.Connection.Transport.Input.CancelPendingRead(); } } - private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) + private ValueTask WriteAsync(byte[] payload, CancellationToken cancellationToken = default) { - lock (_pendingCallsLock) + AssertConnectionValid(); + return _connectionState.Connection.Transport.Output.WriteAsync(payload, cancellationToken); + } + + private void CheckConnectionActive(string methodName) + { + if (_connectionState == null || _connectionState.Stopping) { - ThrowIfConnectionTerminated(invocationId); - if (_pendingCalls.TryGetValue(invocationId, out irq)) - { - _pendingCalls.Remove(invocationId); - return true; - } - else - { - return false; - } + throw new InvalidOperationException($"The '{methodName}' method cannot be called if the connection is not active"); } } + // Debug.Assert plays havoc with Unit Tests. But I want something that I can "assert" only in Debug builds. + [Conditional("DEBUG")] + private static void SafeAssert(bool condition, string message, [CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) + { + if (!condition) + { + throw new Exception($"Assertion failed in {memberName}, at {fileName}:{lineNumber}: {message}"); + } + } + + [Conditional("DEBUG")] + private void AssertInConnectionLock([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) => SafeAssert(_connectionLock.CurrentCount == 0, "We're not in the Connection Lock!", memberName, fileName, lineNumber); + + [Conditional("DEBUG")] + private void AssertConnectionValid([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) + { + AssertInConnectionLock(memberName, fileName, lineNumber); + SafeAssert(_connectionState != null, "We don't have a connection!", memberName, fileName, lineNumber); + } + + private Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + { + Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber); + return _connectionLock.WaitAsync(); + } + + private void ReleaseConnectionLock([CallerMemberName] string memberName = null, + [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + { + Log.ReleasingConnectionLock(_logger, memberName, filePath, lineNumber); + _connectionLock.Release(); + } + private class Subscription : IDisposable { private readonly InvocationHandler _handler; @@ -603,45 +792,6 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private class HubBinder : IInvocationBinder - { - private HubConnection _connection; - - public HubBinder(HubConnection connection) - { - _connection = connection; - } - - public Type GetReturnType(string invocationId) - { - if (!_connection._pendingCalls.TryGetValue(invocationId, out var irq)) - { - Log.ReceivedUnexpectedResponse(_connection._logger, invocationId); - return null; - } - return irq.ResultType; - } - - public IReadOnlyList GetParameterTypes(string methodName) - { - if (!_connection._handlers.TryGetValue(methodName, out var handlers)) - { - Log.MissingHandler(_connection._logger, methodName); - return Type.EmptyTypes; - } - - // We use the parameter types of the first handler - lock (handlers) - { - if (handlers.Count > 0) - { - return handlers[0].ParameterTypes; - } - throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); - } - } - } - private struct InvocationHandler { public Type[] ParameterTypes { get; } @@ -660,5 +810,158 @@ namespace Microsoft.AspNetCore.SignalR.Client return _callback(parameters, _state); } } + + // Represents all the transient state about a connection + // This includes binding information because return type binding depends upon _pendingCalls + private class ConnectionState : IInvocationBinder + { + private volatile bool _stopping; + private readonly HubConnection _hubConnection; + + private TaskCompletionSource _stopTcs; + private readonly object _lock = new object(); + private readonly Dictionary _pendingCalls = new Dictionary(); + private int _nextId; + + public IConnection Connection { get; } + public Task ReceiveTask { get; set; } + public Exception CloseException { get; set; } + + public bool Stopping + { + get => _stopping; + set => _stopping = value; + } + + public ConnectionState(IConnection connection, HubConnection hubConnection) + { + _hubConnection = hubConnection; + Connection = connection; + } + + public string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); + + public void AddInvocation(InvocationRequest irq) + { + lock (_lock) + { + if (_pendingCalls.ContainsKey(irq.InvocationId)) + { + Log.InvocationAlreadyInUse(_hubConnection._logger, irq.InvocationId); + throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); + } + else + { + _pendingCalls.Add(irq.InvocationId, irq); + } + } + } + + public bool TryGetInvocation(string invocationId, out InvocationRequest irq) + { + lock (_lock) + { + return _pendingCalls.TryGetValue(invocationId, out irq); + } + } + + public bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) + { + lock (_lock) + { + if (_pendingCalls.TryGetValue(invocationId, out irq)) + { + _pendingCalls.Remove(invocationId); + return true; + } + else + { + return false; + } + } + } + + public void CancelOutstandingInvocations(Exception exception) + { + Log.CancelingOutstandingInvocations(_hubConnection._logger); + + lock (_lock) + { + foreach (var outstandingCall in _pendingCalls.Values) + { + Log.RemovingInvocation(_hubConnection._logger, outstandingCall.InvocationId); + if (exception != null) + { + outstandingCall.Fail(exception); + } + outstandingCall.Dispose(); + } + _pendingCalls.Clear(); + } + } + + public Task StopAsync(TimeSpan timeout) + { + // We want multiple StopAsync calls on the same connection state + // to wait for the same "stop" to complete. + lock (_lock) + { + if (_stopTcs != null) + { + return _stopTcs.Task; + } + else + { + _stopTcs = new TaskCompletionSource(); + return StopAsyncCore(timeout); + } + } + } + + private async Task StopAsyncCore(TimeSpan timeout) + { + Log.Stopping(_hubConnection._logger); + + // Complete our write pipe, which should cause everything to shut down + Log.TerminatingReceiveLoop(_hubConnection._logger); + Connection.Transport.Input.CancelPendingRead(); + + // Wait ServerTimeout for the server or transport to shut down. + Log.WaitingForReceiveLoopToTerminate(_hubConnection._logger); + await ReceiveTask; + + Log.Stopped(_hubConnection._logger); + _stopTcs.TrySetResult(null); + } + + Type IInvocationBinder.GetReturnType(string invocationId) + { + if (!TryGetInvocation(invocationId, out var irq)) + { + Log.ReceivedUnexpectedResponse(_hubConnection._logger, invocationId); + return null; + } + return irq.ResultType; + } + + IReadOnlyList IInvocationBinder.GetParameterTypes(string methodName) + { + if (!_hubConnection._handlers.TryGetValue(methodName, out var handlers)) + { + Log.MissingHandler(_hubConnection._logger, methodName); + return Type.EmptyTypes; + } + + // We use the parameter types of the first handler + lock (handlers) + { + if (handlers.Count > 0) + { + return handlers[0].ParameterTypes; + } + throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); + } + } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs index 9be6b2993b..b1cdf7d7e1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionBuilder.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -43,12 +43,11 @@ namespace Microsoft.AspNetCore.SignalR.Client } IHubConnectionBuilder builder = this; - var connection = _connectionFactoryDelegate(); var loggerFactory = builder.GetLoggerFactory(); var hubProtocol = builder.GetHubProtocol(); - return new HubConnection(connection, hubProtocol ?? new JsonHubProtocol(), loggerFactory); + return new HubConnection(_connectionFactoryDelegate, hubProtocol ?? new JsonHubProtocol(), loggerFactory); } [EditorBrowsable(EditorBrowsableState.Never)] diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs index 1e935cc075..da39ebe2e1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/InvocationRequest.cs @@ -121,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Client protected override void Cancel() { - _channel.Writer.TryComplete(new OperationCanceledException("Invocation terminated")); + _channel.Writer.TryComplete(new OperationCanceledException()); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Properties/AssemblyInfo.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Properties/AssemblyInfo.cs new file mode 100644 index 0000000000..8bc7094d90 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Properties/AssemblyInfo.cs @@ -0,0 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Runtime.CompilerServices; + +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Client.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs index 1cdfe687a6..b7512d827c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Formatters/TextMessageParser.cs @@ -1,12 +1,30 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; namespace Microsoft.AspNetCore.SignalR.Internal.Formatters { public static class TextMessageParser { + public static bool TryParseMessage(ref ReadOnlySequence buffer, out ReadOnlySequence payload) + { + var position = buffer.PositionOf(TextMessageFormatter.RecordSeparator); + if (position == null) + { + payload = default; + return false; + } + + payload = buffer.Slice(0, position.Value); + + // Skip record separator + buffer = buffer.Slice(buffer.GetPosition(1, position.Value)); + + return true; + } + public static bool TryParseMessage(ref ReadOnlyMemory buffer, out ReadOnlyMemory payload) { var index = buffer.Span.IndexOf(TextMessageFormatter.RecordSeparator); diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index b5605c96f8..44142424c8 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; @@ -11,16 +12,11 @@ namespace Microsoft.AspNetCore.Sockets.Client { public interface IConnection { - Task StartAsync(TransferFormat transferFormat); - Task SendAsync(byte[] data, CancellationToken cancellationToken); - Task StopAsync(); - Task DisposeAsync(); - Task AbortAsync(Exception ex); - - IDisposable OnReceived(Func callback, object state); - - event Action Closed; - + IDuplexPipe Transport { get; } IFeatureCollection Features { get; } + + Task StartAsync(); + Task StartAsync(TransferFormat transferFormat); + Task DisposeAsync(); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs index 2f1f93c8da..e91a9e965b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/DefaultTransportFactory.cs @@ -4,6 +4,7 @@ using System; using System.Net.Http; using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Sockets.Client diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.Log.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.Log.cs index bb6433c34a..9e0d56da7c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.Log.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.Log.cs @@ -5,110 +5,91 @@ using System; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Http { public partial class HttpConnection { private static class Log { - private static readonly Action _httpConnectionStarting = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "HttpConnectionStarting"), "Starting connection."); + private static readonly Action _starting = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "Starting"), "Starting HttpConnection."); - private static readonly Action _httpConnectionClosed = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "HttpConnectionClosed"), "Connection was closed from a different thread."); + private static readonly Action _skippingStart = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "SkippingStart"), "Skipping start, connection is already started."); - private static readonly Action _startingTransport = - LoggerMessage.Define(LogLevel.Debug, new EventId(3, "StartingTransport"), "Starting transport '{Transport}' with Url: {Url}."); + private static readonly Action _started = + LoggerMessage.Define(LogLevel.Information, new EventId(3, "Started"), "HttpConnection Started."); - private static readonly Action _processRemainingMessages = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, "ProcessRemainingMessages"), "Ensuring all outstanding messages are processed."); - - private static readonly Action _drainEvents = - LoggerMessage.Define(LogLevel.Debug, new EventId(5, "DrainEvents"), "Draining event queue."); - - private static readonly Action _completeClosed = - LoggerMessage.Define(LogLevel.Debug, new EventId(6, "CompleteClosed"), "Completing Closed task."); - - private static readonly Action _establishingConnection = - LoggerMessage.Define(LogLevel.Debug, new EventId(7, "EstablishingConnection"), "Establishing Connection at: {Url}."); - - private static readonly Action _errorWithNegotiation = - LoggerMessage.Define(LogLevel.Error, new EventId(8, "ErrorWithNegotiation"), "Failed to start connection. Error getting negotiation response from '{Url}'."); - - private static readonly Action _errorStartingTransport = - LoggerMessage.Define(LogLevel.Error, new EventId(9, "ErrorStartingTransport"), "Failed to start connection. Error starting transport '{Transport}'."); - - private static readonly Action _httpReceiveStarted = - LoggerMessage.Define(LogLevel.Trace, new EventId(10, "HttpReceiveStarted"), "Beginning receive loop."); - - private static readonly Action _skipRaisingReceiveEvent = - LoggerMessage.Define(LogLevel.Debug, new EventId(11, "SkipRaisingReceiveEvent"), "Message received but connection is not connected. Skipping raising Received event."); - - private static readonly Action _scheduleReceiveEvent = - LoggerMessage.Define(LogLevel.Debug, new EventId(12, "ScheduleReceiveEvent"), "Scheduling raising Received event."); - - private static readonly Action _raiseReceiveEvent = - LoggerMessage.Define(LogLevel.Debug, new EventId(13, "RaiseReceiveEvent"), "Raising Received event."); - - private static readonly Action _failedReadingMessage = - LoggerMessage.Define(LogLevel.Debug, new EventId(14, "FailedReadingMessage"), "Could not read message."); - - private static readonly Action _errorReceiving = - LoggerMessage.Define(LogLevel.Error, new EventId(15, "ErrorReceiving"), "Error receiving message."); - - private static readonly Action _endReceive = - LoggerMessage.Define(LogLevel.Trace, new EventId(16, "EndReceive"), "Ending receive loop."); - - private static readonly Action _sendingMessage = - LoggerMessage.Define(LogLevel.Debug, new EventId(17, "SendingMessage"), "Sending message."); - - private static readonly Action _stoppingClient = - LoggerMessage.Define(LogLevel.Information, new EventId(18, "StoppingClient"), "Stopping client."); - - private static readonly Action _exceptionThrownFromCallback = - LoggerMessage.Define(LogLevel.Error, new EventId(19, "ExceptionThrownFromCallback"), "An exception was thrown from the '{Callback}' callback."); - - private static readonly Action _disposingClient = - LoggerMessage.Define(LogLevel.Information, new EventId(20, "DisposingClient"), "Disposing client."); - - private static readonly Action _abortingClient = - LoggerMessage.Define(LogLevel.Error, new EventId(21, "AbortingClient"), "Aborting client."); - - private static readonly Action _errorDuringClosedEvent = - LoggerMessage.Define(LogLevel.Error, new EventId(22, "ErrorDuringClosedEvent"), "An exception was thrown in the handler for the Closed event."); - - private static readonly Action _skippingStop = - LoggerMessage.Define(LogLevel.Debug, new EventId(23, "SkippingStop"), "Skipping stop, connection is already stopped."); + private static readonly Action _disposingHttpConnection = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, "DisposingHttpConnection"), "Disposing HttpConnection."); private static readonly Action _skippingDispose = - LoggerMessage.Define(LogLevel.Debug, new EventId(24, "SkippingDispose"), "Skipping dispose, connection is already disposed."); + LoggerMessage.Define(LogLevel.Debug, new EventId(5, "SkippingDispose"), "Skipping dispose, connection is already disposed."); - private static readonly Action _connectionStateChanged = - LoggerMessage.Define(LogLevel.Debug, new EventId(25, "ConnectionStateChanged"), "Connection state changed from {PreviousState} to {NewState}."); + private static readonly Action _disposed = + LoggerMessage.Define(LogLevel.Information, new EventId(6, "Disposed"), "HttpConnection Disposed."); + + private static readonly Action _startingTransport = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, "StartingTransport"), "Starting transport '{Transport}' with Url: {Url}."); + + private static readonly Action _establishingConnection = + LoggerMessage.Define(LogLevel.Debug, new EventId(8, "EstablishingConnection"), "Establishing connection with server at '{Url}'."); + + private static readonly Action _connectionEstablished = + LoggerMessage.Define(LogLevel.Debug, new EventId(9, "Established"), "Established connection '{ConnectionId}' with the server."); + + private static readonly Action _errorWithNegotiation = + LoggerMessage.Define(LogLevel.Error, new EventId(10, "ErrorWithNegotiation"), "Failed to start connection. Error getting negotiation response from '{Url}'."); + + private static readonly Action _errorStartingTransport = + LoggerMessage.Define(LogLevel.Error, new EventId(11, "ErrorStartingTransport"), "Failed to start connection. Error starting transport '{Transport}'."); private static readonly Action _transportNotSupported = - LoggerMessage.Define(LogLevel.Debug, new EventId(26, "TransportNotSupported"), "Skipping transport {TransportName} because it is not supported by this client."); + LoggerMessage.Define(LogLevel.Debug, new EventId(12, "TransportNotSupported"), "Skipping transport {TransportName} because it is not supported by this client."); private static readonly Action _transportDoesNotSupportTransferFormat = - LoggerMessage.Define(LogLevel.Debug, new EventId(27, "TransportDoesNotSupportTransferFormat"), "Skipping transport {TransportName} because it does not support the requested transfer format '{TransferFormat}'."); + LoggerMessage.Define(LogLevel.Debug, new EventId(13, "TransportDoesNotSupportTransferFormat"), "Skipping transport {TransportName} because it does not support the requested transfer format '{TransferFormat}'."); private static readonly Action _transportDisabledByClient = - LoggerMessage.Define(LogLevel.Debug, new EventId(28, "TransportDisabledByClient"), "Skipping transport {TransportName} because it was disabled by the client."); + LoggerMessage.Define(LogLevel.Debug, new EventId(14, "TransportDisabledByClient"), "Skipping transport {TransportName} because it was disabled by the client."); private static readonly Action _transportFailed = - LoggerMessage.Define(LogLevel.Debug, new EventId(29, "TransportFailed"), "Skipping transport {TransportName} because it failed to initialize."); + LoggerMessage.Define(LogLevel.Debug, new EventId(15, "TransportFailed"), "Skipping transport {TransportName} because it failed to initialize."); private static readonly Action _webSocketsNotSupportedByOperatingSystem = - LoggerMessage.Define(LogLevel.Debug, new EventId(30, "WebSocketsNotSupportedByOperatingSystem"), "Skipping WebSockets because they are not supported by the operating system."); + LoggerMessage.Define(LogLevel.Debug, new EventId(16, "WebSocketsNotSupportedByOperatingSystem"), "Skipping WebSockets because they are not supported by the operating system."); - public static void HttpConnectionStarting(ILogger logger) + private static readonly Action _transportThrewExceptionOnStop = + LoggerMessage.Define(LogLevel.Error, new EventId(17, "TransportThrewExceptionOnStop"), "The transport threw an exception while stopping."); + + public static void Starting(ILogger logger) { - _httpConnectionStarting(logger, null); + _starting(logger, null); } - public static void HttpConnectionClosed(ILogger logger) + public static void SkippingStart(ILogger logger) { - _httpConnectionClosed(logger, null); + _skippingStart(logger, null); + } + + public static void Started(ILogger logger) + { + _started(logger, null); + } + + public static void DisposingHttpConnection(ILogger logger) + { + _disposingHttpConnection(logger, null); + } + + public static void SkippingDispose(ILogger logger) + { + _skippingDispose(logger, null); + } + + public static void Disposed(ILogger logger) + { + _disposed(logger, null); } public static void StartingTransport(ILogger logger, TransportType transportType, Uri url) @@ -119,26 +100,16 @@ namespace Microsoft.AspNetCore.Sockets.Client } } - public static void ProcessRemainingMessages(ILogger logger) - { - _processRemainingMessages(logger, null); - } - - public static void DrainEvents(ILogger logger) - { - _drainEvents(logger, null); - } - - public static void CompleteClosed(ILogger logger) - { - _completeClosed(logger, null); - } - public static void EstablishingConnection(ILogger logger, Uri url) { _establishingConnection(logger, url, null); } + public static void ConnectionEstablished(ILogger logger, string connectionId) + { + _connectionEstablished(logger, connectionId, null); + } + public static void ErrorWithNegotiation(ILogger logger, Uri url, Exception exception) { _errorWithNegotiation(logger, url, exception); @@ -152,89 +123,6 @@ namespace Microsoft.AspNetCore.Sockets.Client } } - public static void HttpReceiveStarted(ILogger logger) - { - _httpReceiveStarted(logger, null); - } - - public static void SkipRaisingReceiveEvent(ILogger logger) - { - _skipRaisingReceiveEvent(logger, null); - } - - public static void ScheduleReceiveEvent(ILogger logger) - { - _scheduleReceiveEvent(logger, null); - } - - public static void RaiseReceiveEvent(ILogger logger) - { - _raiseReceiveEvent(logger, null); - } - - public static void FailedReadingMessage(ILogger logger) - { - _failedReadingMessage(logger, null); - } - - public static void ErrorReceiving(ILogger logger, Exception exception) - { - _errorReceiving(logger, exception); - } - - public static void EndReceive(ILogger logger) - { - _endReceive(logger, null); - } - - public static void SendingMessage(ILogger logger) - { - _sendingMessage(logger, null); - } - - public static void AbortingClient(ILogger logger, Exception ex) - { - _abortingClient(logger, ex); - } - - public static void StoppingClient(ILogger logger) - { - _stoppingClient(logger, null); - } - - public static void DisposingClient(ILogger logger) - { - _disposingClient(logger, null); - } - - public static void SkippingDispose(ILogger logger) - { - _skippingDispose(logger, null); - } - - public static void ConnectionStateChanged(ILogger logger, HttpConnection.ConnectionState previousState, HttpConnection.ConnectionState newState) - { - if (logger.IsEnabled(LogLevel.Debug)) - { - _connectionStateChanged(logger, previousState.ToString(), newState.ToString(), null); - } - } - - public static void SkippingStop(ILogger logger) - { - _skippingStop(logger, null); - } - - public static void ExceptionThrownFromCallback(ILogger logger, string callbackName, Exception exception) - { - _exceptionThrownFromCallback(logger, callbackName, exception); - } - - public static void ErrorDuringClosedEvent(ILogger logger, Exception exception) - { - _errorDuringClosedEvent(logger, exception); - } - public static void TransportNotSupported(ILogger logger, string transport) { _transportNotSupported(logger, transport, null); @@ -268,6 +156,11 @@ namespace Microsoft.AspNetCore.Sockets.Client { _webSocketsNotSupportedByOperatingSystem(logger, null); } + + public static void TransportThrewExceptionOnStop(ILogger logger, Exception ex) + { + _transportThrewExceptionOnStop(logger, ex); + } } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 80d60342d2..4a96c840da 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -2,8 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; -using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Linq; @@ -11,18 +9,16 @@ using System.Net.Http; using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Client.Http.Internal; -using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Http.Internal; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Http { public partial class HttpConnection : IConnection { @@ -31,36 +27,40 @@ namespace Microsoft.AspNetCore.Sockets.Client private static readonly Version Windows8Version = new Version(6, 2); #endif - private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - private volatile ConnectionState _connectionState = ConnectionState.Disconnected; - private readonly object _stateChangeLock = new object(); + private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1); + private bool _started; + private bool _disposed; + + private IDuplexPipe _transportPipe; - private volatile IDuplexPipe _transportChannel; private readonly HttpClient _httpClient; private readonly HttpOptions _httpOptions; - private volatile ITransport _transport; - private volatile Task _receiveLoopTask; - private TaskCompletionSource _startTcs; - private TaskCompletionSource _closeTcs; - private TaskQueue _eventQueue; + private ITransport _transport; private readonly ITransportFactory _transportFactory; private string _connectionId; - private Exception _abortException; - private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); - private PipeReader Input => _transportChannel.Input; - private PipeWriter Output => _transportChannel.Output; - private readonly List _callbacks = new List(); private readonly TransportType _requestedTransportType = TransportType.All; private readonly ConnectionLogScope _logScope; private readonly IDisposable _scopeDisposable; + private readonly ILoggerFactory _loggerFactory; public Uri Url { get; } - public IFeatureCollection Features { get; } = new FeatureCollection(); + public IDuplexPipe Transport + { + get + { + CheckDisposed(); + if (_transportPipe == null) + { + throw new InvalidOperationException($"Cannot access the {nameof(Transport)} pipe before the connection has started."); + } + return _transportPipe; + } + } - public event Action Closed; + public IFeatureCollection Features { get; } = new FeatureCollection(); public HttpConnection(Uri url) : this(url, TransportType.All) @@ -84,8 +84,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public HttpConnection(Uri url, TransportType transportType, ILoggerFactory loggerFactory, HttpOptions httpOptions) { Url = url ?? throw new ArgumentNullException(nameof(url)); - _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + _logger = _loggerFactory.CreateLogger(); _httpOptions = httpOptions; @@ -100,6 +100,277 @@ namespace Microsoft.AspNetCore.Sockets.Client _scopeDisposable = _logger.BeginScope(_logScope); } + public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpOptions httpOptions) + { + Url = url ?? throw new ArgumentNullException(nameof(url)); + _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + _logger = _loggerFactory.CreateLogger(); + _httpOptions = httpOptions; + _httpClient = CreateHttpClient(); + _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); + _logScope = new ConnectionLogScope(); + _scopeDisposable = _logger.BeginScope(_logScope); + } + + public Task StartAsync() => StartAsync(TransferFormat.Binary); + + public async Task StartAsync(TransferFormat transferFormat) + { + await StartAsyncCore(transferFormat).ForceAsync(); + } + + private async Task StartAsyncCore(TransferFormat transferFormat) + { + CheckDisposed(); + + if (_started) + { + Log.SkippingStart(_logger); + return; + } + + await _connectionLock.WaitAsync(); + try + { + CheckDisposed(); + + if (_started) + { + Log.SkippingStart(_logger); + return; + } + + Log.Starting(_logger); + + await SelectAndStartTransport(transferFormat); + + _started = true; + Log.Started(_logger); + } + finally + { + _connectionLock.Release(); + } + } + + public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); + + private async Task DisposeAsyncCore(Exception exception = null) + { + if (_disposed) + { + return; + } + + await _connectionLock.WaitAsync(); + try + { + if (!_disposed && _started) + { + Log.DisposingHttpConnection(_logger); + + // Complete our ends of the pipes. + _transportPipe.Input.Complete(exception); + _transportPipe.Output.Complete(exception); + + // Stop the transport, but we don't care if it throws. + // The transport should also have completed the pipe with this exception. + try + { + await _transport.StopAsync(); + } + catch (Exception ex) + { + Log.TransportThrewExceptionOnStop(_logger, ex); + } + + Log.Disposed(_logger); + } + else + { + Log.SkippingDispose(_logger); + } + } + finally + { + // We want to do these things even if the WaitForWriterToComplete/WaitForReaderToComplete fails + if (!_disposed) + { + _scopeDisposable.Dispose(); + _disposed = true; + } + + _connectionLock.Release(); + } + } + + private async Task SelectAndStartTransport(TransferFormat transferFormat) + { + if (_requestedTransportType == TransportType.WebSockets) + { + Log.StartingTransport(_logger, _requestedTransportType, Url); + await StartTransport(Url, _requestedTransportType, transferFormat); + } + else + { + var negotiationResponse = await GetNegotiationResponse(); + + // This should only need to happen once + var connectUrl = CreateConnectUrl(Url, negotiationResponse.ConnectionId); + + // We're going to search for the transfer format as a string because we don't want to parse + // all the transfer formats in the negotiation response, and we want to allow transfer formats + // we don't understand in the negotiate response. + var transferFormatString = transferFormat.ToString(); + + foreach (var transport in negotiationResponse.AvailableTransports) + { + if (!Enum.TryParse(transport.Transport, out var transportType)) + { + Log.TransportNotSupported(_logger, transport.Transport); + continue; + } + + if (transportType == TransportType.WebSockets && !IsWebSocketsSupported()) + { + Log.WebSocketsNotSupportedByOperatingSystem(_logger); + continue; + } + + try + { + if ((transportType & _requestedTransportType) == 0) + { + Log.TransportDisabledByClient(_logger, transportType); + } + else if (!transport.TransferFormats.Contains(transferFormatString, StringComparer.Ordinal)) + { + Log.TransportDoesNotSupportTransferFormat(_logger, transportType, transferFormat); + } + else + { + // The negotiation response gets cleared in the fallback scenario. + if (negotiationResponse == null) + { + negotiationResponse = await GetNegotiationResponse(); + connectUrl = CreateConnectUrl(Url, negotiationResponse.ConnectionId); + } + + Log.StartingTransport(_logger, transportType, connectUrl); + await StartTransport(connectUrl, transportType, transferFormat); + break; + } + } + catch (Exception ex) + { + Log.TransportFailed(_logger, transportType, ex); + // Try the next transport + // Clear the negotiation response so we know to re-negotiate. + negotiationResponse = null; + } + } + } + + if (_transport == null) + { + throw new InvalidOperationException("Unable to connect to the server with any of the available transports."); + } + } + + private async Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) + { + try + { + // Get a connection ID from the server + Log.EstablishingConnection(logger, url); + var urlBuilder = new UriBuilder(url); + if (!urlBuilder.Path.EndsWith("/")) + { + urlBuilder.Path += "/"; + } + urlBuilder.Path += "negotiate"; + + using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri)) + { + // Corefx changed the default version and High Sierra curlhandler tries to upgrade request + request.Version = new Version(1, 1); + SendUtils.PrepareHttpRequest(request, _httpOptions); + + using (var response = await httpClient.SendAsync(request)) + { + response.EnsureSuccessStatusCode(); + var negotiateResponse = await ParseNegotiateResponse(response); + Log.ConnectionEstablished(_logger, negotiateResponse.ConnectionId); + return negotiateResponse; + } + } + } + catch (Exception ex) + { + Log.ErrorWithNegotiation(logger, url, ex); + throw; + } + } + + private static async Task ParseNegotiateResponse(HttpResponseMessage response) + { + NegotiationResponse negotiationResponse; + using (var reader = new JsonTextReader(new StreamReader(await response.Content.ReadAsStreamAsync()))) + { + try + { + negotiationResponse = new JsonSerializer().Deserialize(reader); + } + catch (Exception ex) + { + throw new FormatException("Invalid negotiation response received.", ex); + } + } + + if (negotiationResponse == null) + { + throw new FormatException("Invalid negotiation response received."); + } + + return negotiationResponse; + } + + private static Uri CreateConnectUrl(Uri url, string connectionId) + { + if (string.IsNullOrWhiteSpace(connectionId)) + { + throw new FormatException("Invalid connection id."); + } + + return Utils.AppendQueryString(url, "id=" + connectionId); + } + + private async Task StartTransport(Uri connectUrl, TransportType transportType, TransferFormat transferFormat) + { + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) + var options = new PipeOptions(writerScheduler: PipeScheduler.ThreadPool, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + + // Construct the transport + var transport = _transportFactory.CreateTransport(transportType); + + // Start the transport, giving it one end of the pipe + try + { + await transport.StartAsync(connectUrl, pair.Application, transferFormat, this); + } + catch (Exception ex) + { + Log.ErrorStartingTransport(_logger, _transport, ex); + _transport = null; + throw; + } + + // We successfully started, set the transport properties (we don't want to set these until the transport is definitely running). + _transport = transport; + _transportPipe = pair.Transport; + } + private HttpClient CreateHttpClient() { var httpClientHandler = new HttpClientHandler(); @@ -148,582 +419,11 @@ namespace Microsoft.AspNetCore.Sockets.Client return httpClient; } - public HttpConnection(Uri url, ITransportFactory transportFactory, ILoggerFactory loggerFactory, HttpOptions httpOptions) + private void CheckDisposed() { - Url = url ?? throw new ArgumentNullException(nameof(url)); - _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; - _logger = _loggerFactory.CreateLogger(); - _httpOptions = httpOptions; - _httpClient = CreateHttpClient(); - _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); - _logScope = new ConnectionLogScope(); - _scopeDisposable = _logger.BeginScope(_logScope); - } - - public Task StartAsync() => StartAsync(TransferFormat.Binary); - public async Task StartAsync(TransferFormat transferFormat) => await StartAsyncCore(transferFormat).ForceAsync(); - - private Task StartAsyncCore(TransferFormat transferFormat) - { - if (ChangeState(from: ConnectionState.Disconnected, to: ConnectionState.Connecting) != ConnectionState.Disconnected) + if (_disposed) { - return Task.FromException( - new InvalidOperationException($"Cannot start a connection that is not in the {nameof(ConnectionState.Disconnected)} state.")); - } - - _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - _eventQueue = new TaskQueue(); - - StartAsyncInternal(transferFormat) - .ContinueWith(t => - { - var abortException = _abortException; - if (t.IsFaulted || abortException != null) - { - _startTcs.SetException(_abortException ?? t.Exception.InnerException); - } - else if (t.IsCanceled) - { - _startTcs.SetCanceled(); - } - else - { - _startTcs.SetResult(null); - } - }); - - return _startTcs.Task; - } - - private async Task GetNegotiationResponse() - { - var negotiationResponse = await Negotiate(Url, _httpClient, _logger); - _connectionId = negotiationResponse.ConnectionId; - _logScope.ConnectionId = _connectionId; - return negotiationResponse; - } - - private async Task StartAsyncInternal(TransferFormat transferFormat) - { - Log.HttpConnectionStarting(_logger); - - try - { - var connectUrl = Url; - if (_requestedTransportType == TransportType.WebSockets) - { - // if we're running on Windows 7 this could throw because the OS does not support web sockets - Log.StartingTransport(_logger, _requestedTransportType, connectUrl); - await StartTransport(connectUrl, _requestedTransportType, transferFormat); - } - else - { - var negotiationResponse = await GetNegotiationResponse(); - - // Connection is being disposed while start was in progress - if (_connectionState == ConnectionState.Disposed) - { - Log.HttpConnectionClosed(_logger); - return; - } - - // This should only need to happen once - connectUrl = CreateConnectUrl(Url, negotiationResponse.ConnectionId); - - // We're going to search for the transfer format as a string because we don't want to parse - // all the transfer formats in the negotiation response, and we want to allow transfer formats - // we don't understand in the negotiate response. - var transferFormatString = transferFormat.ToString(); - - foreach (var transport in negotiationResponse.AvailableTransports) - { - if (!Enum.TryParse(transport.Transport, out var transportType)) - { - Log.TransportNotSupported(_logger, transport.Transport); - continue; - } - - if (transportType == TransportType.WebSockets && !IsWebSocketsSupported()) - { - Log.WebSocketsNotSupportedByOperatingSystem(_logger); - continue; - } - - try - { - if ((transportType & _requestedTransportType) == 0) - { - Log.TransportDisabledByClient(_logger, transportType); - } - else if (!transport.TransferFormats.Contains(transferFormatString, StringComparer.Ordinal)) - { - Log.TransportDoesNotSupportTransferFormat(_logger, transportType, transferFormat); - } - else - { - // The negotiation response gets cleared in the fallback scenario. - if (negotiationResponse == null) - { - negotiationResponse = await GetNegotiationResponse(); - connectUrl = CreateConnectUrl(Url, negotiationResponse.ConnectionId); - } - - Log.StartingTransport(_logger, transportType, connectUrl); - await StartTransport(connectUrl, transportType, transferFormat); - break; - } - } - catch (Exception ex) - { - Log.TransportFailed(_logger, transportType, ex); - // Try the next transport - // Clear the negotiation response so we know to re-negotiate. - negotiationResponse = null; - } - } - } - - if (_transport == null) - { - throw new InvalidOperationException("Unable to connect to the server with any of the available transports."); - } - } - - catch - { - // The connection can now be either in the Connecting or Disposed state - only change the state to - // Disconnected if the connection was in the Connecting state to not resurrect a Disposed connection - ChangeState(from: ConnectionState.Connecting, to: ConnectionState.Disconnected); - throw; - } - - // if the connection is not in the Connecting state here it means the user called DisposeAsync while - // the connection was starting - if (ChangeState(from: ConnectionState.Connecting, to: ConnectionState.Connected) == ConnectionState.Connecting) - { - _closeTcs = new TaskCompletionSource(); - - Input.OnWriterCompleted(async (exception, state) => - { - // Grab the exception and then clear it. - // See comment at AbortAsync for more discussion on the thread-safety - // StartAsync can't be called until the ChangeState below, so we're OK. - var abortException = _abortException; - _abortException = null; - - // There is an inherent race between receive and close. Removing the last message from the channel - // makes Input.Completion task completed and runs this continuation. We need to await _receiveLoopTask - // to make sure that the message removed from the channel is processed before we drain the queue. - // There is a short window between we start the channel and assign the _receiveLoopTask a value. - // To make sure that _receiveLoopTask can be awaited (i.e. is not null) we need to await _startTask. - Log.ProcessRemainingMessages(_logger); - - await _startTcs.Task; - await _receiveLoopTask; - - Log.DrainEvents(_logger); - - await Task.WhenAny(_eventQueue.Drain().NoThrow(), Task.Delay(_eventQueueDrainTimeout)); - - Log.CompleteClosed(_logger); - _logScope.ConnectionId = null; - - // At this point the connection can be either in the Connected or Disposed state. The state should be changed - // to the Disconnected state only if it was in the Connected state. - // From this point on, StartAsync can be called at any time. - ChangeState(from: ConnectionState.Connected, to: ConnectionState.Disconnected); - - _closeTcs.SetResult(null); - - try - { - if (exception != null) - { - Closed?.Invoke(exception); - } - else - { - // Call the closed event. If there was an abort exception, it will be flowed forward - // However, if there wasn't, this will just be null and we're good - Closed?.Invoke(abortException); - } - } - catch (Exception ex) - { - // Suppress (but log) the exception, this is user code - Log.ErrorDuringClosedEvent(_logger, ex); - } - - }, null); - - _receiveLoopTask = ReceiveAsync(); - } - } - - private async Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) - { - try - { - // Get a connection ID from the server - Log.EstablishingConnection(logger, url); - var urlBuilder = new UriBuilder(url); - if (!urlBuilder.Path.EndsWith("/")) - { - urlBuilder.Path += "/"; - } - urlBuilder.Path += "negotiate"; - - using (var request = new HttpRequestMessage(HttpMethod.Post, urlBuilder.Uri)) - { - // Corefx changed the default version and High Sierra curlhandler tries to upgrade request - request.Version = new Version(1, 1); - SendUtils.PrepareHttpRequest(request, _httpOptions); - - using (var response = await httpClient.SendAsync(request)) - { - response.EnsureSuccessStatusCode(); - return await ParseNegotiateResponse(response, logger); - } - } - } - catch (Exception ex) - { - Log.ErrorWithNegotiation(logger, url, ex); - throw; - } - } - - private static async Task ParseNegotiateResponse(HttpResponseMessage response, ILogger logger) - { - NegotiationResponse negotiationResponse; - using (var reader = new JsonTextReader(new StreamReader(await response.Content.ReadAsStreamAsync()))) - { - try - { - negotiationResponse = new JsonSerializer().Deserialize(reader); - } - catch (Exception ex) - { - throw new FormatException("Invalid negotiation response received.", ex); - } - } - - if (negotiationResponse == null) - { - throw new FormatException("Invalid negotiation response received."); - } - - return negotiationResponse; - } - - private static Uri CreateConnectUrl(Uri url, string connectionId) - { - if (string.IsNullOrWhiteSpace(connectionId)) - { - throw new FormatException("Invalid connection id."); - } - - return Utils.AppendQueryString(url, "id=" + connectionId); - } - - private async Task StartTransport(Uri connectUrl, TransportType transportType, TransferFormat transferFormat) - { - var options = new PipeOptions(writerScheduler: PipeScheduler.Inline, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); - var pair = DuplexPipe.CreateConnectionPair(options, options); - _transportChannel = pair.Transport; - _transport = _transportFactory.CreateTransport(transportType); - - // Start the transport, giving it one end of the pipeline - try - { - await _transport.StartAsync(connectUrl, pair.Application, transferFormat, this); - } - catch (Exception ex) - { - Log.ErrorStartingTransport(_logger, _transport, ex); - _transport = null; - throw; - } - } - - private async Task ReceiveAsync() - { - try - { - Log.HttpReceiveStarted(_logger); - - while (true) - { - if (_connectionState != ConnectionState.Connected) - { - Log.SkipRaisingReceiveEvent(_logger); - - break; - } - - var result = await Input.ReadAsync(); - var buffer = result.Buffer; - - try - { - if (!buffer.IsEmpty) - { - Log.ScheduleReceiveEvent(_logger); - var data = buffer.ToArray(); - - _ = _eventQueue.Enqueue(async () => - { - Log.RaiseReceiveEvent(_logger); - - // Copying the callbacks to avoid concurrency issues - ReceiveCallback[] callbackCopies; - lock (_callbacks) - { - callbackCopies = new ReceiveCallback[_callbacks.Count]; - _callbacks.CopyTo(callbackCopies); - } - - foreach (var callbackObject in callbackCopies) - { - try - { - await callbackObject.InvokeAsync(data); - } - catch (Exception ex) - { - Log.ExceptionThrownFromCallback(_logger, nameof(OnReceived), ex); - } - } - }); - - } - else if (result.IsCompleted) - { - break; - } - } - finally - { - Input.AdvanceTo(buffer.End); - } - } - } - catch (Exception ex) - { - Input.Complete(ex); - - Log.ErrorReceiving(_logger, ex); - } - finally - { - Input.Complete(); - } - - Log.EndReceive(_logger); - } - - public async Task SendAsync(byte[] data, CancellationToken cancellationToken = default) => - await SendAsyncCore(data, cancellationToken).ForceAsync(); - - private async Task SendAsyncCore(byte[] data, CancellationToken cancellationToken) - { - if (data == null) - { - throw new ArgumentNullException(nameof(data)); - } - - if (_connectionState != ConnectionState.Connected) - { - throw new InvalidOperationException( - "Cannot send messages when the connection is not in the Connected state."); - } - - Log.SendingMessage(_logger); - - cancellationToken.ThrowIfCancellationRequested(); - - await Output.WriteAsync(data); - } - - // AbortAsync creates a few thread-safety races that we are OK with. - // 1. If the transport shuts down gracefully after AbortAsync is called but BEFORE _abortException is called, then the - // Closed event will not receive the Abort exception. This is OK because technically the transport was shut down gracefully - // before it was aborted - // 2. If the transport is closed gracefully and then AbortAsync is called before it captures the _abortException value - // the graceful shutdown could be turned into an abort. However, again, this is an inherent race between two different conditions: - // The transport shutting down because the server went away, and the user requesting the Abort - // 3. Finally, because this is an instance field, there is a possible race around accidentally re-using _abortException in the restarted - // connection. The scenario here is: AbortAsync(someException); StartAsync(); CloseAsync(); Where the _abortException value from the - // first AbortAsync call is still set at the time CloseAsync gets to calling the Closed event. However, this can't happen because the - // StartAsync method can't be called until the connection state is changed to Disconnected, which happens AFTER the close code - // captures and resets _abortException. - public async Task AbortAsync(Exception exception) => await StopAsyncCore(exception ?? throw new ArgumentNullException(nameof(exception))).ForceAsync(); - - public async Task StopAsync() => await StopAsyncCore(exception: null).ForceAsync(); - - private async Task StopAsyncCore(Exception exception) - { - lock (_stateChangeLock) - { - if (!(_connectionState == ConnectionState.Connecting || _connectionState == ConnectionState.Connected)) - { - Log.SkippingStop(_logger); - return; - } - } - - // Note that this method can be called at the same time when the connection is being closed from the server - // side due to an error. We are resilient to this since we merely try to close the channel here and the - // channel can be closed only once. As a result the continuation that does actual job and raises the Closed - // event runs always only once. - - // See comment at AbortAsync for more discussion on the thread-safety of this. - _abortException = exception; - - Log.StoppingClient(_logger); - - try - { - await _startTcs.Task; - } - catch - { - // We only await the start task to make sure that StartAsync completed. The - // _startTask is returned to the user and they should handle exceptions. - } - - TaskCompletionSource closeTcs = null; - Task receiveLoopTask = null; - ITransport transport = null; - - lock (_stateChangeLock) - { - // Copy locals in lock to prevent a race when the server closes the connection and StopAsync is called - // at the same time - if (_connectionState != ConnectionState.Connected) - { - // If not Connected then someone else disconnected while StopAsync was in progress, we can now NO-OP - return; - } - - // Create locals of relevant member variables to prevent a race when Closed event triggers a connect - // while StopAsync is still running - closeTcs = _closeTcs; - receiveLoopTask = _receiveLoopTask; - transport = _transport; - } - - if (_transportChannel != null) - { - Output.Complete(); - } - - if (transport != null) - { - await transport.StopAsync(); - } - - if (receiveLoopTask != null) - { - await receiveLoopTask; - } - - if (closeTcs != null) - { - await closeTcs.Task; - } - } - - public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); - - private async Task DisposeAsyncCore() - { - // This will no-op if we're already stopped - await StopAsyncCore(exception: null); - - if (ChangeState(to: ConnectionState.Disposed) == ConnectionState.Disposed) - { - Log.SkippingDispose(_logger); - - // the connection was already disposed - return; - } - - Log.DisposingClient(_logger); - - _httpClient?.Dispose(); - _scopeDisposable.Dispose(); - } - - public IDisposable OnReceived(Func callback, object state) - { - var receiveCallback = new ReceiveCallback(callback, state); - lock (_callbacks) - { - _callbacks.Add(receiveCallback); - } - return new Subscription(receiveCallback, _callbacks); - } - - private class ReceiveCallback - { - private readonly Func _callback; - private readonly object _state; - - public ReceiveCallback(Func callback, object state) - { - _callback = callback; - _state = state; - } - - public Task InvokeAsync(byte[] data) - { - return _callback(data, _state); - } - } - - private class Subscription : IDisposable - { - private readonly ReceiveCallback _receiveCallback; - private readonly List _callbacks; - public Subscription(ReceiveCallback callback, List callbacks) - { - _receiveCallback = callback; - _callbacks = callbacks; - } - - public void Dispose() - { - lock (_callbacks) - { - _callbacks.Remove(_receiveCallback); - } - } - } - - private ConnectionState ChangeState(ConnectionState from, ConnectionState to) - { - lock (_stateChangeLock) - { - var state = _connectionState; - if (_connectionState == from) - { - _connectionState = to; - } - - Log.ConnectionStateChanged(_logger, state, to); - return state; - } - } - - private ConnectionState ChangeState(ConnectionState to) - { - lock (_stateChangeLock) - { - var state = _connectionState; - _connectionState = to; - Log.ConnectionStateChanged(_logger, state, to); - return state; + throw new ObjectDisposedException(nameof(HttpConnection)); } } @@ -747,13 +447,12 @@ namespace Microsoft.AspNetCore.Sockets.Client #endif } - // Internal because it's used by logging to avoid ToStringing prematurely. - internal enum ConnectionState + private async Task GetNegotiationResponse() { - Disconnected, - Connecting, - Connected, - Disposed + var negotiationResponse = await Negotiate(Url, _httpClient, _logger); + _connectionId = negotiationResponse.ConnectionId; + _logScope.ConnectionId = _connectionId; + return negotiationResponse; } private class NegotiationResponse diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs deleted file mode 100644 index 490b6dc3de..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.Sockets.Client -{ - public static partial class HttpConnectionExtensions - { - public static IDisposable OnReceived(this HttpConnection connection, Func callback) - { - return connection.OnReceived((data, state) => - { - var currentCallback = (Func)state; - return currentCallback(data); - }, callback); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs index c9c269185e..9eeac8276d 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http /// /// Gets or sets a delegate that will be invoked with the object used - /// by the to configure the WebSocket. + /// to configure the WebSocket when using the WebSockets transport. /// /// /// This delegate is invoked after headers from and the access token from diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.Log.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.Log.cs similarity index 80% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.Log.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.Log.cs index 6d0df576e5..c41df896d6 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.Log.cs @@ -1,11 +1,12 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Net.Http; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class LongPollingTransport { @@ -38,6 +39,11 @@ namespace Microsoft.AspNetCore.Sockets.Client private static readonly Action _errorPolling = LoggerMessage.Define(LogLevel.Error, new EventId(9, "ErrorPolling"), "Error while polling '{PollUrl}'."); + // long? does properly format as "(null)" when null. + private static readonly Action _pollResponseReceived = + LoggerMessage.Define(LogLevel.Trace, new EventId(10, "PollResponseReceived"), + "Poll response with status code {StatusCode} received from server. Content length: {ContentLength}."); + // EventIds 100 - 106 used in SendUtils public static void StartTransport(ILogger logger, TransferFormat transferFormat) @@ -84,6 +90,15 @@ namespace Microsoft.AspNetCore.Sockets.Client { _errorPolling(logger, pollUrl, exception); } + + public static void PollResponseReceived(ILogger logger, HttpResponseMessage response) + { + if (logger.IsEnabled(LogLevel.Trace)) + { + _pollResponseReceived(logger, (int)response.StatusCode, + response.Content.Headers.ContentLength ?? -1, null); + } + } } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs similarity index 97% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs index f702a2f33c..21925cac62 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/LongPollingTransport.cs @@ -13,7 +13,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.AspNetCore.Connections; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class LongPollingTransport : ITransport { @@ -107,6 +107,8 @@ namespace Microsoft.AspNetCore.Sockets.Client continue; } + Log.PollResponseReceived(_logger, response); + response.EnsureSuccessStatusCode(); if (response.StatusCode == HttpStatusCode.NoContent || cancellationToken.IsCancellationRequested) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.Log.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.Log.cs similarity index 97% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.Log.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.Log.cs index 10414a88e5..e644722362 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.Log.cs @@ -1,11 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class ServerSentEventsTransport { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.cs similarity index 99% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.cs index 922d5fddbd..0146bdfa11 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/ServerSentEventsTransport.cs @@ -13,7 +13,7 @@ using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class ServerSentEventsTransport : ITransport { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs deleted file mode 100644 index 10cf204c88..0000000000 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs +++ /dev/null @@ -1,71 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.Sockets.Client.Internal -{ - // Allows serial queuing of Task instances - // The tasks are not called on the current synchronization context - - public sealed class TaskQueue - { - private readonly object _lockObj = new object(); - private Task _lastQueuedTask; - private volatile bool _drained; - - public TaskQueue() - : this(Task.CompletedTask) - { } - - public TaskQueue(Task initialTask) - { - _lastQueuedTask = initialTask; - } - - public bool IsDrained - { - get { return _drained; } - } - - public Task Enqueue(Func taskFunc) - { - return Enqueue(s => taskFunc(), null); - } - - public Task Enqueue(Func taskFunc, object state) - { - lock (_lockObj) - { - if (_drained) - { - return _lastQueuedTask; - } - - var newTask = _lastQueuedTask.ContinueWith((t, s1) => - { - if (t.IsFaulted || t.IsCanceled) - { - return t; - } - - return taskFunc(s1) ?? Task.CompletedTask; - }, - state).Unwrap(); - _lastQueuedTask = newTask; - return newTask; - } - } - - public Task Drain() - { - lock (_lockObj) - { - _drained = true; - - return _lastQueuedTask; - } - } - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.Log.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.Log.cs similarity index 99% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.Log.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.Log.cs index 062c6c9a0f..42fefbe8ba 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.Log.cs @@ -6,7 +6,7 @@ using System.Net.WebSockets; using Microsoft.AspNetCore.Connections; using Microsoft.Extensions.Logging; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class WebSocketsTransport { diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.cs similarity index 98% rename from src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs rename to src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.cs index 8e3981be43..e5e71d3b75 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/WebSocketsTransport.cs @@ -14,7 +14,7 @@ using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -namespace Microsoft.AspNetCore.Sockets.Client +namespace Microsoft.AspNetCore.Sockets.Client.Internal { public partial class WebSocketsTransport : ITransport { @@ -124,7 +124,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { using (socket) { - // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + // Begin sending and receiving. var receiving = StartReceiving(socket); var sending = StartSending(socket); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 64b25af727..eeaaae07f8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -43,8 +44,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CheckFixedMessage(IHubProtocol protocol, TransportType transportType, string path) + public async Task CheckFixedMessage(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, $"{nameof(CheckFixedMessage)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { var connection = new HubConnectionBuilder() @@ -64,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -76,13 +78,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CanSendAndReceiveMessage(IHubProtocol protocol, TransportType transportType, string path) + public async Task CanSendAndReceiveMessage(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, $"{nameof(CanSendAndReceiveMessage)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { const string originalMessage = "SignalR"; - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -93,7 +95,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -105,13 +107,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CanStopAndStartConnection(IHubProtocol protocol, TransportType transportType, string path) + public async Task CanStopAndStartConnection(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanStopAndStartConnection)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { const string originalMessage = "SignalR"; - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -124,7 +126,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -134,15 +136,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Fact] - public async Task CanStartConnectionFromClosedEvent() + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task CanStartConnectionFromClosedEvent(string protocolName, TransportType transportType, string path) { - using (StartLog(out var loggerFactory)) + var protocol = HubProtocols[protocolName]; + using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanStartConnectionFromClosedEvent)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { var logger = loggerFactory.CreateLogger(); const string originalMessage = "SignalR"; - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + "/default"), loggerFactory); - var connection = new HubConnection(httpConnection, new JsonHubProtocol(), loggerFactory); + + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, "/default", transportType), new JsonHubProtocol(), loggerFactory); var restartTcs = new TaskCompletionSource(); connection.Closed += async e => { @@ -175,7 +179,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -185,16 +189,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + private Func GetHttpConnectionFactory(ILoggerFactory loggerFactory, string path, TransportType transportType) + { + return () => new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task MethodsAreCaseInsensitive(IHubProtocol protocol, TransportType transportType, string path) + public async Task MethodsAreCaseInsensitive(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, $"{nameof(MethodsAreCaseInsensitive)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { const string originalMessage = "SignalR"; var uriString = "http://test/" + path; - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -205,7 +214,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -217,14 +226,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CanInvokeClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) + public async Task CanInvokeClientMethodFromServer(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanInvokeClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { const string originalMessage = "SignalR"; - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -238,7 +247,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -250,12 +259,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) + public async Task InvokeNonExistantClientMethodFromServer(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(InvokeNonExistantClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); var closeTcs = new TaskCompletionSource(); connection.Closed += e => { @@ -290,12 +299,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CanStreamClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) + public async Task CanStreamClientMethodFromServer(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(CanStreamClientMethodFromServer)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -307,7 +316,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -319,12 +328,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task CanCloseStreamMethodEarly(IHubProtocol protocol, TransportType transportType, string path) + public async Task CanCloseStreamMethodEarly(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, $"{nameof(CanCloseStreamMethodEarly)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -333,16 +342,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests var channel = await connection.StreamAsChannelAsync("Stream", 1000, cts.Token).OrTimeout(); + // Wait for the server to start streaming items await channel.WaitToReadAsync().AsTask().OrTimeout(); + cts.Cancel(); - var results = await channel.ReadAllAsync().OrTimeout(); + var results = await channel.ReadAllAsync(suppressExceptions: true).OrTimeout(); Assert.True(results.Count > 0 && results.Count < 1000); + + // We should have been canceled. + await Assert.ThrowsAsync(() => channel.Completion); } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -354,12 +368,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task StreamDoesNotStartIfTokenAlreadyCanceled(IHubProtocol protocol, TransportType transportType, string path) + public async Task StreamDoesNotStartIfTokenAlreadyCanceled(string protocolName, TransportType transportType, string path) { - using (StartLog(out var loggerFactory, $"{nameof(StreamDoesNotStartIfTokenAlreadyCanceled)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) + var protocol = HubProtocols[protocolName]; + using (StartLog(out var loggerFactory, LogLevel.Trace, $"{nameof(StreamDoesNotStartIfTokenAlreadyCanceled)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = + new HubConnection( + GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, + loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -369,11 +386,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests var channel = await connection.StreamAsChannelAsync("Stream", 5, cts.Token).OrTimeout(); - await Assert.ThrowsAnyAsync(() => channel.WaitToReadAsync().AsTask().OrTimeout()); + await Assert.ThrowsAnyAsync(() => + channel.WaitToReadAsync().AsTask().OrTimeout()); } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -385,12 +403,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ExceptionFromStreamingSentToClient(IHubProtocol protocol, TransportType transportType, string path) + public async Task ExceptionFromStreamingSentToClient(string protocolName, TransportType transportType, string path) { + var protocol = HubProtocols[protocolName]; using (StartLog(out var loggerFactory, $"{nameof(ExceptionFromStreamingSentToClient)}_{protocol.Name}_{transportType}_{path.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + path), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, protocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, path, transportType), protocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -401,7 +419,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -413,12 +431,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionIfHubMethodCannotBeResolved(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionIfHubMethodCannotBeResolved(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionIfHubMethodCannotBeResolved)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -428,7 +446,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -440,12 +458,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionOnHubMethodArgumentCountMismatch(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionOnHubMethodArgumentCountMismatch(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionOnHubMethodArgumentCountMismatch)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -455,7 +473,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -467,12 +485,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionOnHubMethodArgumentTypeMismatch(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionOnHubMethodArgumentTypeMismatch(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionOnHubMethodArgumentTypeMismatch)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -482,7 +500,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -494,12 +512,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionIfStreamingHubMethodCannotBeResolved(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionIfStreamingHubMethodCannotBeResolved(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionIfStreamingHubMethodCannotBeResolved)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -510,7 +528,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -522,13 +540,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionOnStreamingHubMethodArgumentCountMismatch(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionOnStreamingHubMethodArgumentCountMismatch(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionOnStreamingHubMethodArgumentCountMismatch)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { loggerFactory.AddConsole(LogLevel.Trace); - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -539,7 +557,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -551,12 +569,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionOnStreamingHubMethodArgumentTypeMismatch(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionOnStreamingHubMethodArgumentTypeMismatch(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionOnStreamingHubMethodArgumentTypeMismatch)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -567,7 +585,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -579,12 +597,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionIfNonStreamMethodInvokedWithStreamAsync(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionIfNonStreamMethodInvokedWithStreamAsync(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionIfNonStreamMethodInvokedWithStreamAsync)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -594,7 +612,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -606,12 +624,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionIfStreamMethodInvokedWithInvoke(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionIfStreamMethodInvokedWithInvoke(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionIfStreamMethodInvokedWithInvoke)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -621,7 +639,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -633,12 +651,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] - public async Task ServerThrowsHubExceptionIfBuildingAsyncEnumeratorIsNotPossible(IHubProtocol hubProtocol, TransportType transportType, string hubPath) + public async Task ServerThrowsHubExceptionIfBuildingAsyncEnumeratorIsNotPossible(string hubProtocolName, TransportType transportType, string hubPath) { + var hubProtocol = HubProtocols[hubProtocolName]; using (StartLog(out var loggerFactory, $"{nameof(ServerThrowsHubExceptionIfBuildingAsyncEnumeratorIsNotPossible)}_{hubProtocol.Name}_{transportType}_{hubPath.TrimStart('/')}")) { - var httpConnection = new HttpConnection(new Uri(_serverFixture.Url + hubPath), transportType, loggerFactory); - var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory); + var connection = new HubConnection(GetHttpConnectionFactory(loggerFactory, hubPath, transportType), hubProtocol, loggerFactory); try { await connection.StartAsync().OrTimeout(); @@ -648,7 +666,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -682,7 +700,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -713,7 +731,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -747,7 +765,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -785,7 +803,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -814,7 +832,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } catch (Exception ex) { - loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); throw; } finally @@ -834,9 +852,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { foreach (var hubPath in HubPaths) { - if (!(protocol is MessagePackHubProtocol) || transport != TransportType.ServerSentEvents) + if (!(protocol.Value is MessagePackHubProtocol) || transport != TransportType.ServerSentEvents) { - yield return new object[] { protocol, transport, hubPath }; + yield return new object[] { protocol.Key, transport, hubPath }; } } } @@ -847,11 +865,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests // This list excludes "special" hub paths like "default-nowebsockets" which exist for specific tests. public static string[] HubPaths = new[] { "/default", "/dynamic", "/hubT" }; - public static IEnumerable HubProtocols => - new IHubProtocol[] + public static Dictionary HubProtocols => + new Dictionary { - new JsonHubProtocol(), - new MessagePackHubProtocol(), + { "json", new JsonHubProtocol() }, + { "messagepack", new MessagePackHubProtocol() }, }; public static IEnumerable TransportTypes() diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs deleted file mode 100644 index 20b9d457d5..0000000000 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.AbortAsync.cs +++ /dev/null @@ -1,130 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Sockets; -using Xunit; - -namespace Microsoft.AspNetCore.SignalR.Client.Tests -{ - public partial class HttpConnectionTests - { - // Nested class for grouping - public class AbortAsync - { - [Fact] - public Task AbortAsyncTriggersClosedEventWithException() - { - return WithConnectionAsync(CreateConnection(), async (connection, closed) => - { - // Start the connection - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // Abort with an error - var expected = new Exception("Ruh roh!"); - await connection.AbortAsync(expected).OrTimeout(); - - // Verify that it is thrown - var actual = await Assert.ThrowsAsync(async () => await closed.OrTimeout()); - Assert.Same(expected, actual); - }); - } - - [Fact] - public Task AbortAsyncWhileStoppingTriggersClosedEventWithException() - { - return WithConnectionAsync(CreateConnection(transport: new TestTransport(onTransportStop: SyncPoint.Create(2, out var syncPoints))), async (connection, closed) => - { - // Start the connection - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // Stop normally - var stopTask = connection.StopAsync().OrTimeout(); - - // Wait to reach the first sync point - await syncPoints[0].WaitForSyncPoint().OrTimeout(); - - // Abort with an error - var expected = new Exception("Ruh roh!"); - var abortTask = connection.AbortAsync(expected).OrTimeout(); - - // Wait for the sync point to hit again - await syncPoints[1].WaitForSyncPoint().OrTimeout(); - - // Release sync point 0 - syncPoints[0].Continue(); - - // We should close with the error from Abort (because it was set by the call to Abort even though Stop triggered the close) - var actual = await Assert.ThrowsAsync(async () => await closed.OrTimeout()); - Assert.Same(expected, actual); - - // Clean-up - syncPoints[1].Continue(); - await Task.WhenAll(stopTask, abortTask).OrTimeout(); - }); - } - - [Fact] - public Task StopAsyncWhileAbortingTriggersClosedEventWithoutException() - { - return WithConnectionAsync(CreateConnection(transport: new TestTransport(onTransportStop: SyncPoint.Create(2, out var syncPoints))), async (connection, closed) => - { - // Start the connection - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // Abort with an error - var expected = new Exception("Ruh roh!"); - var abortTask = connection.AbortAsync(expected).OrTimeout(); - - // Wait to reach the first sync point - await syncPoints[0].WaitForSyncPoint().OrTimeout(); - - // Stop normally, without a sync point. - // This should clear the exception, meaning Closed will not "throw" - syncPoints[1].Continue(); - await connection.StopAsync(); - await closed.OrTimeout(); - - // Clean-up - syncPoints[0].Continue(); - await abortTask.OrTimeout(); - }); - } - - [Fact] - public Task StartAsyncCannotBeCalledWhileAbortAsyncInProgress() - { - return WithConnectionAsync(CreateConnection(transport: new TestTransport(onTransportStop: SyncPoint.Create(out var syncPoint))), async (connection, closed) => - { - // Start the connection - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // Abort with an error - var expected = new Exception("Ruh roh!"); - var abortTask = connection.AbortAsync(expected).OrTimeout(); - - // Wait to reach the first sync point - await syncPoint.WaitForSyncPoint().OrTimeout(); - - var ex = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text).OrTimeout()); - Assert.Equal("Cannot start a connection that is not in the Disconnected state.", ex.Message); - - // Release the sync point and wait for close to complete - // (it will throw the abort exception) - syncPoint.Continue(); - await abortTask.OrTimeout(); - var actual = await Assert.ThrowsAsync(() => closed.OrTimeout()); - Assert.Same(expected, actual); - - // We can start now - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // And we can stop without getting the abort exception. - await connection.StopAsync().OrTimeout(); - }); - } - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs index 310e777db2..bf60248e66 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs @@ -2,13 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.Extensions.Logging.Testing; using Xunit; using Xunit.Abstractions; @@ -24,92 +25,58 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task CannotStartRunningConnection() + public async Task CanStartStartedConnection() { using (StartLog(out var loggerFactory)) { - await WithConnectionAsync(CreateConnection(loggerFactory: loggerFactory), async (connection, closed) => + await WithConnectionAsync(CreateConnection(loggerFactory: loggerFactory), async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); - var exception = - await Assert.ThrowsAsync( - async () => await connection.StartAsync(TransferFormat.Text).OrTimeout()); - Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); + await connection.StartAsync(TransferFormat.Text).OrTimeout(); }); } } + [Fact] + public async Task CanStartStartingConnection() + { + using (StartLog(out var loggerFactory)) + { + await WithConnectionAsync( + CreateConnection(loggerFactory: loggerFactory, transport: new TestTransport(onTransportStart: SyncPoint.Create(out var syncPoint))), + async (connection) => + { + var firstStart = connection.StartAsync(TransferFormat.Text).OrTimeout(); + await syncPoint.WaitForSyncPoint(); + var secondStart = connection.StartAsync(TransferFormat.Text).OrTimeout(); + syncPoint.Continue(); + + await firstStart; + await secondStart; + }); + } + } [Fact] - public async Task CannotStartConnectionDisposedAfterStarting() + public async Task CannotStartConnectionOnceDisposed() { using (StartLog(out var loggerFactory)) { await WithConnectionAsync( CreateConnection(loggerFactory: loggerFactory), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); await connection.DisposeAsync(); var exception = - await Assert.ThrowsAsync( + await Assert.ThrowsAsync( async () => await connection.StartAsync(TransferFormat.Text).OrTimeout()); - Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); + Assert.Equal(nameof(HttpConnection), exception.ObjectName); }); } } - [Fact] - public async Task CannotStartDisposedConnection() - { - using (StartLog(out var loggerFactory)) - { - await WithConnectionAsync( - CreateConnection(loggerFactory: loggerFactory), - async (connection, closed) => - { - await connection.DisposeAsync(); - var exception = - await Assert.ThrowsAsync( - async () => await connection.StartAsync(TransferFormat.Text).OrTimeout()); - - Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); - }); - } - } - - [Fact] - public async Task CanDisposeStartingConnection() - { - using (StartLog(out var loggerFactory)) - { - await WithConnectionAsync( - CreateConnection( - loggerFactory: loggerFactory, - transport: new TestTransport( - onTransportStart: SyncPoint.Create(out var transportStart), - onTransportStop: SyncPoint.Create(out var transportStop))), - async (connection, closed) => - { - // Start the connection and wait for the transport to start up. - var startTask = connection.StartAsync(TransferFormat.Text); - await transportStart.WaitForSyncPoint().OrTimeout(); - - // While the transport is starting, dispose the connection - var disposeTask = connection.DisposeAsync(); - transportStart.Continue(); // We need to release StartAsync, because Dispose waits for it. - - // Wait for start to finish, as that has to finish before the transport will be stopped. - await startTask.OrTimeout(); - - // Then release DisposeAsync (via the transport StopAsync call) - await transportStop.WaitForSyncPoint().OrTimeout(); - transportStop.Continue(); - }); - } - } - [Theory] [InlineData(2)] [InlineData(3)] @@ -138,7 +105,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests CreateConnection( loggerFactory: loggerFactory, transport: new TestTransport(onTransportStart: OnTransportStart)), - async (connection, closed) => + async (connection) => { Assert.Equal(0, startCounter); await connection.StartAsync(TransferFormat.Text); @@ -164,7 +131,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests CreateConnection( loggerFactory: loggerFactory, transport: new TestTransport(onTransportStart: OnTransportStart)), - async (connection, closed) => + async (connection) => { var ex = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text)); Assert.Equal("Unable to connect to the server with any of the available transports.", ex.Message); @@ -174,66 +141,115 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task CanStartStoppedConnection() + public async Task CanDisposeUnstartedConnection() { using (StartLog(out var loggerFactory)) { await WithConnectionAsync( CreateConnection(loggerFactory: loggerFactory), - async (connection, closed) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.StopAsync().OrTimeout(); - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - }); + async (connection) => + { + await connection.DisposeAsync(); + + }); } } [Fact] - public async Task CanStopStartingConnection() + public async Task CanDisposeStartingConnection() { using (StartLog(out var loggerFactory)) { await WithConnectionAsync( CreateConnection( loggerFactory: loggerFactory, - transport: new TestTransport(onTransportStart: SyncPoint.Create(out var transportStart))), - async (connection, closed) => - { - // Start and wait for the transport to start up. - var startTask = connection.StartAsync(TransferFormat.Text); - await transportStart.WaitForSyncPoint().OrTimeout(); + transport: new TestTransport( + onTransportStart: SyncPoint.Create(out var transportStart), + onTransportStop: SyncPoint.Create(out var transportStop))), + async (connection) => + { + // Start the connection and wait for the transport to start up. + var startTask = connection.StartAsync(TransferFormat.Text); + await transportStart.WaitForSyncPoint().OrTimeout(); - // Stop the connection while it's starting - var stopTask = connection.StopAsync(); - transportStart.Continue(); // We need to release Start in order for Stop to begin working. + // While the transport is starting, dispose the connection + var disposeTask = connection.DisposeAsync().OrTimeout(); + transportStart.Continue(); // We need to release StartAsync, because Dispose waits for it. - // Wait for start to finish, which will allow stop to finish and the connection to close. - await startTask.OrTimeout(); - await stopTask.OrTimeout(); - await closed.OrTimeout(); - }); + // Wait for start to finish, as that has to finish before the transport will be stopped. + await startTask.OrTimeout(); + + // Then release DisposeAsync (via the transport StopAsync call) + await transportStop.WaitForSyncPoint().OrTimeout(); + transportStop.Continue(); + + // Dispose should finish + await disposeTask; + }); } } [Fact] - public async Task StoppingStoppingConnectionNoOps() + public async Task CanDisposeDisposingConnection() { using (StartLog(out var loggerFactory)) { await WithConnectionAsync( - CreateConnection(loggerFactory: loggerFactory), - async (connection, closed) => + CreateConnection( + loggerFactory: loggerFactory, + transport: new TestTransport(onTransportStop: SyncPoint.Create(out var transportStop))), + async (connection) => { + // Start the connection await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await Task.WhenAll(connection.StopAsync(), connection.StopAsync()).OrTimeout(); - await closed.OrTimeout(); + + // Dispose the connection + var stopTask = connection.DisposeAsync().OrTimeout(); + + // Once the transport starts shutting down + await transportStop.WaitForSyncPoint(); + Assert.False(stopTask.IsCompleted); + + // Start disposing again, and then let the first dispose continue + var disposeTask = connection.DisposeAsync().OrTimeout(); + transportStop.Continue(); + + // Wait for the tasks to complete + await stopTask.OrTimeout(); + await disposeTask.OrTimeout(); + + // We should be disposed and thus unable to restart. + await AssertDisposedAsync(connection); }); } } [Fact] - public async Task CanStartConnectionAfterConnectionStoppedWithError() + public async Task TransportIsStoppedWhenConnectionIsDisposed() + { + var testHttpHandler = new TestHttpMessageHandler(); + + using (var httpClient = new HttpClient(testHttpHandler)) + { + var testTransport = new TestTransport(); + await WithConnectionAsync( + CreateConnection(transport: testTransport), + async (connection) => + { + // Start the transport + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + Assert.NotNull(testTransport.Receiving); + Assert.False(testTransport.Receiving.IsCompleted); + + // Stop the connection, and we should stop the transport + await connection.DisposeAsync().OrTimeout(); + await testTransport.Receiving.OrTimeout(); + }); + } + } + + [Fact] + public async Task TransportPipeIsCompletedWhenErrorOccursInTransport() { using (StartLog(out var loggerFactory)) { @@ -257,119 +273,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(httpHandler, loggerFactory), - async (connection, closed) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.SendAsync(new byte[] { 0x42 }).OrTimeout(); - - // Wait for the connection to close, because the send failed. - await Assert.ThrowsAsync(() => closed.OrTimeout()); - - // Start it up again - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - }); - } - } - - [Fact] - public async Task DisposedStoppingConnectionDisposesConnection() - { - using (StartLog(out var loggerFactory)) - { - await WithConnectionAsync( - CreateConnection( - loggerFactory: loggerFactory, - transport: new TestTransport(onTransportStop: SyncPoint.Create(out var transportStop))), - async (connection, closed) => - { - // Start the connection - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // Stop the connection - var stopTask = connection.StopAsync().OrTimeout(); - - // Once the transport starts shutting down - await transportStop.WaitForSyncPoint(); - - // Start disposing and allow it to finish shutting down - var disposeTask = connection.DisposeAsync().OrTimeout(); - transportStop.Continue(); - - // Wait for the tasks to complete - await stopTask.OrTimeout(); - await closed.OrTimeout(); - await disposeTask.OrTimeout(); - - // We should be disposed and thus unable to restart. - var exception = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text).OrTimeout()); - Assert.Equal("Cannot start a connection that is not in the Disconnected state.", exception.Message); - }); - } - } - - [Fact] - public async Task CanDisposeStoppedConnection() - { - using (StartLog(out var loggerFactory)) - { - await WithConnectionAsync( - CreateConnection(loggerFactory: loggerFactory), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.StopAsync().OrTimeout(); - await closed.OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + await connection.Transport.Output.WriteAsync(new byte[] { 0x42 }).OrTimeout(); + + // We should get the exception in the transport input completion. + await Assert.ThrowsAsync(() => connection.Transport.Input.WaitForWriterToComplete()); }); } } - [Fact] - public Task ClosedEventRaisedWhenTheClientIsDisposed() - { - return WithConnectionAsync( - CreateConnection(), - async (connection, closed) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - await closed.OrTimeout(); - }); - } - - [Fact] - public async Task ConnectionClosedWhenTransportFails() - { - var testTransport = new TestTransport(); - - var expected = new Exception("Whoops!"); - - await WithConnectionAsync( - CreateConnection(transport: testTransport), - async (connection, closed) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - testTransport.Application.Output.Complete(expected); - var actual = await Assert.ThrowsAsync(() => closed.OrTimeout()); - Assert.Same(expected, actual); - - var sendException = await Assert.ThrowsAsync(() => connection.SendAsync(new byte[0]).OrTimeout()); - Assert.Equal("Cannot send messages when the connection is not in the Connected state.", sendException.Message); - }); - } - - [Fact] - public Task ClosedEventNotRaisedWhenTheClientIsStoppedButWasNeverStarted() - { - return WithConnectionAsync( - CreateConnection(), - async (connection, closed) => - { - await connection.DisposeAsync().OrTimeout(); - Assert.False(closed.IsCompleted); - }); - } - [Fact] public async Task SSEWontStartIfSuccessfulConnectionIsNotEstablished() { @@ -386,7 +300,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(httpHandler, loggerFactory: loggerFactory, url: null, transport: sse), - async (connection, closed) => + async (connection) => { await Assert.ThrowsAsync( () => connection.StartAsync(TransferFormat.Text).OrTimeout()); @@ -412,7 +326,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(httpHandler, loggerFactory: loggerFactory, url: null, transport: sse), - async (connection, closed) => + async (connection) => { var startTask = connection.StartAsync(TransferFormat.Text).OrTimeout(); Assert.False(connectResponseTcs.Task.IsCompleted); @@ -423,30 +337,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } - [Fact] - public async Task TransportIsStoppedWhenConnectionIsStopped() + private static async Task AssertDisposedAsync(HttpConnection connection) { - var testHttpHandler = new TestHttpMessageHandler(); - - // Just keep returning data when polled - testHttpHandler.OnLongPoll(_ => ResponseUtils.CreateResponse(HttpStatusCode.OK)); - - using (var httpClient = new HttpClient(testHttpHandler)) - { - var longPollingTransport = new LongPollingTransport(httpClient); - await WithConnectionAsync( - CreateConnection(transport: longPollingTransport), - async (connection, closed) => - { - // Start the transport - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - Assert.False(longPollingTransport.Running.IsCompleted, "Expected that the transport would still be running"); - - // Stop the connection, and we should stop the transport - await connection.StopAsync().OrTimeout(); - await longPollingTransport.Running.OrTimeout(); - }); - } + var exception = + await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text).OrTimeout()); + Assert.Equal(nameof(HttpConnection), exception.ObjectName); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs index c691020420..39137ac18b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs @@ -43,102 +43,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } - private static async Task WithConnectionAsync(HttpConnection connection, Func body) + private static async Task WithConnectionAsync(HttpConnection connection, Func body) { try { - var closedTcs = new TaskCompletionSource(); - connection.Closed += ex => - { - if (ex != null) - { - closedTcs.SetException(ex); - } - else - { - closedTcs.SetResult(null); - } - }; - // Using OrTimeout here will hide any timeout issues in the test :(. - await body(connection, closedTcs.Task); + await body(connection); } finally { await connection.DisposeAsync().OrTimeout(); } } - - // Possibly useful as a general-purpose async testing helper? - private class SyncPoint - { - private TaskCompletionSource _atSyncPoint = new TaskCompletionSource(); - private TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(); - - /// - /// Waits for the code-under-test to reach . - /// - /// - public Task WaitForSyncPoint() => _atSyncPoint.Task; - - /// - /// Releases the code-under-test to continue past where it waited for . - /// - public void Continue() => _continueFromSyncPoint.TrySetResult(null); - - /// - /// Used by the code-under-test to wait for the test code to sync up. - /// - /// - /// This code will unblock and then block waiting for to be called. - /// - /// - public Task WaitToContinue() - { - _atSyncPoint.TrySetResult(null); - return _continueFromSyncPoint.Task; - } - - public static Func Create(out SyncPoint syncPoint) - { - var handler = Create(1, out var syncPoints); - syncPoint = syncPoints[0]; - return handler; - } - - /// - /// Creates a re-entrant function that waits for sync points in sequence. - /// - /// The number of sync points to expect - /// The objects that can be used to coordinate the sync point - /// - public static Func Create(int count, out SyncPoint[] syncPoints) - { - // Need to use a local so the closure can capture it. You can't use out vars in a closure. - var localSyncPoints = new SyncPoint[count]; - for (var i = 0; i < count; i += 1) - { - localSyncPoints[i] = new SyncPoint(); - } - - syncPoints = localSyncPoints; - - var counter = 0; - return () => - { - if (counter >= localSyncPoints.Length) - { - return Task.CompletedTask; - } - else - { - var syncPoint = localSyncPoints[counter]; - - counter += 1; - return syncPoint.WaitToContinue(); - } - }; - } - } } } + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs index cdf3921efb..a4da4dc7f0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs @@ -70,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(testHttpHandler, url: requestedUrl), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); }); @@ -95,17 +95,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests new { transport = "QuantumEntanglement", - transferFormats = new string[] { "Qbits" }, + transferFormats = new[] { "Qbits" }, }, new { transport = "CarrierPigeon", - transferFormats = new string[] { "Text" }, + transferFormats = new[] { "Text" }, }, new { transport = "LongPolling", - transferFormats = new string[] { "Text", "Binary" } + transferFormats = new[] { "Text", "Binary" } }, } })); @@ -118,7 +118,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(testHttpHandler, transportFactory: transportFactory.Object), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Binary).OrTimeout(); }); @@ -141,17 +141,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests new { transport = "WebSockets", - transferFormats = new string[] { "Qbits" }, + transferFormats = new[] { "Qbits" }, }, new { transport = "ServerSentEvents", - transferFormats = new string[] { "Text" }, + transferFormats = new[] { "Text" }, }, new { transport = "LongPolling", - transferFormats = new string[] { "Text", "Binary" } + transferFormats = new[] { "Text", "Binary" } }, } })); @@ -164,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(testHttpHandler, transportFactory: transportFactory.Object), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Binary).OrTimeout(); }); @@ -178,7 +178,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(testHttpHandler), - async (connection, closed) => + async (connection) => { var exception = await Assert.ThrowsAsync( () => connection.StartAsync(TransferFormat.Text).OrTimeout()); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs deleted file mode 100644 index 3f71fe400b..0000000000 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.OnReceived.cs +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Net; -using System.Text; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Client.Tests; -using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Sockets; -using Xunit; - -namespace Microsoft.AspNetCore.SignalR.Client.Tests -{ - public partial class HttpConnectionTests - { - public class OnReceived - { - [Fact] - public async Task CanReceiveData() - { - var testHttpHandler = new TestHttpMessageHandler(); - - testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.OK, "42")); - testHttpHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); - - await WithConnectionAsync( - CreateConnection(testHttpHandler), - async (connection, closed) => - { - var receiveTcs = new TaskCompletionSource(); - connection.OnReceived((data, state) => - { - var tcs = ((TaskCompletionSource)state); - tcs.TrySetResult(Encoding.UTF8.GetString(data)); - return Task.CompletedTask; - }, receiveTcs); - - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - Assert.Contains("42", await receiveTcs.Task.OrTimeout()); - }); - } - - [Fact] - public async Task CanReceiveDataEvenIfExceptionThrownFromPreviousReceivedEvent() - { - var testHttpHandler = new TestHttpMessageHandler(); - - testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.OK, "42")); - testHttpHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); - - await WithConnectionAsync( - CreateConnection(testHttpHandler), - async (connection, closed) => - { - var receiveTcs = new TaskCompletionSource(); - var receivedRaised = false; - connection.OnReceived((data, state) => - { - if (!receivedRaised) - { - receivedRaised = true; - return Task.FromException(new InvalidOperationException()); - } - - receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); - return Task.CompletedTask; - }, receiveTcs); - - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - Assert.Contains("42", await receiveTcs.Task.OrTimeout()); - Assert.True(receivedRaised); - }); - } - - [Fact] - public async Task CanReceiveDataEvenIfExceptionThrownSynchronouslyFromPreviousReceivedEvent() - { - var testHttpHandler = new TestHttpMessageHandler(); - - testHttpHandler.OnLongPoll(cancellationToken => ResponseUtils.CreateResponse(HttpStatusCode.OK, "42")); - testHttpHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); - - await WithConnectionAsync( - CreateConnection(testHttpHandler), - async (connection, closed) => - { - var receiveTcs = new TaskCompletionSource(); - var receivedRaised = false; - connection.OnReceived((data, state) => - { - if (!receivedRaised) - { - receivedRaised = true; - throw new InvalidOperationException(); - } - - receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); - return Task.CompletedTask; - }, receiveTcs); - - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - Assert.Contains("42", await receiveTcs.Task.OrTimeout()); - Assert.True(receivedRaised); - }); - } - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs similarity index 57% rename from test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs rename to test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs index 968be6b051..a4ced7e517 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.SendAsync.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs @@ -2,20 +2,53 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO.Pipelines; using System.Net; using System.Net.Http; +using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Client.Http; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HttpConnectionTests { - public class SendAsync + public class Transport { + [Fact] + public async Task CanReceiveData() + { + var testHttpHandler = new TestHttpMessageHandler(); + + // Set the long poll up to return a single message over a few polls. + var requestCount = 0; + var messageFragments = new[] {"This ", "is ", "a ", "test"}; + testHttpHandler.OnLongPoll(cancellationToken => + { + if (requestCount >= messageFragments.Length) + { + return ResponseUtils.CreateResponse(HttpStatusCode.NoContent); + } + + var resp = ResponseUtils.CreateResponse(HttpStatusCode.OK, messageFragments[requestCount]); + requestCount += 1; + return resp; + }); + testHttpHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + await WithConnectionAsync( + CreateConnection(testHttpHandler), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + Assert.Contains("This is a test", Encoding.UTF8.GetString(await connection.Transport.Input.ReadAllAsync())); + }); + } + [Fact] public async Task CanSendData() { @@ -36,11 +69,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(testHttpHandler), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.SendAsync(data).OrTimeout(); + await connection.Transport.Output.WriteAsync(data).OrTimeout(); Assert.Equal(data, await sendTcs.Task.OrTimeout()); @@ -53,74 +86,44 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { return WithConnectionAsync( CreateConnection(), - async (connection, closed) => + async (connection) => { var exception = await Assert.ThrowsAsync( - () => connection.SendAsync(new byte[0]).OrTimeout()); - Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); + () => connection.Transport.Output.WriteAsync(new byte[0]).OrTimeout()); + Assert.Equal($"Cannot access the {nameof(Transport)} pipe before the connection has started.", exception.Message); }); } [Fact] - public Task SendThrowsIfConnectionIsStopped() + public Task TransportPipeCannotBeAccessedAfterConnectionIsDisposed() { return WithConnectionAsync( CreateConnection(), - async (connection, closed) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.StopAsync().OrTimeout(); - - var exception = await Assert.ThrowsAsync( - () => connection.SendAsync(new byte[0]).OrTimeout()); - Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); - }); - } - - [Fact] - public Task SendThrowsIfConnectionIsDisposed() - { - return WithConnectionAsync( - CreateConnection(), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); await connection.DisposeAsync().OrTimeout(); - var exception = await Assert.ThrowsAsync( - () => connection.SendAsync(new byte[0]).OrTimeout()); - Assert.Equal("Cannot send messages when the connection is not in the Connected state.", exception.Message); + var exception = await Assert.ThrowsAsync( + () => connection.Transport.Output.WriteAsync(new byte[0]).OrTimeout()); + Assert.Equal(nameof(HttpConnection), exception.ObjectName); }); } [Fact] - public async Task ExceptionOnSendAsyncClosesWithError() + public Task TransportIsShutDownAfterDispose() { - var testHttpHandler = new TestHttpMessageHandler(); - - var longPollTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - testHttpHandler.OnLongPoll(cancellationToken => - { - cancellationToken.Register(() => longPollTcs.TrySetResult(null)); - - return longPollTcs.Task; - }); - - testHttpHandler.OnSocketSend((buf, cancellationToken) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError)); - }); - - await WithConnectionAsync( - CreateConnection(testHttpHandler), - async (connection, closed) => + var transport = new TestTransport(); + return WithConnectionAsync( + CreateConnection(transport: transport), + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); + await connection.DisposeAsync().OrTimeout(); - await connection.SendAsync(new byte[] { 0 }).OrTimeout(); - - var exception = await Assert.ThrowsAsync(() => closed.OrTimeout()); + // This will throw OperationCancelledException if it's forcibly terminated + // which we don't want + await transport.Receiving.OrTimeout(); }); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 27227c1b4a..381ec22886 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -2,17 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net; using System.Net.Http; using System.Security.Cryptography.X509Certificates; -using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; -using Microsoft.AspNetCore.Sockets.Client.Http.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Testing; @@ -54,87 +50,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests () => new HttpConnection(new Uri("http://fakeuri.org/"), requestedTransportType)); } - [Fact] - public async Task EventsAreNotRunningOnMainLoop() - { - var testTransport = new TestTransport(); - - await WithConnectionAsync( - CreateConnection(transport: testTransport), - async (connection, closed) => - { - // Block up the OnReceived callback until we finish the test. - var onReceived = new SyncPoint(); - connection.OnReceived(_ => onReceived.WaitToContinue().OrTimeout()); - - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - - // This will trigger the received callback - await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); - - // Wait to hit the sync point. We are now blocking up the TaskQueue - await onReceived.WaitForSyncPoint().OrTimeout(); - - // Now we write something else and we want to test that the HttpConnection receive loop is still - // removing items from the channel even though OnReceived is blocked up. - await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); - - // Now that we've written, we wait for WaitToReadAsync to return an INCOMPLETE task. It will do so - // once HttpConnection reads the message. We also use a CTS to timeout in case the loop is indeed blocked - var cts = new CancellationTokenSource(); - cts.CancelAfter(TimeSpan.FromSeconds(5)); - while (testTransport.Application.Input.WaitToReadAsync().IsCompleted && !cts.IsCancellationRequested) - { - // Yield to allow the HttpConnection to dequeue the message - await Task.Yield(); - } - - // If we exited because we were cancelled, throw. - cts.Token.ThrowIfCancellationRequested(); - - // We're free! Unblock onreceived - onReceived.Continue(); - }); - } - - [Fact] - public async Task EventQueueTimeout() - { - using (StartLog(out var loggerFactory)) - { - var logger = loggerFactory.CreateLogger(); - - var testTransport = new TestTransport(); - - await WithConnectionAsync( - CreateConnection(transport: testTransport), - async (connection, closed) => - { - var onReceived = new SyncPoint(); - connection.OnReceived(_ => onReceived.WaitToContinue().OrTimeout()); - - logger.LogInformation("Starting connection"); - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - logger.LogInformation("Started connection"); - - await testTransport.Application.Output.WriteAsync(new byte[] { 1 }); - await onReceived.WaitForSyncPoint().OrTimeout(); - - // Dispose should complete, even though the receive callbacks are completely blocked up. - logger.LogInformation("Disposing connection"); - await connection.DisposeAsync().OrTimeout(TimeSpan.FromSeconds(10)); - logger.LogInformation("Disposed connection"); - - // Clear up blocked tasks. - onReceived.Continue(); - }); - } - } - [Fact] public async Task HttpOptionsSetOntoHttpClientHandler() { - var testHttpHandler = new TestHttpMessageHandler(); + var testHttpHandler = TestHttpMessageHandler.CreateDefault(); var negotiateUrlTcs = new TaskCompletionSource(); testHttpHandler.OnNegotiate((request, cancellationToken) => @@ -146,7 +65,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests HttpClientHandler httpClientHandler = null; - HttpOptions httpOptions = new HttpOptions(); + var httpOptions = new HttpOptions(); httpOptions.HttpMessageHandler = inner => { httpClientHandler = (HttpClientHandler)inner; @@ -161,7 +80,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await WithConnectionAsync( CreateConnection(httpOptions, url: "http://fakeuri.org/"), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); }); @@ -198,7 +117,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await WithConnectionAsync( CreateConnection(httpOptions, loggerFactory: mockLoggerFactory.Object, url: "http://fakeuri.org/"), - async (connection, closed) => + async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); }); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs deleted file mode 100644 index 950c68e752..0000000000 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs +++ /dev/null @@ -1,204 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.Extensions.Logging; -using Newtonsoft.Json; -using Xunit; - -namespace Microsoft.AspNetCore.SignalR.Client.Tests -{ - public class HubConnectionExtensionsTests - { - [Fact] - public async Task On() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - () => tcs.SetResult(new object[0])), - new object[0]); - } - - [Fact] - public async Task OnT1() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - r => tcs.SetResult(new object[] { r })), - new object[] { 42 }); - } - - [Fact] - public async Task OnT2() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2) => tcs.SetResult(new object[] { r1, r2 })), - new object[] { 42, "abc" }); - } - - [Fact] - public async Task OnT3() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3) => tcs.SetResult(new object[] { r1, r2, r3 })), - new object[] { 42, "abc", 24.0f }); - } - - [Fact] - public async Task OnT4() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3, r4) => tcs.SetResult(new object[] { r1, r2, r3, r4 })), - new object[] { 42, "abc", 24.0f, 10d }); - } - - [Fact] - public async Task OnT5() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3, r4, r5) => tcs.SetResult(new object[] { r1, r2, r3, r4, r5 })), - new object[] { 42, "abc", 24.0f, 10d, "123" }); - } - - [Fact] - public async Task OnT6() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3, r4, r5, r6) => tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6 })), - new object[] { 42, "abc", 24.0f, 10d, "123", 24 }); - } - - [Fact] - public async Task OnT7() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3, r4, r5, r6, r7) => tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7 })), - new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c' }); - } - - [Fact] - public async Task OnT8() - { - await InvokeOn( - (hubConnection, tcs) => hubConnection.On("Foo", - (r1, r2, r3, r4, r5, r6, r7, r8) => tcs.SetResult(new object[] { r1, r2, r3, r4, r5, r6, r7, r8 })), - new object[] { 42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ" }); - } - - private async Task InvokeOn(Action> onAction, object[] args) - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - var handlerTcs = new TaskCompletionSource(); - try - { - onAction(hubConnection, handlerTcs); - await hubConnection.StartAsync(); - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await connection.ReceiveJsonMessage( - new - { - invocationId = "1", - type = 1, - target = "Foo", - arguments = args - }).OrTimeout(); - - var result = await handlerTcs.Task.OrTimeout(); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task ConnectionNotClosedOnCallbackArgumentCountMismatch() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - var receiveTcs = new TaskCompletionSource(); - - try - { - hubConnection.On("Foo", r => { receiveTcs.SetResult(r); }); - await hubConnection.StartAsync().OrTimeout(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await connection.ReceiveJsonMessage( - new - { - invocationId = "1", - type = 1, - target = "Foo", - arguments = new object[] { 42, "42" } - }).OrTimeout(); - - await connection.ReceiveJsonMessage( - new - { - invocationId = "2", - type = 1, - target = "Foo", - arguments = new object[] { 42 } - }).OrTimeout(); - - Assert.Equal(42, await receiveTcs.Task.OrTimeout()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task ConnectionNotClosedOnCallbackArgumentTypeMismatch() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - var receiveTcs = new TaskCompletionSource(); - - try - { - hubConnection.On("Foo", r => { receiveTcs.SetResult(r); }); - await hubConnection.StartAsync().OrTimeout(); - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await connection.ReceiveJsonMessage( - new - { - invocationId = "1", - type = 1, - target = "Foo", - arguments = new object[] { "xxx" } - }).OrTimeout(); - - await connection.ReceiveJsonMessage( - new - { - invocationId = "2", - type = 1, - target = "Foo", - arguments = new object[] { 42 } - }).OrTimeout(); - - Assert.Equal(42, await receiveTcs.Task.OrTimeout()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - } - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs deleted file mode 100644 index c8453ccc58..0000000000 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ /dev/null @@ -1,429 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Channels; -using System.Threading.Tasks; -using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.Extensions.Logging; -using Xunit; - -namespace Microsoft.AspNetCore.SignalR.Client.Tests -{ - // This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer). - // We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks - // don't cause problems. - public class HubConnectionProtocolTests - { - [Fact] - public async Task SendAsyncSendsANonBlockingInvocationMessage() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.SendAsync("Foo"); - - var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); - - Assert.Equal("{\"type\":1,\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task ClientSendsHandshakeMessageWhenStartingConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - var handshakeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); - - Assert.Equal("{\"protocol\":\"json\",\"version\":1}\u001e", handshakeMessage); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task InvokeSendsAnInvocationMessage() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); - - Assert.Equal("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task ReceiveCloseMessageWithoutErrorWillCloseHubConnection() - { - TaskCompletionSource closedTcs = new TaskCompletionSource(); - - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - hubConnection.Closed += e => closedTcs.SetResult(e); - - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await connection.ReceiveJsonMessage(new { type = 7 }).OrTimeout(); - - Exception closeException = await closedTcs.Task.OrTimeout(); - Assert.Null(closeException); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task ReceiveCloseMessageWithErrorWillCloseHubConnection() - { - TaskCompletionSource closedTcs = new TaskCompletionSource(); - - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - hubConnection.Closed += e => closedTcs.SetResult(e); - - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await connection.ReceiveJsonMessage(new { type = 7, error = "Error!" }).OrTimeout(); - - Exception closeException = await closedTcs.Task.OrTimeout(); - Assert.NotNull(closeException); - Assert.Equal("Error!", closeException.Message); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task StreamSendsAnInvocationMessage() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo"); - - var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); - - Assert.Equal("{\"type\":4,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage); - - // Complete the channel - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - await channel.Completion; - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task InvokeCompletedWhenCompletionMessageReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - - await invokeTask.OrTimeout(); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task StreamCompletesWhenCompletionMessageIsReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - - Assert.Empty(await channel.ReadAllAsync()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task InvokeYieldsResultWhenCompletionMessageReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout(); - - Assert.Equal(42, await invokeTask.OrTimeout()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); - - var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); - Assert.Equal("An error occurred", ex.Message); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task StreamFailsIfCompletionMessageHasPayload() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout(); - - var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); - Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task StreamFailsWithExceptionWhenCompletionWithErrorReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); - - var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); - Assert.Equal("An error occurred", ex.Message); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task InvokeFailsWithErrorWhenStreamingItemReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var invokeTask = hubConnection.InvokeAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = 42 }).OrTimeout(); - - var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); - Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsChannelAsync' method.", ex.Message); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task StreamYieldsItemsAsTheyArrive() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - var channel = await hubConnection.StreamAsChannelAsync("Foo"); - - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "1" }).OrTimeout(); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "2" }).OrTimeout(); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "3" }).OrTimeout(); - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - - var notifications = await channel.ReadAllAsync().OrTimeout(); - - Assert.Equal(new[] { "1", "2", "3", }, notifications.ToArray()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - var handlerCalled = new TaskCompletionSource(); - try - { - await hubConnection.StartAsync(); - - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - hubConnection.On("Foo", (r1, r2, r3) => handlerCalled.TrySetResult(new object[] { r1, r2, r3 })); - - var args = new object[] { 1, "Foo", 2.0f }; - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 1, target = "Foo", arguments = args }).OrTimeout(); - - Assert.Equal(args, await handlerCalled.Task.OrTimeout()); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - - [Fact] - public async Task AcceptsPingMessages() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, - new JsonHubProtocol(), new LoggerFactory()); - - try - { - await hubConnection.StartAsync().OrTimeout(); - - // Ignore handshake message - await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - // Send an invocation - var invokeTask = hubConnection.InvokeAsync("Foo"); - - // Receive the ping mid-invocation so we can see that the rest of the flow works fine - await connection.ReceiveJsonMessage(new { type = 6 }).OrTimeout(); - - // Receive a completion - await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - - // Ensure the invokeTask completes properly - await invokeTask.OrTimeout(); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); - } - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs new file mode 100644 index 0000000000..37a51a4337 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs @@ -0,0 +1,351 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets.Client; +using Newtonsoft.Json.Linq; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public partial class HubConnectionTests + { + public class ConnectionLifecycle + { + // This tactic (using names and a dictionary) allows non-serializable data (like a Func) to be used in a theory AND get it to show in the new hierarchical view in Test Explorer as separate tests you can run individually. + private static readonly IDictionary> MethodsThatRequireActiveConnection = new Dictionary>() + { + { nameof(HubConnection.InvokeAsync), (connection) => connection.InvokeAsync("Foo") }, + { nameof(HubConnection.SendAsync), (connection) => connection.SendAsync("Foo") }, + { nameof(HubConnection.StreamAsChannelAsync), (connection) => connection.StreamAsChannelAsync("Foo") }, + }; + + public static IEnumerable MethodsNamesThatRequireActiveConnection => MethodsThatRequireActiveConnection.Keys.Select(k => new object[] { k }); + + [Fact] + public async Task StartAsyncStartsTheUnderlyingConnection() + { + var testConnection = new TestConnection(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + await connection.StartAsync(); + Assert.True(testConnection.Started.IsCompleted); + }); + } + + [Fact] + public async Task StartAsyncWaitsForPreviousStartIfAlreadyStarting() + { + // Set up StartAsync to wait on the syncPoint when starting + var testConnection = new TestConnection(onStart: SyncPoint.Create(out var syncPoint)); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + var firstStart = connection.StartAsync().OrTimeout(); + Assert.False(firstStart.IsCompleted); + + // Wait for us to be in IConnection.StartAsync + await syncPoint.WaitForSyncPoint(); + + // Try starting again + var secondStart = connection.StartAsync().OrTimeout(); + Assert.False(secondStart.IsCompleted); + + // Release the sync point + syncPoint.Continue(); + + // Both starts should finish fine + await firstStart; + await secondStart; + }); + } + + [Fact] + public async Task StartingAfterStopCreatesANewConnection() + { + // Set up StartAsync to wait on the syncPoint when starting + var createCount = 0; + IConnection ConnectionFactory() + { + createCount += 1; + return new TestConnection(); + } + + await AsyncUsing(new HubConnection(ConnectionFactory, new JsonHubProtocol()), async connection => + { + await connection.StartAsync().OrTimeout(); + Assert.Equal(1, createCount); + await connection.StopAsync().OrTimeout(); + + await connection.StartAsync().OrTimeout(); + Assert.Equal(2, createCount); + }); + } + + [Fact] + public async Task StartingDuringStopCreatesANewConnection() + { + // Set up StartAsync to wait on the syncPoint when starting + var createCount = 0; + var onDisposeForFirstConnection = SyncPoint.Create(out var syncPoint); + IConnection ConnectionFactory() + { + createCount += 1; + return new TestConnection(onDispose: createCount == 1 ? onDisposeForFirstConnection : null); + } + + await AsyncUsing(new HubConnection(ConnectionFactory, new JsonHubProtocol()), async connection => + { + await connection.StartAsync().OrTimeout(); + Assert.Equal(1, createCount); + + var stopTask = connection.StopAsync().OrTimeout(); + + // Wait to hit DisposeAsync on TestConnection (which should be after StopAsync has cleared the connection state) + await syncPoint.WaitForSyncPoint(); + + // We should be able to start now, and StopAsync hasn't completed, nor will it complete while Starting + Assert.False(stopTask.IsCompleted); + await connection.StartAsync().OrTimeout(); + Assert.False(stopTask.IsCompleted); + + // When we release the sync point, the StopAsync task will finish + syncPoint.Continue(); + await stopTask; + }); + } + + [Theory] + [MemberData(nameof(MethodsNamesThatRequireActiveConnection))] + public async Task MethodsThatRequireStartedConnectionFailIfConnectionNotYetStarted(string name) + { + var method = MethodsThatRequireActiveConnection[name]; + + var testConnection = new TestConnection(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + var ex = await Assert.ThrowsAsync(() => method(connection)); + Assert.Equal($"The '{name}' method cannot be called if the connection is not active", ex.Message); + }); + } + + [Theory] + [MemberData(nameof(MethodsNamesThatRequireActiveConnection))] + public async Task MethodsThatRequireStartedConnectionWaitForStartIfConnectionIsCurrentlyStarting(string name) + { + var method = MethodsThatRequireActiveConnection[name]; + + // Set up StartAsync to wait on the syncPoint when starting + var testConnection = new TestConnection(onStart: SyncPoint.Create(out var syncPoint)); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + // Start, and wait for the sync point to be hit + var startTask = connection.StartAsync().OrTimeout(); + Assert.False(startTask.IsCompleted); + await syncPoint.WaitForSyncPoint(); + + // Run the method, but it will be waiting for the lock + var targetTask = method(connection).OrTimeout(); + + // Release the SyncPoint + syncPoint.Continue(); + + // Wait for start to finish + await startTask; + + // We need some special logic to ensure InvokeAsync completes. + if (string.Equals(name, nameof(HubConnection.InvokeAsync))) + { + await ForceLastInvocationToComplete(testConnection); + } + + // Wait for the method to complete. + await targetTask; + }); + } + + [Fact] + public async Task StopAsyncStopsConnection() + { + var testConnection = new TestConnection(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + await connection.StartAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + + await connection.StopAsync().OrTimeout(); + Assert.True(testConnection.Disposed.IsCompleted); + }); + } + + [Fact] + public async Task StopAsyncNoOpsIfConnectionNotYetStarted() + { + var testConnection = new TestConnection(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + await connection.StopAsync().OrTimeout(); + Assert.False(testConnection.Disposed.IsCompleted); + }); + } + + [Fact] + public async Task StopAsyncNoOpsIfConnectionAlreadyStopped() + { + var testConnection = new TestConnection(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + await connection.StartAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + + await connection.StopAsync().OrTimeout(); + Assert.True(testConnection.Disposed.IsCompleted); + + await connection.StopAsync().OrTimeout(); + }); + } + + [Fact] + public async Task CompletingTheTransportSideMarksConnectionAsClosed() + { + var testConnection = new TestConnection(); + var closed = new TaskCompletionSource(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + connection.Closed += (e) => closed.TrySetResult(null); + await connection.StartAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + + // Complete the transport side and wait for the connection to close + testConnection.CompleteFromTransport(); + await closed.Task.OrTimeout(); + + // We should be stopped now + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + Assert.Equal($"The '{nameof(HubConnection.SendAsync)}' method cannot be called if the connection is not active", ex.Message); + }); + } + + [Fact] + public async Task TransportCompletionWhileShuttingDownIsNoOp() + { + var testConnection = new TestConnection(); + var testConnectionClosed = new TaskCompletionSource(); + var connectionClosed = new TaskCompletionSource(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + // We're hooking the TestConnection shutting down here because the HubConnection one will be blocked on the lock + testConnection.Transport.Input.OnWriterCompleted((_, __) => testConnectionClosed.TrySetResult(null), null); + connection.Closed += (e) => connectionClosed.TrySetResult(null); + + await connection.StartAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + + // Start shutting down and complete the transport side + var stopTask = connection.StopAsync().OrTimeout(); + testConnection.CompleteFromTransport(); + + // Wait for the connection to close. + await testConnectionClosed.Task.OrTimeout(); + + // The stop should be completed. + await stopTask; + + // The HubConnection should now be closed. + await connectionClosed.Task.OrTimeout(); + + // We should be stopped now + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + Assert.Equal($"The '{nameof(HubConnection.SendAsync)}' method cannot be called if the connection is not active", ex.Message); + + Assert.Equal(1, testConnection.DisposeCount); + }); + } + + [Fact] + public async Task StopAsyncDuringUnderlyingConnectionCloseWaitsAndNoOps() + { + var testConnection = new TestConnection(); + var connectionClosed = new TaskCompletionSource(); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + connection.Closed += (e) => connectionClosed.TrySetResult(null); + + await connection.StartAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + + // Complete the transport side and wait for the connection to close + testConnection.CompleteFromTransport(); + + // Start stopping manually (these can't be synchronized by a Sync Point because the transport is disposed outside the lock) + var stopTask = connection.StopAsync().OrTimeout(); + + await testConnection.Disposed.OrTimeout(); + + // Wait for the stop task to complete and the closed event to fire + await stopTask; + await connectionClosed.Task.OrTimeout(); + + // We should be stopped now + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + Assert.Equal($"The '{nameof(HubConnection.SendAsync)}' method cannot be called if the connection is not active", ex.Message); + }); + } + + [Theory] + [MemberData(nameof(MethodsNamesThatRequireActiveConnection))] + public async Task MethodsThatRequireActiveConnectionWaitForStopAndFailIfConnectionIsCurrentlyStopping(string methodName) + { + var method = MethodsThatRequireActiveConnection[methodName]; + + // Set up StartAsync to wait on the syncPoint when starting + var testConnection = new TestConnection(onDispose: SyncPoint.Create(out var syncPoint)); + await AsyncUsing(new HubConnection(() => testConnection, new JsonHubProtocol()), async connection => + { + await connection.StartAsync().OrTimeout(); + + // Stop and invoke the method. These two aren't synchronizable via a Sync Point any more because the transport is disposed + // outside the lock :( + var disposeTask = connection.StopAsync().OrTimeout(); + var targetTask = method(connection).OrTimeout(); + + // Release the sync point + syncPoint.Continue(); + + // Wait for the method to complete, with an expected error. + var ex = await Assert.ThrowsAsync(() => targetTask); + Assert.Equal($"The '{methodName}' method cannot be called if the connection is not active", ex.Message); + + await disposeTask; + }); + } + + private static async Task ForceLastInvocationToComplete(TestConnection testConnection) + { + // We need to "complete" the invocation + var message = await testConnection.ReadSentTextMessageAsync(); + var json = JObject.Parse(message); // Gotta remove the record separator. + await testConnection.ReceiveJsonMessage(new + { + type = HubProtocolConstants.CompletionMessageType, + invocationId = json["invocationId"], + }); + } + + // A helper that we wouldn't want to use in product code, but is fine for testing until IAsyncDisposable arrives :) + private static async Task AsyncUsing(HubConnection connection, Func action) + { + try + { + await action(connection); + } + finally + { + // Dispose isn't under test here, so fire and forget so that errors/timeouts here don't cause + // test errors that mask the real errors. + _ = connection.DisposeAsync(); + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Extensions.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Extensions.cs new file mode 100644 index 0000000000..7d47e53c0f --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Extensions.cs @@ -0,0 +1,202 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; +using Newtonsoft.Json; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public partial class HubConnectionTests + { + public class Extensions + { + [Fact] + public async Task On() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + () => tcs.SetResult(new object[0])), + new object[0]); + } + + [Fact] + public async Task OnT1() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + r => tcs.SetResult(new object[] {r})), + new object[] {42}); + } + + [Fact] + public async Task OnT2() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2) => tcs.SetResult(new object[] {r1, r2})), + new object[] {42, "abc"}); + } + + [Fact] + public async Task OnT3() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3) => tcs.SetResult(new object[] {r1, r2, r3})), + new object[] {42, "abc", 24.0f}); + } + + [Fact] + public async Task OnT4() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4) => tcs.SetResult(new object[] {r1, r2, r3, r4})), + new object[] {42, "abc", 24.0f, 10d}); + } + + [Fact] + public async Task OnT5() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5) => tcs.SetResult(new object[] {r1, r2, r3, r4, r5})), + new object[] {42, "abc", 24.0f, 10d, "123"}); + } + + [Fact] + public async Task OnT6() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6) => tcs.SetResult(new object[] {r1, r2, r3, r4, r5, r6})), + new object[] {42, "abc", 24.0f, 10d, "123", 24}); + } + + [Fact] + public async Task OnT7() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7) => tcs.SetResult(new object[] {r1, r2, r3, r4, r5, r6, r7})), + new object[] {42, "abc", 24.0f, 10d, "123", 24, 'c'}); + } + + [Fact] + public async Task OnT8() + { + await InvokeOn( + (hubConnection, tcs) => hubConnection.On("Foo", + (r1, r2, r3, r4, r5, r6, r7, r8) => tcs.SetResult(new object[] {r1, r2, r3, r4, r5, r6, r7, r8})), + new object[] {42, "abc", 24.0f, 10d, "123", 24, 'c', "XYZ"}); + } + + private async Task InvokeOn(Action> onAction, object[] args) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var handlerTcs = new TaskCompletionSource(); + try + { + onAction(hubConnection, handlerTcs); + await hubConnection.StartAsync(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "1", + type = 1, + target = "Foo", + arguments = args + }).OrTimeout(); + + await handlerTcs.Task.OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ConnectionNotClosedOnCallbackArgumentCountMismatch() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var receiveTcs = new TaskCompletionSource(); + + try + { + hubConnection.On("Foo", r => { receiveTcs.SetResult(r); }); + await hubConnection.StartAsync().OrTimeout(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "1", + type = 1, + target = "Foo", + arguments = new object[] {42, "42"} + }).OrTimeout(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "2", + type = 1, + target = "Foo", + arguments = new object[] {42} + }).OrTimeout(); + + Assert.Equal(42, await receiveTcs.Task.OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ConnectionNotClosedOnCallbackArgumentTypeMismatch() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var receiveTcs = new TaskCompletionSource(); + + try + { + hubConnection.On("Foo", r => { receiveTcs.SetResult(r); }); + await hubConnection.StartAsync().OrTimeout(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "1", + type = 1, + target = "Foo", + arguments = new object[] {"xxx"} + }).OrTimeout(); + + await connection.ReceiveJsonMessage( + new + { + invocationId = "2", + type = 1, + target = "Foo", + arguments = new object[] {42} + }).OrTimeout(); + + Assert.Equal(42, await receiveTcs.Task.OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Helpers.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Helpers.cs new file mode 100644 index 0000000000..6aa20afff0 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Helpers.cs @@ -0,0 +1,16 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public partial class HubConnectionTests + { + private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null) + { + return new HubConnection(() => connection, protocol ?? new JsonHubProtocol(), new LoggerFactory()); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs new file mode 100644 index 0000000000..2cabce8d80 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.Protocol.cs @@ -0,0 +1,408 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading.Channels; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + // This includes tests that verify HubConnection conforms to the Hub Protocol, without setting up a full server (even TestServer). + // We can also have more control over the messages we send to HubConnection in order to ensure that protocol errors and other quirks + // don't cause problems. + public partial class HubConnectionTests + { + public class Protocol + { + [Fact] + public async Task SendAsyncSendsANonBlockingInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.SendAsync("Foo").OrTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) + Assert.Equal("{\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ClientSendsHandshakeMessageWhenStartingConnection() + { + var connection = new TestConnection(autoNegotiate: false); + var hubConnection = CreateHubConnection(connection); + try + { + // We can't await StartAsync because it depends on the negotiate process! + var startTask = hubConnection.StartAsync().OrTimeout(); + + var handshakeMessage = await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + + // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) + Assert.Equal("{\"protocol\":\"json\",\"version\":1}", handshakeMessage); + + await startTask; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeSendsAnInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) + Assert.Equal("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ReceiveCloseMessageWithoutErrorWillCloseHubConnection() + { + TaskCompletionSource closedTcs = new TaskCompletionSource(); + + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + hubConnection.Closed += e => closedTcs.SetResult(e); + + try + { + await hubConnection.StartAsync().OrTimeout(); + + await connection.ReceiveJsonMessage(new {type = 7}).OrTimeout(); + + Exception closeException = await closedTcs.Task.OrTimeout(); + Assert.Null(closeException); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ReceiveCloseMessageWithErrorWillCloseHubConnection() + { + TaskCompletionSource closedTcs = new TaskCompletionSource(); + + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + hubConnection.Closed += e => closedTcs.SetResult(e); + + try + { + await hubConnection.StartAsync().OrTimeout(); + + await connection.ReceiveJsonMessage(new {type = 7, error = "Error!"}).OrTimeout(); + + Exception closeException = await closedTcs.Task.OrTimeout(); + Assert.NotNull(closeException); + Assert.Equal("The server closed the connection with the following error: Error!", closeException.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamSendsAnInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); + + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) + Assert.Equal("{\"type\":4,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + + // Complete the channel + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + await channel.Completion; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeCompletedWhenCompletionMessageReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + + await invokeTask.OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamCompletesWhenCompletionMessageIsReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + + Assert.Empty(await channel.ReadAllAsync()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeYieldsResultWhenCompletionMessageReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, result = 42}).OrTimeout(); + + Assert.Equal(42, await invokeTask.OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, error = "An error occurred"}).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("An error occurred", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamFailsIfCompletionMessageHasPayload() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, result = "Oops"}).OrTimeout(); + + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); + Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamFailsWithExceptionWhenCompletionWithErrorReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3, error = "An error occurred"}).OrTimeout(); + + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAllAsync().OrTimeout()); + Assert.Equal("An error occurred", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task InvokeFailsWithErrorWhenStreamingItemReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = 42}).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => invokeTask).OrTimeout(); + Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsChannelAsync' method.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StreamYieldsItemsAsTheyArrive() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var channel = await hubConnection.StreamAsChannelAsync("Foo").OrTimeout(); + + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "1"}).OrTimeout(); + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "2"}).OrTimeout(); + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 2, item = "3"}).OrTimeout(); + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + + var notifications = await channel.ReadAllAsync().OrTimeout(); + + Assert.Equal(new[] {"1", "2", "3",}, notifications.ToArray()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task HandlerRegisteredWithOnIsFiredWhenInvocationReceived() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + var handlerCalled = new TaskCompletionSource(); + try + { + await hubConnection.StartAsync().OrTimeout(); + + hubConnection.On("Foo", (r1, r2, r3) => handlerCalled.TrySetResult(new object[] {r1, r2, r3})); + + var args = new object[] {1, "Foo", 2.0f}; + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 1, target = "Foo", arguments = args}).OrTimeout(); + + Assert.Equal(args, await handlerCalled.Task.OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task AcceptsPingMessages() + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection); + + try + { + await hubConnection.StartAsync().OrTimeout(); + + // Send an invocation + var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + + // Receive the ping mid-invocation so we can see that the rest of the flow works fine + await connection.ReceiveJsonMessage(new {type = 6}).OrTimeout(); + + // Receive a completion + await connection.ReceiveJsonMessage(new {invocationId = "1", type = 3}).OrTimeout(); + + // Ensure the invokeTask completes properly + await invokeTask.OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 8206aa4dfd..d043e00c1b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -16,44 +16,17 @@ using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests { - public class HubConnectionTests + public partial class HubConnectionTests { - [Fact] - public async Task StartAsyncCallsConnectionStart() - { - var connection = new Mock(); - var protocol = new Mock(); - protocol.SetupGet(p => p.TransferFormat).Returns(TransferFormat.Text); - connection.SetupGet(p => p.Features).Returns(new FeatureCollection()); - connection.Setup(m => m.StartAsync(TransferFormat.Text)).Returns(Task.CompletedTask).Verifiable(); - var hubConnection = new HubConnection(connection.Object, protocol.Object, null); - await hubConnection.StartAsync(); - - connection.Verify(c => c.StartAsync(TransferFormat.Text), Times.Once()); - } - - [Fact] - public async Task DisposeAsyncCallsConnectionStart() - { - var connection = new Mock(); - connection.Setup(m => m.Features).Returns(new FeatureCollection()); - connection.Setup(m => m.StartAsync(TransferFormat.Text)).Verifiable(); - var hubConnection = new HubConnection(connection.Object, Mock.Of(), null); - await hubConnection.DisposeAsync(); - - connection.Verify(c => c.DisposeAsync(), Times.Once()); - } - [Fact] public async Task InvokeThrowsIfSerializingMessageFails() { var exception = new InvalidOperationException(); - var mockProtocol = MockHubProtocol.Throw(exception); - var hubConnection = new HubConnection(new TestConnection(), mockProtocol, null); - await hubConnection.StartAsync(); + var hubConnection = CreateHubConnection(new TestConnection(), protocol: MockHubProtocol.Throw(exception)); + await hubConnection.StartAsync().OrTimeout(); var actualException = - await Assert.ThrowsAsync(async () => await hubConnection.InvokeAsync("test")); + await Assert.ThrowsAsync(async () => await hubConnection.InvokeAsync("test").OrTimeout()); Assert.Same(exception, actualException); } @@ -61,133 +34,49 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task SendAsyncThrowsIfSerializingMessageFails() { var exception = new InvalidOperationException(); - var mockProtocol = MockHubProtocol.Throw(exception); - var hubConnection = new HubConnection(new TestConnection(), mockProtocol, null); - await hubConnection.StartAsync(); + var hubConnection = CreateHubConnection(new TestConnection(), protocol: MockHubProtocol.Throw(exception)); + await hubConnection.StartAsync().OrTimeout(); var actualException = - await Assert.ThrowsAsync(async () => await hubConnection.SendAsync("test")); + await Assert.ThrowsAsync(async () => await hubConnection.SendAsync("test").OrTimeout()); Assert.Same(exception, actualException); } [Fact] public async Task ClosedEventRaisedWhenTheClientIsStopped() { - var hubConnection = new HubConnection(new TestConnection(), Mock.Of(), null); + var hubConnection = new HubConnection(() => new TestConnection(), Mock.Of(), null); var closedEventTcs = new TaskCompletionSource(); hubConnection.Closed += e => closedEventTcs.SetResult(e); await hubConnection.StartAsync().OrTimeout(); - await hubConnection.DisposeAsync().OrTimeout(); + await hubConnection.StopAsync().OrTimeout(); Assert.Null(await closedEventTcs.Task); } - [Fact] - public async Task CannotCallInvokeOnNotStartedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - var exception = await Assert.ThrowsAsync( - () => hubConnection.InvokeAsync("test")); - - Assert.Equal("The 'InvokeAsync' method cannot be called before the connection has been started.", exception.Message); - } - - [Fact] - public async Task CannotCallInvokeOnClosedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - await hubConnection.StartAsync(); - await hubConnection.DisposeAsync(); - var exception = await Assert.ThrowsAsync( - () => hubConnection.InvokeAsync("test")); - - Assert.Equal("Connection has been terminated.", exception.Message); - } - - [Fact] - public async Task CannotCallSendOnNotStartedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - var exception = await Assert.ThrowsAsync( - () => hubConnection.SendAsync("test")); - - Assert.Equal("The 'SendAsync' method cannot be called before the connection has been started.", exception.Message); - } - - [Fact] - public async Task CannotCallSendOnClosedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - await hubConnection.StartAsync(); - await hubConnection.DisposeAsync(); - var exception = await Assert.ThrowsAsync(() => hubConnection.SendAsync("test")); - - Assert.Equal("Connection has been terminated.", exception.Message); - } - - [Fact] - public async Task CannotCallStreamOnClosedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - await hubConnection.StartAsync(); - await hubConnection.DisposeAsync(); - var exception = await Assert.ThrowsAsync( - () => hubConnection.StreamAsChannelAsync("test")); - - Assert.Equal("Connection has been terminated.", exception.Message); - } - - [Fact] - public async Task CannotCallStreamOnNotStartedHubConnection() - { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - - var exception = await Assert.ThrowsAsync( - () => hubConnection.StreamAsChannelAsync("test")); - - Assert.Equal("The 'StreamAsChannelAsync' method cannot be called before the connection has been started.", exception.Message); - } - [Fact] public async Task PendingInvocationsAreCancelledWhenConnectionClosesCleanly() { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); + var hubConnection = CreateHubConnection(new TestConnection()); - await hubConnection.StartAsync(); - var invokeTask = hubConnection.InvokeAsync("testMethod"); - await hubConnection.DisposeAsync(); + await hubConnection.StartAsync().OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("testMethod").OrTimeout(); + await hubConnection.StopAsync().OrTimeout(); await Assert.ThrowsAsync(async () => await invokeTask); } [Fact] - public async Task PendingInvocationsAreTerminatedWithExceptionWhenConnectionClosesDueToError() + public async Task PendingInvocationsAreTerminatedWithExceptionWhenTransportCompletesWithError() { - var mockConnection = new Mock(); - mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection()); - mockConnection - .Setup(m => m.DisposeAsync()) - .Returns(Task.FromResult(null)); + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, protocol: Mock.Of()); - var hubConnection = new HubConnection(mockConnection.Object, Mock.Of(), new LoggerFactory()); - - await hubConnection.StartAsync(); - var invokeTask = hubConnection.InvokeAsync("testMethod"); + await hubConnection.StartAsync().OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("testMethod").OrTimeout(); var exception = new InvalidOperationException(); - mockConnection.Raise(m => m.Closed += null, exception); + connection.CompleteFromTransport(exception); var actualException = await Assert.ThrowsAsync(async () => await invokeTask); Assert.Equal(exception, actualException); @@ -196,9 +85,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [Fact] public async Task ConnectionTerminatedIfServerTimeoutIntervalElapsesWithNoMessages() { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - + var hubConnection = CreateHubConnection(new TestConnection()); hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(100); var closeTcs = new TaskCompletionSource(); @@ -211,18 +98,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task OnReceivedAfterTimerDisposedDoesNotThrow() + public async Task PendingInvocationsAreTerminatedIfServerTimeoutIntervalElapsesWithNoMessages() { - var connection = new TestConnection(); - var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory()); - await hubConnection.StartAsync().OrTimeout(); - await hubConnection.DisposeAsync().OrTimeout(); + var hubConnection = CreateHubConnection(new TestConnection()); + hubConnection.ServerTimeout = TimeSpan.FromMilliseconds(500); - // Fire callbacks, they shouldn't fail - foreach (var registration in connection.Callbacks) - { - await registration.InvokeAsync(new byte[0]); - } + await hubConnection.StartAsync().OrTimeout(); + + // Start an invocation (but we won't complete it) + var invokeTask = hubConnection.InvokeAsync("Method").OrTimeout(); + + var exception = await Assert.ThrowsAsync(() => invokeTask); + Assert.Equal("Server timeout (500.00ms) elapsed without receiving a message from the server.", exception.Message); } // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse @@ -231,9 +118,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private HubInvocationMessage _parsed; private Exception _error; - public int ParseCalls { get; private set; } = 0; - public int WriteCalls { get; private set; } = 0; - public static MockHubProtocol ReturnOnParse(HubInvocationMessage parsed) { return new MockHubProtocol @@ -262,7 +146,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public bool TryParseMessages(ReadOnlyMemory input, IInvocationBinder binder, IList messages) { - ParseCalls += 1; if (_error != null) { throw _error; @@ -278,8 +161,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public void WriteMessage(HubMessage message, Stream output) { - WriteCalls += 1; - if (_error != null) { throw _error; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 658a4e16f4..2c62cd6ba6 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.IO.Pipelines; -using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; @@ -13,16 +12,15 @@ using System.Text; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.SignalR.Client.Tests; -using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Moq; using Moq.Protected; using Xunit; -namespace Microsoft.AspNetCore.Client.Tests +namespace Microsoft.AspNetCore.SignalR.Client.Tests { public class LongPollingTransportTests { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 1d684ae82f..833693082f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -4,7 +4,6 @@ using System; using System.IO; using System.IO.Pipelines; -using System.Linq; using System.Net.Http; using System.Net.Http.Headers; using System.Reflection; @@ -17,6 +16,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Moq; using Moq.Protected; using Xunit; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs new file mode 100644 index 0000000000..d39d24af55 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + // Possibly useful as a general-purpose async testing helper? + public class SyncPoint + { + private readonly TaskCompletionSource _atSyncPoint = new TaskCompletionSource(); + private readonly TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(); + + /// + /// Waits for the code-under-test to reach . + /// + /// + public Task WaitForSyncPoint() => _atSyncPoint.Task; + + /// + /// Releases the code-under-test to continue past where it waited for . + /// + public void Continue() => _continueFromSyncPoint.TrySetResult(null); + + /// + /// Used by the code-under-test to wait for the test code to sync up. + /// + /// + /// This code will unblock and then block waiting for to be called. + /// + /// + public Task WaitToContinue() + { + _atSyncPoint.TrySetResult(null); + return _continueFromSyncPoint.Task; + } + + public static Func Create(out SyncPoint syncPoint) + { + var handler = Create(1, out var syncPoints); + syncPoint = syncPoints[0]; + return handler; + } + + /// + /// Creates a re-entrant function that waits for sync points in sequence. + /// + /// The number of sync points to expect + /// The objects that can be used to coordinate the sync point + /// + public static Func Create(int count, out SyncPoint[] syncPoints) + { + // Need to use a local so the closure can capture it. You can't use out vars in a closure. + var localSyncPoints = new SyncPoint[count]; + for (var i = 0; i < count; i += 1) + { + localSyncPoints[i] = new SyncPoint(); + } + + syncPoints = localSyncPoints; + + var counter = 0; + return () => + { + if (counter >= localSyncPoints.Length) + { + return Task.CompletedTask; + } + else + { + var syncPoint = localSyncPoints[counter]; + + counter += 1; + return syncPoint.WaitToContinue(); + } + }; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TaskQueueTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TaskQueueTests.cs deleted file mode 100644 index f6765fc7bc..0000000000 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TaskQueueTests.cs +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets.Client.Internal; -using Xunit; - -namespace Microsoft.AspNetCore.Client.Tests -{ - public class TaskQueueTests - { - [Fact] - public async Task DrainingTaskQueueShutsQueueOff() - { - var queue = new TaskQueue(); - await queue.Enqueue(() => Task.CompletedTask); - await queue.Drain(); - - // This would throw if the task was queued successfully - await queue.Enqueue(() => Task.FromException(new Exception())); - } - - [Fact] - public async Task TaskQueueDoesNotQueueNewTasksIfPreviousTaskFaulted() - { - var exception = new Exception(); - var queue = new TaskQueue(); - var ignore = queue.Enqueue(() => Task.FromException(exception)); - var task = queue.Enqueue(() => Task.CompletedTask); - - var actual = await Assert.ThrowsAsync(async () => await task); - Assert.Same(exception, actual); - } - - [Fact] - public void TaskQueueRunsTasksInSequence() - { - var queue = new TaskQueue(); - int n = 0; - queue.Enqueue(() => - { - n = 1; - return Task.CompletedTask; - }); - - Task task = queue.Enqueue(() => - { - return Task.Delay(100).ContinueWith(t => n = 2); - }); - - task.Wait(); - Assert.Equal(2, n); - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index c2ffc164e4..f090512621 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -2,11 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; +using System.Buffers; using System.IO; +using System.IO.Pipelines; using System.Text; using System.Threading; -using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Connections; @@ -19,83 +19,64 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { internal class TestConnection : IConnection { - private TaskCompletionSource _started = new TaskCompletionSource(); - private TaskCompletionSource _disposed = new TaskCompletionSource(); + private readonly bool _autoNegotiate; + private readonly TaskCompletionSource _started = new TaskCompletionSource(); + private readonly TaskCompletionSource _disposed = new TaskCompletionSource(); - private Channel _sentMessages = Channel.CreateUnbounded(); - private Channel _receivedMessages = Channel.CreateUnbounded(); + private int _disposeCount = 0; - private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource(); - private Task _receiveLoop; - - public event Action Closed; public Task Started => _started.Task; public Task Disposed => _disposed.Task; - public ChannelReader SentMessages => _sentMessages.Reader; - public ChannelWriter ReceivedMessages => _receivedMessages.Writer; - private bool _closed; - private object _closedLock = new object(); + private readonly Func _onStart; + private readonly Func _onDispose; - public List Callbacks { get; } = new List(); + public IDuplexPipe Application { get; } + public IDuplexPipe Transport { get; } public IFeatureCollection Features { get; } = new FeatureCollection(); + public int DisposeCount => _disposeCount; - public TestConnection() + public TestConnection(Func onStart = null, Func onDispose = null, bool autoNegotiate = true) { - _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token); + _autoNegotiate = autoNegotiate; + _onStart = onStart ?? (() => Task.CompletedTask); + _onDispose = onDispose ?? (() => Task.CompletedTask); + + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + Application = pair.Application; + Transport = pair.Transport; + + Application.Input.OnWriterCompleted((ex, _) => Application.Output.Complete(), null); } - public Task AbortAsync(Exception ex) => DisposeCoreAsync(ex); public Task DisposeAsync() => DisposeCoreAsync(); - // TestConnection isn't restartable - public Task StopAsync() => DisposeAsync(); + public Task StartAsync() => StartAsync(TransferFormat.Binary); - private Task DisposeCoreAsync(Exception ex = null) - { - TriggerClosed(ex); - _receiveShutdownToken.Cancel(); - return _receiveLoop; - } - - public async Task SendAsync(byte[] data, CancellationToken cancellationToken) - { - if (!_started.Task.IsCompleted) - { - throw new InvalidOperationException("Connection must be started before SendAsync can be called"); - } - - while (await _sentMessages.Writer.WaitToWriteAsync(cancellationToken)) - { - if (_sentMessages.Writer.TryWrite(data)) - { - return; - } - } - throw new ObjectDisposedException("Unable to send message, underlying channel was closed"); - } - - public Task StartAsync(TransferFormat transferFormat) + public async Task StartAsync(TransferFormat transferFormat) { _started.TrySetResult(null); - return Task.CompletedTask; + + await _onStart(); + + if (_autoNegotiate) + { + // We can't await this as it will block StartAsync which will block + // HubConnection.StartAsync which sends the Handshake in the first place! + _ = ReadHandshakeAndSendResponseAsync(); + } } - public async Task ReadHandshakeAndSendResponseAsync() + public async Task ReadHandshakeAndSendResponseAsync() { - await SentMessages.ReadAsync(); + var s = await ReadSentTextMessageAsync(); var output = new MemoryStream(); HandshakeProtocol.WriteResponseMessage(HandshakeResponseMessage.Empty, output); + await Application.Output.WriteAsync(output.ToArray()); - await _receivedMessages.Writer.WriteAsync(output.ToArray()); - } - - public async Task ReadSentTextMessageAsync() - { - var message = await SentMessages.ReadAsync(); - return Encoding.UTF8.GetString(message); + return s; } public Task ReceiveJsonMessage(object jsonObject) @@ -103,7 +84,51 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var json = JsonConvert.SerializeObject(jsonObject, Formatting.None); var bytes = FormatMessageToArray(Encoding.UTF8.GetBytes(json)); - return _receivedMessages.Writer.WriteAsync(bytes).AsTask(); + return Application.Output.WriteAsync(bytes).AsTask(); + } + + public async Task ReadSentTextMessageAsync() + { + // Read a single text message from the Application Input pipe + while (true) + { + var result = await Application.Input.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + + try + { + if (TextMessageParser.TryParseMessage(ref buffer, out var payload)) + { + consumed = buffer.Start; + return Encoding.UTF8.GetString(payload.ToArray()); + } + else if (result.IsCompleted) + { + throw new InvalidOperationException("Out of data!"); + } + } + finally + { + Application.Input.AdvanceTo(consumed); + } + } + } + + public void CompleteFromTransport(Exception ex = null) + { + Application.Output.Complete(ex); + } + + private async Task DisposeCoreAsync(Exception ex = null) + { + Interlocked.Increment(ref _disposeCount); + _disposed.TrySetResult(null); + await _onDispose(); + + // Simulate HttpConnection's behavior by Completing the Transport pipe. + Transport.Input.Complete(); + Transport.Output.Complete(); } private byte[] FormatMessageToArray(byte[] message) @@ -113,99 +138,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests TextMessageFormatter.WriteRecordSeparator(output); return output.ToArray(); } - - private async Task ReceiveLoopAsync(CancellationToken token) - { - try - { - while (!token.IsCancellationRequested) - { - while (await _receivedMessages.Reader.WaitToReadAsync(token)) - { - while (_receivedMessages.Reader.TryRead(out var message)) - { - ReceiveCallback[] callbackCopies; - lock (Callbacks) - { - callbackCopies = Callbacks.ToArray(); - } - - foreach (var callback in callbackCopies) - { - await callback.InvokeAsync(message); - } - } - } - } - TriggerClosed(); - } - catch (OperationCanceledException) - { - // Do nothing, we were just asked to shut down. - TriggerClosed(); - } - catch (Exception ex) - { - TriggerClosed(ex); - } - } - - private void TriggerClosed(Exception ex = null) - { - lock (_closedLock) - { - if (!_closed) - { - _closed = true; - Closed?.Invoke(ex); - } - } - } - - public IDisposable OnReceived(Func callback, object state) - { - var receiveCallBack = new ReceiveCallback(callback, state); - lock (Callbacks) - { - Callbacks.Add(receiveCallBack); - } - return new Subscription(receiveCallBack, Callbacks); - } - - public class ReceiveCallback - { - private readonly Func _callback; - private readonly object _state; - - public ReceiveCallback(Func callback, object state) - { - _callback = callback; - _state = state; - } - - public Task InvokeAsync(byte[] data) - { - return _callback(data, _state); - } - } - - private class Subscription : IDisposable - { - private readonly ReceiveCallback _callback; - private readonly List _callbacks; - public Subscription(ReceiveCallback callback, List callbacks) - { - _callback = callback; - _callbacks = callbacks; - } - - public void Dispose() - { - lock (_callbacks) - { - _callbacks.Remove(_callback); - } - } - } } } + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs index 348980bfa1..b255f00906 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await _handler(request, cancellationToken); } - public static HttpMessageHandler CreateDefault() + public static TestHttpMessageHandler CreateDefault() { var testHttpMessageHandler = new TestHttpMessageHandler(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs index 1c35dedb6f..4417be8593 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestTransport.cs @@ -14,6 +14,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public TransferFormat? Format { get; } public IDuplexPipe Application { get; private set; } + public Task Receiving { get; private set; } public TestTransport(Func onTransportStop = null, Func onTransportStart = null, TransferFormat transferFormat = TransferFormat.Text) { @@ -22,20 +23,51 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Format = transferFormat; } - public Task StartAsync(Uri url, IDuplexPipe application, TransferFormat transferFormat, IConnection connection) + public async Task StartAsync(Uri url, IDuplexPipe application, TransferFormat transferFormat, IConnection connection) { if ((Format & transferFormat) == 0) { throw new InvalidOperationException($"The '{transferFormat}' transfer format is not supported by this transport."); } Application = application; - return _startHandler(); + await _startHandler(); + + // Start a loop to read from the pipe + Receiving = ReceiveLoop(); + async Task ReceiveLoop() + { + while (true) + { + var result = await Application.Input.ReadAsync(); + if (result.IsCompleted) + { + break; + } + else if (result.IsCanceled) + { + // This is useful for detecting that the connection tried to gracefully terminate. + // If the Receiving task is faulted/cancelled, it means StopAsync was the thing that + // actually terminated the connection (not ideal, we want the transport pipe to + // shut down gracefully) + throw new OperationCanceledException(); + } + + Application.Input.AdvanceTo(result.Buffer.End); + } + + // Call the transport stop handler + await _stopHandler(); + + // Complete our end of the pipe + Application.Output.Complete(); + Application.Input.Complete(); + } } - public async Task StopAsync() + public Task StopAsync() { - await _stopHandler(); - Application.Output.Complete(); + Application.Input.CancelPendingRead(); + return Receiving; } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs index bf4afa8979..aae0f20c0c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ChannelExtensions.cs @@ -3,24 +3,35 @@ using System.Collections.Generic; using System.Threading.Tasks; +using Xunit; namespace System.Threading.Channels { public static class ChannelExtensions { - public static async Task> ReadAllAsync(this ChannelReader channel) + public static async Task> ReadAllAsync(this ChannelReader channel, bool suppressExceptions = false) { var list = new List(); - while (await channel.WaitToReadAsync()) + try { - while (channel.TryRead(out var item)) + while (await channel.WaitToReadAsync()) { - list.Add(item); + while (channel.TryRead(out var item)) + { + list.Add(item); + } + } + + // Manifest any error from channel.Completion (which should be completed now) + if (!suppressExceptions) + { + await channel.Completion; } } - - // Manifest any error from channel.Completion (which should be completed now) - await channel.Completion; + catch (Exception) when (suppressExceptions) + { + // Suppress the exception + } return list; } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj index 4e4b124472..cd73175dce 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/Microsoft.AspNetCore.SignalR.Tests.Utils.csproj @@ -21,4 +21,8 @@ + + + + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeCompletionExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeCompletionExtensions.cs new file mode 100644 index 0000000000..4d8a5aeba3 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeCompletionExtensions.cs @@ -0,0 +1,44 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace System.IO.Pipelines +{ + public static class PipeCompletionExtensions + { + public static Task WaitForWriterToComplete(this PipeReader reader) + { + var tcs = new TaskCompletionSource(); + reader.OnWriterCompleted((ex, state) => + { + if (ex != null) + { + ((TaskCompletionSource)state).TrySetException(ex); + } + else + { + ((TaskCompletionSource)state).TrySetResult(null); + } + }, tcs); + return tcs.Task; + } + + public static Task WaitForReaderToComplete(this PipeWriter writer) + { + var tcs = new TaskCompletionSource(); + writer.OnReaderCompleted((ex, state) => + { + if (ex != null) + { + ((TaskCompletionSource)state).TrySetException(ex); + } + else + { + ((TaskCompletionSource)state).TrySetResult(null); + } + }, tcs); + return tcs.Task; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs index 11b0da30b4..bb68a0f7c1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/PipeReaderExtensions.cs @@ -61,6 +61,7 @@ namespace System.IO.Pipelines pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); continue; } + pipeReader.AdvanceTo(result.Buffer.GetPosition(numBytes)); break; } @@ -72,19 +73,14 @@ namespace System.IO.Pipelines { var result = await pipeReader.ReadAsync(); - try + if (result.IsCompleted) { - if (result.IsCompleted) - { - return result.Buffer.ToArray(); - } - } - finally - { - // Consume nothing, just wait for everything - pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); + return result.Buffer.ToArray(); } + + // Consume nothing, just wait for everything + pipeReader.AdvanceTo(result.Buffer.Start, result.Buffer.End); } } } -} +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs index 822772e244..805bedffaa 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/ServerFixture.cs @@ -33,12 +33,25 @@ namespace Microsoft.AspNetCore.SignalR.Tests public string Url { get; private set; } - public ServerFixture() + public ServerFixture() : this(loggerFactory: null) + { + } + + public ServerFixture(ILoggerFactory loggerFactory) { _logSinkProvider = new LogSinkProvider(); - var testLog = AssemblyTestLog.ForAssembly(typeof(TStartup).Assembly); - _logToken = testLog.StartTestLog(null, $"{nameof(ServerFixture)}_{typeof(TStartup).Name}", out _loggerFactory, "ServerFixture"); + if (loggerFactory == null) + { + var testLog = AssemblyTestLog.ForAssembly(typeof(TStartup).Assembly); + _logToken = testLog.StartTestLog(null, $"{nameof(ServerFixture)}_{typeof(TStartup).Name}", + out _loggerFactory, "ServerFixture"); + } + else + { + _loggerFactory = loggerFactory; + } + _logger = _loggerFactory.CreateLogger>(); StartServer(); @@ -51,6 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests _host = new WebHostBuilder() .ConfigureLogging(builder => builder + .SetMinimumLevel(LogLevel.Debug) .AddProvider(_logSinkProvider) .AddProvider(new ForwardingLoggerProvider(_loggerFactory))) .UseStartup(typeof(TStartup)) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs index 65f252f56a..2ff2279245 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TaskExtensions.cs @@ -35,6 +35,12 @@ namespace System.Threading.Tasks await task; } + public static Task OrTimeout(this ValueTask task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) => + OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); + + public static Task OrTimeout(this ValueTask task, TimeSpan timeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) => + task.AsTask().OrTimeout(timeout, memberName, filePath, lineNumber); + public static Task OrTimeout(this Task task, int milliseconds = DefaultTimeout, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int? lineNumber = null) { return OrTimeout(task, new TimeSpan(0, 0, 0, 0, milliseconds), memberName, filePath, lineNumber); @@ -61,7 +67,7 @@ namespace System.Threading.Tasks public static async Task OrThrowIfOtherFails(this Task task, Task otherTask) { var completed = await Task.WhenAny(task, otherTask); - if(completed == otherTask && otherTask.IsFaulted) + if (completed == otherTask && otherTask.IsFaulted) { // Manifest the exception otherTask.GetAwaiter().GetResult(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs index 02c0f7ae27..2904a6cb89 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs @@ -5,6 +5,7 @@ using System; using System.Net.Http; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; using Xunit; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 92e64fdfe8..98e88fe958 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -119,7 +119,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [ConditionalFact] [WebSocketsSupportedCondition] - public async Task HTTPRequestsNotSentWhenWebSocketsTransportRequested() + public async Task HttpRequestsNotSentWhenWebSocketsTransportRequested() { using (StartLog(out var loggerFactory)) { @@ -136,19 +136,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests try { - var receiveTcs = new TaskCompletionSource(); - connection.OnReceived((data, state) => - { - var tcs = (TaskCompletionSource)state; - tcs.TrySetResult(data); - return Task.CompletedTask; - }, receiveTcs); - var message = new byte[] { 42 }; await connection.StartAsync(TransferFormat.Binary).OrTimeout(); - await connection.SendAsync(message).OrTimeout(); - var receivedData = await receiveTcs.Task.OrTimeout(); + await connection.Transport.Output.WriteAsync(message).OrTimeout(); + + var receivedData = await connection.Transport.Input.ReadAllAsync(); Assert.Equal(message, receivedData); } catch (Exception ex) @@ -179,28 +172,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connection = new HttpConnection(new Uri(url), transportType, loggerFactory); try { - var closeTcs = new TaskCompletionSource(); - connection.Closed += e => - { - if (e != null) - { - closeTcs.SetException(e); - } - else - { - closeTcs.SetResult(null); - } - }; - - var receiveTcs = new TaskCompletionSource(); - connection.OnReceived((data, state) => - { - logger.LogInformation("Received {length} byte message", data.Length); - var tcs = (TaskCompletionSource)state; - tcs.TrySetResult(Encoding.UTF8.GetString(data)); - return Task.CompletedTask; - }, receiveTcs); - logger.LogInformation("Starting connection to {url}", url); await connection.StartAsync(requestedTransferFormat).OrTimeout(); logger.LogInformation("Started connection to {url}", url); @@ -210,7 +181,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Sending {length} byte message", bytes.Length); try { - await connection.SendAsync(bytes).OrTimeout(); + await connection.Transport.Output.WriteAsync(bytes).OrTimeout(); } catch (OperationCanceledException) { @@ -220,12 +191,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests // Our solution to this is to just catch OperationCanceledException from the sent message if the race happens // because we know the send went through, and its safe to check the response. } - logger.LogInformation("Sent message", bytes.Length); + logger.LogInformation("Sent message"); logger.LogInformation("Receiving message"); - Assert.Equal(message, await receiveTcs.Task.OrTimeout()); + Assert.Equal(message, Encoding.UTF8.GetString(await connection.Transport.Input.ReadAllAsync())); logger.LogInformation("Completed receive"); - await closeTcs.Task.OrTimeout(); } catch (Exception ex) { @@ -264,27 +234,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests try { - var receiveTcs = new TaskCompletionSource(); - connection.OnReceived((data, state) => - { - logger.LogInformation("Received {length} byte message", data.Length); - var tcs = (TaskCompletionSource)state; - tcs.TrySetResult(data); - return Task.CompletedTask; - }, receiveTcs); - logger.LogInformation("Starting connection to {url}", url); await connection.StartAsync(TransferFormat.Binary).OrTimeout(); logger.LogInformation("Started connection to {url}", url); var bytes = Encoding.UTF8.GetBytes(message); logger.LogInformation("Sending {length} byte message", bytes.Length); - await connection.SendAsync(bytes).OrTimeout(); - logger.LogInformation("Sent message", bytes.Length); + await connection.Transport.Output.WriteAsync(bytes).OrTimeout(); + logger.LogInformation("Sent message"); logger.LogInformation("Receiving message"); // Big timeout here because it can take a while to receive all the bytes - var receivedData = await receiveTcs.Task.OrTimeout(TimeSpan.FromSeconds(30)); + var receivedData = await connection.Transport.Input.ReadAllAsync(); Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); logger.LogInformation("Completed receive"); } @@ -406,26 +367,26 @@ namespace Microsoft.AspNetCore.SignalR.Tests private class FakeTransport : ITransport { - public string prevConnectionId = null; - private int tries = 0; + private int _tries; + private string _prevConnectionId = null; private IDuplexPipe _application; public Task StartAsync(Uri url, IDuplexPipe application, TransferFormat transferFormat, IConnection connection) { _application = application; - tries++; - Assert.True(QueryHelpers.ParseQuery(url.Query.ToString()).TryGetValue("id", out var id)); - if (prevConnectionId == null) + _tries++; + Assert.True(QueryHelpers.ParseQuery(url.Query).TryGetValue("id", out var id)); + if (_prevConnectionId == null) { - prevConnectionId = id; + _prevConnectionId = id; } else { - Assert.True(prevConnectionId != id); - prevConnectionId = id; + Assert.True(_prevConnectionId != id); + _prevConnectionId = id; } - if (tries < 3) + if (_tries < 3) { throw new Exception(); } @@ -462,11 +423,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { foreach (var transport in TransportTypes) { - yield return new object[] { transport[0], TransferFormat.Text }; + yield return new[] { transport[0], TransferFormat.Text }; if ((TransportType)transport[0] != TransportType.ServerSentEvents) { - yield return new object[] { transport[0], TransferFormat.Binary }; + yield return new[] { transport[0], TransferFormat.Binary }; } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index d718ac7c9c..039d37ba85 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -17,6 +17,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; +using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging.Testing; using Moq;