- If we're done before the client sends a FIN, force a FIN into the raw SocketInput so the task in FileteredStreamAdapter finishes gracefully and we dispose everything in proper order. - If there's an error while writing to a stream (like ObjectDisposedException), log it once and prevent further write attempts. This means the client closed the connection while we were still writing output. - This also fixes a bug related to the point above, where memory blocks were being leaked instead of returned to the pool (because we weren't catching the exception from Write()).
This commit is contained in:
parent
33ad355114
commit
f29dd60999
|
|
@ -10,23 +10,27 @@ using Microsoft.Extensions.Logging;
|
|||
|
||||
namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
||||
{
|
||||
public class FilteredStreamAdapter
|
||||
public class FilteredStreamAdapter : IDisposable
|
||||
{
|
||||
private readonly string _connectionId;
|
||||
private readonly Stream _filteredStream;
|
||||
private readonly Stream _socketInputStream;
|
||||
private readonly IKestrelTrace _log;
|
||||
private readonly MemoryPool _memory;
|
||||
private MemoryPoolBlock _block;
|
||||
private bool _aborted = false;
|
||||
|
||||
public FilteredStreamAdapter(
|
||||
string connectionId,
|
||||
Stream filteredStream,
|
||||
MemoryPool memory,
|
||||
IKestrelTrace logger,
|
||||
IThreadPool threadPool)
|
||||
{
|
||||
SocketInput = new SocketInput(memory, threadPool);
|
||||
SocketOutput = new StreamSocketOutput(filteredStream, memory);
|
||||
SocketOutput = new StreamSocketOutput(connectionId, filteredStream, memory, logger);
|
||||
|
||||
_connectionId = connectionId;
|
||||
_log = logger;
|
||||
_filteredStream = filteredStream;
|
||||
_socketInputStream = new SocketInputStream(SocketInput);
|
||||
|
|
@ -37,16 +41,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
|||
|
||||
public ISocketOutput SocketOutput { get; private set; }
|
||||
|
||||
public void ReadInput()
|
||||
public Task ReadInputAsync()
|
||||
{
|
||||
_block = _memory.Lease();
|
||||
// Use pooled block for copy
|
||||
_filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
|
||||
return _filteredStream.CopyToAsync(_socketInputStream, _block).ContinueWith((task, state) =>
|
||||
{
|
||||
((FilteredStreamAdapter)state).OnStreamClose(task);
|
||||
}, this);
|
||||
}
|
||||
|
||||
public void Abort()
|
||||
{
|
||||
_aborted = true;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
SocketInput.Dispose();
|
||||
}
|
||||
|
||||
private void OnStreamClose(Task copyAsyncTask)
|
||||
{
|
||||
_memory.Return(_block);
|
||||
|
|
@ -61,10 +75,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
|||
SocketInput.AbortAwaiting();
|
||||
_log.LogError("FilteredStreamAdapter.CopyToAsync canceled.");
|
||||
}
|
||||
else if (_aborted)
|
||||
{
|
||||
SocketInput.AbortAwaiting();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_filteredStream.Dispose();
|
||||
_socketInputStream.Dispose();
|
||||
}
|
||||
catch (Exception ex)
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
|||
protected override void Dispose(bool disposing)
|
||||
{
|
||||
// Close _socketInput with a fake zero-length write that will result in a zero-length read.
|
||||
_socketInput.IncomingData(null, 0, 0);
|
||||
_socketInput.IncomingFin();
|
||||
base.Dispose(disposing);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,33 +16,57 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
|||
private static readonly byte[] _endChunkBytes = Encoding.ASCII.GetBytes("\r\n");
|
||||
private static readonly byte[] _nullBuffer = new byte[0];
|
||||
|
||||
private readonly string _connectionId;
|
||||
private readonly Stream _outputStream;
|
||||
private readonly MemoryPool _memory;
|
||||
private readonly IKestrelTrace _logger;
|
||||
private MemoryPoolBlock _producingBlock;
|
||||
|
||||
private bool _canWrite = true;
|
||||
|
||||
private object _writeLock = new object();
|
||||
|
||||
public StreamSocketOutput(Stream outputStream, MemoryPool memory)
|
||||
public StreamSocketOutput(string connectionId, Stream outputStream, MemoryPool memory, IKestrelTrace logger)
|
||||
{
|
||||
_connectionId = connectionId;
|
||||
_outputStream = outputStream;
|
||||
_memory = memory;
|
||||
_logger = logger;
|
||||
}
|
||||
|
||||
public void Write(ArraySegment<byte> buffer, bool chunk)
|
||||
{
|
||||
lock (_writeLock)
|
||||
{
|
||||
if (chunk && buffer.Array != null)
|
||||
if (buffer.Count == 0 )
|
||||
{
|
||||
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count);
|
||||
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
|
||||
return;
|
||||
}
|
||||
|
||||
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count);
|
||||
|
||||
if (chunk && buffer.Array != null)
|
||||
try
|
||||
{
|
||||
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length);
|
||||
if (!_canWrite)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (chunk && buffer.Array != null)
|
||||
{
|
||||
var beginChunkBytes = ChunkWriter.BeginChunkBytes(buffer.Count);
|
||||
_outputStream.Write(beginChunkBytes.Array, beginChunkBytes.Offset, beginChunkBytes.Count);
|
||||
}
|
||||
|
||||
_outputStream.Write(buffer.Array ?? _nullBuffer, buffer.Offset, buffer.Count);
|
||||
|
||||
if (chunk && buffer.Array != null)
|
||||
{
|
||||
_outputStream.Write(_endChunkBytes, 0, _endChunkBytes.Length);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_canWrite = false;
|
||||
_logger.ConnectionError(_connectionId, ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -65,14 +89,38 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Filter
|
|||
var block = _producingBlock;
|
||||
while (block != end.Block)
|
||||
{
|
||||
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count);
|
||||
// If we don't handle an exception from _outputStream.Write() here, we'll leak memory blocks.
|
||||
if (_canWrite)
|
||||
{
|
||||
try
|
||||
{
|
||||
_outputStream.Write(block.Data.Array, block.Data.Offset, block.Data.Count);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_canWrite = false;
|
||||
_logger.ConnectionError(_connectionId, ex);
|
||||
}
|
||||
}
|
||||
|
||||
var returnBlock = block;
|
||||
block = block.Next;
|
||||
returnBlock.Pool.Return(returnBlock);
|
||||
}
|
||||
|
||||
if (_canWrite)
|
||||
{
|
||||
try
|
||||
{
|
||||
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_canWrite = false;
|
||||
_logger.ConnectionError(_connectionId, ex);
|
||||
}
|
||||
}
|
||||
|
||||
_outputStream.Write(end.Block.Array, end.Block.Data.Offset, end.Index - end.Block.Data.Offset);
|
||||
end.Block.Pool.Return(end.Block);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
|
|||
private Frame _frame;
|
||||
private ConnectionFilterContext _filterContext;
|
||||
private LibuvStream _libuvStream;
|
||||
private FilteredStreamAdapter _filteredStreamAdapter;
|
||||
private Task _readInputContinuation;
|
||||
|
||||
private readonly SocketInput _rawSocketInput;
|
||||
private readonly SocketOutput _rawSocketOutput;
|
||||
|
|
@ -175,13 +177,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
|
|||
// Called on Libuv thread
|
||||
public virtual void OnSocketClosed()
|
||||
{
|
||||
_rawSocketInput.Dispose();
|
||||
|
||||
// If a connection filter was applied there will be two SocketInputs.
|
||||
// If a connection filter failed, SocketInput will be null.
|
||||
if (SocketInput != null && SocketInput != _rawSocketInput)
|
||||
if (_filteredStreamAdapter != null)
|
||||
{
|
||||
SocketInput.Dispose();
|
||||
_filteredStreamAdapter.Abort();
|
||||
_rawSocketInput.IncomingFin();
|
||||
_readInputContinuation.ContinueWith((task, state) =>
|
||||
{
|
||||
((Connection)state)._filterContext.Connection.Dispose();
|
||||
((Connection)state)._filteredStreamAdapter.Dispose();
|
||||
((Connection)state)._rawSocketInput.Dispose();
|
||||
}, this);
|
||||
}
|
||||
else
|
||||
{
|
||||
_rawSocketInput.Dispose();
|
||||
}
|
||||
|
||||
lock (_stateLock)
|
||||
|
|
@ -207,12 +216,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
|
|||
|
||||
if (_filterContext.Connection != _libuvStream)
|
||||
{
|
||||
var filteredStreamAdapter = new FilteredStreamAdapter(_filterContext.Connection, Memory, Log, ThreadPool);
|
||||
_filteredStreamAdapter = new FilteredStreamAdapter(ConnectionId, _filterContext.Connection, Memory, Log, ThreadPool);
|
||||
|
||||
SocketInput = filteredStreamAdapter.SocketInput;
|
||||
SocketOutput = filteredStreamAdapter.SocketOutput;
|
||||
SocketInput = _filteredStreamAdapter.SocketInput;
|
||||
SocketOutput = _filteredStreamAdapter.SocketOutput;
|
||||
|
||||
filteredStreamAdapter.ReadInput();
|
||||
_readInputContinuation = _filteredStreamAdapter.ReadInputAsync();
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
|||
|
|
@ -140,6 +140,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
|
|||
}
|
||||
}
|
||||
|
||||
public void IncomingFin()
|
||||
{
|
||||
// Force a FIN
|
||||
IncomingData(null, 0, 0);
|
||||
}
|
||||
|
||||
private void Complete()
|
||||
{
|
||||
var awaitableState = Interlocked.Exchange(
|
||||
|
|
|
|||
|
|
@ -5,14 +5,32 @@ using System;
|
|||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Server.Kestrel;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.KestrelTests
|
||||
{
|
||||
public class ChunkedResponseTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task ResponsesAreChunkedAutomatically()
|
||||
public static TheoryData<ServiceContext> ConnectionFilterData
|
||||
{
|
||||
get
|
||||
{
|
||||
return new TheoryData<ServiceContext>
|
||||
{
|
||||
{
|
||||
new TestServiceContext()
|
||||
},
|
||||
{
|
||||
new TestServiceContext(new PassThroughConnectionFilter())
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task ResponsesAreChunkedAutomatically(ServiceContext testContext)
|
||||
{
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
|
|
@ -20,7 +38,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
response.Headers.Clear();
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
@ -43,8 +61,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ZeroLengthWritesAreIgnored()
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task ZeroLengthWritesAreIgnored(ServiceContext testContext)
|
||||
{
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
|
|
@ -53,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello "), 0, 6);
|
||||
await response.Body.WriteAsync(new byte[0], 0, 0);
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
@ -76,15 +95,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite()
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task EmptyResponseBodyHandledCorrectlyWithZeroLengthWrite(ServiceContext testContext)
|
||||
{
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var response = httpContext.Response;
|
||||
response.Headers.Clear();
|
||||
await response.Body.WriteAsync(new byte[0], 0, 0);
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
@ -103,8 +123,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ConnectionClosedIfExeptionThrownAfterWrite()
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task ConnectionClosedIfExeptionThrownAfterWrite(ServiceContext testContext)
|
||||
{
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
|
|
@ -112,7 +133,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
response.Headers.Clear();
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World!"), 0, 12);
|
||||
throw new Exception();
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
@ -133,8 +154,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite()
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task ConnectionClosedIfExeptionThrownAfterZeroLengthWrite(ServiceContext testContext)
|
||||
{
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
|
|
@ -142,7 +164,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
response.Headers.Clear();
|
||||
await response.Body.WriteAsync(new byte[0], 0, 0);
|
||||
throw new Exception();
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
@ -162,8 +184,9 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritesAreFlushedPriorToResponseCompletion()
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task WritesAreFlushedPriorToResponseCompletion(ServiceContext testContext)
|
||||
{
|
||||
var flushWh = new ManualResetEventSlim();
|
||||
|
||||
|
|
@ -177,7 +200,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
flushWh.Wait();
|
||||
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("World!"), 0, 6);
|
||||
}))
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
|
|
|
|||
|
|
@ -1018,5 +1018,37 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
Assert.True(registrationWh.Wait(1000));
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(ConnectionFilterData))]
|
||||
public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ServiceContext testContext)
|
||||
{
|
||||
var testLogger = new TestApplicationErrorLogger();
|
||||
testContext.Log = new KestrelTrace(testLogger);
|
||||
|
||||
using (var server = new TestServer(async httpContext =>
|
||||
{
|
||||
var response = httpContext.Response;
|
||||
response.Headers.Clear();
|
||||
response.Headers["Content-Length"] = new[] { "11" };
|
||||
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
|
||||
}, testContext))
|
||||
{
|
||||
using (var connection = new TestConnection(server.Port))
|
||||
{
|
||||
await connection.Send(
|
||||
"GET / HTTP/1.0",
|
||||
"",
|
||||
"");
|
||||
await connection.ReceiveForcedEnd(
|
||||
"HTTP/1.0 200 OK",
|
||||
"Content-Length: 11",
|
||||
"",
|
||||
"Hello World");
|
||||
}
|
||||
}
|
||||
|
||||
Assert.Equal(0, testLogger.TotalErrorsLogged);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
// As it calls ProduceStart with write immediate == true
|
||||
// This happens in WebSocket Upgrade over SSL
|
||||
|
||||
ISocketOutput socketOutput = new StreamSocketOutput(new ThrowsOnNullWriteStream(), null);
|
||||
ISocketOutput socketOutput = new StreamSocketOutput("id", new ThrowsOnNullWriteStream(), null, new TestKestrelTrace());
|
||||
|
||||
// Should not throw
|
||||
socketOutput.Write(default(ArraySegment<byte>), true);
|
||||
|
|
|
|||
|
|
@ -32,12 +32,15 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
Console.WriteLine($"Log {logLevel}[{eventId}]: {formatter(state, exception)} {exception?.Message}");
|
||||
#endif
|
||||
|
||||
TotalErrorsLogged++;
|
||||
|
||||
if (eventId.Id == ApplicationErrorEventId)
|
||||
{
|
||||
ApplicationErrorsLogged++;
|
||||
}
|
||||
|
||||
if (logLevel == LogLevel.Error)
|
||||
{
|
||||
TotalErrorsLogged++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue