diff --git a/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netcoreapp3.0.cs b/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netcoreapp3.0.cs index e773e46696..4124f33b55 100644 --- a/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netcoreapp3.0.cs +++ b/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netcoreapp3.0.cs @@ -26,12 +26,15 @@ namespace Microsoft.AspNetCore.SignalR.Client public static readonly System.TimeSpan DefaultServerTimeout; public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { } public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, System.IServiceProvider serviceProvider, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { } + public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, System.IServiceProvider serviceProvider, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.AspNetCore.SignalR.Client.IRetryPolicy reconnectPolicy) { } public string ConnectionId { get { throw null; } } public System.TimeSpan HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan ServerTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.SignalR.Client.HubConnectionState State { get { throw null; } } public event System.Func Closed { add { } remove { } } + public event System.Func Reconnected { add { } remove { } } + public event System.Func Reconnecting { add { } remove { } } [System.Diagnostics.DebuggerStepThroughAttribute] public System.Threading.Tasks.Task DisposeAsync() { throw null; } [System.Diagnostics.DebuggerStepThroughAttribute] @@ -65,6 +68,9 @@ namespace Microsoft.AspNetCore.SignalR.Client public static partial class HubConnectionBuilderExtensions { public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder ConfigureLogging(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, System.Action configureLogging) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, Microsoft.AspNetCore.SignalR.Client.IRetryPolicy retryPolicy) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, System.TimeSpan[] reconnectDelays) { throw null; } } public static partial class HubConnectionExtensions { @@ -152,6 +158,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { Disconnected = 0, Connected = 1, + Connecting = 2, + Reconnecting = 3, } public partial interface IConnectionFactory { @@ -162,4 +170,15 @@ namespace Microsoft.AspNetCore.SignalR.Client { Microsoft.AspNetCore.SignalR.Client.HubConnection Build(); } + public partial interface IRetryPolicy + { + System.TimeSpan? NextRetryDelay(Microsoft.AspNetCore.SignalR.Client.RetryContext retryContext); + } + public sealed partial class RetryContext + { + public RetryContext() { } + public System.TimeSpan ElapsedTime { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public long PreviousRetryCount { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public System.Exception RetryReason { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + } } diff --git a/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netstandard2.0.cs b/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netstandard2.0.cs index 640b7c85e0..8e8854ac6a 100644 --- a/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netstandard2.0.cs +++ b/src/SignalR/clients/csharp/Client.Core/ref/Microsoft.AspNetCore.SignalR.Client.Core.netstandard2.0.cs @@ -26,12 +26,15 @@ namespace Microsoft.AspNetCore.SignalR.Client public static readonly System.TimeSpan DefaultServerTimeout; public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { } public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, System.IServiceProvider serviceProvider, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { } + public HubConnection(Microsoft.AspNetCore.SignalR.Client.IConnectionFactory connectionFactory, Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol protocol, System.IServiceProvider serviceProvider, Microsoft.Extensions.Logging.ILoggerFactory loggerFactory, Microsoft.AspNetCore.SignalR.Client.IRetryPolicy reconnectPolicy) { } public string ConnectionId { get { throw null; } } public System.TimeSpan HandshakeTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan KeepAliveInterval { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public System.TimeSpan ServerTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public Microsoft.AspNetCore.SignalR.Client.HubConnectionState State { get { throw null; } } public event System.Func Closed { add { } remove { } } + public event System.Func Reconnected { add { } remove { } } + public event System.Func Reconnecting { add { } remove { } } [System.Diagnostics.DebuggerStepThroughAttribute] public System.Threading.Tasks.Task DisposeAsync() { throw null; } [System.Diagnostics.DebuggerStepThroughAttribute] @@ -64,6 +67,9 @@ namespace Microsoft.AspNetCore.SignalR.Client public static partial class HubConnectionBuilderExtensions { public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder ConfigureLogging(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, System.Action configureLogging) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, Microsoft.AspNetCore.SignalR.Client.IRetryPolicy retryPolicy) { throw null; } + public static Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder WithAutomaticReconnect(this Microsoft.AspNetCore.SignalR.Client.IHubConnectionBuilder hubConnectionBuilder, System.TimeSpan[] reconnectDelays) { throw null; } } public static partial class HubConnectionExtensions { @@ -140,6 +146,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { Disconnected = 0, Connected = 1, + Connecting = 2, + Reconnecting = 3, } public partial interface IConnectionFactory { @@ -150,4 +158,15 @@ namespace Microsoft.AspNetCore.SignalR.Client { Microsoft.AspNetCore.SignalR.Client.HubConnection Build(); } + public partial interface IRetryPolicy + { + System.TimeSpan? NextRetryDelay(Microsoft.AspNetCore.SignalR.Client.RetryContext retryContext); + } + public sealed partial class RetryContext + { + public RetryContext() { } + public System.TimeSpan ElapsedTime { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public long PreviousRetryCount { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + public System.Exception RetryReason { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 8eafaedb20..7061e276b1 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -198,6 +198,48 @@ namespace Microsoft.AspNetCore.SignalR.Client private static readonly Action _completingStream = LoggerMessage.Define(LogLevel.Trace, new EventId(66, "CompletingStream"), "Sending completion message for stream '{StreamId}'."); + private static readonly Action _stateTransitionFailed = + LoggerMessage.Define(LogLevel.Error, new EventId(67, "StateTransitionFailed"), "The HubConnection failed to transition from the {ExpectedState} state to the {NewState} state because it was actually in the {ActualState} state."); + + private static readonly Action _reconnecting = + LoggerMessage.Define(LogLevel.Information, new EventId(68, "Reconnecting"), "HubConnection reconnecting."); + + private static readonly Action _reconnectingWithError = + LoggerMessage.Define(LogLevel.Error, new EventId(69, "ReconnectingWithError"), "HubConnection reconnecting due to an error."); + + private static readonly Action _reconnected = + LoggerMessage.Define(LogLevel.Information, new EventId(70, "Reconnected"), "HubConnection reconnected successfully after {ReconnectAttempts} attempts and {ElapsedTime} elapsed."); + + private static readonly Action _reconnectAttemptsExhausted = + LoggerMessage.Define(LogLevel.Information, new EventId(71, "ReconnectAttemptsExhausted"), "Reconnect retries have been exhausted after {ReconnectAttempts} failed attempts and {ElapsedTime} elapsed. Disconnecting."); + + private static readonly Action _awaitingReconnectRetryDelay = + LoggerMessage.Define(LogLevel.Trace, new EventId(72, "AwaitingReconnectRetryDelay"), "Reconnect attempt number {ReconnectAttempts} will start in {RetryDelay}."); + + private static readonly Action _reconnectAttemptFailed = + LoggerMessage.Define(LogLevel.Trace, new EventId(73, "ReconnectAttemptFailed"), "Reconnect attempt failed."); + + private static readonly Action _errorDuringReconnectingEvent = + LoggerMessage.Define(LogLevel.Error, new EventId(74, "ErrorDuringReconnectingEvent"), "An exception was thrown in the handler for the Reconnecting event."); + + private static readonly Action _errorDuringReconnectedEvent = + LoggerMessage.Define(LogLevel.Error, new EventId(75, "ErrorDuringReconnectedEvent"), "An exception was thrown in the handler for the Reconnected event."); + + private static readonly Action _errorDuringNextRetryDelay = + LoggerMessage.Define(LogLevel.Error, new EventId(76, "ErrorDuringNextRetryDelay"), $"An exception was thrown from {nameof(IRetryPolicy)}.{nameof(IRetryPolicy.NextRetryDelay)}()."); + + private static readonly Action _firstReconnectRetryDelayNull = + LoggerMessage.Define(LogLevel.Warning, new EventId(77, "FirstReconnectRetryDelayNull"), "Connection not reconnecting because the IRetryPolicy returned null on the first reconnect attempt."); + + private static readonly Action _reconnectingStoppedDuringRetryDelay = + LoggerMessage.Define(LogLevel.Trace, new EventId(78, "ReconnectingStoppedDueToStateChangeDuringRetryDelay"), "Connection stopped during reconnect delay. Done reconnecting."); + + private static readonly Action _reconnectingStoppedDuringReconnectAttempt = + LoggerMessage.Define(LogLevel.Trace, new EventId(79, "ReconnectingStoppedDueToStateChangeDuringReconnectAttempt"), "Connection stopped during reconnect attempt. Done reconnecting."); + + private static readonly Action _attemptingStateTransition = + LoggerMessage.Define(LogLevel.Trace, new EventId(80, "AttemptingStateTransition"), "The HubConnection is attempting to transition from the {ExpectedState} state to the {NewState} state."); + public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count) { _preparingNonBlockingInvocation(logger, target, count, null); @@ -528,6 +570,76 @@ namespace Microsoft.AspNetCore.SignalR.Client { _completingStream(logger, streamId, null); } + + public static void StateTransitionFailed(ILogger logger, HubConnectionState expectedState, HubConnectionState newState, HubConnectionState actualState) + { + _stateTransitionFailed(logger, expectedState, newState, actualState, null); + } + + public static void Reconnecting(ILogger logger) + { + _reconnecting(logger, null); + } + + public static void ReconnectingWithError(ILogger logger, Exception exception) + { + _reconnectingWithError(logger, exception); + } + + public static void Reconnected(ILogger logger, long reconnectAttempts, TimeSpan elapsedTime) + { + _reconnected(logger, reconnectAttempts, elapsedTime, null); + } + + public static void ReconnectAttemptsExhausted(ILogger logger, long reconnectAttempts, TimeSpan elapsedTime) + { + _reconnectAttemptsExhausted(logger, reconnectAttempts, elapsedTime, null); + } + + public static void AwaitingReconnectRetryDelay(ILogger logger, long reconnectAttempts, TimeSpan retryDelay) + { + _awaitingReconnectRetryDelay(logger, reconnectAttempts, retryDelay, null); + } + + public static void ReconnectAttemptFailed(ILogger logger, Exception exception) + { + _reconnectAttemptFailed(logger, exception); + } + + public static void ErrorDuringReconnectingEvent(ILogger logger, Exception exception) + { + _errorDuringReconnectingEvent(logger, exception); + } + + public static void ErrorDuringReconnectedEvent(ILogger logger, Exception exception) + { + _errorDuringReconnectedEvent(logger, exception); + } + + public static void ErrorDuringNextRetryDelay(ILogger logger, Exception exception) + { + _errorDuringNextRetryDelay(logger, exception); + } + + public static void FirstReconnectRetryDelayNull(ILogger logger) + { + _firstReconnectRetryDelayNull(logger, null); + } + + public static void ReconnectingStoppedDuringRetryDelay(ILogger logger) + { + _reconnectingStoppedDuringRetryDelay(logger, null); + } + + public static void ReconnectingStoppedDuringReconnectAttempt(ILogger logger) + { + _reconnectingStoppedDuringReconnectAttempt(logger, null); + } + + public static void AttemptingStateTransition(ILogger logger, HubConnectionState expectedState, HubConnectionState newState) + { + _attemptingStateTransition(logger, expectedState, newState, null); + } } } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index cdbd7f7caa..560d32978e 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -41,28 +41,6 @@ namespace Microsoft.AspNetCore.SignalR.Client // This lock protects the connection state. private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1); - private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems")); -#if NETCOREAPP3_0 - private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems")); -#endif - // Persistent across all connections - private readonly ILoggerFactory _loggerFactory; - private readonly ILogger _logger; - private readonly IHubProtocol _protocol; - private readonly IServiceProvider _serviceProvider; - private readonly IConnectionFactory _connectionFactory; - private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(StringComparer.Ordinal); - - private long _nextActivationServerTimeout; - private long _nextActivationSendPing; - private bool _disposed; - private bool _hasInherentKeepAlive; - private string _connectionId; - - private CancellationToken _uploadStreamToken; - - private readonly ConnectionLogScope _logScope; - // The receive loop has a single reader and single writer at a time so optimize the channel for that private static readonly UnboundedChannelOptions _receiveLoopOptions = new UnboundedChannelOptions { @@ -70,8 +48,24 @@ namespace Microsoft.AspNetCore.SignalR.Client SingleWriter = true }; - // Transient state to a connection - private ConnectionState _connectionState; + private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals(nameof(SendStreamItems))); +#if NETCOREAPP3_0 + private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals(nameof(SendIAsyncEnumerableStreamItems))); +#endif + // Persistent across all connections + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + private readonly ConnectionLogScope _logScope; + private readonly IHubProtocol _protocol; + private readonly IServiceProvider _serviceProvider; + private readonly IConnectionFactory _connectionFactory; + private readonly IRetryPolicy _reconnectPolicy; + private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(StringComparer.Ordinal); + + // Holds all mutable state other than user-defined handlers and settable properties. + private readonly ReconnectingConnectionState _state; + + private bool _disposed; /// /// Occurs when the connection is closed. The connection could be closed due to an error or due to either the server or client intentionally @@ -95,13 +89,49 @@ namespace Microsoft.AspNetCore.SignalR.Client /// } /// else /// { - /// Console.WriteLine($"Connection closed due to an error: {exception.Message}"); + /// Console.WriteLine($"Connection closed due to an error: {exception}"); /// } /// }; /// /// public event Func Closed; + /// + /// Occurs when the starts reconnecting after losing its underlying connection. + /// + /// + /// The that occurred will be passed in as the sole argument to this handler. + /// + /// + /// The following example attaches a handler to the event, and checks the provided argument to log the error. + /// + /// + /// connection.Reconnecting += (exception) => + /// { + /// Console.WriteLine($"Connection started reconnecting due to an error: {exception}"); + /// }; + /// + /// + public event Func Reconnecting; + + /// + /// Occurs when the successfully reconnects after losing its underlying connection. + /// + /// + /// The parameter will be the 's new ConnectionId or null if negotiation was skipped. + /// + /// + /// The following example attaches a handler to the event, and checks the provided argument to log the ConnectionId. + /// + /// + /// connection.Reconnected += (connectionId) => + /// { + /// Console.WriteLine($"Connection successfully reconnected. The ConnectionId is now: {connectionId}"); + /// }; + /// + /// + public event Func Reconnected; + // internal for testing purposes internal TimeSpan TickRate { get; set; } = TimeSpan.FromSeconds(1); @@ -131,24 +161,31 @@ namespace Microsoft.AspNetCore.SignalR.Client /// This value will be null if the negotiation step is skipped via HttpConnectionOptions or if the WebSockets transport is explicitly specified because the /// client skips negotiation in that case as well. /// - public string ConnectionId => _connectionId; + public string ConnectionId => _state.CurrentConnectionStateUnsynchronized?.Connection.ConnectionId; /// /// Indicates the state of the to the server. /// - public HubConnectionState State - { - get - { - // Copy reference for thread-safety - var connectionState = _connectionState; - if (connectionState == null || connectionState.Stopped) - { - return HubConnectionState.Disconnected; - } + public HubConnectionState State => _state.OverallState; - return HubConnectionState.Connected; - } + /// + /// Initializes a new instance of the class. + /// + /// The used to create a connection each time is called. + /// The used by the connection. + /// An containing the services provided to this instance. + /// The logger factory. + /// + /// The that controls the timing and number of reconnect attempts. + /// The will not reconnect if the is null. + /// + /// + /// The used to initialize the connection will be disposed when the connection is disposed. + /// + public HubConnection(IConnectionFactory connectionFactory, IHubProtocol protocol, IServiceProvider serviceProvider, ILoggerFactory loggerFactory, IRetryPolicy reconnectPolicy) + : this(connectionFactory, protocol, serviceProvider, loggerFactory) + { + _reconnectPolicy = reconnectPolicy; } /// @@ -180,6 +217,7 @@ namespace Microsoft.AspNetCore.SignalR.Client _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); + _state = new ReconnectingConnectionState(_logger); _logScope = new ConnectionLogScope(); } @@ -194,7 +232,43 @@ namespace Microsoft.AspNetCore.SignalR.Client CheckDisposed(); using (_logger.BeginScope(_logScope)) { - await StartAsyncCore(cancellationToken).ForceAsync(); + await StartAsyncInner(cancellationToken).ForceAsync(); + } + } + + private async Task StartAsyncInner(CancellationToken cancellationToken = default) + { + await _state.WaitConnectionLockAsync(); + try + { + if (!_state.TryChangeState(HubConnectionState.Disconnected, HubConnectionState.Connecting)) + { + throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started if it is not in the {nameof(HubConnectionState.Disconnected)} state."); + } + + // The StopCts is canceled at the start of StopAsync should be reset every time the connection finishes stopping. + // If this token is currently canceled, it means that StartAsync was called while StopAsync was still running. + if (_state.StopCts.Token.IsCancellationRequested) + { + throw new InvalidOperationException($"The {nameof(HubConnection)} cannot be started while {nameof(StopAsync)} is running."); + } + + await StartAsyncCore(cancellationToken); + + _state.ChangeState(HubConnectionState.Connecting, HubConnectionState.Connected); + } + catch + { + if (_state.TryChangeState(HubConnectionState.Connecting, HubConnectionState.Disconnected)) + { + _state.StopCts = new CancellationTokenSource(); + } + + throw; + } + finally + { + _state.ReleaseConnectionLock(); } } @@ -338,53 +412,40 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task StartAsyncCore(CancellationToken cancellationToken) { - await WaitConnectionLockAsync(); + _state.AssertInConnectionLock(); + SafeAssert(_state.CurrentConnectionStateUnsynchronized == null, "We already have a connection!"); + + cancellationToken.ThrowIfCancellationRequested(); + + CheckDisposed(); + + Log.Starting(_logger); + + // Start the connection + var connection = await _connectionFactory.ConnectAsync(_protocol.TransferFormat); + var startingConnectionState = new ConnectionState(connection, this); + + // From here on, if an error occurs we need to shut down the connection because + // we still own it. try { - if (_connectionState != null) - { - // We're already connected - return; - } - - cancellationToken.ThrowIfCancellationRequested(); - - CheckDisposed(); - - Log.Starting(_logger); - - // Start the connection - var connection = await _connectionFactory.ConnectAsync(_protocol.TransferFormat); - _connectionId = connection.ConnectionId; - var startingConnectionState = new ConnectionState(connection, this); - _hasInherentKeepAlive = connection.Features.Get()?.HasInherentKeepAlive ?? false; - - // From here on, if an error occurs we need to shut down the connection because - // we still own it. - try - { - Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); - await HandshakeAsync(startingConnectionState, cancellationToken); - } - catch (Exception ex) - { - Log.ErrorStartingConnection(_logger, ex); - - // Can't have any invocations to cancel, we're in the lock. - await CloseAsync(startingConnectionState.Connection); - throw; - } - - // Set this at the end to avoid setting internal state until the connection is real - _connectionState = startingConnectionState; - _connectionState.ReceiveTask = ReceiveLoop(_connectionState); - - Log.Started(_logger); + Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); + await HandshakeAsync(startingConnectionState, cancellationToken); } - finally + catch (Exception ex) { - ReleaseConnectionLock(); + Log.ErrorStartingConnection(_logger, ex); + + // Can't have any invocations to cancel, we're in the lock. + await CloseAsync(startingConnectionState.Connection); + throw; } + + // Set this at the end to avoid setting internal state until the connection is real + _state.CurrentConnectionStateUnsynchronized = startingConnectionState; + startingConnectionState.ReceiveTask = ReceiveLoop(startingConnectionState); + + Log.Started(_logger); } private Task CloseAsync(ConnectionContext connection) @@ -397,9 +458,31 @@ namespace Microsoft.AspNetCore.SignalR.Client // if we're disposing. private async Task StopAsyncCore(bool disposing) { - // Block a Start from happening until we've finished capturing the connection state. + // StartAsync acquires the connection lock for the duration of the handshake. + // ReconnectAsync also acquires the connection lock for reconnect attempts and handshakes. + // Cancel the StopCts without acquiring the lock so we can short-circuit it. + _state.StopCts.Cancel(); + + // Potentially wait for StartAsync to finish, and block a new StartAsync from + // starting until we've finished stopping. + await _state.WaitConnectionLockAsync(); + + // Ensure that ReconnectingState.ReconnectTask is not accessed outside of the lock. + var reconnectTask = _state.ReconnectTask; + + if (reconnectTask.Status != TaskStatus.RanToCompletion) + { + // Let the current reconnect attempts finish if necessary without the lock. + // Otherwise, ReconnectAsync will stall forever acquiring the lock. + // It should never throw, even if the reconnect attempts fail. + // The StopCts should prevent the HubConnection from restarting until it is reset. + _state.ReleaseConnectionLock(); + await reconnectTask; + await _state.WaitConnectionLockAsync(); + } + ConnectionState connectionState; - await WaitConnectionLockAsync(); + try { if (disposing && _disposed) @@ -409,7 +492,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } CheckDisposed(); - connectionState = _connectionState; + connectionState = _state.CurrentConnectionStateUnsynchronized; // Set the stopping flag so that any invocations after this get a useful error message instead of // silently failing or throwing an error about the pipe being completed. @@ -423,12 +506,10 @@ namespace Microsoft.AspNetCore.SignalR.Client (_serviceProvider as IDisposable)?.Dispose(); _disposed = true; } - - _connectionId = null; } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } // Now stop the connection we captured @@ -439,7 +520,6 @@ namespace Microsoft.AspNetCore.SignalR.Client } #if NETCOREAPP3_0 - /// /// Invokes a streaming hub method on the server using the specified method name, return type and arguments. /// @@ -483,15 +563,15 @@ namespace Microsoft.AspNetCore.SignalR.Client async Task OnStreamCanceled(InvocationRequest irq) { // We need to take the connection lock in order to ensure we a) have a connection and b) are the only one accessing the write end of the pipe. - await WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(); try { - if (_connectionState != null) + if (_state.CurrentConnectionStateUnsynchronized != null) { Log.SendingCancellation(_logger, irq.InvocationId); // Fire and forget, if it fails that means we aren't connected anymore. - _ = SendHubMessage(new CancelInvocationMessage(irq.InvocationId), irq.CancellationToken); + _ = SendHubMessage(_state.CurrentConnectionStateUnsynchronized, new CancelInvocationMessage(irq.InvocationId), irq.CancellationToken); } else { @@ -500,44 +580,46 @@ namespace Microsoft.AspNetCore.SignalR.Client } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } // Cancel the invocation irq.Dispose(); } - var readers = PackageStreamingParams(ref args, out var streamIds); + var readers = default(Dictionary); CheckDisposed(); - await WaitConnectionLockAsync(); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(StreamAsChannelCoreAsync)); ChannelReader channel; try { CheckDisposed(); - CheckConnectionActive(nameof(StreamAsChannelCoreAsync)); cancellationToken.ThrowIfCancellationRequested(); + readers = PackageStreamingParams(connectionState, ref args, out var streamIds); + // I just want an excuse to use 'irq' as a variable name... - var irq = InvocationRequest.Stream(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out channel); - await InvokeStreamCore(methodName, irq, args, streamIds?.ToArray(), cancellationToken); + var irq = InvocationRequest.Stream(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, out channel); + await InvokeStreamCore(connectionState, methodName, irq, args, streamIds?.ToArray(), cancellationToken); if (cancellationToken.CanBeCanceled) { cancellationToken.Register(state => _ = OnStreamCanceled((InvocationRequest)state), irq); } + + LaunchStreams(connectionState, readers, cancellationToken); } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } - LaunchStreams(readers, cancellationToken); return channel; } - private Dictionary PackageStreamingParams(ref object[] args, out List streamIds) + private Dictionary PackageStreamingParams(ConnectionState connectionState, ref object[] args, out List streamIds) { Dictionary readers = null; streamIds = null; @@ -552,7 +634,7 @@ namespace Microsoft.AspNetCore.SignalR.Client readers = new Dictionary(); } - var id = _connectionState.GetNextId(); + var id = connectionState.GetNextId(); readers[id] = args[i]; if (streamIds == null) @@ -574,13 +656,14 @@ namespace Microsoft.AspNetCore.SignalR.Client return readers; } - private void LaunchStreams(Dictionary readers, CancellationToken cancellationToken) + private void LaunchStreams(ConnectionState connectionState, Dictionary readers, CancellationToken cancellationToken) { if (readers == null) { // if there were no streaming parameters then readers is never initialized return; } + foreach (var kvp in readers) { var reader = kvp.Value; @@ -593,18 +676,18 @@ namespace Microsoft.AspNetCore.SignalR.Client { _ = _sendIAsyncStreamItemsMethod .MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments()) - .Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken }); + .Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken }); continue; } #endif _ = _sendStreamItemsMethod .MakeGenericMethod(reader.GetType().GetGenericArguments()) - .Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken }); + .Invoke(this, new object[] { connectionState, kvp.Key.ToString(), reader, cancellationToken }); } } // this is called via reflection using the `_sendStreamItems` field - private Task SendStreamItems(string streamId, ChannelReader reader, CancellationToken token) + private Task SendStreamItems(ConnectionState connectionState, string streamId, ChannelReader reader, CancellationToken token) { async Task ReadChannelStream(CancellationTokenSource tokenSource) { @@ -612,18 +695,18 @@ namespace Microsoft.AspNetCore.SignalR.Client { while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item)) { - await SendWithLock(new StreamItemMessage(streamId, item)); + await SendWithLock(connectionState, new StreamItemMessage(streamId, item)); Log.SendingStreamItem(_logger, streamId); } } } - return CommonStreaming(streamId, token, ReadChannelStream); + return CommonStreaming(connectionState, streamId, token, ReadChannelStream); } #if NETCOREAPP3_0 // this is called via reflection using the `_sendIAsyncStreamItemsMethod` field - private Task SendIAsyncEnumerableStreamItems(string streamId, IAsyncEnumerable stream, CancellationToken token) + private Task SendIAsyncEnumerableStreamItems(ConnectionState connectionState, string streamId, IAsyncEnumerable stream, CancellationToken token) { async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource) { @@ -631,18 +714,20 @@ namespace Microsoft.AspNetCore.SignalR.Client await foreach (var streamValue in streamValues) { - await SendWithLock(new StreamItemMessage(streamId, streamValue)); + await SendWithLock(connectionState, new StreamItemMessage(streamId, streamValue)); Log.SendingStreamItem(_logger, streamId); } } - return CommonStreaming(streamId, token, ReadAsyncEnumerableStream); + return CommonStreaming(connectionState, streamId, token, ReadAsyncEnumerableStream); } #endif - private async Task CommonStreaming(string streamId, CancellationToken token, Func createAndConsumeStream) + private async Task CommonStreaming(ConnectionState connectionState, string streamId, CancellationToken token, Func createAndConsumeStream) { - var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token); + // It's safe to access connectionState.UploadStreamToken as we still have the connection lock + _state.AssertInConnectionLock(); + var cts = CancellationTokenSource.CreateLinkedTokenSource(connectionState.UploadStreamToken, token); Log.StartingStream(_logger, streamId); string responseError = null; @@ -657,37 +742,39 @@ namespace Microsoft.AspNetCore.SignalR.Client } Log.CompletingStream(_logger, streamId); - await SendWithLock(CompletionMessage.WithError(streamId, responseError)); + + await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cts.Token); } private async Task InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - var readers = PackageStreamingParams(ref args, out var streamIds); + var readers = default(Dictionary); CheckDisposed(); - await WaitConnectionLockAsync(); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(InvokeCoreAsync)); Task invocationTask; try { CheckDisposed(); - CheckConnectionActive(nameof(InvokeCoreAsync)); - var irq = InvocationRequest.Invoke(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out invocationTask); - await InvokeCore(methodName, irq, args, streamIds?.ToArray(), cancellationToken); + readers = PackageStreamingParams(connectionState, ref args, out var streamIds); + + var irq = InvocationRequest.Invoke(cancellationToken, returnType, connectionState.GetNextId(), _loggerFactory, this, out invocationTask); + await InvokeCore(connectionState, methodName, irq, args, streamIds?.ToArray(), cancellationToken); + + LaunchStreams(connectionState, readers, cancellationToken); } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } - LaunchStreams(readers, cancellationToken); - // Wait for this outside the lock, because it won't complete until the server responds return await invocationTask; } - private async Task InvokeCore(string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) + private async Task InvokeCore(ConnectionState connectionState, string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) { Log.PreparingBlockingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); @@ -695,26 +782,26 @@ namespace Microsoft.AspNetCore.SignalR.Client var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args, streams); Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); - _connectionState.AddInvocation(irq); + connectionState.AddInvocation(irq); // Trace the full invocation Log.IssuingInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); try { - await SendHubMessage(invocationMessage, cancellationToken); + await SendHubMessage(connectionState, invocationMessage, cancellationToken); } catch (Exception ex) { Log.FailedToSendInvocation(_logger, invocationMessage.InvocationId, ex); - _connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); + connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); irq.Fail(ex); } } - private async Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) + private async Task InvokeStreamCore(ConnectionState connectionState, string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) { - AssertConnectionValid(); + _state.AssertConnectionValid(); Log.PreparingStreamingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); @@ -722,70 +809,84 @@ namespace Microsoft.AspNetCore.SignalR.Client Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); - _connectionState.AddInvocation(irq); + connectionState.AddInvocation(irq); // Trace the full invocation Log.IssuingInvocation(_logger, invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); try { - await SendHubMessage(invocationMessage, cancellationToken); + await SendHubMessage(connectionState, invocationMessage, cancellationToken); } catch (Exception ex) { Log.FailedToSendInvocation(_logger, invocationMessage.InvocationId, ex); - _connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); + connectionState.TryRemoveInvocation(invocationMessage.InvocationId, out _); irq.Fail(ex); } } - private async Task SendHubMessage(HubMessage hubMessage, CancellationToken cancellationToken = default) + private async Task SendHubMessage(ConnectionState connectionState, HubMessage hubMessage, CancellationToken cancellationToken = default) { - AssertConnectionValid(); - - _protocol.WriteMessage(hubMessage, _connectionState.Connection.Transport.Output); + _state.AssertConnectionValid(); + _protocol.WriteMessage(hubMessage, connectionState.Connection.Transport.Output); Log.SendingMessage(_logger, hubMessage); // REVIEW: If a token is passed in and is canceled during FlushAsync it seems to break .Complete()... - await _connectionState.Connection.Transport.Output.FlushAsync(); + await connectionState.Connection.Transport.Output.FlushAsync(); Log.MessageSent(_logger, hubMessage); // We've sent a message, so don't ping for a while - ResetSendPing(); + connectionState.ResetSendPing(); } private async Task SendCoreAsyncCore(string methodName, object[] args, CancellationToken cancellationToken) { - var readers = PackageStreamingParams(ref args, out var streamIds); + var readers = default(Dictionary); - Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length); - var invocationMessage = new InvocationMessage(null, methodName, args, streamIds?.ToArray()); - await SendWithLock(invocationMessage, callerName: nameof(SendCoreAsync)); - - LaunchStreams(readers, cancellationToken); - } - - private async Task SendWithLock(HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "") - { CheckDisposed(); - await WaitConnectionLockAsync(); + var connectionState = await _state.WaitForActiveConnectionAsync(nameof(SendCoreAsync)); try { - CheckConnectionActive(callerName); CheckDisposed(); - await SendHubMessage(message, cancellationToken); + + readers = PackageStreamingParams(connectionState, ref args, out var streamIds); + + Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length); + var invocationMessage = new InvocationMessage(null, methodName, args, streamIds?.ToArray()); + await SendHubMessage(connectionState, invocationMessage, cancellationToken); + + LaunchStreams(connectionState, readers, cancellationToken); } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); + } + } + + private async Task SendWithLock(ConnectionState expectedConnectionState, HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "") + { + CheckDisposed(); + var connectionState = await _state.WaitForActiveConnectionAsync(callerName); + try + { + CheckDisposed(); + + SafeAssert(ReferenceEquals(expectedConnectionState, connectionState), "The connection state changed unexpectedly!"); + + await SendHubMessage(connectionState, message, cancellationToken); + } + finally + { + _state.ReleaseConnectionLock(); } } private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState, ChannelWriter invocationMessageWriter) { Log.ResettingKeepAliveTimer(_logger); - ResetTimeout(); + connectionState.ResetTimeout(); InvocationRequest irq; switch (message) @@ -922,7 +1023,7 @@ namespace Microsoft.AspNetCore.SignalR.Client try { using (var handshakeCts = new CancellationTokenSource(HandshakeTimeout)) - using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, handshakeCts.Token)) + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, handshakeCts.Token, _state.StopCts.Token)) { while (true) { @@ -987,16 +1088,17 @@ namespace Microsoft.AspNetCore.SignalR.Client { // We hold a local capture of the connection state because StopAsync may dump out the current one. // We'll be locking any time we want to check back in to the "active" connection state. + _state.AssertInConnectionLock(); Log.ReceiveLoopStarting(_logger); // Performs periodic tasks -- here sending pings and checking timeout // Disposed with `timer.Stop()` in the finally block below var timer = new TimerAwaitable(TickRate, TickRate); - var timerTask = TimerLoop(timer); + var timerTask = connectionState.TimerLoop(timer); var uploadStreamSource = new CancellationTokenSource(); - _uploadStreamToken = uploadStreamSource.Token; + connectionState.UploadStreamToken = uploadStreamSource.Token; var invocationMessageChannel = Channel.CreateUnbounded(_receiveLoopOptions); var invocationMessageReceiveTask = StartProcessingInvocationMessages(invocationMessageChannel.Reader); @@ -1043,6 +1145,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { // Closing because we got a close frame, possibly with an error in it. connectionState.CloseException = exception; + connectionState.Stopping = true; break; } } @@ -1084,144 +1187,298 @@ namespace Microsoft.AspNetCore.SignalR.Client timer.Stop(); await timerTask; uploadStreamSource.Cancel(); + await HandleConnectionClose(connectionState); } + } + // Internal for testing + internal Task RunTimerActions() + { + // Don't bother acquiring the connection lock. This is only called from tests. + return _state.CurrentConnectionStateUnsynchronized.RunTimerActions(); + } + + // Internal for testing + internal void OnServerTimeout() + { + // Don't bother acquiring the connection lock. This is only called from tests. + _state.CurrentConnectionStateUnsynchronized.OnServerTimeout(); + } + + private async Task HandleConnectionClose(ConnectionState connectionState) + { // Clear the connectionState field - await WaitConnectionLockAsync(); + await _state.WaitConnectionLockAsync(); try { - SafeAssert(ReferenceEquals(_connectionState, connectionState), + SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, connectionState), "Someone other than ReceiveLoop cleared the connection state!"); - _connectionState = null; + _state.CurrentConnectionStateUnsynchronized = null; + + // Dispose the connection + await CloseAsync(connectionState.Connection); + + // Cancel any outstanding invocations within the connection lock + connectionState.CancelOutstandingInvocations(connectionState.CloseException); + + if (connectionState.Stopping || _reconnectPolicy == null) + { + if (connectionState.CloseException != null) + { + Log.ShutdownWithError(_logger, connectionState.CloseException); + } + else + { + Log.ShutdownConnection(_logger); + } + + _state.ChangeState(HubConnectionState.Connected, HubConnectionState.Disconnected); + CompleteClose(connectionState.CloseException); + } + else + { + _state.ReconnectTask = ReconnectAsync(connectionState.CloseException); + } } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } + } - // Dispose the connection - await CloseAsync(connectionState.Connection); - - // Cancel any outstanding invocations within the connection lock - connectionState.CancelOutstandingInvocations(connectionState.CloseException); - - if (connectionState.CloseException != null) - { - Log.ShutdownWithError(_logger, connectionState.CloseException); - } - else - { - Log.ShutdownConnection(_logger); - } + private void CompleteClose(Exception closeException) + { + _state.AssertInConnectionLock(); + _state.StopCts = new CancellationTokenSource(); + RunCloseEvent(closeException); + } + private void RunCloseEvent(Exception closeException) + { var closed = Closed; + async Task RunClosedEventAsync() + { + // Dispatch to the thread pool before we invoke the user callback + await AwaitableThreadPool.Yield(); + + try + { + Log.InvokingClosedEventHandler(_logger); + await closed.Invoke(closeException); + } + catch (Exception ex) + { + Log.ErrorDuringClosedEvent(_logger, ex); + } + } + // There is no need to start a new task if there is no Closed event registered if (closed != null) { // Fire-and-forget the closed event - _ = RunClosedEvent(closed, connectionState.CloseException); + _ = RunClosedEventAsync(); } } - private void ResetSendPing() + private async Task ReconnectAsync(Exception closeException) { - Volatile.Write(ref _nextActivationSendPing, (DateTime.UtcNow + KeepAliveInterval).Ticks); - } + var previousReconnectAttempts = 0; + var reconnectStartTime = DateTime.UtcNow; + var retryReason = closeException; + var nextRetryDelay = GetNextRetryDelay(previousReconnectAttempts++, TimeSpan.Zero, retryReason); + + // We still have the connection lock from the caller, HandleConnectionClose. + _state.AssertInConnectionLock(); - private void ResetTimeout() - { - Volatile.Write(ref _nextActivationServerTimeout, (DateTime.UtcNow + ServerTimeout).Ticks); - } - - private async Task TimerLoop(TimerAwaitable timer) - { - // Tell the server we intend to ping - // Old clients never ping, and shouldn't be timed out - // So ping to tell the server that we should be timed out if we stop - await SendHubMessage(PingMessage.Instance); - - // initialize the timers - timer.Start(); - ResetTimeout(); - ResetSendPing(); - - using (timer) + if (nextRetryDelay == null) { - // await returns True until `timer.Stop()` is called in the `finally` block of `ReceiveLoop` - while (await timer) - { - await RunTimerActions(); - } - } - } + Log.FirstReconnectRetryDelayNull(_logger); - // Internal for testing - internal async Task RunTimerActions() - { - if (!_hasInherentKeepAlive && DateTime.UtcNow.Ticks > Volatile.Read(ref _nextActivationServerTimeout)) - { - OnServerTimeout(); - } + _state.ChangeState(HubConnectionState.Connected, HubConnectionState.Disconnected); - if (DateTime.UtcNow.Ticks > Volatile.Read(ref _nextActivationSendPing)) - { - await PingServer(); - } - } - - private void OnServerTimeout() - { - _connectionState.CloseException = new TimeoutException( - $"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server."); - _connectionState.Connection.Transport.Input.CancelPendingRead(); - } - - private async Task PingServer() - { - if (_disposed || !_connectionLock.Wait(0)) - { - Log.UnableToAcquireConnectionLockForPing(_logger); + CompleteClose(closeException); return; } - Log.AcquiredConnectionLockForPing(_logger); + _state.ChangeState(HubConnectionState.Connected, HubConnectionState.Reconnecting); - try + if (closeException != null) { - if (_disposed || _connectionState == null || _connectionState.Stopping) + Log.ReconnectingWithError(_logger, closeException); + } + else + { + Log.Reconnecting(_logger); + } + + RunReconnectingEvent(closeException); + + while (nextRetryDelay != null) + { + Log.AwaitingReconnectRetryDelay(_logger, previousReconnectAttempts, nextRetryDelay.Value); + + try { + await Task.Delay(nextRetryDelay.Value, _state.StopCts.Token); + } + catch (OperationCanceledException ex) + { + Log.ReconnectingStoppedDuringRetryDelay(_logger); + + await _state.WaitConnectionLockAsync(); + try + { + _state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Disconnected); + + CompleteClose(GetOperationCanceledException("Connection stopped during reconnect delay. Done reconnecting.", ex, _state.StopCts.Token)); + } + finally + { + _state.ReleaseConnectionLock(); + } + return; } - await SendHubMessage(PingMessage.Instance); + + await _state.WaitConnectionLockAsync(); + try + { + SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), + "Someone other than Reconnect set the connection state!"); + + // HandshakeAsync already checks ReconnectingConnectionState.StopCts.Token. + await StartAsyncCore(CancellationToken.None); + + Log.Reconnected(_logger, previousReconnectAttempts, DateTime.UtcNow - reconnectStartTime); + + _state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Connected); + + RunReconnectedEvent(); + return; + } + catch (Exception ex) + { + retryReason = ex; + + Log.ReconnectAttemptFailed(_logger, ex); + + if (_state.StopCts.IsCancellationRequested) + { + Log.ReconnectingStoppedDuringReconnectAttempt(_logger); + + _state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Disconnected); + + CompleteClose(GetOperationCanceledException("Connection stopped during reconnect attempt. Done reconnecting.", ex, _state.StopCts.Token)); + return; + } + } + finally + { + _state.ReleaseConnectionLock(); + } + + nextRetryDelay = GetNextRetryDelay(previousReconnectAttempts++, DateTime.UtcNow - reconnectStartTime, retryReason); + } + + await _state.WaitConnectionLockAsync(); + try + { + SafeAssert(ReferenceEquals(_state.CurrentConnectionStateUnsynchronized, null), + "Someone other than Reconnect set the connection state!"); + + var elapsedTime = DateTime.UtcNow - reconnectStartTime; + Log.ReconnectAttemptsExhausted(_logger, previousReconnectAttempts, elapsedTime); + + _state.ChangeState(HubConnectionState.Reconnecting, HubConnectionState.Disconnected); + + var message = $"Reconnect retries have been exhausted after {previousReconnectAttempts} failed attempts and {elapsedTime} elapsed. Disconnecting."; + CompleteClose(new OperationCanceledException(message)); } finally { - ReleaseConnectionLock(); + _state.ReleaseConnectionLock(); } } - private async Task RunClosedEvent(Func closed, Exception closeException) + private TimeSpan? GetNextRetryDelay(long previousRetryCount, TimeSpan elapsedTime, Exception retryReason) { - // Dispatch to the thread pool before we invoke the user callback - await AwaitableThreadPool.Yield(); - try { - Log.InvokingClosedEventHandler(_logger); - await closed.Invoke(closeException); + return _reconnectPolicy.NextRetryDelay(new RetryContext + { + PreviousRetryCount = previousRetryCount, + ElapsedTime = elapsedTime, + RetryReason = retryReason, + }); } catch (Exception ex) { - Log.ErrorDuringClosedEvent(_logger, ex); + Log.ErrorDuringNextRetryDelay(_logger, ex); + return null; } } - private void CheckConnectionActive(string methodName) + private OperationCanceledException GetOperationCanceledException(string message, Exception innerException, CancellationToken cancellationToken) { - if (_connectionState == null || _connectionState.Stopping) +#if NETCOREAPP3_0 + return new OperationCanceledException(message, innerException, _state.StopCts.Token); +#else + return new OperationCanceledException(message, innerException); +#endif + } + + private void RunReconnectingEvent(Exception closeException) + { + var reconnecting = Reconnecting; + + async Task RunReconnectingEventAsync() { - throw new InvalidOperationException($"The '{methodName}' method cannot be called if the connection is not active"); + // Dispatch to the thread pool before we invoke the user callback + await AwaitableThreadPool.Yield(); + + try + { + await reconnecting.Invoke(closeException); + } + catch (Exception ex) + { + Log.ErrorDuringReconnectingEvent(_logger, ex); + } + } + + // There is no need to start a new task if there is no Reconnecting event registered + if (reconnecting != null) + { + // Fire-and-forget the closed event + _ = RunReconnectingEventAsync(); + } + } + + private void RunReconnectedEvent() + { + var reconnected = Reconnected; + + async Task RunReconnectedEventAsync() + { + // Dispatch to the thread pool before we invoke the user callback + await AwaitableThreadPool.Yield(); + + try + { + await reconnected.Invoke(ConnectionId); + } + catch (Exception ex) + { + Log.ErrorDuringReconnectedEvent(_logger, ex); + } + } + + // There is no need to start a new task if there is no Reconnected event registered + if (reconnected != null) + { + // Fire-and-forget the reconnected event + _ = RunReconnectedEventAsync(); } } @@ -1235,28 +1492,6 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - [Conditional("DEBUG")] - private void AssertInConnectionLock([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) => SafeAssert(_connectionLock.CurrentCount == 0, "We're not in the Connection Lock!", memberName, fileName, lineNumber); - - [Conditional("DEBUG")] - private void AssertConnectionValid([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) - { - AssertInConnectionLock(memberName, fileName, lineNumber); - SafeAssert(_connectionState != null, "We don't have a connection!", memberName, fileName, lineNumber); - } - - private Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) - { - Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber); - return _connectionLock.WaitAsync(); - } - - private void ReleaseConnectionLock([CallerMemberName] string memberName = null, - [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) - { - Log.ReleasingConnectionLock(_logger, memberName, filePath, lineNumber); - _connectionLock.Release(); - } private class Subscription : IDisposable { @@ -1345,21 +1580,27 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - //TODO: Refactor all transient state about the connection into the ConnectionState class. private class ConnectionState : IInvocationBinder { - private volatile bool _stopping; private readonly HubConnection _hubConnection; + private readonly ILogger _logger; - private TaskCompletionSource _stopTcs; private readonly object _lock = new object(); private readonly Dictionary _pendingCalls = new Dictionary(StringComparer.Ordinal); + private TaskCompletionSource _stopTcs; + + private volatile bool _stopping; + private int _nextInvocationId; - private int _nextStreamId; + + private long _nextActivationServerTimeout; + private long _nextActivationSendPing; + private bool _hasInherentKeepAlive; public ConnectionContext Connection { get; } public Task ReceiveTask { get; set; } public Exception CloseException { get; set; } + public CancellationToken UploadStreamToken { get; set; } public bool Stopping { @@ -1367,17 +1608,18 @@ namespace Microsoft.AspNetCore.SignalR.Client set => _stopping = value; } - public bool Stopped => _stopTcs?.Task.Status == TaskStatus.RanToCompletion; - public ConnectionState(ConnectionContext connection, HubConnection hubConnection) { + Connection = connection; + _hubConnection = hubConnection; _hubConnection._logScope.ConnectionId = connection.ConnectionId; - Connection = connection; + + _logger = _hubConnection._logger; + _hasInherentKeepAlive = connection.Features.Get()?.HasInherentKeepAlive ?? false; } - public string GetNextId() => Interlocked.Increment(ref _nextInvocationId).ToString(CultureInfo.InvariantCulture); - public string GetNextStreamId() => Interlocked.Increment(ref _nextStreamId).ToString(CultureInfo.InvariantCulture); + public string GetNextId() => (++_nextInvocationId).ToString(CultureInfo.InvariantCulture); public void AddInvocation(InvocationRequest irq) { @@ -1385,7 +1627,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (_pendingCalls.ContainsKey(irq.InvocationId)) { - Log.InvocationAlreadyInUse(_hubConnection._logger, irq.InvocationId); + Log.InvocationAlreadyInUse(_logger, irq.InvocationId); throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); } else @@ -1421,13 +1663,13 @@ namespace Microsoft.AspNetCore.SignalR.Client public void CancelOutstandingInvocations(Exception exception) { - Log.CancelingOutstandingInvocations(_hubConnection._logger); + Log.CancelingOutstandingInvocations(_logger); lock (_lock) { foreach (var outstandingCall in _pendingCalls.Values) { - Log.RemovingInvocation(_hubConnection._logger, outstandingCall.InvocationId); + Log.RemovingInvocation(_logger, outstandingCall.InvocationId); if (exception != null) { outstandingCall.Fail(exception); @@ -1458,27 +1700,103 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task StopAsyncCore() { - Log.Stopping(_hubConnection._logger); + Log.Stopping(_logger); // Complete our write pipe, which should cause everything to shut down - Log.TerminatingReceiveLoop(_hubConnection._logger); + Log.TerminatingReceiveLoop(_logger); Connection.Transport.Input.CancelPendingRead(); // Wait ServerTimeout for the server or transport to shut down. - Log.WaitingForReceiveLoopToTerminate(_hubConnection._logger); + Log.WaitingForReceiveLoopToTerminate(_logger); await ReceiveTask; - Log.Stopped(_hubConnection._logger); + Log.Stopped(_logger); _hubConnection._logScope.ConnectionId = null; _stopTcs.TrySetResult(null); } + public async Task TimerLoop(TimerAwaitable timer) + { + // Tell the server we intend to ping. + // Old clients never ping, and shouldn't be timed out, so ping to tell the server that we should be timed out if we stop. + // The TimerLoop is started from the ReceiveLoop with the connection lock still acquired. + _hubConnection._state.AssertInConnectionLock(); + await _hubConnection.SendHubMessage(this, PingMessage.Instance); + + // initialize the timers + timer.Start(); + ResetTimeout(); + ResetSendPing(); + + using (timer) + { + // await returns True until `timer.Stop()` is called in the `finally` block of `ReceiveLoop` + while (await timer) + { + await RunTimerActions(); + } + } + } + + public void ResetSendPing() + { + Volatile.Write(ref _nextActivationSendPing, (DateTime.UtcNow + _hubConnection.KeepAliveInterval).Ticks); + } + + public void ResetTimeout() + { + Volatile.Write(ref _nextActivationServerTimeout, (DateTime.UtcNow + _hubConnection.ServerTimeout).Ticks); + } + + // Internal for testing + internal async Task RunTimerActions() + { + if (!_hasInherentKeepAlive && DateTime.UtcNow.Ticks > Volatile.Read(ref _nextActivationServerTimeout)) + { + OnServerTimeout(); + } + + if (DateTime.UtcNow.Ticks > Volatile.Read(ref _nextActivationSendPing) && !Stopping) + { + if (!_hubConnection._state.TryAcquireConnectionLock()) + { + Log.UnableToAcquireConnectionLockForPing(_logger); + return; + } + + Log.AcquiredConnectionLockForPing(_logger); + + try + { + if (_hubConnection._state.CurrentConnectionStateUnsynchronized != null) + { + SafeAssert(ReferenceEquals(_hubConnection._state.CurrentConnectionStateUnsynchronized, this), + "Something reset the connection state before the timer loop completed!"); + + await _hubConnection.SendHubMessage(this, PingMessage.Instance); + } + } + finally + { + _hubConnection._state.ReleaseConnectionLock(); + } + } + } + + // Internal for testing + internal void OnServerTimeout() + { + CloseException = new TimeoutException( + $"Server timeout ({_hubConnection.ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server."); + Connection.Transport.Input.CancelPendingRead(); + } + Type IInvocationBinder.GetReturnType(string invocationId) { if (!TryGetInvocation(invocationId, out var irq)) { - Log.ReceivedUnexpectedResponse(_hubConnection._logger, invocationId); + Log.ReceivedUnexpectedResponse(_logger, invocationId); return null; } return irq.ResultType; @@ -1490,7 +1808,7 @@ namespace Microsoft.AspNetCore.SignalR.Client // literally the same code as the above method if (!TryGetInvocation(invocationId, out var irq)) { - Log.ReceivedUnexpectedResponse(_hubConnection._logger, invocationId); + Log.ReceivedUnexpectedResponse(_logger, invocationId); return null; } return irq.ResultType; @@ -1500,7 +1818,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { if (!_hubConnection._handlers.TryGetValue(methodName, out var invocationHandlerList)) { - Log.MissingHandler(_hubConnection._logger, methodName); + Log.MissingHandler(_logger, methodName); return Type.EmptyTypes; } @@ -1513,5 +1831,92 @@ namespace Microsoft.AspNetCore.SignalR.Client throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); } } + + private class ReconnectingConnectionState + { + // This lock protects the connection state. + private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1); + + private readonly ILogger _logger; + + public ReconnectingConnectionState(ILogger logger) + { + _logger = logger; + StopCts = new CancellationTokenSource(); + ReconnectTask = Task.CompletedTask; + } + + public ConnectionState CurrentConnectionStateUnsynchronized { get; set; } + + public HubConnectionState OverallState { get; private set; } + + public CancellationTokenSource StopCts { get; set; } = new CancellationTokenSource(); + + public Task ReconnectTask { get; set; } = Task.CompletedTask; + + public void ChangeState(HubConnectionState expectedState, HubConnectionState newState) + { + if (!TryChangeState(expectedState, newState)) + { + Log.StateTransitionFailed(_logger, expectedState, newState, OverallState); + throw new InvalidOperationException($"The HubConnection failed to transition from the '{expectedState}' state to the '{newState}' state because it was actually in the '{OverallState}' state."); + } + } + + public bool TryChangeState(HubConnectionState expectedState, HubConnectionState newState) + { + AssertInConnectionLock(); + + Log.AttemptingStateTransition(_logger, expectedState, newState); + + if (OverallState != expectedState) + { + return false; + } + + OverallState = newState; + return true; + } + + [Conditional("DEBUG")] + public void AssertInConnectionLock([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) => SafeAssert(_connectionLock.CurrentCount == 0, "We're not in the Connection Lock!", memberName, fileName, lineNumber); + + [Conditional("DEBUG")] + public void AssertConnectionValid([CallerMemberName] string memberName = null, [CallerFilePath] string fileName = null, [CallerLineNumber] int lineNumber = 0) + { + AssertInConnectionLock(memberName, fileName, lineNumber); + SafeAssert(CurrentConnectionStateUnsynchronized != null, "We don't have a connection!", memberName, fileName, lineNumber); + } + + public Task WaitConnectionLockAsync([CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + { + Log.WaitingOnConnectionLock(_logger, memberName, filePath, lineNumber); + return _connectionLock.WaitAsync(); + } + + public bool TryAcquireConnectionLock() + { + return _connectionLock.Wait(0); + } + + public async Task WaitForActiveConnectionAsync(string methodName, [CallerMemberName] string memberName = null, [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + { + await WaitConnectionLockAsync(); + + if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping) + { + throw new InvalidOperationException($"The '{methodName}' method cannot be called if the connection is not active"); + } + + return CurrentConnectionStateUnsynchronized; + } + + public void ReleaseConnectionLock([CallerMemberName] string memberName = null, + [CallerFilePath] string filePath = null, [CallerLineNumber] int lineNumber = 0) + { + Log.ReleasingConnectionLock(_logger, memberName, filePath, lineNumber); + _connectionLock.Release(); + } + } } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionBuilderExtensions.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionBuilderExtensions.cs index 132d236c15..d56f254171 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionBuilderExtensions.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionBuilderExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; @@ -23,5 +24,44 @@ namespace Microsoft.AspNetCore.SignalR.Client hubConnectionBuilder.Services.AddLogging(configureLogging); return hubConnectionBuilder; } + + /// + /// Configures the to automatically attempt to reconnect if the connection is lost. + /// The client will wait the default 0, 2, 10 and 30 seconds respectively before trying up to four reconnect attempts. + /// + /// The to configure. + /// The same instance of the for chaining. + public static IHubConnectionBuilder WithAutomaticReconnect(this IHubConnectionBuilder hubConnectionBuilder) + { + hubConnectionBuilder.Services.AddSingleton(new DefaultRetryPolicy()); + return hubConnectionBuilder; + } + + /// + /// Configures the to automatically attempt to reconnect if the connection is lost. + /// + /// The to configure. + /// + /// An array containing the delays before trying each reconnect attempt. + /// The length of the array represents how many failed reconnect attempts it takes before the client will stop attempting to reconnect. + /// + /// The same instance of the for chaining. + public static IHubConnectionBuilder WithAutomaticReconnect(this IHubConnectionBuilder hubConnectionBuilder, TimeSpan[] reconnectDelays) + { + hubConnectionBuilder.Services.AddSingleton(new DefaultRetryPolicy(reconnectDelays)); + return hubConnectionBuilder; + } + + /// + /// Configures the to automatically attempt to reconnect if the connection is lost. + /// + /// The to configure. + /// An that controls the timing and number of reconnect attempts. + /// The same instance of the for chaining. + public static IHubConnectionBuilder WithAutomaticReconnect(this IHubConnectionBuilder hubConnectionBuilder, IRetryPolicy retryPolicy) + { + hubConnectionBuilder.Services.AddSingleton(retryPolicy); + return hubConnectionBuilder; + } } -} \ No newline at end of file +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionState.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionState.cs index 7a230bd70e..16d1d72e94 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnectionState.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnectionState.cs @@ -15,6 +15,14 @@ namespace Microsoft.AspNetCore.SignalR.Client /// /// The hub connection is connected. /// - Connected + Connected, + /// + /// The hub connection is connecting. + /// + Connecting, + /// + /// The hub connection is reconnecting. + /// + Reconnecting, } -} \ No newline at end of file +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/IRetryPolicy.cs b/src/SignalR/clients/csharp/Client.Core/src/IRetryPolicy.cs new file mode 100644 index 0000000000..54bea71128 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/IRetryPolicy.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Client +{ + /// + /// An abstraction that controls when the client attempts to reconnect and how many times it does so. + /// + public interface IRetryPolicy + { + /// + /// If passed to , + /// this will be called after the trasnport loses a connection to determine if and for how long to wait before the next reconnect attempt. + /// + /// + /// Information related to the next possible reconnect attempt including the number of consecutive failed retries so far, time spent + /// reconnecting so far and the error that lead to this reconnect attempt. + /// + /// + /// A representing the amount of time to wait from now before starting the next reconnect attempt. + /// tells the client to stop retrying and close. + /// + TimeSpan? NextRetryDelay(RetryContext retryContext); + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/Internal/DefaultRetryPolicy.cs b/src/SignalR/clients/csharp/Client.Core/src/Internal/DefaultRetryPolicy.cs new file mode 100644 index 0000000000..2c6573cc44 --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/Internal/DefaultRetryPolicy.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Client.Internal +{ + internal class DefaultRetryPolicy : IRetryPolicy + { + internal static TimeSpan?[] DEFAULT_RETRY_DELAYS_IN_MILLISECONDS = new TimeSpan?[] + { + TimeSpan.Zero, + TimeSpan.FromSeconds(2), + TimeSpan.FromSeconds(10), + TimeSpan.FromSeconds(30), + null, + }; + + private TimeSpan?[] _retryDelays; + + public DefaultRetryPolicy() + { + _retryDelays = DEFAULT_RETRY_DELAYS_IN_MILLISECONDS; + } + + public DefaultRetryPolicy(TimeSpan[] retryDelays) + { + _retryDelays = new TimeSpan?[retryDelays.Length + 1]; + + for (int i = 0; i < retryDelays.Length; i++) + { + _retryDelays[i] = retryDelays[i]; + } + } + + public TimeSpan? NextRetryDelay(RetryContext retryContext) + { + return _retryDelays[retryContext.PreviousRetryCount]; + } + } +} diff --git a/src/SignalR/clients/csharp/Client.Core/src/Properties/AssemblyInfo.cs b/src/SignalR/clients/csharp/Client.Core/src/Properties/AssemblyInfo.cs index 8bc7094d90..8376590f76 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/Properties/AssemblyInfo.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/Properties/AssemblyInfo.cs @@ -3,4 +3,5 @@ using System.Runtime.CompilerServices; +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Client.FunctionalTests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Client.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/SignalR/clients/csharp/Client.Core/src/RetryContext.cs b/src/SignalR/clients/csharp/Client.Core/src/RetryContext.cs new file mode 100644 index 0000000000..667fb79c6f --- /dev/null +++ b/src/SignalR/clients/csharp/Client.Core/src/RetryContext.cs @@ -0,0 +1,29 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.SignalR.Client +{ + /// + /// The context passed to to help the policy determine + /// how long to wait before the next retry and whether there should be another retry at all. + /// + public sealed class RetryContext + { + /// + /// The number of consecutive failed retries so far. + /// + public long PreviousRetryCount { get; set; } + + /// + /// The amount of time spent retrying so far. + /// + public TimeSpan ElapsedTime { get; set; } + + /// + /// The error precipitating the current retry if any. + /// + public Exception RetryReason { get; set; } + } +} diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 79cead3405..0ce0c0bbe8 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -36,11 +36,25 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests string path = null, HttpTransportType? transportType = null, IHubProtocol protocol = null, - ILoggerFactory loggerFactory = null) + ILoggerFactory loggerFactory = null, + bool withAutomaticReconnect = false) { var hubConnectionBuilder = new HubConnectionBuilder(); - hubConnectionBuilder.Services.AddSingleton(protocol); - hubConnectionBuilder.WithLoggerFactory(loggerFactory); + + if (protocol != null) + { + hubConnectionBuilder.Services.AddSingleton(protocol); + } + + if (loggerFactory != null) + { + hubConnectionBuilder.WithLoggerFactory(loggerFactory); + } + + if (withAutomaticReconnect) + { + hubConnectionBuilder.WithAutomaticReconnect(); + } var delegateConnectionFactory = new DelegateConnectionFactory( GetHttpConnectionFactory(url, loggerFactory, path, transportType ?? HttpTransportType.LongPolling | HttpTransportType.WebSockets | HttpTransportType.ServerSentEvents), @@ -1617,6 +1631,195 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task CanAutomaticallyReconnect(HttpTransportType transportType) + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + writeContext.EventId.Name == "ReconnectingWithError"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var connection = CreateHubConnection( + server.Url, + path: HubPaths.First(), + transportType: transportType, + loggerFactory: LoggerFactory, + withAutomaticReconnect: true); + + try + { + var echoMessage = "test"; + var reconnectingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + connection.Reconnecting += _ => + { + reconnectingTcs.SetResult(null); + return Task.CompletedTask; + }; + + connection.Reconnected += connectionId => + { + reconnectedTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + + await connection.StartAsync().OrTimeout(); + var initialConnectionId = connection.ConnectionId; + + connection.OnServerTimeout(); + + await reconnectingTcs.Task.OrTimeout(); + var newConnectionId = await reconnectedTcs.Task.OrTimeout(); + Assert.NotEqual(initialConnectionId, newConnectionId); + Assert.Equal(connection.ConnectionId, newConnectionId); + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), echoMessage).OrTimeout(); + Assert.Equal(echoMessage, result); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Fact] + public async Task CanAutomaticallyReconnectAfterRedirect() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + writeContext.EventId.Name == "ReconnectingWithError"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var connection = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + "/redirect") + .WithAutomaticReconnect() + .Build(); + + try + { + var echoMessage = "test"; + var reconnectingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + connection.Reconnecting += _ => + { + reconnectingTcs.SetResult(null); + return Task.CompletedTask; + }; + + connection.Reconnected += connectionId => + { + reconnectedTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + await connection.StartAsync().OrTimeout(); + var initialConnectionId = connection.ConnectionId; + + connection.OnServerTimeout(); + + await reconnectingTcs.Task.OrTimeout(); + var newConnectionId = await reconnectedTcs.Task.OrTimeout(); + Assert.NotEqual(initialConnectionId, newConnectionId); + Assert.Equal(connection.ConnectionId, newConnectionId); + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), echoMessage).OrTimeout(); + Assert.Equal(echoMessage, result); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + + [Fact] + public async Task CanAutomaticallyReconnectAfterSkippingNegotiation() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + writeContext.EventId.Name == "ReconnectingWithError"; + } + + using (StartServer(out var server, ExpectedErrors)) + { + var connectionBuilder = new HubConnectionBuilder() + .WithLoggerFactory(LoggerFactory) + .WithUrl(server.Url + HubPaths.First(), HttpTransportType.WebSockets) + .WithAutomaticReconnect(); + + connectionBuilder.Services.Configure(o => + { + o.SkipNegotiation = true; + }); + + var connection = connectionBuilder.Build(); + + try + { + var echoMessage = "test"; + var reconnectingTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + connection.Reconnecting += _ => + { + reconnectingTcs.SetResult(null); + return Task.CompletedTask; + }; + + connection.Reconnected += connectionId => + { + reconnectedTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + await connection.StartAsync().OrTimeout(); + Assert.Null(connection.ConnectionId); + + connection.OnServerTimeout(); + + await reconnectingTcs.Task.OrTimeout(); + var newConnectionId = await reconnectedTcs.Task.OrTimeout(); + Assert.Null(newConnectionId); + Assert.Null(connection.ConnectionId); + + var result = await connection.InvokeAsync(nameof(TestHub.Echo), echoMessage).OrTimeout(); + Assert.Equal(echoMessage, result); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + private class PollTrackingMessageHandler : DelegatingHandler { public Task ActivePoll { get; private set; } diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs index 61e34fab5c..f60b124d59 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.ConnectionLifecycle.cs @@ -67,7 +67,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task StartAsyncWaitsForPreviousStartIfAlreadyStarting() + public async Task StartAsyncThrowsIfPreviousStartIsAlreadyStarting() { // Set up StartAsync to wait on the syncPoint when starting var testConnection = new TestConnection(onStart: SyncPoint.Create(out var syncPoint)); @@ -86,9 +86,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // Release the sync point syncPoint.Continue(); - // Both starts should finish fine + // The first start should finish fine, but the second throws an InvalidOperationException. await firstStart; - await secondStart; + await Assert.ThrowsAsync(() => secondStart); }); } @@ -147,16 +147,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var stopTask = connection.StopAsync().OrTimeout(); // Wait to hit DisposeAsync on TestConnection (which should be after StopAsync has cleared the connection state) - await syncPoint.WaitForSyncPoint(); + await syncPoint.WaitForSyncPoint().OrTimeout(); - // We should be able to start now, and StopAsync hasn't completed, nor will it complete while Starting + // We should not yet be able to start now because StopAsync hasn't completed Assert.False(stopTask.IsCompleted); - await connection.StartAsync().OrTimeout(); + var startTask = connection.StartAsync().OrTimeout(); Assert.False(stopTask.IsCompleted); // When we release the sync point, the StopAsync task will finish syncPoint.Continue(); await stopTask; + + // Which will then allow StartAsync to finish. + await startTask; }); } @@ -240,7 +243,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.False(startTask.IsCompleted); await syncPoint.WaitForSyncPoint(); - Assert.Equal(HubConnectionState.Disconnected, connection.State); + Assert.Equal(HubConnectionState.Connecting, connection.State); // Release the SyncPoint syncPoint.Continue(); @@ -442,6 +445,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // Stop and invoke the method. These two aren't synchronizable via a Sync Point any more because the transport is disposed // outside the lock :( var disposeTask = connection.StopAsync().OrTimeout(); + + // Wait to hit DisposeAsync on TestConnection (which should be after StopAsync has cleared the connection state) + await syncPoint.WaitForSyncPoint().OrTimeout(); + var targetTask = method(connection).OrTimeout(); // Release the sync point diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs index 76cf7959ca..6aa6106f40 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Helpers.cs @@ -7,7 +7,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HubConnectionTests { - private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null, ILoggerFactory loggerFactory = null) + private static HubConnection CreateHubConnection(TestConnection connection, IHubProtocol protocol = null, ILoggerFactory loggerFactory = null, IRetryPolicy reconnectPolicy = null) { var builder = new HubConnectionBuilder(); @@ -27,7 +27,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests builder.Services.AddSingleton(protocol); } + if (reconnectPolicy != null) + { + builder.WithAutomaticReconnect(reconnectPolicy); + } + return builder.Build(); } } -} \ No newline at end of file +} diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs new file mode 100644 index 0000000000..8e6c684b7b --- /dev/null +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.Reconnect.cs @@ -0,0 +1,913 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging.Testing; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public partial class HubConnectionTests + { + [Fact] + public async Task ReconnectIsNotEnabledByDefault() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ShutdownWithError" || + writeContext.EventId.Name == "ServerDisconnectedWithError"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var exception = new Exception(); + + var testConnection = new TestConnection(); + await using var hubConnection = CreateHubConnection(testConnection, loggerFactory: LoggerFactory); + + var reconnectingCalled = false; + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCalled = true; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + testConnection.CompleteFromTransport(exception); + + Assert.Same(exception, await closedErrorTcs.Task.OrTimeout()); + Assert.False(reconnectingCalled); + } + } + + [Fact] + public async Task ReconnectCanBeOptedInto() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); + } + + var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = default(ReconnectingConnectionFactory); + var startCallCount = 0; + var originalConnectionId = "originalConnectionId"; + var reconnectedConnectionId = "reconnectedConnectionId"; + + async Task OnTestConnectionStart() + { + startCallCount++; + + // Only fail the first reconnect attempt. + if (startCallCount == 2) + { + await failReconnectTcs.Task; + } + + var testConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + + // Change the connection id before reconnecting. + if (startCallCount == 3) + { + testConnection.ConnectionId = reconnectedConnectionId; + } + else + { + testConnection.ConnectionId = originalConnectionId; + } + } + + testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + Assert.Same(originalConnectionId, hubConnection.ConnectionId); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var reconnectException = new Exception(); + failReconnectTcs.SetException(reconnectException); + + Assert.Same(reconnectedConnectionId, await reconnectedConnectionIdTcs.Task.OrTimeout()); + + Assert.Equal(2, retryContexts.Count); + Assert.Same(reconnectException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectStopsIfTheReconnectPolicyReturnsNull() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); + } + + var failReconnectTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var startCallCount = 0; + + Task OnTestConnectionStart() + { + startCallCount++; + + // Fail the first reconnect attempts. + if (startCallCount > 1) + { + return failReconnectTcs.Task; + } + + return Task.CompletedTask; + } + + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return context.PreviousRetryCount == 0 ? TimeSpan.Zero : (TimeSpan?)null; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var reconnectException = new Exception(); + failReconnectTcs.SetException(reconnectException); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.IsType(closeError); + + Assert.Equal(2, retryContexts.Count); + Assert.Same(reconnectException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[1].ElapsedTime); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectCanHappenMultipleTimes() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var secondException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); + + Assert.Same(secondException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Equal(2, retryContexts.Count); + Assert.Same(secondException, retryContexts[1].RetryReason); + Assert.Equal(0, retryContexts[1].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[1].ElapsedTime); + + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(2, reconnectingCount); + Assert.Equal(2, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(2, reconnectingCount); + Assert.Equal(2, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectEventsNotFiredIfFirstRetryDelayIsNull() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + writeContext.EventId.Name == "ServerDisconnectedWithError"; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + await closedErrorTcs.Task.OrTimeout(); + + Assert.Equal(0, reconnectingCount); + Assert.Equal(0, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectDoesNotStartIfConnectionIsLostDuringInitialHandshake() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(null); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var closedCount = 0; + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedCount++; + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await Assert.ThrowsAsync(() => startTask).OrTimeout()); + Assert.Equal(HubConnectionState.Disconnected, hubConnection.State); + Assert.Equal(0, reconnectingCount); + Assert.Equal(0, reconnectedCount); + Assert.Equal(0, closedCount); + } + } + + [Fact] + public async Task ReconnectContinuesIfConnectionLostDuringReconnectHandshake() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + + if (retryContexts.Count == 2) + { + secondRetryDelayTcs.SetResult(null); + } + + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync(); + + // Complete handshake + var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + + await startTask.OrTimeout(); + + var firstException = new Exception(); + currentTestConnection.CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var secondException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(secondException); + + await secondRetryDelayTcs.Task.OrTimeout(); + + Assert.Equal(2, retryContexts.Count); + Assert.Same(secondException, retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); + + // Complete handshake + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectContinuesIfInvalidHandshakeResponse() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "HandshakeServerError" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(autoHandshake: false)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var secondRetryDelayTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + + if (retryContexts.Count == 2) + { + secondRetryDelayTcs.SetResult(null); + } + + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var reconnectedConnectionIdTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + reconnectedConnectionIdTcs.SetResult(connectionId); + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + var startTask = hubConnection.StartAsync(); + + // Complete handshake + var currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + + await startTask.OrTimeout(); + + var firstException = new Exception(); + currentTestConnection.CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + // Respond to handshake with error. + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadSentTextMessageAsync().OrTimeout(); + + var output = MemoryBufferWriter.Get(); + try + { + HandshakeProtocol.WriteResponseMessage(new HandshakeResponseMessage("Error!"), output); + await currentTestConnection.Application.Output.WriteAsync(output.ToArray()).OrTimeout(); + } + finally + { + MemoryBufferWriter.Return(output); + } + + await secondRetryDelayTcs.Task.OrTimeout(); + + Assert.Equal(2, retryContexts.Count); + Assert.IsType(retryContexts[1].RetryReason); + Assert.Equal(1, retryContexts[1].PreviousRetryCount); + Assert.True(TimeSpan.Zero <= retryContexts[0].ElapsedTime); + + // Complete handshake + + currentTestConnection = await testConnectionFactory.GetNextOrCurrentTestConnection(); + await currentTestConnection.ReadHandshakeAndSendResponseAsync().OrTimeout(); + await reconnectedConnectionIdTcs.Task.OrTimeout(); + + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + Assert.Equal(TaskStatus.WaitingForActivation, closedErrorTcs.Task.Status); + + await hubConnection.StopAsync().OrTimeout(); + + var closeError = await closedErrorTcs.Task.OrTimeout(); + Assert.Null(closeError); + Assert.Equal(1, reconnectingCount); + Assert.Equal(1, reconnectedCount); + } + } + + [Fact] + public async Task ReconnectCanBeStoppedWhileRestartingUnderlyingConnection() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + async Task OnTestConnectionStart() + { + try + { + await connectionStartTcs.Task; + } + finally + { + connectionStartTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + } + } + + var testConnectionFactory = new ReconnectingConnectionFactory(() => new TestConnection(OnTestConnectionStart)); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + return TimeSpan.Zero; + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + // Allow the first connection to start successfully. + connectionStartTcs.SetResult(null); + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + var secondException = new Exception(); + var stopTask = hubConnection.StopAsync(); + connectionStartTcs.SetResult(null); + + Assert.IsType(await closedErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); + await stopTask.OrTimeout(); + } + } + + [Fact] + public async Task ReconnectCanBeStoppedDuringRetryDelay() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HubConnection).FullName && + (writeContext.EventId.Name == "ServerDisconnectedWithError" || + writeContext.EventId.Name == "ReconnectingWithError" || + writeContext.EventId.Name == "ErrorReceivingHandshakeResponse" || + writeContext.EventId.Name == "ErrorStartingConnection"); + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var builder = new HubConnectionBuilder().WithLoggerFactory(LoggerFactory); + var testConnectionFactory = new ReconnectingConnectionFactory(); + builder.Services.AddSingleton(testConnectionFactory); + + var retryContexts = new List(); + var mockReconnectPolicy = new Mock(); + mockReconnectPolicy.Setup(p => p.NextRetryDelay(It.IsAny())).Returns(context => + { + retryContexts.Add(context); + // Hopefully this test never takes over a minute. + return TimeSpan.FromMinutes(1); + }); + builder.WithAutomaticReconnect(mockReconnectPolicy.Object); + + await using var hubConnection = builder.Build(); + var reconnectingCount = 0; + var reconnectedCount = 0; + var reconnectingErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var closedErrorTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + hubConnection.Reconnecting += error => + { + reconnectingCount++; + reconnectingErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + hubConnection.Reconnected += connectionId => + { + reconnectedCount++; + return Task.CompletedTask; + }; + + hubConnection.Closed += error => + { + closedErrorTcs.SetResult(error); + return Task.CompletedTask; + }; + + // Allow the first connection to start successfully. + await hubConnection.StartAsync().OrTimeout(); + + var firstException = new Exception(); + (await testConnectionFactory.GetNextOrCurrentTestConnection()).CompleteFromTransport(firstException); + + Assert.Same(firstException, await reconnectingErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Same(firstException, retryContexts[0].RetryReason); + Assert.Equal(0, retryContexts[0].PreviousRetryCount); + Assert.Equal(TimeSpan.Zero, retryContexts[0].ElapsedTime); + + await hubConnection.StopAsync().OrTimeout(); + + Assert.IsType(await closedErrorTcs.Task.OrTimeout()); + Assert.Single(retryContexts); + Assert.Equal(1, reconnectingCount); + Assert.Equal(0, reconnectedCount); + } + } + + private class ReconnectingConnectionFactory : IConnectionFactory + { + public readonly Func _testConnectionFactory; + public TaskCompletionSource _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + public ReconnectingConnectionFactory() + : this (() => new TestConnection()) + { + } + + public ReconnectingConnectionFactory(Func testConnectionFactory) + { + _testConnectionFactory = testConnectionFactory; + } + + public Task GetNextOrCurrentTestConnection() + { + return _testConnectionTcs.Task; + } + + public async Task ConnectAsync(TransferFormat transferFormat, CancellationToken cancellationToken = default) + { + var testConnection = _testConnectionFactory(); + + _testConnectionTcs.SetResult(testConnection); + + try + { + return await testConnection.StartAsync(transferFormat); + } + catch + { + _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + throw; + } + } + + public async Task DisposeAsync(ConnectionContext connection) + { + var disposingTestConnection = await _testConnectionTcs.Task; + + _testConnectionTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + await disposingTestConnection.DisposeAsync(); + } + } + } +} diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs index 6c6c426e5f..40cca161c0 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs @@ -339,7 +339,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // after cancellation, don't send from the pipe foreach (var number in new[] { 42, 43, 322, 3145, -1234 }) { - await channel.Writer.WriteAsync(number); } diff --git a/src/SignalR/clients/ts/signalr/src/HubConnectionBuilder.ts b/src/SignalR/clients/ts/signalr/src/HubConnectionBuilder.ts index ef51ea7ebd..c9ae4e15d9 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnectionBuilder.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnectionBuilder.ts @@ -164,7 +164,7 @@ export class HubConnectionBuilder { /** Configures the {@link @aspnet/signalr.HubConnection} to automatically attempt to reconnect if the connection is lost. * - * @param {number[]} reconnectPolicy An {@link @aspnet/signalR.IReconnectPolicy} that controls the timing and number of reconnect attempts. + * @param {IReconnectPolicy} reconnectPolicy An {@link @aspnet/signalR.IReconnectPolicy} that controls the timing and number of reconnect attempts. */ public withAutomaticReconnect(reconnectPolicy: IReconnectPolicy): HubConnectionBuilder; public withAutomaticReconnect(retryDelaysOrReconnectPolicy?: number[] | IReconnectPolicy): HubConnectionBuilder { diff --git a/src/SignalR/samples/ClientSample/HubSample.cs b/src/SignalR/samples/ClientSample/HubSample.cs index 198a316585..f37f6d1991 100644 --- a/src/SignalR/samples/ClientSample/HubSample.cs +++ b/src/SignalR/samples/ClientSample/HubSample.cs @@ -1,14 +1,13 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.IO; -using System.Linq; -using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.Extensions.CommandLineUtils; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace ClientSample @@ -37,6 +36,11 @@ namespace ClientSample logging.AddConsole(); }); + connectionBuilder.Services.Configure(options => + { + options.MinLevel = LogLevel.Trace; + }); + if (uri.Scheme == "net.tcp") { connectionBuilder.WithEndPoint(uri); @@ -46,6 +50,8 @@ namespace ClientSample connectionBuilder.WithUrl(uri); } + connectionBuilder.WithAutomaticReconnect(); + var connection = connectionBuilder.Build(); Console.CancelKeyPress += (sender, a) => @@ -68,7 +74,8 @@ namespace ClientSample return Task.CompletedTask; }; - while (true) + + do { // Dispose the previous token closedTokenSource?.Dispose(); @@ -77,48 +84,34 @@ namespace ClientSample closedTokenSource = new CancellationTokenSource(); // Connect to the server - if (!await ConnectAsync(connection)) + } while (!await ConnectAsync(connection)); + + Console.WriteLine("Connected to {0}", uri); + + // Handle the connected connection + while (true) + { + try { + var line = Console.ReadLine(); + + if (line == null || closedTokenSource.Token.IsCancellationRequested) + { + break; + } + + await connection.InvokeAsync("Send", line); + } + catch (ObjectDisposedException) + { + // We're shutting down the client break; } - - Console.WriteLine("Connected to {0}", uri); ; - - // Handle the connected connection - while (true) + catch (Exception ex) { - try - { - var line = Console.ReadLine(); - - if (line == null || closedTokenSource.Token.IsCancellationRequested) - { - break; - } - - await connection.InvokeAsync("Send", line); - } - catch (IOException) - { - // Process being shutdown - break; - } - catch (OperationCanceledException) - { - // The connection closed - break; - } - catch (ObjectDisposedException) - { - // We're shutting down the client - break; - } - catch (Exception ex) - { - // Send could have failed because the connection closed - System.Console.WriteLine(ex); - break; - } + // Send could have failed because the connection closed + // Continue to loop because we should be reconnecting. + Console.WriteLine(ex); } }