Fix connection termination issues when using connection filters (#737, #747).

- If we're done before the client sends a FIN, force a FIN into the raw
  SocketInput so the task in FileteredStreamAdapter finishes gracefully
  and we dispose everything in proper order.
- If there's an error while writing to a stream (like ObjectDisposedException),
  log it once and prevent further write attempts. This means the client closed
  the connection while we were still writing output.
- This also fixes a bug related to the point above, where memory blocks were
  being leaked instead of returned to the pool (because we weren't catching
  the exception from Write()).
This commit is contained in:
Cesar Blum Silveira 2016-04-10 21:14:08 -07:00
parent 33ad355114
commit f29dd60999
9 changed files with 185 additions and 47 deletions

View File

@ -10,23 +10,27 @@ using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Server.Kestrel.Filter
{
public class FilteredStreamAdapter
public class FilteredStreamAdapter : IDisposable
{
private readonly string _connectionId;
private readonly Stream _filteredStream;
private readonly Stream _socketInputStream;
private readonly IKestrelTrace _log;
private readonly MemoryPool _memory;
private MemoryPoolBlock _block;
private bool _aborted = false;
public FilteredStreamAdapter(
string connectionId,
Stream filteredStream,
MemoryPool memory,
IKestrelTrace logger,
IThreadPool threadPool)
{
SocketInput = new SocketInput(memory, threadPool);
SocketOutput = new StreamSocketOutput(filteredStream, memory);
SocketOutput = new StreamSocketOutput(connectionId, filteredStream, memory, logger);
_connectionId = connectionId;
_log = logger;
_filteredStream = filteredStream;
_socketInputStream = new SocketInputStream(SocketInput);
@ -37,16 +41,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
public ISocketOutput SocketOutput { get; private set; }
public void ReadInput()
public Task ReadInputAsync()
{
_block = _memory.Lease();
// Use pooled block for copy
_filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
return _filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
{
((FilteredStreamAdapter)state).OnStreamClose(task);
}, this);
}
public void Abort()
{
_aborted = true;
}
public void Dispose()
{
SocketInput.Dispose();
}
private void OnStreamClose(Task copyAsyncTask)
{
_memory.Return(_block);
@ -61,10 +75,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
SocketInput.AbortAwaiting();
_log.LogError("FilteredStreamAdapter.CopyToAsync canceled.");
}
else if (_aborted)
{
SocketInput.AbortAwaiting();
}
try
{
_filteredStream.Dispose();
_socketInputStream.Dispose();
}
catch (Exception ex)

View File

@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
protected override void Dispose(bool disposing)
{
// Close _socketInput with a fake zero-length write that will result in a zero-length read.
_socketInput.IncomingData(null, 0, 0);
_socketInput.IncomingFin();
base.Dispose(disposing);
}
}

View File

@ -16,33 +16,57 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n");
private static readonly byte[] _nullBuffer = new byte[0];
private readonly string _connectionId;
private readonly Stream _outputStream;
private readonly MemoryPool _memory;
private readonly IKestrelTrace _logger;
private MemoryPoolBlock _producingBlock;
private bool _canWrite = true;
private object _writeLock = new object();
public StreamSocketOutput(Stream outputStream, MemoryPool memory)
public StreamSocketOutput(string connectionId, Stream outputStream, MemoryPool memory, IKestrelTrace logger)
{
_connectionId = connectionId;
_outputStream = outputStream;
_memory = memory;
_logger = logger;
}
public void Write(ArraySegment<byte> buffer, bool chunk)
{
lock (_writeLock)
{
if (chunk && buffer.Array != null)
if (buffer.Count == 0 )
{
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count);
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
return;
}
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count);
if (chunk && buffer.Array != null)
try
{
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length);
if (!_canWrite)
{
return;
}
if (chunk && buffer.Array != null)
{
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count);
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
}
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count);
if (chunk && buffer.Array != null)
{
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length);
}
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
}
}
}
@ -65,14 +89,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
var block = _producingBlock;
while (block != end.Block)
{
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count);
// If we don't handle an exception from _outputStream.Write() here, we'll leak memory blocks.
if (_canWrite)
{
try
{
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count);
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
}
}
var returnBlock = block;
block = block.Next;
returnBlock.Pool.Return(returnBlock);
}
if (_canWrite)
{
try
{
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
}
catch (Exception ex)
{
_canWrite = false;
_logger.ConnectionError(_connectionId, ex);
}
}
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
end.Block.Pool.Return(end.Block);
}
}

View File

@ -30,6 +30,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private Frame _frame;
private ConnectionFilterContext _filterContext;
private LibuvStream _libuvStream;
private FilteredStreamAdapter _filteredStreamAdapter;
private Task _readInputContinuation;
private readonly SocketInput _rawSocketInput;
private readonly SocketOutput _rawSocketOutput;
@ -175,13 +177,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
// Called on Libuv thread
public virtual void OnSocketClosed()
{
_rawSocketInput.Dispose();
// If a connection filter was applied there will be two SocketInputs.
// If a connection filter failed, SocketInput will be null.
if (SocketInput != null && SocketInput != _rawSocketInput)
if (_filteredStreamAdapter != null)
{
SocketInput.Dispose();
_filteredStreamAdapter.Abort();
_rawSocketInput.IncomingFin();
_readInputContinuation.ContinueWith((task, state) =>
{
((Connection)state)._filterContext.Connection.Dispose();
((Connection)state)._filteredStreamAdapter.Dispose();
((Connection)state)._rawSocketInput.Dispose();
}, this);
}
else
{
_rawSocketInput.Dispose();
}
lock (_stateLock)
@ -207,12 +216,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
if (_filterContext.Connection != _libuvStream)
{
var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory, Log, ThreadPool);
_filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool);
SocketInput = filteredStreamAdapter.SocketInput;
SocketOutput = filteredStreamAdapter.SocketOutput;
SocketInput = _filteredStreamAdapter.SocketInput;
SocketOutput = _filteredStreamAdapter.SocketOutput;
filteredStreamAdapter.ReadInput();
_readInputContinuation = _filteredStreamAdapter.ReadInputAsync();
}
else
{

View File

@ -140,6 +140,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
}
}
public void IncomingFin()
{
// Force a FIN
IncomingData(null, 0, 0);
}
private void Complete()
{
var awaitableState = Interlocked.Exchange(

View File

@ -5,14 +5,32 @@ using System;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel;
using Xunit;
namespace Microsoft.AspNetCore.Server.KestrelTests
{
public class ChunkedResponseTests
{
[Fact]
public async Task ResponsesAreChunkedAutomatically()
public static TheoryData<ServiceContext> ConnectionFilterData
{
get
{
return new TheoryData<ServiceContext>
{
{
new TestServiceContext()
},
{
new TestServiceContext(new PassThroughConnectionFilter())
}
};
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ResponsesAreChunkedAutomatically(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
@ -20,7 +38,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
@ -43,8 +61,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task ZeroLengthWritesAreIgnored()
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ZeroLengthWritesAreIgnored(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
@ -53,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
await response.Body.WriteAsync(new byte[0], 0, 0);
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
@ -76,15 +95,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite()
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
response.Headers.Clear();
await response.Body.WriteAsync(new byte[0], 0, 0);
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
@ -103,8 +123,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task ConnectionClosedIfExeptionThrownAfterWrite()
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterWrite(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
@ -112,7 +133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12);
throw new Exception();
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
@ -133,8 +154,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite()
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
@ -142,7 +164,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
response.Headers.Clear();
await response.Body.WriteAsync(new byte[0], 0, 0);
throw new Exception();
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
@ -162,8 +184,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task WritesAreFlushedPriorToResponseCompletion()
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task WritesAreFlushedPriorToResponseCompletion(ServiceContext testContext)
{
var flushWh = new ManualResetEventSlim();
@ -177,7 +200,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
flushWh.Wait();
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
}))
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{

View File

@ -1018,5 +1018,37 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.True(registrationWh.Wait(1000));
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ServiceContext testContext)
{
var testLogger = new TestApplicationErrorLogger();
testContext.Log = new KestrelTrace(testLogger);
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
response.Headers.Clear();
response.Headers["Content-Length"] = new[] { "11" };
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
await connection.Send(
"GET / HTTP/1.0",
"",
"");
await connection.ReceiveForcedEnd(
"HTTP/1.0 200 OK",
"Content-Length: 11",
"",
"Hello World");
}
}
Assert.Equal(0, testLogger.TotalErrorsLogged);
}
}
}

View File

@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
// As it calls ProduceStart with write immediate == true
// This happens in WebSocket Upgrade over SSL
ISocketOutput socketOutput = new StreamSocketOutput(new ThrowsOnNullWriteStream(), null);
ISocketOutput socketOutput = new StreamSocketOutput("id", new ThrowsOnNullWriteStream(), null, new TestKestrelTrace());
// Should not throw
socketOutput.Write(default(ArraySegment<byte>), true);

View File

@ -32,12 +32,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}");
#endif
TotalErrorsLogged++;
if (eventId.Id == ApplicationErrorEventId)
{
ApplicationErrorsLogged++;
}
if (logLevel == LogLevel.Error)
{
TotalErrorsLogged++;
}
}
}
}