diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index 60400beed6..82e6c9cc89 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -17,10 +16,11 @@ namespace Microsoft.AspNetCore.Sockets.Client { private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - private int _connectionState = ConnectionState.Initial; - private IChannelConnection _transportChannel; - private ITransport _transport; - private Task _receiveLoopTask; + private volatile int _connectionState = ConnectionState.Initial; + private volatile IChannelConnection _transportChannel; + private volatile ITransport _transport; + private volatile Task _receiveLoopTask; + private volatile Task _startTask = Task.CompletedTask; private ReadableChannel Input => _transportChannel.Input; private WritableChannel Output => _transportChannel.Output; @@ -47,11 +47,14 @@ namespace Microsoft.AspNetCore.Sockets.Client public Task StartAsync(HttpClient httpClient) => StartAsync(transport: null, httpClient: httpClient); public Task StartAsync(ITransport transport) => StartAsync(transport: transport, httpClient: null); - // TODO HIGH: Fix a race when the connection is being stopped/disposed when start has not finished running - public async Task StartAsync(ITransport transport, HttpClient httpClient) + public Task StartAsync(ITransport transport, HttpClient httpClient) { - _transport = transport ?? new WebSocketsTransport(_loggerFactory); + _startTask = StartAsyncInternal(transport, httpClient); + return _startTask; + } + private async Task StartAsyncInternal(ITransport transport, HttpClient httpClient) + { if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial) != ConnectionState.Initial) { @@ -61,6 +64,14 @@ namespace Microsoft.AspNetCore.Sockets.Client try { var connectUrl = await GetConnectUrl(Url, httpClient, _logger); + + // Connection is being stopped while start was in progress + if (_connectionState == ConnectionState.Disconnected) + { + return; + } + + _transport = transport ?? new WebSocketsTransport(_loggerFactory); await StartTransport(connectUrl); } catch @@ -69,16 +80,31 @@ namespace Microsoft.AspNetCore.Sockets.Client throw; } - // start receive loop - _receiveLoopTask = ReceiveAsync(); - - Interlocked.Exchange(ref _connectionState, ConnectionState.Connected); - - // Do not "simplify" - events can be removed from a different thread - var connectedEventHandler = Connected; - if (connectedEventHandler != null) + // if the connection is not in the Connecting state here it means the user called DisposeAsync + if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting) + == ConnectionState.Connecting) { - connectedEventHandler(); + // Do not "simplify" - events can be removed from a different thread + var connectedEventHandler = Connected; + if (connectedEventHandler != null) + { + connectedEventHandler(); + } + + var ignore = Input.Completion.ContinueWith(t => + { + Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + + // Do not "simplify" - events can be removed from a different thread + var closedEventHandler = Closed; + if (closedEventHandler != null) + { + closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null); + } + }); + + // start receive loop + _receiveLoopTask = ReceiveAsync(); } } @@ -126,18 +152,6 @@ namespace Microsoft.AspNetCore.Sockets.Client _transportChannel = new ChannelConnection(applicationToTransport, transportToApplication); - var ignore = Input.Completion.ContinueWith(t => - { - Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); - - // Do not "simplify" - events can be removed from a different thread - var closedEventHandler = Closed; - if (closedEventHandler != null) - { - closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null); - } - }); - // Start the transport, giving it one end of the pipeline try { @@ -213,6 +227,15 @@ namespace Microsoft.AspNetCore.Sockets.Client public async Task DisposeAsync() { Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected); + try + { + await _startTask; + } + catch + { + // We only await the start task to make sure that StartAsync completed. The + // _startTask is returned to the user and they should handle exceptions. + } if (_transportChannel != null) { diff --git a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs index c861c58a34..a0cfceab64 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/LongPollingTransport.cs @@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets.Client private Task _poller; private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); - public Task Running { get; private set; } + public Task Running { get; private set; } = Task.CompletedTask; public LongPollingTransport(HttpClient httpClient) : this(httpClient, null) diff --git a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs index f34ffc47d7..17253d3f8e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs @@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("WebSocketsTransport"); } - public Task Running { get; private set; } + public Task Running { get; private set; } = Task.CompletedTask; public async Task StartAsync(Uri url, IChannelConnection application) { diff --git a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs index 0d47df2017..2f7d675e95 100644 --- a/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Client.Tests/ConnectionTests.cs @@ -105,6 +105,49 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } + [Fact] + public async Task CanStopStartingConnection() + { + // Used to make sure StartAsync is not completed before DisposeAsync is called + var releaseNegotiateTcs = new TaskCompletionSource(); + // Used to make sure that DisposeAsync runs after we check the state in StartAsync + var allowDisposeTcs = new TaskCompletionSource(); + // Used to make sure that DisposeAsync continues only after StartAsync finished + var releaseDisposeTcs = new TaskCompletionSource(); + + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + // allow DisposeAsync to continue once we know we are past the connection state check + allowDisposeTcs.SetResult(null); + await releaseNegotiateTcs.Task; + return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var transport = new Mock(); + transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; }); + var connection = new Connection(new Uri("http://fakeuri.org/")); + + var startTask = connection.StartAsync(transport.Object, httpClient); + await allowDisposeTcs.Task; + var disposeTask = connection.DisposeAsync(); + // allow StartAsync to continue once DisposeAsync has started + releaseNegotiateTcs.SetResult(null); + + // unblock DisposeAsync only after StartAsync completed + await startTask.OrTimeout(); + releaseDisposeTcs.SetResult(null); + await disposeTask.OrTimeout(); + + transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>()), Times.Never); + } + } + [Fact] public async Task SendReturnsFalseIfConnectionIsNotStarted() { @@ -165,7 +208,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests { await connection.DisposeAsync(); } - } }