Make StopAsync multi-thread safe (#1666).

This commit is contained in:
Cesar Blum Silveira 2017-07-11 14:29:50 -07:00 committed by GitHub
parent b0dc76a6ae
commit fd1758fdfc
2 changed files with 136 additions and 20 deletions

View File

@ -25,7 +25,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
private readonly ITransportFactory _transportFactory;
private bool _hasStarted;
private int _stopped;
private int _stopping;
private readonly TaskCompletionSource<object> _stoppedTcs = new TaskCompletionSource<object>();
public KestrelServer(IOptions<KestrelServerOptions> options, ITransportFactory transportFactory, ILoggerFactory loggerFactory)
: this(transportFactory, CreateServiceContext(options, loggerFactory))
@ -154,35 +155,46 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
// Graceful shutdown if possible
public async Task StopAsync(CancellationToken cancellationToken)
{
if (Interlocked.Exchange(ref _stopped, 1) == 1)
if (Interlocked.Exchange(ref _stopping, 1) == 1)
{
await _stoppedTcs.Task.ConfigureAwait(false);
return;
}
var tasks = new Task[_transports.Count];
for (int i = 0; i < _transports.Count; i++)
try
{
tasks[i] = _transports[i].UnbindAsync();
}
await Task.WhenAll(tasks).ConfigureAwait(false);
if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false))
{
Trace.NotAllConnectionsClosedGracefully();
if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false))
var tasks = new Task[_transports.Count];
for (int i = 0; i < _transports.Count; i++)
{
Trace.NotAllConnectionsAborted();
tasks[i] = _transports[i].UnbindAsync();
}
}
await Task.WhenAll(tasks).ConfigureAwait(false);
for (int i = 0; i < _transports.Count; i++)
if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false))
{
Trace.NotAllConnectionsClosedGracefully();
if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false))
{
Trace.NotAllConnectionsAborted();
}
}
for (int i = 0; i < _transports.Count; i++)
{
tasks[i] = _transports[i].StopAsync();
}
await Task.WhenAll(tasks).ConfigureAwait(false);
_heartbeat.Dispose();
}
catch (Exception ex)
{
tasks[i] = _transports[i].StopAsync();
_stoppedTcs.TrySetException(ex);
throw;
}
await Task.WhenAll(tasks).ConfigureAwait(false);
_heartbeat.Dispose();
_stoppedTcs.TrySetResult(null);
}
// Ungraceful shutdown

View File

@ -9,7 +9,6 @@ using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
@ -165,6 +164,111 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Equal("transportFactory", exception.ParamName);
}
[Fact]
public async Task StopAsyncCallsCompleteWhenFirstCallCompletes()
{
var options = new KestrelServerOptions
{
ListenOptions =
{
new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
}
};
var unbind = new SemaphoreSlim(0);
var stop = new SemaphoreSlim(0);
var mockTransport = new Mock<ITransport>();
mockTransport
.Setup(transport => transport.BindAsync())
.Returns(Task.CompletedTask);
mockTransport
.Setup(transport => transport.UnbindAsync())
.Returns(async () => await unbind.WaitAsync());
mockTransport
.Setup(transport => transport.StopAsync())
.Returns(async () => await stop.WaitAsync());
var mockTransportFactory = new Mock<ITransportFactory>();
mockTransportFactory
.Setup(transportFactory => transportFactory.Create(It.IsAny<IEndPointInformation>(), It.IsAny<IConnectionHandler>()))
.Returns(mockTransport.Object);
var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, Mock.Of<LoggerFactory>());
await server.StartAsync(new DummyApplication(), CancellationToken.None);
var stopTask1 = server.StopAsync(default(CancellationToken));
var stopTask2 = server.StopAsync(default(CancellationToken));
var stopTask3 = server.StopAsync(default(CancellationToken));
Assert.False(stopTask1.IsCompleted);
Assert.False(stopTask2.IsCompleted);
Assert.False(stopTask3.IsCompleted);
unbind.Release();
stop.Release();
await Task.WhenAll(new[] { stopTask1, stopTask2, stopTask3 }).TimeoutAfter(TimeSpan.FromSeconds(10));
mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once);
mockTransport.Verify(transport => transport.StopAsync(), Times.Once);
}
[Fact]
public async Task StopAsyncCallsCompleteWithThrownException()
{
var options = new KestrelServerOptions
{
ListenOptions =
{
new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
}
};
var unbind = new SemaphoreSlim(0);
var unbindException = new InvalidOperationException();
var mockTransport = new Mock<ITransport>();
mockTransport
.Setup(transport => transport.BindAsync())
.Returns(Task.CompletedTask);
mockTransport
.Setup(transport => transport.UnbindAsync())
.Returns(async () =>
{
await unbind.WaitAsync();
throw unbindException;
});
mockTransport
.Setup(transport => transport.StopAsync())
.Returns(Task.CompletedTask);
var mockTransportFactory = new Mock<ITransportFactory>();
mockTransportFactory
.Setup(transportFactory => transportFactory.Create(It.IsAny<IEndPointInformation>(), It.IsAny<IConnectionHandler>()))
.Returns(mockTransport.Object);
var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, Mock.Of<LoggerFactory>());
await server.StartAsync(new DummyApplication(), CancellationToken.None);
var stopTask1 = server.StopAsync(default(CancellationToken));
var stopTask2 = server.StopAsync(default(CancellationToken));
var stopTask3 = server.StopAsync(default(CancellationToken));
Assert.False(stopTask1.IsCompleted);
Assert.False(stopTask2.IsCompleted);
Assert.False(stopTask3.IsCompleted);
unbind.Release();
var timeout = TimeSpan.FromSeconds(10);
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask1.TimeoutAfter(timeout)));
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask2.TimeoutAfter(timeout)));
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask3.TimeoutAfter(timeout)));
mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once);
}
private static KestrelServer CreateServer(KestrelServerOptions options, ILogger testLogger)
{
return new KestrelServer(Options.Create(options), new MockTransportFactory(), new LoggerFactory(new [] { new KestrelTestLoggerProvider(testLogger)} ));