diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index cfdf7beca5..d1f1c6fa63 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private volatile IChannelConnection _transportChannel; private volatile ITransport _transport; private volatile Task _receiveLoopTask; - private volatile Task _startTask = Task.CompletedTask; + private TaskCompletionSource _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private TaskQueue _eventQueue = new TaskQueue(); private ReadableChannel Input => _transportChannel.Input; @@ -52,18 +52,35 @@ namespace Microsoft.AspNetCore.Sockets.Client public Task StartAsync(ITransport transport, HttpClient httpClient) { - _startTask = StartAsyncInternal(transport, httpClient); - return _startTask; + if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) + != ConnectionState.Initial) + { + return Task.FromException( + new InvalidOperationException("Cannot start a connection that is not in the Initial state.")); + } + + StartAsyncInternal(transport, httpClient) + .ContinueWith(t => + { + if (t.IsFaulted) + { + _startTcs.SetException(t.Exception.InnerException); + } + else if (t.IsCanceled) + { + _startTcs.SetCanceled(); + } + else + { + _startTcs.SetResult(null); + } + }); + + return _startTcs.Task; } private async Task StartAsyncInternal(ITransport transport, HttpClient httpClient) { - if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) - != ConnectionState.Initial) - { - throw new InvalidOperationException("Cannot start a connection that is not in the Initial state."); - } - _logger.LogDebug("Starting connection."); try @@ -116,7 +133,8 @@ namespace Microsoft.AspNetCore.Sockets.Client // There is a short window between we start the channel and assign the _receiveLoopTask a value. // To make sure that _receiveLoopTask can be awaited (i.e. is not null) we need to await _startTask. _logger.LogDebug("Ensuring all outstanding messages are processed."); - await _startTask; + + await _startTcs.Task; await _receiveLoopTask; _logger.LogDebug("Draining event queue"); @@ -287,10 +305,15 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.LogInformation("Stopping client."); - Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + if (Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected) == ConnectionState.Initial) + { + // the connection was never started so there is nothing to clean up + return; + } + try { - await _startTask; + await _startTcs.Task; } catch {