From 49c01eefecf1dbd75e5536aa803d689390c6770a Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Fri, 24 May 2019 10:32:41 -0700 Subject: [PATCH] Log clearer handshake failures in SignalR .NET client (#10433) --- .../Client.Core/src/HubConnection.Log.cs | 26 +- .../csharp/Client.Core/src/HubConnection.cs | 66 +- ...HttpConnectionTests.ConnectionLifecycle.cs | 41 +- .../HubConnectionTests.ConnectionLifecycle.cs | 290 +-- .../UnitTests/HubConnectionTests.Helpers.cs | 10 +- .../UnitTests/HubConnectionTests.Protocol.cs | 80 +- .../UnitTests/HubConnectionTests.Reconnect.cs | 1682 +++++++++-------- 7 files changed, 1179 insertions(+), 1016 deletions(-) diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 7061e276b1..b37239869a 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -103,7 +103,7 @@ namespace Microsoft.AspNetCore.SignalR.Client LoggerMessage.Define(LogLevel.Error, new EventId(34, "ErrorInvokingClientSideMethod"), "Invoking client side method '{MethodName}' failed."); private static readonly Action _errorProcessingHandshakeResponse = - LoggerMessage.Define(LogLevel.Error, new EventId(35, "ErrorReceivingHandshakeResponse"), "Error processing the handshake response."); + LoggerMessage.Define(LogLevel.Error, new EventId(35, "ErrorReceivingHandshakeResponse"), "The underlying connection closed while processing the handshake response. See exception for details."); private static readonly Action _handshakeServerError = LoggerMessage.Define(LogLevel.Error, new EventId(36, "HandshakeServerError"), "Server returned handshake error: {Error}"); @@ -240,6 +240,15 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _attemptingStateTransition = LoggerMessage.Define(LogLevel.Trace, new EventId(80, "AttemptingStateTransition"), "The HubConnection is attempting to transition from the {ExpectedState} state to the {NewState} state."); + private static readonly Action _errorInvalidHandshakeResponse = + LoggerMessage.Define(LogLevel.Error, new EventId(81, "ErrorInvalidHandshakeResponse"), "Received an invalid handshake response."); + + private static readonly Action _errorHandshakeTimedOut = + LoggerMessage.Define(LogLevel.Error, new EventId(82, "ErrorHandshakeTimedOut"), "The handshake timed out after {HandshakeTimeoutSeconds} seconds."); + + private static readonly Action _errorHandshakeCanceled = + LoggerMessage.Define(LogLevel.Error, new EventId(83, "ErrorHandshakeCanceled"), "The handshake was canceled by the client."); + public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count) { _preparingNonBlockingInvocation(logger, target, count, null); @@ -640,6 +649,21 @@ namespace Microsoft.AspNetCore.SignalR.Client { _attemptingStateTransition(logger, expectedState, newState, null); } + + public static void ErrorInvalidHandshakeResponse(ILogger logger, Exception exception) + { + _errorInvalidHandshakeResponse(logger, exception); + } + + public static void ErrorHandshakeTimedOut(ILogger logger, TimeSpan handshakeTimeout, Exception exception) + { + _errorHandshakeTimedOut(logger, handshakeTimeout.TotalSeconds, exception); + } + + public static void ErrorHandshakeCanceled(ILogger logger, Exception exception) + { + _errorHandshakeCanceled(logger, exception); + } } } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 13ddb56792..0cde5b0dcc 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -253,9 +253,9 @@ namespace Microsoft.AspNetCore.SignalR.Client throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started while {nameof(StopAsync)} is running."); } - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, _state.StopCts.Token)) + using (CreateLinkedToken(cancellationToken, _state.StopCts.Token, out var linkedToken)) { - await StartAsyncCore(cancellationToken); + await StartAsyncCore(linkedToken); } _state.ChangeState(HubConnectionState.Connecting, HubConnectionState.Connected); @@ -1018,20 +1018,23 @@ namespace Microsoft.AspNetCore.SignalR.Client if (sendHandshakeResult.IsCompleted) { // The other side disconnected - throw new InvalidOperationException("The server disconnected before the handshake was completed"); + var ex = new IOException("The server disconnected before the handshake could be started."); + Log.ErrorReceivingHandshakeResponse(_logger, ex); + throw ex; } var input = startingConnectionState.Connection.Transport.Input; + using var handshakeCts = new CancellationTokenSource(HandshakeTimeout); + try { - using (var handshakeCts = new CancellationTokenSource(HandshakeTimeout)) // cancellationToken already contains _state.StopCts.Token, so we don't have to link it again - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, handshakeCts.Token)) + using (CreateLinkedToken(cancellationToken, handshakeCts.Token, out var linkedToken)) { while (true) { - var result = await input.ReadAsync(cts.Token); + var result = await input.ReadAsync(linkedToken); var buffer = result.Buffer; var consumed = buffer.Start; @@ -1057,6 +1060,7 @@ namespace Microsoft.AspNetCore.SignalR.Client $"Unable to complete handshake with the server due to an error: {message.Error}"); } + Log.HandshakeComplete(_logger); break; } } @@ -1075,17 +1079,34 @@ namespace Microsoft.AspNetCore.SignalR.Client } } } + catch (HubException) + { + // This was already logged as a HandshakeServerError + throw; + } + catch (InvalidDataException ex) + { + Log.ErrorInvalidHandshakeResponse(_logger, ex); + throw; + } + catch (OperationCanceledException ex) + { + if (handshakeCts.IsCancellationRequested) + { + Log.ErrorHandshakeTimedOut(_logger, HandshakeTimeout, ex); + } + else + { + Log.ErrorHandshakeCanceled(_logger, ex); + } - // shutdown if we're unable to read handshake - // Ignore HubException because we throw it when we receive a handshake response with an error - // And because we already have the error, we don't need to log that the handshake failed - catch (Exception ex) when (!(ex is HubException)) + throw; + } + catch (Exception ex) { Log.ErrorReceivingHandshakeResponse(_logger, ex); throw; } - - Log.HandshakeComplete(_logger); } private async Task ReceiveLoop(ConnectionState connectionState) @@ -1485,6 +1506,26 @@ namespace Microsoft.AspNetCore.SignalR.Client } } + private IDisposable CreateLinkedToken(CancellationToken token1, CancellationToken token2, out CancellationToken linkedToken) + { + if (!token1.CanBeCanceled) + { + linkedToken = token2; + return null; + } + else if (!token2.CanBeCanceled) + { + linkedToken = token1; + return null; + } + else + { + var cts = CancellationTokenSource.CreateLinkedTokenSource(token1, token2); + linkedToken = cts.Token; + return cts; + } + } + // Debug.Assert plays havoc with Unit Tests. But I want something that I can "assert" only in Debug builds. [Conditional("DEBUG")] private static void SafeAssert(bool condition, string message, [CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) @@ -1495,7 +1536,6 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private class Subscription : IDisposable { private readonly InvocationHandler _handler; diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs index c141ffbe8a..2088ef3927 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HttpConnectionTests.ConnectionLifecycle.cs @@ -43,13 +43,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests CreateConnection(loggerFactory: LoggerFactory, transport: new TestTransport(onTransportStart: SyncPoint.Create(out var syncPoint))), async (connection) => { - var firstStart = connection.StartAsync(TransferFormat.Text).OrTimeout(); - await syncPoint.WaitForSyncPoint(); - var secondStart = connection.StartAsync(TransferFormat.Text).OrTimeout(); + var firstStart = connection.StartAsync(TransferFormat.Text); + await syncPoint.WaitForSyncPoint().OrTimeout(); + var secondStart = connection.StartAsync(TransferFormat.Text); syncPoint.Continue(); - await firstStart; - await secondStart; + await firstStart.OrTimeout(); + await secondStart.OrTimeout(); }); } } @@ -64,10 +64,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests async (connection) => { await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.DisposeAsync(); + await connection.DisposeAsync().OrTimeout(); var exception = await Assert.ThrowsAsync( - async () => await connection.StartAsync(TransferFormat.Text).OrTimeout()); + async () => await connection.StartAsync(TransferFormat.Text)).OrTimeout(); Assert.Equal(nameof(HttpConnection), exception.ObjectName); }); @@ -121,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests async (connection) => { Assert.Equal(0, startCounter); - await connection.StartAsync(TransferFormat.Text); + await connection.StartAsync(TransferFormat.Text).OrTimeout(); Assert.Equal(passThreshold, startCounter); }); } @@ -154,7 +154,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests transport: new TestTransport(onTransportStart: OnTransportStart)), async (connection) => { - var ex = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text)); + var ex = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text)).OrTimeout(); Assert.Equal("Unable to connect to the server with any of the available transports. " + "(WebSockets failed: Transport failed to start) (ServerSentEvents failed: Transport failed to start) (LongPolling failed: Transport failed to start)", ex.Message); @@ -179,8 +179,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests CreateConnection(loggerFactory: LoggerFactory), async (connection) => { - await connection.DisposeAsync(); - + await connection.DisposeAsync().OrTimeout(); }); } } @@ -203,7 +202,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await transportStart.WaitForSyncPoint().OrTimeout(); // While the transport is starting, dispose the connection - var disposeTask = connection.DisposeAsync().OrTimeout(); + var disposeTask = connection.DisposeAsync(); transportStart.Continue(); // We need to release StartAsync, because Dispose waits for it. // Wait for start to finish, as that has to finish before the transport will be stopped. @@ -214,7 +213,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests transportStop.Continue(); // Dispose should finish - await disposeTask; + await disposeTask.OrTimeout(); }); } } @@ -234,14 +233,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.StartAsync(TransferFormat.Text).OrTimeout(); // Dispose the connection - var stopTask = connection.DisposeAsync().OrTimeout(); + var stopTask = connection.DisposeAsync(); // Once the transport starts shutting down - await transportStop.WaitForSyncPoint(); + await transportStop.WaitForSyncPoint().OrTimeout(); Assert.False(stopTask.IsCompleted); // Start disposing again, and then let the first dispose continue - var disposeTask = connection.DisposeAsync().OrTimeout(); + var disposeTask = connection.DisposeAsync(); transportStop.Continue(); // Wait for the tasks to complete @@ -249,7 +248,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await disposeTask.OrTimeout(); // We should be disposed and thus unable to restart. - await AssertDisposedAsync(connection); + await AssertDisposedAsync(connection).OrTimeout(); }); } } @@ -316,7 +315,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.Transport.Output.WriteAsync(new byte[] { 0x42 }).OrTimeout(); // We should get the exception in the transport input completion. - await Assert.ThrowsAsync(() => connection.Transport.Input.WaitForWriterToComplete()); + await Assert.ThrowsAsync(() => connection.Transport.Input.WaitForWriterToComplete()).OrTimeout(); }); } } @@ -371,11 +370,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests CreateConnection(httpHandler, loggerFactory: LoggerFactory, transport: sse), async (connection) => { - var startTask = connection.StartAsync(TransferFormat.Text).OrTimeout(); + var startTask = connection.StartAsync(TransferFormat.Text); Assert.False(connectResponseTcs.Task.IsCompleted); Assert.False(startTask.IsCompleted); connectResponseTcs.TrySetResult(null); - await startTask; + await startTask.OrTimeout(); }); } } @@ -383,7 +382,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private static async Task AssertDisposedAsync(HttpConnection connection) { var exception = - await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text).OrTimeout()); + await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text)); Assert.Equal(nameof(HttpConnection), exception.ObjectName); } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs index f60b124d59..f48742474b 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Testing; using Newtonsoft.Json.Linq; using Xunit; @@ -18,7 +19,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HubConnectionTests { - public class ConnectionLifecycle + public class ConnectionLifecycle : VerifiableLoggedTest { // This tactic (using names and a dictionary) allows non-serializable data (like a Func) to be used in a theory AND get it to show in the new hierarchical view in Test Explorer as separate tests you can run individually. private static readonly IDictionary> MethodsThatRequireActiveConnection = new Dictionary>() @@ -30,27 +31,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public static IEnumerable MethodsNamesThatRequireActiveConnection => MethodsThatRequireActiveConnection.Keys.Select(k => new object[] { k }); - private HubConnection CreateHubConnection(TestConnection testConnection) - { - var builder = new HubConnectionBuilder(); - - var delegateConnectionFactory = new DelegateConnectionFactory( - testConnection.StartAsync, - connection => ((TestConnection)connection).DisposeAsync()); - builder.Services.AddSingleton(delegateConnectionFactory); - - return builder.Build(); - } - - private HubConnection CreateHubConnection(Func> connectDelegate, Func disposeDelegate) - { - var builder = new HubConnectionBuilder(); - - var delegateConnectionFactory = new DelegateConnectionFactory(connectDelegate, disposeDelegate); - builder.Services.AddSingleton(delegateConnectionFactory); - - return builder.Build(); - } [Fact] public async Task StartAsyncStartsTheUnderlyingConnection() @@ -60,7 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { Assert.Equal(HubConnectionState.Disconnected, connection.State); - await connection.StartAsync(); + await connection.StartAsync().OrTimeout(); Assert.True(testConnection.Started.IsCompleted); Assert.Equal(HubConnectionState.Connected, connection.State); }); @@ -73,22 +53,22 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testConnection = new TestConnection(onStart: SyncPoint.Create(out var syncPoint)); await AsyncUsing(CreateHubConnection(testConnection), async connection => { - var firstStart = connection.StartAsync().OrTimeout(); + var firstStart = connection.StartAsync(); Assert.False(firstStart.IsCompleted); // Wait for us to be in IConnectionFactory.ConnectAsync - await syncPoint.WaitForSyncPoint(); + await syncPoint.WaitForSyncPoint().OrTimeout(); // Try starting again - var secondStart = connection.StartAsync().OrTimeout(); + var secondStart = connection.StartAsync(); Assert.False(secondStart.IsCompleted); // Release the sync point syncPoint.Continue(); // The first start should finish fine, but the second throws an InvalidOperationException. - await firstStart; - await Assert.ThrowsAsync(() => secondStart); + await firstStart.OrTimeout(); + await Assert.ThrowsAsync(() => secondStart).OrTimeout(); }); } @@ -108,7 +88,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return ((TestConnection)connection).DisposeAsync(); } - await AsyncUsing(CreateHubConnection(ConnectionFactory, DisposeAsync), async connection => + var builder = new HubConnectionBuilder(); + var delegateConnectionFactory = new DelegateConnectionFactory(ConnectionFactory, DisposeAsync); + builder.Services.AddSingleton(delegateConnectionFactory); + + await AsyncUsing(builder.Build(), async connection => { Assert.Equal(HubConnectionState.Disconnected, connection.State); @@ -139,47 +123,74 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Task DisposeAsync(ConnectionContext connection) => ((TestConnection)connection).DisposeAsync(); - await AsyncUsing(CreateHubConnection(ConnectionFactory, DisposeAsync), async connection => + var builder = new HubConnectionBuilder(); + var delegateConnectionFactory = new DelegateConnectionFactory(ConnectionFactory, DisposeAsync); + builder.Services.AddSingleton(delegateConnectionFactory); + + await AsyncUsing(builder.Build(), async connection => { await connection.StartAsync().OrTimeout(); Assert.Equal(1, createCount); - var stopTask = connection.StopAsync().OrTimeout(); + var stopTask = connection.StopAsync(); // Wait to hit DisposeAsync on TestConnection (which should be after StopAsync has cleared the connection state) await syncPoint.WaitForSyncPoint().OrTimeout(); // We should not yet be able to start now because StopAsync hasn't completed Assert.False(stopTask.IsCompleted); - var startTask = connection.StartAsync().OrTimeout(); + var startTask = connection.StartAsync(); Assert.False(stopTask.IsCompleted); // When we release the sync point, the StopAsync task will finish syncPoint.Continue(); - await stopTask; + await stopTask.OrTimeout(); // Which will then allow StartAsync to finish. - await startTask; + await startTask.OrTimeout(); }); } [Fact] public async Task StartAsyncWithFailedHandshakeCanBeStopped() { - var testConnection = new TestConnection(autoHandshake: false); - await AsyncUsing(CreateHubConnection(testConnection), async connection => - { - testConnection.Transport.Input.Complete(); - try - { - await connection.StartAsync(); - } - catch - { } + var handshakeConnectionErrorLogged = false; - await connection.StopAsync(); - Assert.True(testConnection.Started.IsCompleted); - }); + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == typeof(HubConnection).FullName) + { + if (writeContext.EventId.Name == "ErrorReceivingHandshakeResponse") + { + handshakeConnectionErrorLogged = true; + return true; + } + + return writeContext.EventId.Name == "ErrorStartingConnection"; + } + + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var testConnection = new TestConnection(autoHandshake: false); + await AsyncUsing(CreateHubConnection(testConnection, loggerFactory: LoggerFactory), async connection => + { + testConnection.Transport.Input.Complete(); + try + { + await connection.StartAsync().OrTimeout(); + } + catch + { } + + await connection.StopAsync().OrTimeout(); + Assert.True(testConnection.Started.IsCompleted); + }); + } + + Assert.True(handshakeConnectionErrorLogged, "The connnection error during the handshake wasn't logged."); } [Theory] @@ -191,7 +202,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testConnection = new TestConnection(); await AsyncUsing(CreateHubConnection(testConnection), async connection => { - var ex = await Assert.ThrowsAsync(() => method(connection)); + var ex = await Assert.ThrowsAsync(() => method(connection)).OrTimeout(); Assert.Equal($"The '{name}' method cannot be called if the connection is not active", ex.Message); }); } @@ -207,27 +218,27 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await AsyncUsing(CreateHubConnection(testConnection), async connection => { // Start, and wait for the sync point to be hit - var startTask = connection.StartAsync().OrTimeout(); + var startTask = connection.StartAsync(); Assert.False(startTask.IsCompleted); - await syncPoint.WaitForSyncPoint(); + await syncPoint.WaitForSyncPoint().OrTimeout(); // Run the method, but it will be waiting for the lock - var targetTask = method(connection).OrTimeout(); + var targetTask = method(connection); // Release the SyncPoint syncPoint.Continue(); // Wait for start to finish - await startTask; + await startTask.OrTimeout(); // We need some special logic to ensure InvokeAsync completes. if (string.Equals(name, nameof(HubConnection.InvokeCoreAsync))) { - await ForceLastInvocationToComplete(testConnection); + await ForceLastInvocationToComplete(testConnection).OrTimeout(); } // Wait for the method to complete. - await targetTask; + await targetTask.OrTimeout(); }); } @@ -239,9 +250,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await AsyncUsing(CreateHubConnection(testConnection), async connection => { // Start, and wait for the sync point to be hit - var startTask = connection.StartAsync().OrTimeout(); + var startTask = connection.StartAsync(); Assert.False(startTask.IsCompleted); - await syncPoint.WaitForSyncPoint(); + await syncPoint.WaitForSyncPoint().OrTimeout(); Assert.Equal(HubConnectionState.Connecting, connection.State); @@ -249,7 +260,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests syncPoint.Continue(); // Wait for start to finish - await startTask; + await startTask.OrTimeout(); Assert.Equal(HubConnectionState.Connected, connection.State); }); @@ -349,7 +360,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await closed.Task.OrTimeout(); // We should be stopped now - var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo")).OrTimeout(); Assert.Equal($"The '{nameof(HubConnection.SendCoreAsync)}' method cannot be called if the connection is not active", ex.Message); }); } @@ -374,20 +385,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.True(testConnection.Started.IsCompleted); // Start shutting down and complete the transport side - var stopTask = connection.StopAsync().OrTimeout(); + var stopTask = connection.StopAsync(); testConnection.CompleteFromTransport(); // Wait for the connection to close. await testConnectionClosed.Task.OrTimeout(); // The stop should be completed. - await stopTask; + await stopTask.OrTimeout(); // The HubConnection should now be closed. await connectionClosed.Task.OrTimeout(); // We should be stopped now - var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo")).OrTimeout(); Assert.Equal($"The '{nameof(HubConnection.SendCoreAsync)}' method cannot be called if the connection is not active", ex.Message); await testConnection.Disposed.OrTimeout(); @@ -416,16 +427,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests testConnection.CompleteFromTransport(); // Start stopping manually (these can't be synchronized by a Sync Point because the transport is disposed outside the lock) - var stopTask = connection.StopAsync().OrTimeout(); + var stopTask = connection.StopAsync(); await testConnection.Disposed.OrTimeout(); // Wait for the stop task to complete and the closed event to fire - await stopTask; + await stopTask.OrTimeout(); await connectionClosed.Task.OrTimeout(); // We should be stopped now - var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo").OrTimeout()); + var ex = await Assert.ThrowsAsync(() => connection.SendAsync("Foo")).OrTimeout(); Assert.Equal($"The '{nameof(HubConnection.SendCoreAsync)}' method cannot be called if the connection is not active", ex.Message); }); } @@ -444,89 +455,141 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // Stop and invoke the method. These two aren't synchronizable via a Sync Point any more because the transport is disposed // outside the lock :( - var disposeTask = connection.StopAsync().OrTimeout(); + var disposeTask = connection.StopAsync(); // Wait to hit DisposeAsync on TestConnection (which should be after StopAsync has cleared the connection state) await syncPoint.WaitForSyncPoint().OrTimeout(); - var targetTask = method(connection).OrTimeout(); + var targetTask = method(connection); // Release the sync point syncPoint.Continue(); // Wait for the method to complete, with an expected error. - var ex = await Assert.ThrowsAsync(() => targetTask); + var ex = await Assert.ThrowsAsync(() => targetTask).OrTimeout(); Assert.Equal($"The '{methodName}' method cannot be called if the connection is not active", ex.Message); - await disposeTask; + await disposeTask.OrTimeout(); }); } [Fact] public async Task ClientTimesoutWhenHandshakeResponseTakesTooLong() { - var connection = new TestConnection(autoHandshake: false); - var hubConnection = CreateHubConnection(connection); - try - { - hubConnection.HandshakeTimeout = TimeSpan.FromMilliseconds(1); + var handshakeTimeoutLogged = false; - await Assert.ThrowsAsync(() => hubConnection.StartAsync().OrTimeout()); - Assert.Equal(HubConnectionState.Disconnected, hubConnection.State); - } - finally + bool ExpectedErrors(WriteContext writeContext) { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + if (writeContext.LoggerName == typeof(HubConnection).FullName) + { + if (writeContext.EventId.Name == "ErrorHandshakeTimedOut") + { + handshakeTimeoutLogged = true; + return true; + } + + return writeContext.EventId.Name == "ErrorStartingConnection"; + } + + return false; } + + using (StartVerifiableLog(ExpectedErrors)) + { + var connection = new TestConnection(autoHandshake: false); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + try + { + hubConnection.HandshakeTimeout = TimeSpan.FromMilliseconds(1); + + await Assert.ThrowsAsync(() => hubConnection.StartAsync()).OrTimeout(); + Assert.Equal(HubConnectionState.Disconnected, hubConnection.State); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + Assert.True(handshakeTimeoutLogged, "The handshake timeout wasn't logged."); } [Fact] public async Task StartAsyncWithTriggeredCancellationTokenIsCanceled() { - var onStartCalled = false; - var connection = new TestConnection(onStart: () => + using (StartVerifiableLog()) { - onStartCalled = true; - return Task.CompletedTask; - }); - var hubConnection = CreateHubConnection(connection); - try - { - await Assert.ThrowsAsync(() => hubConnection.StartAsync(new CancellationToken(canceled: true)).OrTimeout()); - Assert.False(onStartCalled); - } - finally - { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + var onStartCalled = false; + var connection = new TestConnection(onStart: () => + { + onStartCalled = true; + return Task.CompletedTask; + }); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + try + { + await Assert.ThrowsAsync(() => hubConnection.StartAsync(new CancellationToken(canceled: true))).OrTimeout(); + Assert.False(onStartCalled); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } } } [Fact] public async Task StartAsyncCanTriggerCancellationTokenToCancelHandshake() { - var cts = new CancellationTokenSource(); - var connection = new TestConnection(onStart: () => - { - cts.Cancel(); - return Task.CompletedTask; - }, autoHandshake: false); - var hubConnection = CreateHubConnection(connection); - // We want to make sure the cancellation is because of the token passed to StartAsync - hubConnection.HandshakeTimeout = Timeout.InfiniteTimeSpan; - try - { - var startTask = hubConnection.StartAsync(cts.Token); - await Assert.ThrowsAnyAsync(() => startTask.OrTimeout()); + var handshakeCancellationLogged = false; - // We aren't worried about the exact message and it's localized so asserting it is non-trivial. - } - finally + bool ExpectedErrors(WriteContext writeContext) { - await hubConnection.DisposeAsync().OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + if (writeContext.LoggerName == typeof(HubConnection).FullName) + { + if (writeContext.EventId.Name == "ErrorHandshakeCanceled") + { + handshakeCancellationLogged = true; + return true; + } + + return writeContext.EventId.Name == "ErrorStartingConnection"; + } + + return false; } + + using (StartVerifiableLog(ExpectedErrors)) + { + var cts = new CancellationTokenSource(); + TestConnection connection = null; + + connection = new TestConnection(autoHandshake: false); + + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + // We want to make sure the cancellation is because of the token passed to StartAsync + hubConnection.HandshakeTimeout = Timeout.InfiniteTimeSpan; + + try + { + var startTask = hubConnection.StartAsync(cts.Token); + + await connection.ReadSentTextMessageAsync().OrTimeout(); + cts.Cancel(); + + // We aren't worried about the exact message and it's localized so asserting it is non-trivial. + await Assert.ThrowsAnyAsync(() => startTask).OrTimeout(); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + Assert.True(handshakeCancellationLogged, "The handshake cancellation wasn't logged."); } [Fact] @@ -574,13 +637,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { try { + // Using OrTimeout here will hide any timeout issues in the test :(. await action(connection); } finally { // Dispose isn't under test here, so fire and forget so that errors/timeouts here don't cause // test errors that mask the real errors. - _ = connection.DisposeAsync(); + _ = connection.DisposeAsync(); } } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs index 6aa6106f40..0ec53e6479 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.DependencyInjection; @@ -7,7 +10,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HubConnectionTests { - private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null, ILoggerFactory loggerFactory = null, IRetryPolicy reconnectPolicy = null) + private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null, ILoggerFactory loggerFactory = null) { var builder = new HubConnectionBuilder(); @@ -27,11 +30,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests builder.Services.AddSingleton(protocol); } - if (reconnectPolicy != null) - { - builder.WithAutomaticReconnect(reconnectPolicy); - } - return builder.Build(); } } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs index 5ddcc51161..93cc3dd863 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Protocol.cs @@ -2,9 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; +using System.IO; using System.Threading.Channels; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Tests; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests @@ -14,7 +15,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // don't cause problems. public partial class HubConnectionTests { - public class Protocol + public class Protocol : VerifiableLoggedTest { [Fact] public async Task SendAsyncSendsANonBlockingInvocationMessage() @@ -31,6 +32,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) Assert.Equal("{\"type\":1,\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + + await invokeTask.OrTimeout(); } finally { @@ -47,14 +50,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { // We can't await StartAsync because it depends on the negotiate process! - var startTask = hubConnection.StartAsync().OrTimeout(); + var startTask = hubConnection.StartAsync(); var handshakeMessage = await connection.ReadHandshakeAndSendResponseAsync().OrTimeout(); // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) Assert.Equal("{\"protocol\":\"json\",\"version\":1}", handshakeMessage); - await startTask; + await startTask.OrTimeout(); } finally { @@ -63,6 +66,35 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } + [Fact] + public async Task InvalidHandshakeResponseCausesStartToFail() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(autoHandshake: false); + var hubConnection = CreateHubConnection(connection); + try + { + // We can't await StartAsync because it depends on the negotiate process! + var startTask = hubConnection.StartAsync(); + + await connection.ReadSentTextMessageAsync().OrTimeout(); + + // The client expects the first message to be a handshake response, but a handshake response doesn't have a "type". + await connection.ReceiveJsonMessage(new { type = "foo" }).OrTimeout(); + + var ex = await Assert.ThrowsAsync(() => startTask).OrTimeout(); + + Assert.Equal("Expected a handshake response from the server.", ex.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Fact] public async Task ClientIsOkayReceivingMinorVersionInHandshake() { @@ -74,9 +106,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { var startTask = hubConnection.StartAsync(); - var message = await connection.ReadHandshakeAndSendResponseAsync(56); + var message = await connection.ReadHandshakeAndSendResponseAsync(56).OrTimeout(); - await startTask; + await startTask.OrTimeout(); } finally { @@ -94,12 +126,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); // ReadSentTextMessageAsync strips off the record separator (because it has use it as a separator now that we use Pipelines) Assert.Equal("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Foo\",\"arguments\":[]}", invokeMessage); + + Assert.Equal(TaskStatus.WaitingForActivation, invokeTask.Status); } finally { @@ -184,7 +218,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // Complete the channel await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); - await channel.Completion; + await channel.Completion.OrTimeout(); } finally { @@ -202,7 +236,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout(); @@ -246,7 +280,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = 42 }).OrTimeout(); @@ -268,7 +302,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); @@ -295,7 +329,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout(); - var ex = await Assert.ThrowsAsync(async () => await channel.ReadAndCollectAllAsync().OrTimeout()); + var ex = await Assert.ThrowsAsync(() => channel.ReadAndCollectAllAsync()).OrTimeout(); Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message); } finally @@ -318,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout(); - var ex = await Assert.ThrowsAsync(async () => await channel.ReadAndCollectAllAsync().OrTimeout()); + var ex = await Assert.ThrowsAsync(async () => await channel.ReadAndCollectAllAsync()).OrTimeout(); Assert.Equal("An error occurred", ex.Message); } finally @@ -337,7 +371,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await hubConnection.StartAsync().OrTimeout(); - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = 42 }).OrTimeout(); @@ -479,7 +513,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await hubConnection.StartAsync().OrTimeout(); // Send an invocation - var invokeTask = hubConnection.InvokeAsync("Foo").OrTimeout(); + var invokeTask = hubConnection.InvokeAsync("Foo"); // Receive the ping mid-invocation so we can see that the rest of the flow works fine await connection.ReceiveJsonMessage(new { type = 6 }).OrTimeout(); @@ -506,15 +540,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var task = hubConnection.StartAsync(); - await connection.ReceiveTextAsync("{"); + await connection.ReceiveTextAsync("{").OrTimeout(); Assert.False(task.IsCompleted); - await connection.ReceiveTextAsync("}"); + await connection.ReceiveTextAsync("}").OrTimeout(); Assert.False(task.IsCompleted); - await connection.ReceiveTextAsync("\u001e"); + await connection.ReceiveTextAsync("\u001e").OrTimeout(); await task.OrTimeout(); } @@ -539,9 +573,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests tcs.TrySetResult(data); }); - await connection.ReceiveTextAsync(payload); + await connection.ReceiveTextAsync(payload).OrTimeout(); - await hubConnection.StartAsync(); + await hubConnection.StartAsync().OrTimeout(); var response = await tcs.Task.OrTimeout(); Assert.Equal("hello", response); @@ -568,15 +602,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await hubConnection.StartAsync().OrTimeout(); - await connection.ReceiveTextAsync("{\"type\":1, "); + await connection.ReceiveTextAsync("{\"type\":1, ").OrTimeout(); Assert.False(tcs.Task.IsCompleted); - await connection.ReceiveTextAsync("\"target\": \"Echo\", \"arguments\""); + await connection.ReceiveTextAsync("\"target\": \"Echo\", \"arguments\"").OrTimeout(); Assert.False(tcs.Task.IsCompleted); - await connection.ReceiveTextAsync(":[\"hello\"]}\u001e"); + await connection.ReceiveTextAsync(":[\"hello\"]}\u001e").OrTimeout(); var response = await tcs.Task.OrTimeout(); diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs index 8e6c684b7b..a039bbd0ef 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Testing; using Moq; @@ -17,897 +18,900 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HubConnectionTests { - [Fact] - public async Task ReconnectIsNotEnabledByDefault() + public class Reconnect : VerifiableLoggedTest { - bool ExpectedErrors(WriteContext writeContext) + [Fact] + public async Task IsNotEnabledByDefault() { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ShutdownWithError" || - writeContext.EventId.Name == "ServerDisconnectedWithError"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var exception = new Exception(); - - var testConnection = new TestConnection(); - await using var hubConnection = CreateHubConnection(testConnection, loggerFactory: LoggerFactory); - - var reconnectingCalled = false; - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => + bool ExpectedErrors(WriteContext writeContext) { - reconnectingCalled = true; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - await hubConnection.StartAsync().OrTimeout(); - - testConnection.CompleteFromTransport(exception); - - Assert.Same(exception, await closedErrorTcs.Task.OrTimeout()); - Assert.False(reconnectingCalled); - } - } - - [Fact] - public async Task ReconnectCanBeOptedInto() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError"); - } - - var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = default(ReconnectingConnectionFactory); - var startCallCount = 0; - var originalConnectionId = "originalConnectionId"; - var reconnectedConnectionId = "reconnectedConnectionId"; - - async Task OnTestConnectionStart() - { - startCallCount++; - - // Only fail the first reconnect attempt. - if (startCallCount == 2) - { - await failReconnectTcs.Task; - } - - var testConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - - // Change the connection id before reconnecting. - if (startCallCount == 3) - { - testConnection.ConnectionId = reconnectedConnectionId; - } - else - { - testConnection.ConnectionId = originalConnectionId; - } + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ShutdownWithError" || + writeContext.EventId.Name == "ServerDisconnectedWithError"); } - testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + using (StartVerifiableLog(ExpectedErrors)) { - retryContexts.Add(context); - return TimeSpan.Zero; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + var exception = new Exception(); - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var testConnection = new TestConnection(); + await using var hubConnection = CreateHubConnection(testConnection, loggerFactory: LoggerFactory); - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; + var reconnectingCalled = false; + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - reconnectedConnectionIdTcs.SetResult(connectionId); - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - await hubConnection.StartAsync().OrTimeout(); - - Assert.Same(originalConnectionId, hubConnection.ConnectionId); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - var reconnectException = new Exception(); - failReconnectTcs.SetException(reconnectException); - - Assert.Same(reconnectedConnectionId, await reconnectedConnectionIdTcs.Task.OrTimeout()); - - Assert.Equal(2, retryContexts.Count); - Assert.Same(reconnectException, retryContexts[1].RetryReason); - Assert.Equal(1, retryContexts[1].PreviousRetryCount); - Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); - - await hubConnection.StopAsync().OrTimeout(); - - var closeError = await closedErrorTcs.Task.OrTimeout(); - Assert.Null(closeError); - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectStopsIfTheReconnectPolicyReturnsNull() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError"); - } - - var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var startCallCount = 0; - - Task OnTestConnectionStart() - { - startCallCount++; - - // Fail the first reconnect attempts. - if (startCallCount > 1) + hubConnection.Reconnecting += error => { - return failReconnectTcs.Task; - } + reconnectingCalled = true; + return Task.CompletedTask; + }; - return Task.CompletedTask; + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + testConnection.CompleteFromTransport(exception); + + Assert.Same(exception, await closedErrorTcs.Task.OrTimeout()); + Assert.False(reconnectingCalled); + } + } + + [Fact] + public async Task CanBeOptedInto() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); } - var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); - builder.Services.AddSingleton(testConnectionFactory); + var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var retryContexts = new List(); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + using (StartVerifiableLog(ExpectedErrors)) { - retryContexts.Add(context); - return context.PreviousRetryCount == 0 ? TimeSpan.Zero : (TimeSpan?)null; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = default(ReconnectingConnectionFactory); + var startCallCount = 0; + var originalConnectionId = "originalConnectionId"; + var reconnectedConnectionId = "reconnectedConnectionId"; - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - await hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - var reconnectException = new Exception(); - failReconnectTcs.SetException(reconnectException); - - var closeError = await closedErrorTcs.Task.OrTimeout(); - Assert.IsType(closeError); - - Assert.Equal(2, retryContexts.Count); - Assert.Same(reconnectException, retryContexts[1].RetryReason); - Assert.Equal(1, retryContexts[1].PreviousRetryCount); - Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); - - Assert.Equal(1, reconnectingCount); - Assert.Equal(0, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectCanHappenMultipleTimes() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => - { - retryContexts.Add(context); - return TimeSpan.Zero; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); - - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - reconnectedConnectionIdTcs.SetResult(connectionId); - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - await hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - await reconnectedConnectionIdTcs.Task.OrTimeout(); - - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); - - reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - var secondException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); - - Assert.Same(secondException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Equal(2, retryContexts.Count); - Assert.Same(secondException, retryContexts[1].RetryReason); - Assert.Equal(0, retryContexts[1].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[1].ElapsedTime); - - await reconnectedConnectionIdTcs.Task.OrTimeout(); - - Assert.Equal(2, reconnectingCount); - Assert.Equal(2, reconnectedCount); - Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); - - await hubConnection.StopAsync().OrTimeout(); - - var closeError = await closedErrorTcs.Task.OrTimeout(); - Assert.Null(closeError); - Assert.Equal(2, reconnectingCount); - Assert.Equal(2, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectEventsNotFiredIfFirstRetryDelayIsNull() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - writeContext.EventId.Name == "ServerDisconnectedWithError"; - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(); - builder.Services.AddSingleton(testConnectionFactory); - - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); - - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - await hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - await closedErrorTcs.Task.OrTimeout(); - - Assert.Equal(0, reconnectingCount); - Assert.Equal(0, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectDoesNotStartIfConnectionIsLostDuringInitialHandshake() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || - writeContext.EventId.Name == "ErrorStartingConnection"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); - builder.Services.AddSingleton(testConnectionFactory); - - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); - - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var closedCount = 0; - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedCount++; - return Task.CompletedTask; - }; - - var startTask = hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await Assert.ThrowsAsync(() => startTask).OrTimeout()); - Assert.Equal(HubConnectionState.Disconnected, hubConnection.State); - Assert.Equal(0, reconnectingCount); - Assert.Equal(0, reconnectedCount); - Assert.Equal(0, closedCount); - } - } - - [Fact] - public async Task ReconnectContinuesIfConnectionLostDuringReconnectHandshake() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError" || - writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || - writeContext.EventId.Name == "ErrorStartingConnection"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => - { - retryContexts.Add(context); - - if (retryContexts.Count == 2) + async Task OnTestConnectionStart() { - secondRetryDelayTcs.SetResult(null); + startCallCount++; + + // Only fail the first reconnect attempt. + if (startCallCount == 2) + { + await failReconnectTcs.Task; + } + + var testConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + + // Change the connection id before reconnecting. + if (startCallCount == 3) + { + testConnection.ConnectionId = reconnectedConnectionId; + } + else + { + testConnection.ConnectionId = originalConnectionId; + } } - return TimeSpan.Zero; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - reconnectedConnectionIdTcs.SetResult(connectionId); - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - var startTask = hubConnection.StartAsync(); - - // Complete handshake - var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - - await startTask.OrTimeout(); - - var firstException = new Exception(); - currentTestConnection.CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - var secondException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); - - await secondRetryDelayTcs.Task.OrTimeout(); - - Assert.Equal(2, retryContexts.Count); - Assert.Same(secondException, retryContexts[1].RetryReason); - Assert.Equal(1, retryContexts[1].PreviousRetryCount); - Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); - - // Complete handshake - currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - await reconnectedConnectionIdTcs.Task.OrTimeout(); - - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); - - await hubConnection.StopAsync().OrTimeout(); - - var closeError = await closedErrorTcs.Task.OrTimeout(); - Assert.Null(closeError); - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectContinuesIfInvalidHandshakeResponse() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError" || - writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || - writeContext.EventId.Name == "HandshakeServerError" || - writeContext.EventId.Name == "ErrorStartingConnection"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => - { - retryContexts.Add(context); - - if (retryContexts.Count == 2) + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => { - secondRetryDelayTcs.SetResult(null); + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + Assert.Same(originalConnectionId, hubConnection.ConnectionId); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var reconnectException = new Exception(); + failReconnectTcs.SetException(reconnectException); + + Assert.Same(reconnectedConnectionId, await reconnectedConnectionIdTcs.Task.OrTimeout()); + + Assert.Equal(2, retryContexts.Count); + Assert.Same(reconnectException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task StopsIfTheReconnectPolicyReturnsNull() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); + } + + var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var startCallCount = 0; + + Task OnTestConnectionStart() + { + startCallCount++; + + // Fail the first reconnect attempts. + if (startCallCount > 1) + { + return failReconnectTcs.Task; + } + + return Task.CompletedTask; } - return TimeSpan.Zero; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return context.PreviousRetryCount == 0 ? TimeSpan.Zero : (TimeSpan?)null; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - reconnectedConnectionIdTcs.SetResult(connectionId); - return Task.CompletedTask; - }; + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; - var startTask = hubConnection.StartAsync(); + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; - // Complete handshake - var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + await hubConnection.StartAsync().OrTimeout(); - await startTask.OrTimeout(); + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - var firstException = new Exception(); - currentTestConnection.CompleteFromTransport(firstException); + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + var reconnectException = new Exception(); + failReconnectTcs.SetException(reconnectException); - // Respond to handshake with error. - currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - await currentTestConnection.ReadSentTextMessageAsync().OrTimeout(); + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.IsType(closeError); - var output = MemoryBufferWriter.Get(); - try - { - HandshakeProtocol.WriteResponseMessage(new HandshakeResponseMessage("Error!"), output); - await currentTestConnection.Application.Output.WriteAsync(output.ToArray()).OrTimeout(); + Assert.Equal(2, retryContexts.Count); + Assert.Same(reconnectException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); } - finally - { - MemoryBufferWriter.Return(output); - } - - await secondRetryDelayTcs.Task.OrTimeout(); - - Assert.Equal(2, retryContexts.Count); - Assert.IsType(retryContexts[1].RetryReason); - Assert.Equal(1, retryContexts[1].PreviousRetryCount); - Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); - - // Complete handshake - - currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); - await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); - await reconnectedConnectionIdTcs.Task.OrTimeout(); - - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); - - await hubConnection.StopAsync().OrTimeout(); - - var closeError = await closedErrorTcs.Task.OrTimeout(); - Assert.Null(closeError); - Assert.Equal(1, reconnectingCount); - Assert.Equal(1, reconnectedCount); - } - } - - [Fact] - public async Task ReconnectCanBeStoppedWhileRestartingUnderlyingConnection() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError" || - writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || - writeContext.EventId.Name == "ErrorStartingConnection"); } - using (StartVerifiableLog(ExpectedErrors)) + [Fact] + public async Task CanHappenMultipleTimes() { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - async Task OnTestConnectionStart() + bool ExpectedErrors(WriteContext writeContext) { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var secondException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); + + Assert.Same(secondException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Equal(2, retryContexts.Count); + Assert.Same(secondException, retryContexts[1].RetryReason); + Assert.Equal(0, retryContexts[1].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[1].ElapsedTime); + + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(2, reconnectingCount); + Assert.Equal(2, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(2, reconnectingCount); + Assert.Equal(2, reconnectedCount); + } + } + + [Fact] + public async Task EventsNotFiredIfFirstRetryDelayIsNull() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + writeContext.EventId.Name == "ServerDisconnectedWithError"; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + await closedErrorTcs.Task.OrTimeout(); + + Assert.Equal(0, reconnectingCount); + Assert.Equal(0, reconnectedCount); + } + } + + [Fact] + public async Task DoesNotStartIfConnectionIsLostDuringInitialHandshake() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var closedCount = 0; + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedCount++; + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await Assert.ThrowsAsync(() => startTask).OrTimeout()); + Assert.Equal(HubConnectionState.Disconnected, hubConnection.State); + Assert.Equal(0, reconnectingCount); + Assert.Equal(0, reconnectedCount); + Assert.Equal(0, closedCount); + } + } + + [Fact] + public async Task ContinuesIfConnectionLostDuringReconnectHandshake() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + + if (retryContexts.Count == 2) + { + secondRetryDelayTcs.SetResult(null); + } + + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync(); + + // Complete handshake + var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + + await startTask.OrTimeout(); + + var firstException = new Exception(); + currentTestConnection.CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var secondException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); + + await secondRetryDelayTcs.Task.OrTimeout(); + + Assert.Equal(2, retryContexts.Count); + Assert.Same(secondException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); + + // Complete handshake + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task ContinuesIfInvalidHandshakeResponse() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "HandshakeServerError" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + + if (retryContexts.Count == 2) + { + secondRetryDelayTcs.SetResult(null); + } + + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync(); + + // Complete handshake + var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + + await startTask.OrTimeout(); + + var firstException = new Exception(); + currentTestConnection.CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + // Respond to handshake with error. + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadSentTextMessageAsync().OrTimeout(); + + var output = MemoryBufferWriter.Get(); try { - await connectionStartTcs.Task; + HandshakeProtocol.WriteResponseMessage(new HandshakeResponseMessage("Error!"), output); + await currentTestConnection.Application.Output.WriteAsync(output.ToArray()).OrTimeout(); } finally { - connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + MemoryBufferWriter.Return(output); + } + + await secondRetryDelayTcs.Task.OrTimeout(); + + Assert.Equal(2, retryContexts.Count); + Assert.IsType(retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); + + // Complete handshake + + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task CanBeStoppedWhileRestartingUnderlyingConnection() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorHandshakeCanceled" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + async Task OnTestConnectionStart() + { + try + { + await connectionStartTcs.Task; + } + finally + { + connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + // Allow the first connection to start successfully. + connectionStartTcs.SetResult(null); + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var secondException = new Exception(); + var stopTask = hubConnection.StopAsync(); + connectionStartTcs.SetResult(null); + + Assert.IsType(await closedErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); + await stopTask.OrTimeout(); + } + } + + [Fact] + public async Task CanBeStoppedDuringRetryDelay() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + // Hopefully this test never takes over a minute. + return TimeSpan.FromMinutes(1); + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + // Allow the first connection to start successfully. + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + await hubConnection.StopAsync().OrTimeout(); + + Assert.IsType(await closedErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); + } + } + + private class ReconnectingConnectionFactory : IConnectionFactory + { + public readonly Func _testConnectionFactory; + public TaskCompletionSource _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public ReconnectingConnectionFactory() + : this(() => new TestConnection()) + { + } + + public ReconnectingConnectionFactory(Func testConnectionFactory) + { + _testConnectionFactory = testConnectionFactory; + } + + public Task GetNextOrCurrentTestConnection() + { + return _testConnectionTcs.Task; + } + + public async Task ConnectAsync(TransferFormat transferFormat, CancellationToken cancellationToken = default) + { + var testConnection = _testConnectionFactory(); + + _testConnectionTcs.SetResult(testConnection); + + try + { + return await testConnection.StartAsync(transferFormat); + } + catch + { + _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + throw; } } - var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + public async Task DisposeAsync(ConnectionContext connection) { - retryContexts.Add(context); - return TimeSpan.Zero; - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + var disposingTestConnection = await _testConnectionTcs.Task; - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - // Allow the first connection to start successfully. - connectionStartTcs.SetResult(null); - await hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - var secondException = new Exception(); - var stopTask = hubConnection.StopAsync(); - connectionStartTcs.SetResult(null); - - Assert.IsType(await closedErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Equal(1, reconnectingCount); - Assert.Equal(0, reconnectedCount); - await stopTask.OrTimeout(); - } - } - - [Fact] - public async Task ReconnectCanBeStoppedDuringRetryDelay() - { - bool ExpectedErrors(WriteContext writeContext) - { - return writeContext.LoggerName == typeof(HubConnection).FullName && - (writeContext.EventId.Name == "ServerDisconnectedWithError" || - writeContext.EventId.Name == "ReconnectingWithError" || - writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || - writeContext.EventId.Name == "ErrorStartingConnection"); - } - - using (StartVerifiableLog(ExpectedErrors)) - { - var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); - var testConnectionFactory = new ReconnectingConnectionFactory(); - builder.Services.AddSingleton(testConnectionFactory); - - var retryContexts = new List(); - var mockReconnectPolicy = new Mock(); - mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => - { - retryContexts.Add(context); - // Hopefully this test never takes over a minute. - return TimeSpan.FromMinutes(1); - }); - builder.WithAutomaticReconnect(mockReconnectPolicy.Object); - - await using var hubConnection = builder.Build(); - var reconnectingCount = 0; - var reconnectedCount = 0; - var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - hubConnection.Reconnecting += error => - { - reconnectingCount++; - reconnectingErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - hubConnection.Reconnected += connectionId => - { - reconnectedCount++; - return Task.CompletedTask; - }; - - hubConnection.Closed += error => - { - closedErrorTcs.SetResult(error); - return Task.CompletedTask; - }; - - // Allow the first connection to start successfully. - await hubConnection.StartAsync().OrTimeout(); - - var firstException = new Exception(); - (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); - - Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Same(firstException, retryContexts[0].RetryReason); - Assert.Equal(0, retryContexts[0].PreviousRetryCount); - Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); - - await hubConnection.StopAsync().OrTimeout(); - - Assert.IsType(await closedErrorTcs.Task.OrTimeout()); - Assert.Single(retryContexts); - Assert.Equal(1, reconnectingCount); - Assert.Equal(0, reconnectedCount); - } - } - - private class ReconnectingConnectionFactory : IConnectionFactory - { - public readonly Func _testConnectionFactory; - public TaskCompletionSource _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - public ReconnectingConnectionFactory() - : this (() => new TestConnection()) - { - } - - public ReconnectingConnectionFactory(Func testConnectionFactory) - { - _testConnectionFactory = testConnectionFactory; - } - - public Task GetNextOrCurrentTestConnection() - { - return _testConnectionTcs.Task; - } - - public async Task ConnectAsync(TransferFormat transferFormat, CancellationToken cancellationToken = default) - { - var testConnection = _testConnectionFactory(); - - _testConnectionTcs.SetResult(testConnection); - - try - { - return await testConnection.StartAsync(transferFormat); - } - catch - { _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - throw; + + await disposingTestConnection.DisposeAsync(); } } - - public async Task DisposeAsync(ConnectionContext connection) - { - var disposingTestConnection = await _testConnectionTcs.Task; - - _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - - await disposingTestConnection.DisposeAsync(); - } } } }