From b73e42b617a84f425a7c43ffd6198e6d539f672e Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Tue, 3 Nov 2015 15:34:29 -0800 Subject: [PATCH] Abort request on any write failure --- .../Http/Connection.cs | 18 +++++- .../Http/SocketOutput.cs | 24 +++++-- .../EngineTests.cs | 64 +++++++++++++++++++ .../SocketOutputTests.cs | 8 +-- 4 files changed, 103 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs index 182bc98654..4e95d3d1ca 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs @@ -42,7 +42,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http _connectionId = Interlocked.Increment(ref _lastConnectionId); _rawSocketInput = new SocketInput(Memory2); - _rawSocketOutput = new SocketOutput(Thread, _socket, _connectionId, Log); + _rawSocketOutput = new SocketOutput(Thread, _socket, this, _connectionId, Log); } public void Start() @@ -100,6 +100,20 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } + public void Abort() + { + if (_frame != null) + { + // Frame.Abort calls user code while this method is always + // called from a libuv thread. + ThreadPool.QueueUserWorkItem(state => + { + var connection = (Connection)this; + connection._frame.Abort(); + }, this); + } + } + private void ApplyConnectionFilter() { var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory2, Log); @@ -157,7 +171,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http if (errorDone) { - _frame.Abort(); + Abort(); } } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs index d2534206e1..d1e353388f 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketOutput.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Server.Kestrel.Infrastructure; @@ -18,6 +19,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http private readonly KestrelThread _thread; private readonly UvStreamHandle _socket; + private readonly Connection _connection; private readonly long _connectionId; private readonly IKestrelTrace _log; @@ -33,10 +35,16 @@ namespace Microsoft.AspNet.Server.Kestrel.Http private WriteContext _nextWriteContext; private readonly Queue> _tasksPending; - public SocketOutput(KestrelThread thread, UvStreamHandle socket, long connectionId, IKestrelTrace log) + public SocketOutput( + KestrelThread thread, + UvStreamHandle socket, + Connection connection, + long connectionId, + IKestrelTrace log) { _thread = thread; _socket = socket; + _connection = connection; _connectionId = connectionId; _log = log; _tasksPending = new Queue>(); @@ -176,10 +184,16 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { _log.ConnectionWriteCallback(_connectionId, status); + if (error != null) + { + _lastWriteError = new IOException(error.Message, error); + + // Abort the connection for any failed write. + _connection.Abort(); + } + lock (_lockObj) { - _lastWriteError = error; - if (_nextWriteContext != null) { ScheduleWrite(); @@ -208,7 +222,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http _numBytesPreCompleted += bytesToWrite; bytesLeftToBuffer -= bytesToWrite; - if (error == null) + if (_lastWriteError == null) { ThreadPool.QueueUserWorkItem( (o) => ((TaskCompletionSource)o).SetResult(null), @@ -218,7 +232,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { // error is closure captured ThreadPool.QueueUserWorkItem( - (o) => ((TaskCompletionSource)o).SetException(error), + (o) => ((TaskCompletionSource)o).SetException(_lastWriteError), tcs); } } diff --git a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs index 48db5b9f21..f19b2b3077 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs @@ -967,6 +967,8 @@ namespace Microsoft.AspNet.Server.KestrelTests readTcs.SetException(ex); throw; } + + readTcs.SetCanceled(); } }, testContext)) { @@ -997,6 +999,68 @@ namespace Microsoft.AspNet.Server.KestrelTests Assert.Equal(2, abortedRequestId); } + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task FailedWritesResultInAbortedRequest(ServiceContext testContext) + { + var writeTcs = new TaskCompletionSource(); + var registrationWh = new ManualResetEventSlim(); + var connectionCloseWh = new ManualResetEventSlim(); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + var request = httpContext.Request; + var lifetime = httpContext.Features.Get(); + + lifetime.RequestAborted.Register(() => registrationWh.Set()); + + await request.Body.CopyToAsync(Stream.Null); + connectionCloseWh.Wait(); + + response.Headers.Clear(); + response.Headers["Content-Length"] = new[] { "5" }; + + try + { + // Ensure write is long enough to disable write-behind buffering + for (int i = 0; i < 10; i++) + { + await response.WriteAsync(new string('a', 65537)); + } + } + catch (Exception ex) + { + writeTcs.SetException(ex); + + // Give a chance for RequestAborted to trip before the app completes + registrationWh.Wait(1000); + + throw; + } + + writeTcs.SetCanceled(); + }, testContext)) + { + using (var connection = new TestConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Content-Length: 5", + "", + "Hello"); + // Don't wait to receive the response. Just close the socket. + } + + connectionCloseWh.Set(); + + // Write failed + await Assert.ThrowsAsync(async () => await writeTcs.Task); + // RequestAborted tripped + Assert.True(registrationWh.Wait(200)); + } + } + private class TestApplicationErrorLogger : ILogger { public int ApplicationErrorsLogged { get; set; } diff --git a/test/Microsoft.AspNet.Server.KestrelTests/SocketOutputTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/SocketOutputTests.cs index d8d4f5cee7..69224073dc 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/SocketOutputTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/SocketOutputTests.cs @@ -40,7 +40,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var kestrelThread = kestrelEngine.Threads[0]; var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); var trace = new KestrelTrace(new TestKestrelTrace()); - var socketOutput = new SocketOutput(kestrelThread, socket, 0, trace); + var socketOutput = new SocketOutput(kestrelThread, socket, null, 0, trace); // I doubt _maxBytesPreCompleted will ever be over a MB. If it is, we should change this test. var bufferSize = 1048576; @@ -85,7 +85,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var kestrelThread = kestrelEngine.Threads[0]; var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); var trace = new KestrelTrace(new TestKestrelTrace()); - var socketOutput = new SocketOutput(kestrelThread, socket, 0, trace); + var socketOutput = new SocketOutput(kestrelThread, socket, null, 0, trace); var bufferSize = maxBytesPreCompleted; var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize); @@ -140,7 +140,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var kestrelThread = kestrelEngine.Threads[0]; var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); var trace = new KestrelTrace(new TestKestrelTrace()); - var socketOutput = new SocketOutput(kestrelThread, socket, 0, trace); + var socketOutput = new SocketOutput(kestrelThread, socket, null, 0, trace); var bufferSize = maxBytesPreCompleted; @@ -219,7 +219,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var kestrelThread = kestrelEngine.Threads[0]; var socket = new MockSocket(kestrelThread.Loop.ThreadId, new TestKestrelTrace()); var trace = new KestrelTrace(new TestKestrelTrace()); - var socketOutput = new SocketOutput(kestrelThread, socket, 0, trace); + var socketOutput = new SocketOutput(kestrelThread, socket, null, 0, trace); var bufferSize = maxBytesPreCompleted; var buffer = new ArraySegment(new byte[bufferSize], 0, bufferSize);