From 4f0480a4d02b09e32e36dbd2f07f1a2d94530fba Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 20 Jul 2015 12:26:39 -0700 Subject: [PATCH] Gracefully handle exceptions thrown from OnStarting callbacks - If OnStarting is being called after the app func has completed, return a 500. - If Onstarting is being called due to a call to write, throw from write. --- .../Http/Frame.cs | 19 ++-- .../EngineTests.cs | 87 +++++++++++++++++++ 2 files changed, 99 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index fd584323df..7dc2603bc3 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -25,7 +25,6 @@ namespace Microsoft.AspNet.Server.Kestrel.Http static Encoding _ascii = Encoding.ASCII; Mode _mode; - private bool _resultStarted; private bool _responseStarted; private bool _keepAlive; private readonly FrameRequestHeaders _requestHeaders = new FrameRequestHeaders(); @@ -246,6 +245,14 @@ namespace Microsoft.AspNet.Server.Kestrel.Http try { await Application.Invoke(this).ConfigureAwait(false); + + // Trigger FireOnStarting if ProduceStart hasn't been called yet. + // We call it here, so it can go through our normal error handling + // and respond with a 500 if an OnStarting callback throws. + if (!_responseStarted) + { + FireOnStarting(); + } } catch (Exception ex) { @@ -284,7 +291,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public void ProduceContinue() { - if (_resultStarted) return; + if (_responseStarted) return; string[] expect; if (HttpVersion.Equals("HTTP/1.1") && @@ -306,11 +313,9 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public void ProduceStart(bool immediate = true) { - if (_resultStarted) return; - _resultStarted = true; - + // ProduceStart shouldn't no-op in the future just b/c FireOnStarting throws. + if (_responseStarted) return; FireOnStarting(); - _responseStarted = true; var status = ReasonPhrases.ToStatus(StatusCode, ReasonPhrase); @@ -334,7 +339,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { if (ex != null) { - if (_resultStarted) + if (_responseStarted) { // We can no longer respond with a 500, so we simply close the connection. ConnectionControl.End(ProduceEndType.SocketDisconnect); diff --git a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs index 771471f632..27526edd83 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs @@ -535,5 +535,92 @@ namespace Microsoft.AspNet.Server.KestrelTests } } } + + [Fact] + public async Task ThrowingInOnStartingResultsIn500Response() + { + using (var server = new TestServer(frame => + { + frame.OnStarting(_ => + { + throw new Exception(); + }, null); + + frame.ResponseHeaders.Clear(); + frame.ResponseHeaders["Content-Length"] = new[] { "11" }; + + // If we write to the response stream, we will not get a 500. + + return Task.FromResult(null); + })) + { + using (var connection = new TestConnection()) + { + await connection.SendEnd( + "GET / HTTP/1.1", + "", + "GET / HTTP/1.1", + "Connection: close", + "", + ""); + await connection.Receive( + "HTTP/1.1 500 Internal Server Error", + ""); + await connection.ReceiveStartsWith("Date:"); + await connection.Receive( + "Content-Length: 0", + "Server: Kestrel", + "", + "HTTP/1.1 500 Internal Server Error", + ""); + await connection.ReceiveStartsWith("Date:"); + await connection.ReceiveEnd( + "Content-Length: 0", + "Server: Kestrel", + "Connection: close", + "", + ""); + } + } + } + + [Fact] + public async Task ThrowingInOnStartingResultsInFailedWrites() + { + using (var server = new TestServer(async frame => + { + var onStartingException = new Exception(); + + frame.OnStarting(_ => + { + throw onStartingException; + }, null); + + frame.ResponseHeaders.Clear(); + frame.ResponseHeaders["Content-Length"] = new[] { "11" }; + + var writeException = await Assert.ThrowsAsync(async () => + await frame.ResponseBody.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11)); + + Assert.Same(onStartingException, writeException); + + // The second write should succeed since the OnStarting callback will not be called again + await frame.ResponseBody.WriteAsync(Encoding.ASCII.GetBytes("Exception!!"), 0, 11); + })) + { + using (var connection = new TestConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + "Content-Length: 11", + "", + "Exception!!"); ; + } + } + } } }