Handle exception in SocketConnection.Shutdown() (#2562)

This commit is contained in:
Stephen Halter 2018-05-10 23:39:06 -07:00 committed by GitHub
parent e6a88c1b9c
commit da21fc89cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 472 additions and 3 deletions

View File

@ -263,8 +263,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal
_aborted = true;
_trace.ConnectionWriteFin(ConnectionId);
// Try to gracefully close the socket even for aborts to match libuv behavior.
_socket.Shutdown(SocketShutdown.Both);
try
{
// Try to gracefully close the socket even for aborts to match libuv behavior.
_socket.Shutdown(SocketShutdown.Both);
}
catch
{
// Ignore any errors from Socket.Shutdown since we're tearing down the connection anyway.
}
_socket.Dispose();
}
}

View File

@ -22,6 +22,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Testing;
@ -1208,6 +1209,149 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
Assert.Equal(2, abortedRequestId);
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task ServerCanAbortConnectionAfterUnobservedClose(ListenOptions listenOptions)
{
const int connectionPausedEventId = 4;
const int connectionFinSentEventId = 7;
const int maxRequestBufferSize = 4096;
var readCallbackUnwired = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var clientClosedConnection = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var serverClosedConnection = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var appFuncCompleted = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var mockLogger = new Mock<ILogger>();
mockLogger
.Setup(logger => logger.IsEnabled(It.IsAny<LogLevel>()))
.Returns(true);
mockLogger
.Setup(logger => logger.Log(It.IsAny<LogLevel>(), It.IsAny<EventId>(), It.IsAny<object>(), It.IsAny<Exception>(), It.IsAny<Func<object, Exception, string>>()))
.Callback<LogLevel, EventId, object, Exception, Func<object, Exception, string>>((logLevel, eventId, state, exception, formatter) =>
{
if (eventId.Id == connectionPausedEventId)
{
readCallbackUnwired.TrySetResult(null);
}
else if (eventId.Id == connectionFinSentEventId)
{
serverClosedConnection.SetResult(null);
}
Logger.Log(logLevel, eventId, state, exception, formatter);
});
var mockLoggerFactory = new Mock<ILoggerFactory>();
mockLoggerFactory
.Setup(factory => factory.CreateLogger(It.IsAny<string>()))
.Returns(Logger);
mockLoggerFactory
.Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv",
"Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets")))
.Returns(mockLogger.Object);
var mockKestrelTrace = new Mock<KestrelTrace>(Logger) { CallBase = true };
var testContext = new TestServiceContext(mockLoggerFactory.Object)
{
Log = mockKestrelTrace.Object,
ServerOptions =
{
Limits =
{
MaxRequestBufferSize = maxRequestBufferSize,
MaxRequestLineSize = maxRequestBufferSize,
MaxRequestHeadersTotalSize = maxRequestBufferSize,
}
}
};
var scratchBuffer = new byte[maxRequestBufferSize * 2 + 1];
using (var server = new TestServer(async context =>
{
await clientClosedConnection.Task;
context.Abort();
await serverClosedConnection.Task;
// TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in
// situations such as this where the awaiter starts awaiting right when SetResult is called.
_ = Task.Run(() => appFuncCompleted.SetResult(null));
}, testContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
$"Content-Length: {scratchBuffer.Length}",
"",
"");
var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);
// Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed.
await readCallbackUnwired.Task.TimeoutAfter(TestConstants.DefaultTimeout);
}
clientClosedConnection.SetResult(null);
await appFuncCompleted.Task.TimeoutAfter(TestConstants.DefaultTimeout);
}
mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny<string>()), Times.Once());
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task AppCanHandleClientAbortingConnectionMidRequest(ListenOptions listenOptions)
{
var readTcs = new TaskCompletionSource<Exception>(TaskContinuationOptions.RunContinuationsAsynchronously);
var mockKestrelTrace = new Mock<KestrelTrace>(Logger) { CallBase = true };
var testContext = new TestServiceContext()
{
Log = mockKestrelTrace.Object,
};
var scratchBuffer = new byte[4096];
using (var server = new TestServer(async context =>
{
try
{
await context.Request.Body.CopyToAsync(Stream.Null);;
}
catch (Exception ex)
{
readTcs.SetException(ex);
throw;
}
readTcs.SetException(new Exception("This shouldn't be reached."));
}, testContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
$"Content-Length: {scratchBuffer.Length * 2}",
"",
"");
await connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);
}
await Assert.ThrowsAnyAsync<IOException>(() => readTcs.Task).TimeoutAfter(TestConstants.DefaultTimeout);
}
mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny<string>()), Times.Once());
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task RequestHeadersAreResetOnEachRequest(ListenOptions listenOptions)

View File

@ -21,6 +21,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core;
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Https;
@ -2327,6 +2328,185 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
}
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task WritingToConnectionAfterUnobservedCloseTriggersRequestAbortedToken(ListenOptions listenOptions)
{
const int connectionPausedEventId = 4;
const int maxRequestBufferSize = 2048;
var requestAborted = false;
var readCallbackUnwired = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var clientClosedConnection = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var writeTcs = new TaskCompletionSource<Exception>(TaskContinuationOptions.RunContinuationsAsynchronously);
var mockKestrelTrace = new Mock<KestrelTrace>(Logger) { CallBase = true };
var mockLogger = new Mock<ILogger>();
mockLogger
.Setup(logger => logger.IsEnabled(It.IsAny<LogLevel>()))
.Returns(true);
mockLogger
.Setup(logger => logger.Log(It.IsAny<LogLevel>(), It.IsAny<EventId>(), It.IsAny<object>(), It.IsAny<Exception>(), It.IsAny<Func<object, Exception, string>>()))
.Callback<LogLevel, EventId, object, Exception, Func<object, Exception, string>>((logLevel, eventId, state, exception, formatter) =>
{
if (eventId.Id == connectionPausedEventId)
{
readCallbackUnwired.TrySetResult(null);
}
Logger.Log(logLevel, eventId, state, exception, formatter);
});
var mockLoggerFactory = new Mock<ILoggerFactory>();
mockLoggerFactory
.Setup(factory => factory.CreateLogger(It.IsAny<string>()))
.Returns(Logger);
mockLoggerFactory
.Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv",
"Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets")))
.Returns(mockLogger.Object);
var testContext = new TestServiceContext(mockLoggerFactory.Object)
{
Log = mockKestrelTrace.Object,
ServerOptions =
{
Limits =
{
MaxRequestBufferSize = maxRequestBufferSize,
MaxRequestLineSize = maxRequestBufferSize,
MaxRequestHeadersTotalSize = maxRequestBufferSize,
}
}
};
var scratchBuffer = new byte[4096];
using (var server = new TestServer(async context =>
{
context.RequestAborted.Register(() =>
{
requestAborted = true;
});
await clientClosedConnection.Task;
try
{
for (var i = 0; i < 1000; i++)
{
await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length, context.RequestAborted);
await Task.Delay(10);
}
}
catch (Exception ex)
{
// TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in
// situations such as this where the awaiter starts awaiting right when SetResult is called.
_ = Task.Run(() => writeTcs.SetException(ex));
throw;
}
writeTcs.SetException(new Exception("This shouldn't be reached."));
}, testContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"POST / HTTP/1.1",
"Host:",
$"Content-Length: {scratchBuffer.Length}",
"",
"");
var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);
// Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed.
await readCallbackUnwired.Task.TimeoutAfter(TestConstants.DefaultTimeout);
}
clientClosedConnection.SetResult(null);
await Assert.ThrowsAnyAsync<OperationCanceledException>(() => writeTcs.Task).TimeoutAfter(TestConstants.DefaultTimeout);
}
mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny<string>()), Times.Once());
Assert.True(requestAborted);
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task AppCanHandleClientAbortingConnectionMidResponse(ListenOptions listenOptions)
{
const int responseBodySegmentSize = 65536;
const int responseBodySegmentCount = 100;
const int responseBodySize = responseBodySegmentSize * responseBodySegmentCount;
var requestAborted = false;
var appCompletedTcs = new TaskCompletionSource<object>(TaskContinuationOptions.RunContinuationsAsynchronously);
var mockKestrelTrace = new Mock<KestrelTrace>(Logger) { CallBase = true };
var testContext = new TestServiceContext()
{
Log = mockKestrelTrace.Object,
};
var scratchBuffer = new byte[responseBodySegmentSize];
using (var server = new TestServer(async context =>
{
context.RequestAborted.Register(() =>
{
requestAborted = true;
});
context.Response.ContentLength = responseBodySize;
try
{
for (var i = 0; i < responseBodySegmentCount; i++)
{
await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length);
await Task.Delay(10);
}
}
finally
{
// WriteAsync shouldn't throw without a CancellationToken passed in. Unfortunately a ECONNRESET UvException sometimes gets thrown.
// This will be fixed by https://github.com/aspnet/KestrelHttpServer/pull/2547
// TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in
// situations such as this where the awaiter starts awaiting right when SetResult is called.
_ = Task.Run(() => appCompletedTcs.SetResult(null));
}
}, testContext, listenOptions))
{
using (var connection = server.CreateConnection())
{
await connection.Send(
"GET / HTTP/1.1",
"Host:",
"",
"");
var readCount = 0;
// Read just part of the response and close the connection.
// https://github.com/aspnet/KestrelHttpServer/issues/2554
for (var i = 0; i < responseBodySegmentCount / 10; i++)
{
readCount += await connection.Stream.ReadAsync(scratchBuffer, 0, scratchBuffer.Length);
}
connection.Socket.Shutdown(SocketShutdown.Send);
await appCompletedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout);
}
}
mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny<string>()), Times.Once());
Assert.True(requestAborted);
}
[Theory]
[MemberData(nameof(ConnectionAdapterData))]
public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ListenOptions listenOptions)

View File

@ -1,7 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal;
@ -13,7 +15,7 @@ namespace Microsoft.AspNetCore.Testing
public Task<IAdaptedConnection> OnConnectionAsync(ConnectionAdapterContext context)
{
var adapted = new AdaptedConnection(new LoggingStream(context.ConnectionStream, new TestApplicationErrorLogger()));
var adapted = new AdaptedConnection(new PassThroughStream(context.ConnectionStream));
return Task.FromResult<IAdaptedConnection>(adapted);
}
@ -30,5 +32,140 @@ namespace Microsoft.AspNetCore.Testing
{
}
}
private class PassThroughStream : Stream
{
private readonly Stream _innerStream;
public PassThroughStream(Stream innerStream)
{
_innerStream = innerStream;
}
public override bool CanRead => _innerStream.CanRead;
public override bool CanSeek => _innerStream.CanSeek;
public override bool CanTimeout => _innerStream.CanTimeout;
public override bool CanWrite => _innerStream.CanWrite;
public override long Length => _innerStream.Length;
public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; }
public override int ReadTimeout { get => _innerStream.ReadTimeout; set => _innerStream.ReadTimeout = value; }
public override int WriteTimeout { get => _innerStream.WriteTimeout; set => _innerStream.WriteTimeout = value; }
public override int Read(byte[] buffer, int offset, int count)
{
return _innerStream.Read(buffer, offset, count);
}
public override int ReadByte()
{
return _innerStream.ReadByte();
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _innerStream.BeginRead(buffer, offset, count, callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return _innerStream.EndRead(asyncResult);
}
public override void Write(byte[] buffer, int offset, int count)
{
_innerStream.Write(buffer, offset, count);
}
public override void WriteByte(byte value)
{
_innerStream.WriteByte(value);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _innerStream.WriteAsync(buffer, offset, count, cancellationToken);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _innerStream.BeginWrite(buffer, offset, count, callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
_innerStream.EndWrite(asyncResult);
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken);
}
public override void Flush()
{
_innerStream.Flush();
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _innerStream.FlushAsync();
}
public override long Seek(long offset, SeekOrigin origin)
{
return _innerStream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_innerStream.SetLength(value);
}
public override void Close()
{
_innerStream.Close();
}
#if NETCOREAPP2_1
public override int Read(Span<byte> buffer)
{
return _innerStream.Read(buffer);
}
public override ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.ReadAsync(buffer, cancellationToken);
}
public override void Write(ReadOnlySpan<byte> buffer)
{
_innerStream.Write(buffer);
}
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
{
return _innerStream.WriteAsync(buffer, cancellationToken);
}
public override void CopyTo(Stream destination, int bufferSize)
{
_innerStream.CopyTo(destination, bufferSize);
}
#endif
}
}
}