diff --git a/src/Microsoft.AspNet.Server.WebListener/MessagePump.cs b/src/Microsoft.AspNet.Server.WebListener/MessagePump.cs index bcf7ba653f..890f9b9e63 100644 --- a/src/Microsoft.AspNet.Server.WebListener/MessagePump.cs +++ b/src/Microsoft.AspNet.Server.WebListener/MessagePump.cs @@ -39,6 +39,10 @@ namespace Microsoft.AspNet.Server.WebListener private int _acceptorCounts; private Action _processRequest; + private bool _stopping; + private int _outstandingRequests; + private ManualResetEvent _shutdownSignal; + // TODO: private IDictionary _capabilities; internal MessagePump(Microsoft.Net.Server.WebListener listener, ILoggerFactory loggerFactory) @@ -49,6 +53,7 @@ namespace Microsoft.AspNet.Server.WebListener _processRequest = new Action(ProcessRequestAsync); _maxAccepts = DefaultMaxAccepts; + _shutdownSignal = new ManualResetEvent(false); } internal Microsoft.Net.Server.WebListener Listener @@ -108,7 +113,7 @@ namespace Microsoft.AspNet.Server.WebListener private async void ProcessRequestsWorker() { int workerIndex = Interlocked.Increment(ref _acceptorCounts); - while (_listener.IsListening && workerIndex <= MaxAccepts) + while (!_stopping && workerIndex <= MaxAccepts) { // Receive a request RequestContext requestContext; @@ -124,7 +129,7 @@ namespace Microsoft.AspNet.Server.WebListener } try { - Task.Factory.StartNew(_processRequest, requestContext); + Task ignored = Task.Factory.StartNew(_processRequest, requestContext); } catch (Exception ex) { @@ -141,11 +146,18 @@ namespace Microsoft.AspNet.Server.WebListener var requestContext = requestContextObj as RequestContext; try { + if (_stopping) + { + SetFatalResponse(requestContext, 503); + return; + } try { + Interlocked.Increment(ref _outstandingRequests); FeatureContext featureContext = new FeatureContext(requestContext); await _appFunc(featureContext.Features).SupressContext(); // TODO: WebSocket/Opaque upgrade - await requestContext.ProcessResponseAsync().SupressContext(); + requestContext.Dispose(); } catch (Exception ex) { @@ -157,10 +169,16 @@ namespace Microsoft.AspNet.Server.WebListener else { // We haven't sent a response yet, try to send a 500 Internal Server Error - SetFatalResponse(requestContext); + SetFatalResponse(requestContext, 500); + } + } + finally + { + if (Interlocked.Decrement(ref _outstandingRequests) == 0 && _stopping) + { + _shutdownSignal.Set(); } } - requestContext.Dispose(); } catch (Exception ex) { @@ -169,16 +187,24 @@ namespace Microsoft.AspNet.Server.WebListener } } - private static void SetFatalResponse(RequestContext context) + private static void SetFatalResponse(RequestContext context, int status) { - context.Response.StatusCode = 500; + context.Response.StatusCode = status; context.Response.ReasonPhrase = string.Empty; context.Response.Headers.Clear(); context.Response.ContentLength = 0; + context.Dispose(); } public void Dispose() { + _stopping = true; + // Wait for active requests to drain + if (_outstandingRequests > 0) + { + _shutdownSignal.WaitOne(); + } + // All requests are finished _listener.Dispose(); } } diff --git a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs index c6f47bb48c..0b84f3a63a 100644 --- a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs +++ b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs @@ -80,6 +80,26 @@ namespace Microsoft.AspNet.Server.WebListener } } + [Fact] + public async Task Server_ShutdownDurringRequest_Success() + { + Task responseTask; + ManualResetEvent received = new ManualResetEvent(false); + using (Utilities.CreateHttpServer(env => + { + received.Set(); + var httpContext = new DefaultHttpContext((IFeatureCollection)env); + httpContext.Response.ContentLength = 11; + return httpContext.Response.WriteAsync("Hello World"); + })) + { + responseTask = SendRequestAsync(Address); + Assert.True(received.WaitOne(10000)); + } + string response = await responseTask; + Assert.Equal("Hello World", response); + } + [Fact] public void Server_AppException_ClientReset() {