From f29dd60999a2800cd2d2901b72ef29d3271d1a01 Mon Sep 17 00:00:00 2001 From: Cesar Blum Silveira Date: Sun, 10 Apr 2016 21:14:08 -0700 Subject: [PATCH] Fix connection termination issues when using connection filters (#737, #747). - If we're done before the client sends a FIN, force a FIN into the raw SocketInput so the task in FileteredStreamAdapter finishes gracefully and we dispose everything in proper order. - If there's an error while writing to a stream (like ObjectDisposedException), log it once and prevent further write attempts. This means the client closed the connection while we were still writing output. - This also fixes a bug related to the point above, where memory blocks were being leaked instead of returned to the pool (because we weren't catching the exception from Write()). --- .../Filter/FilteredStreamAdapter.cs | 27 ++++++-- .../Filter/SocketInputStream.cs | 2 +- .../Filter/StreamSocketOutput.cs | 68 ++++++++++++++++--- .../Http/Connection.cs | 29 +++++--- .../Http/SocketInput.cs | 6 ++ .../ChunkedResponseTests.cs | 59 +++++++++++----- .../EngineTests.cs | 32 +++++++++ .../StreamSocketOutputTests.cs | 2 +- .../TestHelpers/TestApplicationErrorLogger.cs | 7 +- 9 files changed, 185 insertions(+), 47 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/FilteredStreamAdapter.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/FilteredStreamAdapter.cs index 2fa981ceb4..357f6175f8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/FilteredStreamAdapter.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/FilteredStreamAdapter.cs @@ -10,23 +10,27 @@ using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Filter { - public class FilteredStreamAdapter + public class FilteredStreamAdapter : IDisposable { + private readonly string _connectionId; private readonly Stream _filteredStream; private readonly Stream _socketInputStream; private readonly IKestrelTrace _log; private readonly MemoryPool _memory; private MemoryPoolBlock _block; + private bool _aborted = false; public FilteredStreamAdapter( + string connectionId, Stream filteredStream, MemoryPool memory, IKestrelTrace logger, IThreadPool threadPool) { SocketInput = new SocketInput(memory, threadPool); - SocketOutput = new StreamSocketOutput(filteredStream, memory); + SocketOutput = new StreamSocketOutput(connectionId, filteredStream, memory, logger); + _connectionId = connectionId; _log = logger; _filteredStream = filteredStream; _socketInputStream = new SocketInputStream(SocketInput); @@ -37,16 +41,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter public ISocketOutput SocketOutput { get; private set; } - public void ReadInput() + public Task ReadInputAsync() { _block = _memory.Lease(); // Use pooled block for copy - _filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) => + return _filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) => { ((FilteredStreamAdapter)state).OnStreamClose(task); }, this); } + public void Abort() + { + _aborted = true; + } + + public void Dispose() + { + SocketInput.Dispose(); + } + private void OnStreamClose(Task copyAsyncTask) { _memory.Return(_block); @@ -61,10 +75,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter SocketInput.AbortAwaiting(); _log.LogError("FilteredStreamAdapter.CopyToAsync canceled."); } + else if (_aborted) + { + SocketInput.AbortAwaiting(); + } try { - _filteredStream.Dispose(); _socketInputStream.Dispose(); } catch (Exception ex) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/SocketInputStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/SocketInputStream.cs index b10b1f9c69..d2d90b2f2d 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/SocketInputStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/SocketInputStream.cs @@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter protected override void Dispose(bool disposing) { // Close _socketInput with a fake zero-length write that will result in a zero-length read. - _socketInput.IncomingData(null, 0, 0); + _socketInput.IncomingFin(); base.Dispose(disposing); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/StreamSocketOutput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/StreamSocketOutput.cs index 7460cefdb1..f9ba1e42e0 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Filter/StreamSocketOutput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Filter/StreamSocketOutput.cs @@ -16,33 +16,57 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n"); private static readonly byte[] _nullBuffer = new byte[0]; + private readonly string _connectionId; private readonly Stream _outputStream; private readonly MemoryPool _memory; + private readonly IKestrelTrace _logger; private MemoryPoolBlock _producingBlock; + private bool _canWrite = true; + private object _writeLock = new object(); - public StreamSocketOutput(Stream outputStream, MemoryPool memory) + public StreamSocketOutput(string connectionId, Stream outputStream, MemoryPool memory, IKestrelTrace logger) { + _connectionId = connectionId; _outputStream = outputStream; _memory = memory; + _logger = logger; } public void Write(ArraySegment buffer, bool chunk) { lock (_writeLock) { - if (chunk && buffer.Array != null) + if (buffer.Count == 0 ) { - var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count); - _outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count); + return; } - _outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); - - if (chunk && buffer.Array != null) + try { - _outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length); + if (!_canWrite) + { + return; + } + + if (chunk && buffer.Array != null) + { + var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count); + _outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count); + } + + _outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); + + if (chunk && buffer.Array != null) + { + _outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length); + } + } + catch (Exception ex) + { + _canWrite = false; + _logger.ConnectionError(_connectionId, ex); } } } @@ -65,14 +89,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter var block = _producingBlock; while (block != end.Block) { - _outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count); + // If we don't handle an exception from _outputStream.Write() here, we'll leak memory blocks. + if (_canWrite) + { + try + { + _outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count); + } + catch (Exception ex) + { + _canWrite = false; + _logger.ConnectionError(_connectionId, ex); + } + } var returnBlock = block; block = block.Next; returnBlock.Pool.Return(returnBlock); } + + if (_canWrite) + { + try + { + _outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset); + } + catch (Exception ex) + { + _canWrite = false; + _logger.ConnectionError(_connectionId, ex); + } + } - _outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset); end.Block.Pool.Return(end.Block); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs index 71a40284b0..83ede18be7 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Connection.cs @@ -30,6 +30,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private Frame _frame; private ConnectionFilterContext _filterContext; private LibuvStream _libuvStream; + private FilteredStreamAdapter _filteredStreamAdapter; + private Task _readInputContinuation; private readonly SocketInput _rawSocketInput; private readonly SocketOutput _rawSocketOutput; @@ -175,13 +177,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http // Called on Libuv thread public virtual void OnSocketClosed() { - _rawSocketInput.Dispose(); - - // If a connection filter was applied there will be two SocketInputs. - // If a connection filter failed, SocketInput will be null. - if (SocketInput != null && SocketInput != _rawSocketInput) + if (_filteredStreamAdapter != null) { - SocketInput.Dispose(); + _filteredStreamAdapter.Abort(); + _rawSocketInput.IncomingFin(); + _readInputContinuation.ContinueWith((task, state) => + { + ((Connection)state)._filterContext.Connection.Dispose(); + ((Connection)state)._filteredStreamAdapter.Dispose(); + ((Connection)state)._rawSocketInput.Dispose(); + }, this); + } + else + { + _rawSocketInput.Dispose(); } lock (_stateLock) @@ -207,12 +216,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http if (_filterContext.Connection != _libuvStream) { - var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory, Log, ThreadPool); + _filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool); - SocketInput = filteredStreamAdapter.SocketInput; - SocketOutput = filteredStreamAdapter.SocketOutput; + SocketInput = _filteredStreamAdapter.SocketInput; + SocketOutput = _filteredStreamAdapter.SocketOutput; - filteredStreamAdapter.ReadInput(); + _readInputContinuation = _filteredStreamAdapter.ReadInputAsync(); } else { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs index f57d3f6289..1588e83011 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/SocketInput.cs @@ -140,6 +140,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } } + public void IncomingFin() + { + // Force a FIN + IncomingData(null, 0, 0); + } + private void Complete() { var awaitableState = Interlocked.Exchange( diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs index 59b2e16e10..fd4035cc3a 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedResponseTests.cs @@ -5,14 +5,32 @@ using System; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel; using Xunit; namespace Microsoft.AspNetCore.Server.KestrelTests { public class ChunkedResponseTests { - [Fact] - public async Task ResponsesAreChunkedAutomatically() + public static TheoryData ConnectionFilterData + { + get + { + return new TheoryData + { + { + new TestServiceContext() + }, + { + new TestServiceContext(new PassThroughConnectionFilter()) + } + }; + } + } + + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task ResponsesAreChunkedAutomatically(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -20,7 +38,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests response.Headers.Clear(); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { @@ -43,8 +61,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact] - public async Task ZeroLengthWritesAreIgnored() + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task ZeroLengthWritesAreIgnored(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -53,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); await response.Body.WriteAsync(new byte[0], 0, 0); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { @@ -76,15 +95,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact] - public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite() + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { var response = httpContext.Response; response.Headers.Clear(); await response.Body.WriteAsync(new byte[0], 0, 0); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { @@ -103,8 +123,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact] - public async Task ConnectionClosedIfExeptionThrownAfterWrite() + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task ConnectionClosedIfExeptionThrownAfterWrite(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -112,7 +133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests response.Headers.Clear(); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12); throw new Exception(); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { @@ -133,8 +154,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact] - public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite() + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -142,7 +164,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests response.Headers.Clear(); await response.Body.WriteAsync(new byte[0], 0, 0); throw new Exception(); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { @@ -162,8 +184,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } - [Fact] - public async Task WritesAreFlushedPriorToResponseCompletion() + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task WritesAreFlushedPriorToResponseCompletion(ServiceContext testContext) { var flushWh = new ManualResetEventSlim(); @@ -177,7 +200,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests flushWh.Wait(); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); - })) + }, testContext)) { using (var connection = new TestConnection(server.Port)) { diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index c068803b40..aadf9361c1 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -1018,5 +1018,37 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.True(registrationWh.Wait(1000)); } } + + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ServiceContext testContext) + { + var testLogger = new TestApplicationErrorLogger(); + testContext.Log = new KestrelTrace(testLogger); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + response.Headers.Clear(); + response.Headers["Content-Length"] = new[] { "11" }; + await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); + }, testContext)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send( + "GET / HTTP/1.0", + "", + ""); + await connection.ReceiveForcedEnd( + "HTTP/1.0 200 OK", + "Content-Length: 11", + "", + "Hello World"); + } + } + + Assert.Equal(0, testLogger.TotalErrorsLogged); + } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/StreamSocketOutputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/StreamSocketOutputTests.cs index e787ed8589..46c0384f30 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/StreamSocketOutputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/StreamSocketOutputTests.cs @@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // As it calls ProduceStart with write immediate == true // This happens in WebSocket Upgrade over SSL - ISocketOutput socketOutput = new StreamSocketOutput(new ThrowsOnNullWriteStream(), null); + ISocketOutput socketOutput = new StreamSocketOutput("id", new ThrowsOnNullWriteStream(), null, new TestKestrelTrace()); // Should not throw socketOutput.Write(default(ArraySegment), true); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/TestApplicationErrorLogger.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/TestApplicationErrorLogger.cs index de8d462ed9..ea02bb3c64 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/TestApplicationErrorLogger.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/TestHelpers/TestApplicationErrorLogger.cs @@ -32,12 +32,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}"); #endif - TotalErrorsLogged++; - if (eventId.Id == ApplicationErrorEventId) { ApplicationErrorsLogged++; } + + if (logLevel == LogLevel.Error) + { + TotalErrorsLogged++; + } } } }