Make StopAsync multi-thread safe (#1666).
This commit is contained in:
parent
b0dc76a6ae
commit
fd1758fdfc
|
|
@ -25,7 +25,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
|
|||
private readonly ITransportFactory _transportFactory;
|
||||
|
||||
private bool _hasStarted;
|
||||
private int _stopped;
|
||||
private int _stopping;
|
||||
private readonly TaskCompletionSource<object> _stoppedTcs = new TaskCompletionSource<object>();
|
||||
|
||||
public KestrelServer(IOptions<KestrelServerOptions> options, ITransportFactory transportFactory, ILoggerFactory loggerFactory)
|
||||
: this(transportFactory, CreateServiceContext(options, loggerFactory))
|
||||
|
|
@ -154,35 +155,46 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
|
|||
// Graceful shutdown if possible
|
||||
public async Task StopAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
if (Interlocked.Exchange(ref _stopped, 1) == 1)
|
||||
if (Interlocked.Exchange(ref _stopping, 1) == 1)
|
||||
{
|
||||
await _stoppedTcs.Task.ConfigureAwait(false);
|
||||
return;
|
||||
}
|
||||
|
||||
var tasks = new Task[_transports.Count];
|
||||
for (int i = 0; i < _transports.Count; i++)
|
||||
try
|
||||
{
|
||||
tasks[i] = _transports[i].UnbindAsync();
|
||||
}
|
||||
await Task.WhenAll(tasks).ConfigureAwait(false);
|
||||
|
||||
if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
Trace.NotAllConnectionsClosedGracefully();
|
||||
|
||||
if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false))
|
||||
var tasks = new Task[_transports.Count];
|
||||
for (int i = 0; i < _transports.Count; i++)
|
||||
{
|
||||
Trace.NotAllConnectionsAborted();
|
||||
tasks[i] = _transports[i].UnbindAsync();
|
||||
}
|
||||
}
|
||||
await Task.WhenAll(tasks).ConfigureAwait(false);
|
||||
|
||||
for (int i = 0; i < _transports.Count; i++)
|
||||
if (!await ConnectionManager.CloseAllConnectionsAsync(cancellationToken).ConfigureAwait(false))
|
||||
{
|
||||
Trace.NotAllConnectionsClosedGracefully();
|
||||
|
||||
if (!await ConnectionManager.AbortAllConnectionsAsync().ConfigureAwait(false))
|
||||
{
|
||||
Trace.NotAllConnectionsAborted();
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < _transports.Count; i++)
|
||||
{
|
||||
tasks[i] = _transports[i].StopAsync();
|
||||
}
|
||||
await Task.WhenAll(tasks).ConfigureAwait(false);
|
||||
|
||||
_heartbeat.Dispose();
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
tasks[i] = _transports[i].StopAsync();
|
||||
_stoppedTcs.TrySetException(ex);
|
||||
throw;
|
||||
}
|
||||
await Task.WhenAll(tasks).ConfigureAwait(false);
|
||||
|
||||
_heartbeat.Dispose();
|
||||
_stoppedTcs.TrySetResult(null);
|
||||
}
|
||||
|
||||
// Ungraceful shutdown
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ using System.Threading.Tasks;
|
|||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting.Server;
|
||||
using Microsoft.AspNetCore.Hosting.Server.Features;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
|
||||
using Microsoft.AspNetCore.Testing;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -165,6 +164,111 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
|
|||
Assert.Equal("transportFactory", exception.ParamName);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StopAsyncCallsCompleteWhenFirstCallCompletes()
|
||||
{
|
||||
var options = new KestrelServerOptions
|
||||
{
|
||||
ListenOptions =
|
||||
{
|
||||
new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
|
||||
}
|
||||
};
|
||||
|
||||
var unbind = new SemaphoreSlim(0);
|
||||
var stop = new SemaphoreSlim(0);
|
||||
|
||||
var mockTransport = new Mock<ITransport>();
|
||||
mockTransport
|
||||
.Setup(transport => transport.BindAsync())
|
||||
.Returns(Task.CompletedTask);
|
||||
mockTransport
|
||||
.Setup(transport => transport.UnbindAsync())
|
||||
.Returns(async () => await unbind.WaitAsync());
|
||||
mockTransport
|
||||
.Setup(transport => transport.StopAsync())
|
||||
.Returns(async () => await stop.WaitAsync());
|
||||
|
||||
var mockTransportFactory = new Mock<ITransportFactory>();
|
||||
mockTransportFactory
|
||||
.Setup(transportFactory => transportFactory.Create(It.IsAny<IEndPointInformation>(), It.IsAny<IConnectionHandler>()))
|
||||
.Returns(mockTransport.Object);
|
||||
|
||||
var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, Mock.Of<LoggerFactory>());
|
||||
await server.StartAsync(new DummyApplication(), CancellationToken.None);
|
||||
|
||||
var stopTask1 = server.StopAsync(default(CancellationToken));
|
||||
var stopTask2 = server.StopAsync(default(CancellationToken));
|
||||
var stopTask3 = server.StopAsync(default(CancellationToken));
|
||||
|
||||
Assert.False(stopTask1.IsCompleted);
|
||||
Assert.False(stopTask2.IsCompleted);
|
||||
Assert.False(stopTask3.IsCompleted);
|
||||
|
||||
unbind.Release();
|
||||
stop.Release();
|
||||
|
||||
await Task.WhenAll(new[] { stopTask1, stopTask2, stopTask3 }).TimeoutAfter(TimeSpan.FromSeconds(10));
|
||||
|
||||
mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once);
|
||||
mockTransport.Verify(transport => transport.StopAsync(), Times.Once);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StopAsyncCallsCompleteWithThrownException()
|
||||
{
|
||||
var options = new KestrelServerOptions
|
||||
{
|
||||
ListenOptions =
|
||||
{
|
||||
new ListenOptions(new IPEndPoint(IPAddress.Loopback, 0))
|
||||
}
|
||||
};
|
||||
|
||||
var unbind = new SemaphoreSlim(0);
|
||||
var unbindException = new InvalidOperationException();
|
||||
|
||||
var mockTransport = new Mock<ITransport>();
|
||||
mockTransport
|
||||
.Setup(transport => transport.BindAsync())
|
||||
.Returns(Task.CompletedTask);
|
||||
mockTransport
|
||||
.Setup(transport => transport.UnbindAsync())
|
||||
.Returns(async () =>
|
||||
{
|
||||
await unbind.WaitAsync();
|
||||
throw unbindException;
|
||||
});
|
||||
mockTransport
|
||||
.Setup(transport => transport.StopAsync())
|
||||
.Returns(Task.CompletedTask);
|
||||
|
||||
var mockTransportFactory = new Mock<ITransportFactory>();
|
||||
mockTransportFactory
|
||||
.Setup(transportFactory => transportFactory.Create(It.IsAny<IEndPointInformation>(), It.IsAny<IConnectionHandler>()))
|
||||
.Returns(mockTransport.Object);
|
||||
|
||||
var server = new KestrelServer(Options.Create(options), mockTransportFactory.Object, Mock.Of<LoggerFactory>());
|
||||
await server.StartAsync(new DummyApplication(), CancellationToken.None);
|
||||
|
||||
var stopTask1 = server.StopAsync(default(CancellationToken));
|
||||
var stopTask2 = server.StopAsync(default(CancellationToken));
|
||||
var stopTask3 = server.StopAsync(default(CancellationToken));
|
||||
|
||||
Assert.False(stopTask1.IsCompleted);
|
||||
Assert.False(stopTask2.IsCompleted);
|
||||
Assert.False(stopTask3.IsCompleted);
|
||||
|
||||
unbind.Release();
|
||||
|
||||
var timeout = TimeSpan.FromSeconds(10);
|
||||
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask1.TimeoutAfter(timeout)));
|
||||
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask2.TimeoutAfter(timeout)));
|
||||
Assert.Same(unbindException, await Assert.ThrowsAsync<InvalidOperationException>(() => stopTask3.TimeoutAfter(timeout)));
|
||||
|
||||
mockTransport.Verify(transport => transport.UnbindAsync(), Times.Once);
|
||||
}
|
||||
|
||||
private static KestrelServer CreateServer(KestrelServerOptions options, ILogger testLogger)
|
||||
{
|
||||
return new KestrelServer(Options.Create(options), new MockTransportFactory(), new LoggerFactory(new [] { new KestrelTestLoggerProvider(testLogger)} ));
|
||||
|
|
|
|||
Loading…
Reference in New Issue