diff --git a/src/Microsoft.AspNetCore.Server.HttpSys/MessagePump.cs b/src/Microsoft.AspNetCore.Server.HttpSys/MessagePump.cs index 82e6a03039..53330c6f1c 100644 --- a/src/Microsoft.AspNetCore.Server.HttpSys/MessagePump.cs +++ b/src/Microsoft.AspNetCore.Server.HttpSys/MessagePump.cs @@ -25,9 +25,10 @@ namespace Microsoft.AspNetCore.Server.HttpSys private int _acceptorCounts; private Action _processRequest; - private bool _stopping; + private volatile int _stopping; private int _outstandingRequests; - private TaskCompletionSource _shutdownSignal; + private readonly TaskCompletionSource _shutdownSignal = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private int _shutdownSignalCompleted; private readonly ServerAddressesFeature _serverAddresses; @@ -56,13 +57,14 @@ namespace Microsoft.AspNetCore.Server.HttpSys _processRequest = new Action(ProcessRequestAsync); _maxAccepts = _options.MaxAccepts; - _shutdownSignal = new TaskCompletionSource(); } internal HttpSysListener Listener { get; } public IFeatureCollection Features { get; } + private bool Stopping => _stopping == 1; + public Task StartAsync(IHttpApplication application, CancellationToken cancellationToken) { if (application == null) @@ -146,7 +148,7 @@ namespace Microsoft.AspNetCore.Server.HttpSys private async void ProcessRequestsWorker() { int workerIndex = Interlocked.Increment(ref _acceptorCounts); - while (!_stopping && workerIndex <= _maxAccepts) + while (!Stopping && workerIndex <= _maxAccepts) { // Receive a request RequestContext requestContext; @@ -156,8 +158,8 @@ namespace Microsoft.AspNetCore.Server.HttpSys } catch (Exception exception) { - Contract.Assert(_stopping); - if (_stopping) + Contract.Assert(Stopping); + if (Stopping) { LogHelper.LogDebug(_logger, "ListenForNextRequestAsync-Stopping", exception); } @@ -186,7 +188,7 @@ namespace Microsoft.AspNetCore.Server.HttpSys var requestContext = requestContextObj as RequestContext; try { - if (_stopping) + if (Stopping) { SetFatalResponse(requestContext, 503); return; @@ -227,7 +229,7 @@ namespace Microsoft.AspNetCore.Server.HttpSys } finally { - if (Interlocked.Decrement(ref _outstandingRequests) == 0 && _stopping) + if (Interlocked.Decrement(ref _outstandingRequests) == 0 && Stopping) { LogHelper.LogInfo(_logger, "All requests drained."); _shutdownSignal.TrySetResult(0); @@ -250,28 +252,51 @@ namespace Microsoft.AspNetCore.Server.HttpSys public Task StopAsync(CancellationToken cancellationToken) { - _stopping = true; - // Wait for active requests to drain - if (_outstandingRequests > 0) + void RegisterCancelation() { - LogHelper.LogInfo(_logger, "Stopping, waiting for " + _outstandingRequests + " request(s) to drain."); - - var waitForStop = new TaskCompletionSource(); cancellationToken.Register(() => { - LogHelper.LogInfo(_logger, "Timed out, terminating " + _outstandingRequests + " request(s)."); - waitForStop.TrySetResult(0); + if (Interlocked.Exchange(ref _shutdownSignalCompleted, 1) == 0) + { + LogHelper.LogInfo(_logger, "Canceled, terminating " + _outstandingRequests + " request(s)."); + _shutdownSignal.TrySetResult(null); + } }); - - return Task.WhenAny(_shutdownSignal.Task, waitForStop.Task); } - return Task.CompletedTask; + if (Interlocked.Exchange(ref _stopping, 1) == 1) + { + RegisterCancelation(); + + return _shutdownSignal.Task; + } + + try + { + // Wait for active requests to drain + if (_outstandingRequests > 0) + { + LogHelper.LogInfo(_logger, "Stopping, waiting for " + _outstandingRequests + " request(s) to drain."); + RegisterCancelation(); + } + else + { + _shutdownSignal.TrySetResult(null); + } + } + catch (Exception ex) + { + _shutdownSignal.TrySetException(ex); + } + + return _shutdownSignal.Task; } public void Dispose() { - _stopping = true; + _stopping = 1; + _shutdownSignal.TrySetResult(null); + Listener.Dispose(); } diff --git a/test/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests/ServerTests.cs b/test/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests/ServerTests.cs index 9e5bc9e444..b5c9734bb7 100644 --- a/test/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests/ServerTests.cs +++ b/test/Microsoft.AspNetCore.Server.HttpSys.FunctionalTests/ServerTests.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -406,6 +407,194 @@ namespace Microsoft.AspNetCore.Server.HttpSys } } + [ConditionalFact] + public async Task Server_MultipleStopAsyncCallsWaitForRequestsToDrain_Success() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent run = new ManualResetEvent(false); + string address; + using (var server = Utilities.CreateHttpServer(out address, httpContext => + { + received.Set(); + Assert.True(run.WaitOne(TimeSpan.FromSeconds(10))); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(address); + Assert.True(received.WaitOne(TimeSpan.FromSeconds(10))); + + var cts = new CancellationTokenSource(TimeSpan.FromSeconds(10)); + var stopTask1 = server.StopAsync(cts.Token); + var stopTask2 = server.StopAsync(cts.Token); + var stopTask3 = server.StopAsync(cts.Token); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + run.Set(); + + await Task.WhenAll(stopTask1, stopTask2, stopTask3).TimeoutAfter(TimeSpan.FromSeconds(10)); + } + string response = await responseTask; + Assert.Equal("Hello World", response); + } + + [ConditionalFact] + public async Task Server_MultipleStopAsyncCallsCompleteOnCancellation_SameToken_Success() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent run = new ManualResetEvent(false); + string address; + using (var server = Utilities.CreateHttpServer(out address, httpContext => + { + received.Set(); + Assert.True(run.WaitOne(TimeSpan.FromSeconds(10))); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(address); + Assert.True(received.WaitOne(TimeSpan.FromSeconds(10))); + + var cts = new CancellationTokenSource(); + var stopTask1 = server.StopAsync(cts.Token); + var stopTask2 = server.StopAsync(cts.Token); + var stopTask3 = server.StopAsync(cts.Token); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + cts.Cancel(); + + await Task.WhenAll(stopTask1, stopTask2, stopTask3).TimeoutAfter(TimeSpan.FromSeconds(10)); + + run.Set(); + + string response = await responseTask; + Assert.Equal("Hello World", response); + } + } + + [ConditionalFact] + public async Task Server_MultipleStopAsyncCallsCompleteOnSingleCancellation_FirstToken_Success() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent run = new ManualResetEvent(false); + string address; + using (var server = Utilities.CreateHttpServer(out address, httpContext => + { + received.Set(); + Assert.True(run.WaitOne(TimeSpan.FromSeconds(10))); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(address); + Assert.True(received.WaitOne(TimeSpan.FromSeconds(10))); + + var cts = new CancellationTokenSource(); + var stopTask1 = server.StopAsync(cts.Token); + var stopTask2 = server.StopAsync(new CancellationTokenSource().Token); + var stopTask3 = server.StopAsync(new CancellationTokenSource().Token); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + cts.Cancel(); + + await Task.WhenAll(stopTask1, stopTask2, stopTask3).TimeoutAfter(TimeSpan.FromSeconds(10)); + + run.Set(); + + string response = await responseTask; + Assert.Equal("Hello World", response); + } + } + + [ConditionalFact] + public async Task Server_MultipleStopAsyncCallsCompleteOnSingleCancellation_SubsequentToken_Success() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent run = new ManualResetEvent(false); + string address; + using (var server = Utilities.CreateHttpServer(out address, httpContext => + { + received.Set(); + Assert.True(run.WaitOne(TimeSpan.FromSeconds(10))); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(address); + Assert.True(received.WaitOne(10000)); + + var cts = new CancellationTokenSource(); + var stopTask1 = server.StopAsync(new CancellationTokenSource().Token); + var stopTask2 = server.StopAsync(cts.Token); + var stopTask3 = server.StopAsync(new CancellationTokenSource().Token); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + Assert.False(stopTask3.IsCompleted); + + cts.Cancel(); + + await Task.WhenAll(stopTask1, stopTask2, stopTask3).TimeoutAfter(TimeSpan.FromSeconds(10)); + + run.Set(); + + string response = await responseTask; + Assert.Equal("Hello World", response); + } + } + + [ConditionalFact] + public async Task Server_DisposeContinuesPendingStopAsyncCalls() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent run = new ManualResetEvent(false); + string address; + Task stopTask1; + Task stopTask2; + using (var server = Utilities.CreateHttpServer(out address, httpContext => + { + received.Set(); + Assert.True(run.WaitOne(TimeSpan.FromSeconds(10))); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(address); + Assert.True(received.WaitOne(TimeSpan.FromSeconds(10))); + + stopTask1 = server.StopAsync(new CancellationTokenSource().Token); + stopTask2 = server.StopAsync(new CancellationTokenSource().Token); + + Assert.False(stopTask1.IsCompleted); + Assert.False(stopTask2.IsCompleted); + } + + await Task.WhenAll(stopTask1, stopTask2).TimeoutAfter(TimeSpan.FromSeconds(10)); + } + + [ConditionalFact] + public async Task Server_StopAsyncCalledWithNoRequests_Success() + { + using (var server = Utilities.CreateHttpServer(out _, httpContext => Task.CompletedTask)) + { + await server.StopAsync(default(CancellationToken)).TimeoutAfter(TimeSpan.FromSeconds(10)); + } + } + private async Task SendRequestAsync(string uri) { using (HttpClient client = new HttpClient())