Fix connection termination issues when using connection filters (#737, #747).

- If we're done before the client sends a FIN, force a FIN into the raw
  SocketInput so the task in FileteredStreamAdapter finishes gracefully
  and we dispose everything in proper order.
- If there's an error while writing to a stream (like ObjectDisposedException),
  log it once and prevent further write attempts. This means the client closed
  the connection while we were still writing output.
- This also fixes a bug related to the point above, where memory blocks were
  being leaked instead of returned to the pool (because we weren't catching
  the exception from Write()).
This commit is contained in:
Cesar Blum Silveira 2016-04-10 21:14:08 -07:00
parent 33ad355114
commit f29dd60999
9 changed files with 185 additions and 47 deletions

View File

@ -10,23 +10,27 @@ using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Server.Kestrel.Filter namespace Microsoft.AspNetCore.Server.Kestrel.Filter
{ {
public class FilteredStreamAdapter public class FilteredStreamAdapter : IDisposable
{ {
private readonly string _connectionId;
private readonly Stream _filteredStream; private readonly Stream _filteredStream;
private readonly Stream _socketInputStream; private readonly Stream _socketInputStream;
private readonly IKestrelTrace _log; private readonly IKestrelTrace _log;
private readonly MemoryPool _memory; private readonly MemoryPool _memory;
private MemoryPoolBlock _block; private MemoryPoolBlock _block;
private bool _aborted = false;
public FilteredStreamAdapter( public FilteredStreamAdapter(
string connectionId,
Stream filteredStream, Stream filteredStream,
MemoryPool memory, MemoryPool memory,
IKestrelTrace logger, IKestrelTrace logger,
IThreadPool threadPool) IThreadPool threadPool)
{ {
SocketInput = new SocketInput(memory, threadPool); SocketInput = new SocketInput(memory, threadPool);
SocketOutput = new StreamSocketOutput(filteredStream, memory); SocketOutput = new StreamSocketOutput(connectionId, filteredStream, memory, logger);
_connectionId = connectionId;
_log = logger; _log = logger;
_filteredStream = filteredStream; _filteredStream = filteredStream;
_socketInputStream = new SocketInputStream(SocketInput); _socketInputStream = new SocketInputStream(SocketInput);
@ -37,16 +41,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
public ISocketOutput SocketOutput { get; private set; } public ISocketOutput SocketOutput { get; private set; }
public void ReadInput() public Task ReadInputAsync()
{ {
_block = _memory.Lease(); _block = _memory.Lease();
// Use pooled block for copy // Use pooled block for copy
_filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) => return _filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
{ {
((FilteredStreamAdapter)state).OnStreamClose(task); ((FilteredStreamAdapter)state).OnStreamClose(task);
}, this); }, this);
} }
public void Abort()
{
_aborted = true;
}
public void Dispose()
{
SocketInput.Dispose();
}
private void OnStreamClose(Task copyAsyncTask) private void OnStreamClose(Task copyAsyncTask)
{ {
_memory.Return(_block); _memory.Return(_block);
@ -61,10 +75,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
SocketInput.AbortAwaiting(); SocketInput.AbortAwaiting();
_log.LogError("FilteredStreamAdapter.CopyToAsync canceled."); _log.LogError("FilteredStreamAdapter.CopyToAsync canceled.");
} }
else if (_aborted)
{
SocketInput.AbortAwaiting();
}
try try
{ {
_filteredStream.Dispose();
_socketInputStream.Dispose(); _socketInputStream.Dispose();
} }
catch (Exception ex) catch (Exception ex)

View File

@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
protected override void Dispose(bool disposing) protected override void Dispose(bool disposing)
{ {
// Close _socketInput with a fake zero-length write that will result in a zero-length read. // Close _socketInput with a fake zero-length write that will result in a zero-length read.
_socketInput.IncomingData(null, 0, 0); _socketInput.IncomingFin();
base.Dispose(disposing); base.Dispose(disposing);
} }
} }

View File

@ -16,33 +16,57 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n"); private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n");
private static readonly byte[] _nullBuffer = new byte[0]; private static readonly byte[] _nullBuffer = new byte[0];
private readonly string _connectionId;
private readonly Stream _outputStream; private readonly Stream _outputStream;
private readonly MemoryPool _memory; private readonly MemoryPool _memory;
private readonly IKestrelTrace _logger;
private MemoryPoolBlock _producingBlock; private MemoryPoolBlock _producingBlock;
private bool _canWrite = true;
private object _writeLock = new object(); private object _writeLock = new object();
public StreamSocketOutput(Stream outputStream, MemoryPool memory) public StreamSocketOutput(string connectionId, Stream outputStream, MemoryPool memory, IKestrelTrace logger)
{ {
_connectionId = connectionId;
_outputStream = outputStream; _outputStream = outputStream;
_memory = memory; _memory = memory;
_logger = logger;
} }
public void Write(ArraySegment<byte> buffer, bool chunk) public void Write(ArraySegment<byte> buffer, bool chunk)
{ {
lock (_writeLock) lock (_writeLock)
{ {
if (chunk && buffer.Array != null) if (buffer.Count == 0 )
{ {
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count); return;
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
} }
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count); try
if (chunk && buffer.Array != null)
{ {
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length); if (!_canWrite)
{
return;
}
if (chunk && buffer.Array != null)
{
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count);
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
}
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count);
if (chunk && buffer.Array != null)
{
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length);
}
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
} }
} }
} }
@ -65,14 +89,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
var block = _producingBlock; var block = _producingBlock;
while (block != end.Block) while (block != end.Block)
{ {
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count); // If we don't handle an exception from _outputStream.Write() here, we'll leak memory blocks.
if (_canWrite)
{
try
{
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count);
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
}
}
var returnBlock = block; var returnBlock = block;
block = block.Next; block = block.Next;
returnBlock.Pool.Return(returnBlock); returnBlock.Pool.Return(returnBlock);
} }
if (_canWrite)
{
try
{
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
}
}
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
end.Block.Pool.Return(end.Block); end.Block.Pool.Return(end.Block);
} }
} }

View File

@ -30,6 +30,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private Frame _frame; private Frame _frame;
private ConnectionFilterContext _filterContext; private ConnectionFilterContext _filterContext;
private LibuvStream _libuvStream; private LibuvStream _libuvStream;
private FilteredStreamAdapter _filteredStreamAdapter;
private Task _readInputContinuation;
private readonly SocketInput _rawSocketInput; private readonly SocketInput _rawSocketInput;
private readonly SocketOutput _rawSocketOutput; private readonly SocketOutput _rawSocketOutput;
@ -175,13 +177,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
// Called on Libuv thread // Called on Libuv thread
public virtual void OnSocketClosed() public virtual void OnSocketClosed()
{ {
_rawSocketInput.Dispose(); if (_filteredStreamAdapter != null)
// If a connection filter was applied there will be two SocketInputs.
// If a connection filter failed, SocketInput will be null.
if (SocketInput != null && SocketInput != _rawSocketInput)
{ {
SocketInput.Dispose(); _filteredStreamAdapter.Abort();
_rawSocketInput.IncomingFin();
_readInputContinuation.ContinueWith((task, state) =>
{
((Connection)state)._filterContext.Connection.Dispose();
((Connection)state)._filteredStreamAdapter.Dispose();
((Connection)state)._rawSocketInput.Dispose();
}, this);
}
else
{
_rawSocketInput.Dispose();
} }
lock (_stateLock) lock (_stateLock)
@ -207,12 +216,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
if (_filterContext.Connection != _libuvStream) if (_filterContext.Connection != _libuvStream)
{ {
var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory, Log, ThreadPool); _filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool);
SocketInput = filteredStreamAdapter.SocketInput; SocketInput = _filteredStreamAdapter.SocketInput;
SocketOutput = filteredStreamAdapter.SocketOutput; SocketOutput = _filteredStreamAdapter.SocketOutput;
filteredStreamAdapter.ReadInput(); _readInputContinuation = _filteredStreamAdapter.ReadInputAsync();
} }
else else
{ {

View File

@ -140,6 +140,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
} }
} }
public void IncomingFin()
{
// Force a FIN
IncomingData(null, 0, 0);
}
private void Complete() private void Complete()
{ {
var awaitableState = Interlocked.Exchange( var awaitableState = Interlocked.Exchange(

View File

@ -5,14 +5,32 @@ using System;
using System.Text; using System.Text;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.Server.KestrelTests namespace Microsoft.AspNetCore.Server.KestrelTests
{ {
public class ChunkedResponseTests public class ChunkedResponseTests
{ {
[Fact] public static TheoryData<ServiceContext> ConnectionFilterData
public async Task ResponsesAreChunkedAutomatically() {
get
{
return new TheoryData<ServiceContext>
{
{
new TestServiceContext()
},
{
new TestServiceContext(new PassThroughConnectionFilter())
}
};
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ResponsesAreChunkedAutomatically(ServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -20,7 +38,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear(); response.Headers.Clear();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
@ -43,8 +61,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
[Fact] [Theory]
public async Task ZeroLengthWritesAreIgnored() [MemberData(nameof(ConnectionFilterData))]
public async Task ZeroLengthWritesAreIgnored(ServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -53,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
await response.Body.WriteAsync(new byte[0], 0, 0); await response.Body.WriteAsync(new byte[0], 0, 0);
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
@ -76,15 +95,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
[Fact] [Theory]
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite() [MemberData(nameof(ConnectionFilterData))]
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite(ServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
var response = httpContext.Response; var response = httpContext.Response;
response.Headers.Clear(); response.Headers.Clear();
await response.Body.WriteAsync(new byte[0], 0, 0); await response.Body.WriteAsync(new byte[0], 0, 0);
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
@ -103,8 +123,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
[Fact] [Theory]
public async Task ConnectionClosedIfExeptionThrownAfterWrite() [MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterWrite(ServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -112,7 +133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear(); response.Headers.Clear();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12);
throw new Exception(); throw new Exception();
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
@ -133,8 +154,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
[Fact] [Theory]
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite() [MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(ServiceContext testContext)
{ {
using (var server = new TestServer(async httpContext => using (var server = new TestServer(async httpContext =>
{ {
@ -142,7 +164,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear(); response.Headers.Clear();
await response.Body.WriteAsync(new byte[0], 0, 0); await response.Body.WriteAsync(new byte[0], 0, 0);
throw new Exception(); throw new Exception();
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {
@ -162,8 +184,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
} }
} }
[Fact] [Theory]
public async Task WritesAreFlushedPriorToResponseCompletion() [MemberData(nameof(ConnectionFilterData))]
public async Task WritesAreFlushedPriorToResponseCompletion(ServiceContext testContext)
{ {
var flushWh = new ManualResetEventSlim(); var flushWh = new ManualResetEventSlim();
@ -177,7 +200,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
flushWh.Wait(); flushWh.Wait();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6); await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
})) }, testContext))
{ {
using (var connection = new TestConnection(server.Port)) using (var connection = new TestConnection(server.Port))
{ {

View File

@ -1018,5 +1018,37 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.True(registrationWh.Wait(1000)); Assert.True(registrationWh.Wait(1000));
} }
} }
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ServiceContext testContext)
{
var testLogger = new TestApplicationErrorLogger();
testContext.Log = new KestrelTrace(testLogger);
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
response.Headers.Clear();
response.Headers["Content-Length"] = new[] { "11" };
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
await connection.Send(
"GET / HTTP/1.0",
"",
"");
await connection.ReceiveForcedEnd(
"HTTP/1.0 200 OK",
"Content-Length: 11",
"",
"Hello World");
}
}
Assert.Equal(0, testLogger.TotalErrorsLogged);
}
} }
} }

View File

@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
// As it calls ProduceStart with write immediate == true // As it calls ProduceStart with write immediate == true
// This happens in WebSocket Upgrade over SSL // This happens in WebSocket Upgrade over SSL
ISocketOutput socketOutput = new StreamSocketOutput(new ThrowsOnNullWriteStream(), null); ISocketOutput socketOutput = new StreamSocketOutput("id", new ThrowsOnNullWriteStream(), null, new TestKestrelTrace());
// Should not throw // Should not throw
socketOutput.Write(default(ArraySegment<byte>), true); socketOutput.Write(default(ArraySegment<byte>), true);

View File

@ -32,12 +32,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}"); Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}");
#endif #endif
TotalErrorsLogged++;
if (eventId.Id == ApplicationErrorEventId) if (eventId.Id == ApplicationErrorEventId)
{ {
ApplicationErrorsLogged++; ApplicationErrorsLogged++;
} }
if (logLevel == LogLevel.Error)
{
TotalErrorsLogged++;
}
} }
} }
} }