Fixing a race DisposeAsync is called when StartAsync hasn't finished
Fixes: #248
This commit is contained in:
parent
aca34cb4a1
commit
62c3c15a1f
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO.Pipelines;
|
||||
using System.Net.Http;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
|
@ -17,10 +16,11 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
{
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly ILogger _logger;
|
||||
private int _connectionState = ConnectionState.Initial;
|
||||
private IChannelConnection<Message> _transportChannel;
|
||||
private ITransport _transport;
|
||||
private Task _receiveLoopTask;
|
||||
private volatile int _connectionState = ConnectionState.Initial;
|
||||
private volatile IChannelConnection<Message> _transportChannel;
|
||||
private volatile ITransport _transport;
|
||||
private volatile Task _receiveLoopTask;
|
||||
private volatile Task _startTask = Task.CompletedTask;
|
||||
|
||||
private ReadableChannel<Message> Input => _transportChannel.Input;
|
||||
private WritableChannel<Message> Output => _transportChannel.Output;
|
||||
|
|
@ -47,11 +47,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
public Task StartAsync(HttpClient httpClient) => StartAsync(transport: null, httpClient: httpClient);
|
||||
public Task StartAsync(ITransport transport) => StartAsync(transport: transport, httpClient: null);
|
||||
|
||||
// TODO HIGH: Fix a race when the connection is being stopped/disposed when start has not finished running
|
||||
public async Task StartAsync(ITransport transport, HttpClient httpClient)
|
||||
public Task StartAsync(ITransport transport, HttpClient httpClient)
|
||||
{
|
||||
_transport = transport ?? new WebSocketsTransport(_loggerFactory);
|
||||
_startTask = StartAsyncInternal(transport, httpClient);
|
||||
return _startTask;
|
||||
}
|
||||
|
||||
private async Task StartAsyncInternal(ITransport transport, HttpClient httpClient)
|
||||
{
|
||||
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connecting, ConnectionState.Initial)
|
||||
!= ConnectionState.Initial)
|
||||
{
|
||||
|
|
@ -61,6 +64,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
try
|
||||
{
|
||||
var connectUrl = await GetConnectUrl(Url, httpClient, _logger);
|
||||
|
||||
// Connection is being stopped while start was in progress
|
||||
if (_connectionState == ConnectionState.Disconnected)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
_transport = transport ?? new WebSocketsTransport(_loggerFactory);
|
||||
await StartTransport(connectUrl);
|
||||
}
|
||||
catch
|
||||
|
|
@ -69,16 +80,31 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
throw;
|
||||
}
|
||||
|
||||
// start receive loop
|
||||
_receiveLoopTask = ReceiveAsync();
|
||||
|
||||
Interlocked.Exchange(ref _connectionState, ConnectionState.Connected);
|
||||
|
||||
// Do not "simplify" - events can be removed from a different thread
|
||||
var connectedEventHandler = Connected;
|
||||
if (connectedEventHandler != null)
|
||||
// if the connection is not in the Connecting state here it means the user called DisposeAsync
|
||||
if (Interlocked.CompareExchange(ref _connectionState, ConnectionState.Connected, ConnectionState.Connecting)
|
||||
== ConnectionState.Connecting)
|
||||
{
|
||||
connectedEventHandler();
|
||||
// Do not "simplify" - events can be removed from a different thread
|
||||
var connectedEventHandler = Connected;
|
||||
if (connectedEventHandler != null)
|
||||
{
|
||||
connectedEventHandler();
|
||||
}
|
||||
|
||||
var ignore = Input.Completion.ContinueWith(t =>
|
||||
{
|
||||
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
|
||||
|
||||
// Do not "simplify" - events can be removed from a different thread
|
||||
var closedEventHandler = Closed;
|
||||
if (closedEventHandler != null)
|
||||
{
|
||||
closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null);
|
||||
}
|
||||
});
|
||||
|
||||
// start receive loop
|
||||
_receiveLoopTask = ReceiveAsync();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -126,18 +152,6 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
|
||||
_transportChannel = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
|
||||
|
||||
var ignore = Input.Completion.ContinueWith(t =>
|
||||
{
|
||||
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
|
||||
|
||||
// Do not "simplify" - events can be removed from a different thread
|
||||
var closedEventHandler = Closed;
|
||||
if (closedEventHandler != null)
|
||||
{
|
||||
closedEventHandler(t.IsFaulted ? t.Exception.InnerException : null);
|
||||
}
|
||||
});
|
||||
|
||||
// Start the transport, giving it one end of the pipeline
|
||||
try
|
||||
{
|
||||
|
|
@ -213,6 +227,15 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
public async Task DisposeAsync()
|
||||
{
|
||||
Interlocked.Exchange(ref _connectionState, ConnectionState.Disconnected);
|
||||
try
|
||||
{
|
||||
await _startTask;
|
||||
}
|
||||
catch
|
||||
{
|
||||
// We only await the start task to make sure that StartAsync completed. The
|
||||
// _startTask is returned to the user and they should handle exceptions.
|
||||
}
|
||||
|
||||
if (_transportChannel != null)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
private Task _poller;
|
||||
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
|
||||
|
||||
public Task Running { get; private set; }
|
||||
public Task Running { get; private set; } = Task.CompletedTask;
|
||||
|
||||
public LongPollingTransport(HttpClient httpClient)
|
||||
: this(httpClient, null)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("WebSocketsTransport");
|
||||
}
|
||||
|
||||
public Task Running { get; private set; }
|
||||
public Task Running { get; private set; } = Task.CompletedTask;
|
||||
|
||||
public async Task StartAsync(Uri url, IChannelConnection<Message> application)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -105,6 +105,49 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CanStopStartingConnection()
|
||||
{
|
||||
// Used to make sure StartAsync is not completed before DisposeAsync is called
|
||||
var releaseNegotiateTcs = new TaskCompletionSource<object>();
|
||||
// Used to make sure that DisposeAsync runs after we check the state in StartAsync
|
||||
var allowDisposeTcs = new TaskCompletionSource<object>();
|
||||
// Used to make sure that DisposeAsync continues only after StartAsync finished
|
||||
var releaseDisposeTcs = new TaskCompletionSource<object>();
|
||||
|
||||
var mockHttpHandler = new Mock<HttpMessageHandler>();
|
||||
mockHttpHandler.Protected()
|
||||
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
|
||||
{
|
||||
await Task.Yield();
|
||||
// allow DisposeAsync to continue once we know we are past the connection state check
|
||||
allowDisposeTcs.SetResult(null);
|
||||
await releaseNegotiateTcs.Task;
|
||||
return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) };
|
||||
});
|
||||
|
||||
using (var httpClient = new HttpClient(mockHttpHandler.Object))
|
||||
{
|
||||
var transport = new Mock<ITransport>();
|
||||
transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; });
|
||||
var connection = new Connection(new Uri("http://fakeuri.org/"));
|
||||
|
||||
var startTask = connection.StartAsync(transport.Object, httpClient);
|
||||
await allowDisposeTcs.Task;
|
||||
var disposeTask = connection.DisposeAsync();
|
||||
// allow StartAsync to continue once DisposeAsync has started
|
||||
releaseNegotiateTcs.SetResult(null);
|
||||
|
||||
// unblock DisposeAsync only after StartAsync completed
|
||||
await startTask.OrTimeout();
|
||||
releaseDisposeTcs.SetResult(null);
|
||||
await disposeTask.OrTimeout();
|
||||
|
||||
transport.Verify(t => t.StartAsync(It.IsAny<Uri>(), It.IsAny<IChannelConnection<Message>>()), Times.Never);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SendReturnsFalseIfConnectionIsNotStarted()
|
||||
{
|
||||
|
|
@ -165,7 +208,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
|
|||
{
|
||||
await connection.DisposeAsync();
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue