From 23375c911b4659d171bf8532b824c32771a54cb1 Mon Sep 17 00:00:00 2001 From: moozzyk Date: Thu, 16 Mar 2017 09:36:14 -0700 Subject: [PATCH] Fixes a race where wrong task could be awaited in channel completion We had a startTask we would await in DisposeAsync and channel completion continuation. This task would be initially set to a completed task and then once StartAsync was invoked it would be replaced with the actual task representing StartAsyncInternal. However if a transport failed immediately after starting the channel completion continuation could have been called before the StartAsyncInternal method completed. In this case we would await the inital completed task and then very likely would fail trying to await _receiveLoop task because it wouldn't necessarily be set. The fix is to use TaskCompletionSource so we don't try to swap tasks. We need to do some additional state checks because: - the TaskCompletionSource task may never be completed (e.g. DisposeAsync is being called without starting the connection) - TaskCompletionSource allows setting the result only once and we should not return its task more than once (e.g. calling StartAsync after connection was successfully started and stopped) Fixes #304 --- .../Connection.cs | 47 ++++++++++++++----- 1 file changed, 35 insertions(+), 12 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index cfdf7beca5..d1f1c6fa63 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private volatile IChannelConnection _transportChannel; private volatile ITransport _transport; private volatile Task _receiveLoopTask; - private volatile Task _startTask = Task.CompletedTask; + private TaskCompletionSource _startTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private TaskQueue _eventQueue = new TaskQueue(); private ReadableChannel Input => _transportChannel.Input; @@ -52,18 +52,35 @@ namespace Microsoft.AspNetCore.Sockets.Client public Task StartAsync(ITransport transport, HttpClient httpClient) { - _startTask = StartAsyncInternal(transport, httpClient); - return _startTask; + if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) + != ConnectionState.Initial) + { + return Task.FromException( + new InvalidOperationException("Cannot start a connection that is not in the Initial state.")); + } + + StartAsyncInternal(transport, httpClient) + .ContinueWith(t => + { + if (t.IsFaulted) + { + _startTcs.SetException(t.Exception.InnerException); + } + else if (t.IsCanceled) + { + _startTcs.SetCanceled(); + } + else + { + _startTcs.SetResult(null); + } + }); + + return _startTcs.Task; } private async Task StartAsyncInternal(ITransport transport, HttpClient httpClient) { - if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) - != ConnectionState.Initial) - { - throw new InvalidOperationException("Cannot start a connection that is not in the Initial state."); - } - _logger.LogDebug("Starting connection."); try @@ -116,7 +133,8 @@ namespace Microsoft.AspNetCore.Sockets.Client // There is a short window between we start the channel and assign the _receiveLoopTask a value. // To make sure that _receiveLoopTask can be awaited (i.e. is not null) we need to await _startTask. _logger.LogDebug("Ensuring all outstanding messages are processed."); - await _startTask; + + await _startTcs.Task; await _receiveLoopTask; _logger.LogDebug("Draining event queue"); @@ -287,10 +305,15 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.LogInformation("Stopping client."); - Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + if (Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected) == ConnectionState.Initial) + { + // the connection was never started so there is nothing to clean up + return; + } + try { - await _startTask; + await _startTcs.Task; } catch {