Fixing a race DisposeAsync is called when StartAsync hasn't finished

Fixes: #248
This commit is contained in:
moozzyk 2017-03-01 16:49:07 -08:00
parent aca34cb4a1
commit 62c3c15a1f
4 changed files with 97 additions and 32 deletions

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
@ -17,10 +16,11 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
private int _connectionState = ConnectionState.Initial;
private IChannelConnection<Message> _transportChannel;
private ITransport _transport;
private Task _receiveLoopTask;
private volatile int _connectionState = ConnectionState.Initial;
private volatile IChannelConnection<Message> _transportChannel;
private volatile ITransport _transport;
private volatile Task _receiveLoopTask;
private volatile Task _startTask = Task.CompletedTask;
private ReadableChannel<Message> Input => _transportChannel.Input;
private WritableChannel<Message> Output => _transportChannel.Output;
@ -47,11 +47,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
public Task StartAsync(HttpClient httpClient) => StartAsync(transport: null, httpClient: httpClient);
public Task StartAsync(ITransport transport) => StartAsync(transport: transport, httpClient: null);
// TODO HIGH: Fix a race when the connection is being stopped/disposed when start has not finished running
public async Task StartAsync(ITransport transport, HttpClient httpClient)
public Task StartAsync(ITransport transport, HttpClient httpClient)
{
_transport = transport ?? new WebSocketsTransport(_loggerFactory);
_startTask = StartAsyncInternal(transport, httpClient);
return _startTask;
}
private async Task StartAsyncInternal(ITransport transport, HttpClient httpClient)
{
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial)
!= ConnectionState.Initial)
{
@ -61,6 +64,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
try
{
var connectUrl = await GetConnectUrl(Url, httpClient, _logger);
// Connection is being stopped while start was in progress
if (_connectionState == ConnectionState.Disconnected)
{
return;
}
_transport = transport ?? new WebSocketsTransport(_loggerFactory);
await StartTransport(connectUrl);
}
catch
@ -69,16 +80,31 @@ namespace Microsoft.AspNetCore.Sockets.Client
throw;
}
// start receive loop
_receiveLoopTask = ReceiveAsync();
Interlocked.Exchange(ref _connectionState, ConnectionState.Connected);
// Do not "simplify" - events can be removed from a different thread
var connectedEventHandler = Connected;
if (connectedEventHandler != null)
// if the connection is not in the Connecting state here it means the user called DisposeAsync
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting)
== ConnectionState.Connecting)
{
connectedEventHandler();
// Do not "simplify" - events can be removed from a different thread
var connectedEventHandler = Connected;
if (connectedEventHandler != null)
{
connectedEventHandler();
}
var ignore = Input.Completion.ContinueWith(t =>
{
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
// Do not "simplify" - events can be removed from a different thread
var closedEventHandler = Closed;
if (closedEventHandler != null)
{
closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null);
}
});
// start receive loop
_receiveLoopTask = ReceiveAsync();
}
}
@ -126,18 +152,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
_transportChannel = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var ignore = Input.Completion.ContinueWith(t =>
{
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
// Do not "simplify" - events can be removed from a different thread
var closedEventHandler = Closed;
if (closedEventHandler != null)
{
closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null);
}
});
// Start the transport, giving it one end of the pipeline
try
{
@ -213,6 +227,15 @@ namespace Microsoft.AspNetCore.Sockets.Client
public async Task DisposeAsync()
{
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
try
{
await _startTask;
}
catch
{
// We only await the start task to make sure that StartAsync completed. The
// _startTask is returned to the user and they should handle exceptions.
}
if (_transportChannel != null)
{

View File

@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
private Task _poller;
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
public Task Running { get; private set; }
public Task Running { get; private set; } = Task.CompletedTask;
public LongPollingTransport(HttpClient httpClient)
: this(httpClient, null)

View File

@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("WebSocketsTransport");
}
public Task Running { get; private set; }
public Task Running { get; private set; } = Task.CompletedTask;
public async Task StartAsync(Uri url, IChannelConnection<Message> application)
{

View File

@ -105,6 +105,49 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
}
}
[Fact]
public async Task CanStopStartingConnection()
{
// Used to make sure StartAsync is not completed before DisposeAsync is called
var releaseNegotiateTcs = new TaskCompletionSource<object>();
// Used to make sure that DisposeAsync runs after we check the state in StartAsync
var allowDisposeTcs = new TaskCompletionSource<object>();
// Used to make sure that DisposeAsync continues only after StartAsync finished
var releaseDisposeTcs = new TaskCompletionSource<object>();
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
// allow DisposeAsync to continue once we know we are past the connection state check
allowDisposeTcs.SetResult(null);
await releaseNegotiateTcs.Task;
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var transport = new Mock<ITransport>();
transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; });
var connection = new Connection(new Uri("http://fakeuri.org/"));
var startTask = connection.StartAsync(transport.Object, httpClient);
await allowDisposeTcs.Task;
var disposeTask = connection.DisposeAsync();
// allow StartAsync to continue once DisposeAsync has started
releaseNegotiateTcs.SetResult(null);
// unblock DisposeAsync only after StartAsync completed
await startTask.OrTimeout();
releaseDisposeTcs.SetResult(null);
await disposeTask.OrTimeout();
transport.Verify(t => t.StartAsync(It.IsAny<Uri>(), It.IsAny<IChannelConnection<Message>>()), Times.Never);
}
}
[Fact]
public async Task SendReturnsFalseIfConnectionIsNotStarted()
{
@ -165,7 +208,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
{
await connection.DisposeAsync();
}
}
}