diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs index 293c3facc4..126dce7891 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs @@ -275,7 +275,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http ResponseHeaders["Upgrade"] = values; } } - ProduceStart(); + ProduceStartAndFireOnStarting(immediate: true).GetAwaiter().GetResult(); return Task.FromResult(DuplexStream); } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index 7bcc55630c..161536aef3 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -42,6 +42,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http private bool _responseStarted; private bool _keepAlive; private bool _autoChunk; + private bool _applicationFailed; public Frame(ConnectionContext context) : base(context) { @@ -172,28 +173,27 @@ namespace Microsoft.AspNet.Server.Kestrel.Http ResponseBody = new FrameResponseStream(this); DuplexStream = new FrameDuplexStream(RequestBody, ResponseBody); - Exception error = null; try { await Application.Invoke(this).ConfigureAwait(false); - + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + finally + { // Trigger FireOnStarting if ProduceStart hasn't been called yet. // We call it here, so it can go through our normal error handling // and respond with a 500 if an OnStarting callback throws. if (!_responseStarted) { - FireOnStarting(); + await FireOnStarting(); } - } - catch (Exception ex) - { - Log.ApplicationError(ex); - error = ex; - } - finally - { - FireOnCompleted(); - ProduceEnd(error); + + await FireOnCompleted(); + + await ProduceEnd(); } terminated = !_keepAlive; @@ -250,7 +250,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } - private void FireOnStarting() + private async Task FireOnStarting() { List, object>> onStarting = null; lock (_onStartingSync) @@ -262,12 +262,19 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { foreach (var entry in onStarting) { - entry.Key.Invoke(entry.Value).Wait(); + try + { + await entry.Key.Invoke(entry.Value); + } + catch (Exception ex) + { + ReportApplicationError(ex); + } } } } - private void FireOnCompleted() + private async Task FireOnCompleted() { List, object>> onCompleted = null; lock (_onCompletedSync) @@ -281,11 +288,11 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { try { - entry.Key.Invoke(entry.Value).Wait(); + await entry.Key.Invoke(entry.Value); } - catch + catch (Exception ex) { - // Ignore exceptions + ReportApplicationError(ex); } } } @@ -293,19 +300,19 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public void Flush() { - ProduceStart(immediate: false); + ProduceStartAndFireOnStarting(immediate: false).GetAwaiter().GetResult(); SocketOutput.Write(_emptyData, immediate: true); } - public Task FlushAsync(CancellationToken cancellationToken) + public async Task FlushAsync(CancellationToken cancellationToken) { - ProduceStart(immediate: false); - return SocketOutput.WriteAsync(_emptyData, immediate: true); + await ProduceStartAndFireOnStarting(immediate: false); + await SocketOutput.WriteAsync(_emptyData, immediate: true); } public void Write(ArraySegment data) { - ProduceStart(immediate: false); + ProduceStartAndFireOnStarting(immediate: false).GetAwaiter().GetResult(); if (_autoChunk) { @@ -321,21 +328,21 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } - public Task WriteAsync(ArraySegment data, CancellationToken cancellationToken) + public async Task WriteAsync(ArraySegment data, CancellationToken cancellationToken) { - ProduceStart(immediate: false); + await ProduceStartAndFireOnStarting(immediate: false); if (_autoChunk) { if (data.Count == 0) { - return TaskUtilities.CompletedTask; + return; } - return WriteChunkedAsync(data, cancellationToken); + await WriteChunkedAsync(data, cancellationToken); } else { - return SocketOutput.WriteAsync(data, immediate: true, cancellationToken: cancellationToken); + await SocketOutput.WriteAsync(data, immediate: true, cancellationToken: cancellationToken); } } @@ -387,12 +394,6 @@ namespace Microsoft.AspNet.Server.Kestrel.Http SocketOutput.Write(_endChunkedResponseBytes, immediate: true); } - public void Upgrade(IDictionary options, Func callback) - { - _keepAlive = false; - ProduceStart(); - } - private static ArraySegment CreateAsciiByteArraySegment(string text) { var bytes = Encoding.ASCII.GetBytes(text); @@ -412,23 +413,38 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } - public void ProduceStart(bool immediate = true, bool appCompleted = false) + public async Task ProduceStartAndFireOnStarting(bool immediate = true) + { + if (_responseStarted) return; + + await FireOnStarting(); + + if (_applicationFailed) + { + throw new ObjectDisposedException(typeof(Frame).FullName); + } + + await ProduceStart(immediate, appCompleted: false); + } + + private async Task ProduceStart(bool immediate, bool appCompleted) { - // ProduceStart shouldn't no-op in the future just b/c FireOnStarting throws. if (_responseStarted) return; - FireOnStarting(); _responseStarted = true; var status = ReasonPhrases.ToStatus(StatusCode, ReasonPhrase); var responseHeader = CreateResponseHeader(status, appCompleted); - SocketOutput.Write(responseHeader.Item1, immediate: immediate); - responseHeader.Item2.Dispose(); + + using (responseHeader.Item2) + { + await SocketOutput.WriteAsync(responseHeader.Item1, immediate: immediate); + } } - public void ProduceEnd(Exception ex) + private async Task ProduceEnd() { - if (ex != null) + if (_applicationFailed) { if (_responseStarted) { @@ -450,7 +466,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } - ProduceStart(immediate: true, appCompleted: true); + await ProduceStart(immediate: true, appCompleted: true); // _autoChunk should be checked after we are sure ProduceStart() has been called // since ProduceStart() may set _autoChunk to true. @@ -728,5 +744,11 @@ namespace Microsoft.AspNet.Server.Kestrel.Http statusCode != 205 && statusCode != 304; } + + private void ReportApplicationError(Exception ex) + { + _applicationFailed = true; + Log.ApplicationError(ex); + } } } diff --git a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs index b426214341..b5c36efce0 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs @@ -550,7 +550,7 @@ namespace Microsoft.AspNet.Server.KestrelTests [MemberData(nameof(ConnectionFilterData))] public async Task ThrowingResultsIn500Response(ServiceContext testContext) { - bool onStartingCalled = false; + var onStartingCallCount = 0; var testLogger = new TestApplicationErrorLogger(); testContext.Log = new KestrelTrace(testLogger); @@ -560,7 +560,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var response = frame.Get(); response.OnStarting(_ => { - onStartingCalled = true; + onStartingCallCount++; return Task.FromResult(null); }, null); @@ -597,7 +597,7 @@ namespace Microsoft.AspNet.Server.KestrelTests "", ""); - Assert.False(onStartingCalled); + Assert.Equal(2, onStartingCallCount); Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } @@ -758,22 +758,30 @@ namespace Microsoft.AspNet.Server.KestrelTests [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task ThrowingInOnStartingResultsIn500Response(ServiceContext testContext) + public async Task ThrowingInOnStartingResultsInFailedWritesAnd500Response(ServiceContext testContext) { - using (var server = new TestServer(frame => + var onStartingCallCount = 0; + var failedWriteCount = 0; + + var testLogger = new TestApplicationErrorLogger(); + testContext.Log = new KestrelTrace(testLogger); + + using (var server = new TestServer(async frame => { var response = frame.Get(); response.OnStarting(_ => { + onStartingCallCount++; throw new Exception(); }, null); response.Headers.Clear(); response.Headers["Content-Length"] = new[] { "11" }; - // If we write to the response stream, we will not get a 500. + await Assert.ThrowsAsync(async () => + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); - return Task.FromResult(null); + failedWriteCount++; }, testContext)) { using (var connection = new TestConnection()) @@ -802,34 +810,31 @@ namespace Microsoft.AspNet.Server.KestrelTests "Connection: close", "", ""); + + Assert.Equal(2, onStartingCallCount); + Assert.Equal(2, testLogger.ApplicationErrorsLogged); } } } - [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task ThrowingInOnStartingResultsInFailedWrites(ServiceContext testContext) + public async Task ThrowingInOnCompletedIsLoggedAndClosesConnection(ServiceContext testContext) { + var testLogger = new TestApplicationErrorLogger(); + testContext.Log = new KestrelTrace(testLogger); + using (var server = new TestServer(async frame => { - var onStartingException = new Exception(); - var response = frame.Get(); - response.OnStarting(_ => + response.OnCompleted(_ => { - throw onStartingException; + throw new Exception(); }, null); response.Headers.Clear(); response.Headers["Content-Length"] = new[] { "11" }; - var writeException = await Assert.ThrowsAsync(async () => - await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); - - Assert.Same(onStartingException, writeException); - - // The second write should succeed since the OnStarting callback will not be called again - await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Exception!!"), 0, 11); + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); }, testContext)) { using (var connection = new TestConnection()) @@ -838,12 +843,14 @@ namespace Microsoft.AspNet.Server.KestrelTests "GET / HTTP/1.1", "", ""); - await connection.Receive( + await connection.ReceiveEnd( "HTTP/1.1 200 OK", "Content-Length: 11", "", - "Exception!!"); ; + "Hello World"); } + + Assert.Equal(1, testLogger.ApplicationErrorsLogged); } }