From da21fc89cf9a68a2c86f9fdb77cd6ee813f31d24 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Thu, 10 May 2018 23:39:06 -0700 Subject: [PATCH] Handle exception in SocketConnection.Shutdown() (#2562) --- .../Internal/SocketConnection.cs | 12 +- test/Kestrel.FunctionalTests/RequestTests.cs | 144 ++++++++++++++ test/Kestrel.FunctionalTests/ResponseTests.cs | 180 ++++++++++++++++++ test/shared/PassThroughConnectionAdapter.cs | 139 +++++++++++++- 4 files changed, 472 insertions(+), 3 deletions(-) diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs index d410599663..e1f323abba 100644 --- a/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketConnection.cs @@ -263,8 +263,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal _aborted = true; _trace.ConnectionWriteFin(ConnectionId); - // Try to gracefully close the socket even for aborts to match libuv behavior. - _socket.Shutdown(SocketShutdown.Both); + try + { + // Try to gracefully close the socket even for aborts to match libuv behavior. + _socket.Shutdown(SocketShutdown.Both); + } + catch + { + // Ignore any errors from Socket.Shutdown since we're tearing down the connection anyway. + } + _socket.Dispose(); } } diff --git a/test/Kestrel.FunctionalTests/RequestTests.cs b/test/Kestrel.FunctionalTests/RequestTests.cs index 63868ad897..3f82426e5f 100644 --- a/test/Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Kestrel.FunctionalTests/RequestTests.cs @@ -22,6 +22,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; @@ -1208,6 +1209,149 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests Assert.Equal(2, abortedRequestId); } + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task ServerCanAbortConnectionAfterUnobservedClose(ListenOptions listenOptions) + { + const int connectionPausedEventId = 4; + const int connectionFinSentEventId = 7; + const int maxRequestBufferSize = 4096; + + var readCallbackUnwired = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var clientClosedConnection = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var serverClosedConnection = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var appFuncCompleted = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + if (eventId.Id == connectionPausedEventId) + { + readCallbackUnwired.TrySetResult(null); + } + else if (eventId.Id == connectionFinSentEventId) + { + serverClosedConnection.SetResult(null); + } + + Logger.Log(logLevel, eventId, state, exception, formatter); + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var testContext = new TestServiceContext(mockLoggerFactory.Object) + { + Log = mockKestrelTrace.Object, + ServerOptions = + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = maxRequestBufferSize, + MaxRequestHeadersTotalSize = maxRequestBufferSize, + } + } + }; + + var scratchBuffer = new byte[maxRequestBufferSize * 2 + 1]; + + using (var server = new TestServer(async context => + { + await clientClosedConnection.Task; + + context.Abort(); + + await serverClosedConnection.Task; + + // TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in + // situations such as this where the awaiter starts awaiting right when SetResult is called. + _ = Task.Run(() => appFuncCompleted.SetResult(null)); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length}", + "", + ""); + + var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + + // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed. + await readCallbackUnwired.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + + clientClosedConnection.SetResult(null); + + await appFuncCompleted.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task AppCanHandleClientAbortingConnectionMidRequest(ListenOptions listenOptions) + { + var readTcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var testContext = new TestServiceContext() + { + Log = mockKestrelTrace.Object, + }; + + var scratchBuffer = new byte[4096]; + + using (var server = new TestServer(async context => + { + try + { + await context.Request.Body.CopyToAsync(Stream.Null);; + } + catch (Exception ex) + { + readTcs.SetException(ex); + throw; + } + + readTcs.SetException(new Exception("This shouldn't be reached.")); + + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length * 2}", + "", + ""); + + await connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + } + + await Assert.ThrowsAnyAsync(() => readTcs.Task).TimeoutAfter(TestConstants.DefaultTimeout); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + } + [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task RequestHeadersAreResetOnEachRequest(ListenOptions listenOptions) diff --git a/test/Kestrel.FunctionalTests/ResponseTests.cs b/test/Kestrel.FunctionalTests/ResponseTests.cs index 72fe4d1ba2..3e7f4b7d58 100644 --- a/test/Kestrel.FunctionalTests/ResponseTests.cs +++ b/test/Kestrel.FunctionalTests/ResponseTests.cs @@ -21,6 +21,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Https; @@ -2327,6 +2328,185 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task WritingToConnectionAfterUnobservedCloseTriggersRequestAbortedToken(ListenOptions listenOptions) + { + const int connectionPausedEventId = 4; + const int maxRequestBufferSize = 2048; + + var requestAborted = false; + var readCallbackUnwired = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var clientClosedConnection = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var writeTcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var mockLogger = new Mock(); + mockLogger + .Setup(logger => logger.IsEnabled(It.IsAny())) + .Returns(true); + mockLogger + .Setup(logger => logger.Log(It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny(), It.IsAny>())) + .Callback>((logLevel, eventId, state, exception, formatter) => + { + if (eventId.Id == connectionPausedEventId) + { + readCallbackUnwired.TrySetResult(null); + } + + Logger.Log(logLevel, eventId, state, exception, formatter); + }); + + var mockLoggerFactory = new Mock(); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsAny())) + .Returns(Logger); + mockLoggerFactory + .Setup(factory => factory.CreateLogger(It.IsIn("Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv", + "Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets"))) + .Returns(mockLogger.Object); + + var testContext = new TestServiceContext(mockLoggerFactory.Object) + { + Log = mockKestrelTrace.Object, + ServerOptions = + { + Limits = + { + MaxRequestBufferSize = maxRequestBufferSize, + MaxRequestLineSize = maxRequestBufferSize, + MaxRequestHeadersTotalSize = maxRequestBufferSize, + } + } + }; + + var scratchBuffer = new byte[4096]; + + using (var server = new TestServer(async context => + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + await clientClosedConnection.Task; + + try + { + for (var i = 0; i < 1000; i++) + { + await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length, context.RequestAborted); + await Task.Delay(10); + } + } + catch (Exception ex) + { + // TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in + // situations such as this where the awaiter starts awaiting right when SetResult is called. + _ = Task.Run(() => writeTcs.SetException(ex)); + throw; + } + + writeTcs.SetException(new Exception("This shouldn't be reached.")); + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + $"Content-Length: {scratchBuffer.Length}", + "", + ""); + + var ignore = connection.Stream.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + + // Wait until the read callback is no longer hooked up so that the connection disconnect isn't observed. + await readCallbackUnwired.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + + clientClosedConnection.SetResult(null); + + await Assert.ThrowsAnyAsync(() => writeTcs.Task).TimeoutAfter(TestConstants.DefaultTimeout); + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + Assert.True(requestAborted); + } + + [Theory] + [MemberData(nameof(ConnectionAdapterData))] + public async Task AppCanHandleClientAbortingConnectionMidResponse(ListenOptions listenOptions) + { + const int responseBodySegmentSize = 65536; + const int responseBodySegmentCount = 100; + const int responseBodySize = responseBodySegmentSize * responseBodySegmentCount; + + var requestAborted = false; + var appCompletedTcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + + var mockKestrelTrace = new Mock(Logger) { CallBase = true }; + var testContext = new TestServiceContext() + { + Log = mockKestrelTrace.Object, + }; + + var scratchBuffer = new byte[responseBodySegmentSize]; + + using (var server = new TestServer(async context => + { + context.RequestAborted.Register(() => + { + requestAborted = true; + }); + + context.Response.ContentLength = responseBodySize; + + try + { + for (var i = 0; i < responseBodySegmentCount; i++) + { + await context.Response.Body.WriteAsync(scratchBuffer, 0, scratchBuffer.Length); + await Task.Delay(10); + } + } + finally + { + // WriteAsync shouldn't throw without a CancellationToken passed in. Unfortunately a ECONNRESET UvException sometimes gets thrown. + // This will be fixed by https://github.com/aspnet/KestrelHttpServer/pull/2547 + // TaskContinuationOptions.RunContinuationsAsynchronously sometimes runs inline anyway in + // situations such as this where the awaiter starts awaiting right when SetResult is called. + _ = Task.Run(() => appCompletedTcs.SetResult(null)); + } + }, testContext, listenOptions)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + + var readCount = 0; + + // Read just part of the response and close the connection. + // https://github.com/aspnet/KestrelHttpServer/issues/2554 + for (var i = 0; i < responseBodySegmentCount / 10; i++) + { + readCount += await connection.Stream.ReadAsync(scratchBuffer, 0, scratchBuffer.Length); + } + + connection.Socket.Shutdown(SocketShutdown.Send); + + await appCompletedTcs.Task.TimeoutAfter(TestConstants.DefaultTimeout); + } + } + + mockKestrelTrace.Verify(t => t.ConnectionStop(It.IsAny()), Times.Once()); + Assert.True(requestAborted); + } + [Theory] [MemberData(nameof(ConnectionAdapterData))] public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ListenOptions listenOptions) diff --git a/test/shared/PassThroughConnectionAdapter.cs b/test/shared/PassThroughConnectionAdapter.cs index 5f10736595..cfad5c7e7c 100644 --- a/test/shared/PassThroughConnectionAdapter.cs +++ b/test/shared/PassThroughConnectionAdapter.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Adapter.Internal; @@ -13,7 +15,7 @@ namespace Microsoft.AspNetCore.Testing public Task OnConnectionAsync(ConnectionAdapterContext context) { - var adapted = new AdaptedConnection(new LoggingStream(context.ConnectionStream, new TestApplicationErrorLogger())); + var adapted = new AdaptedConnection(new PassThroughStream(context.ConnectionStream)); return Task.FromResult(adapted); } @@ -30,5 +32,140 @@ namespace Microsoft.AspNetCore.Testing { } } + + private class PassThroughStream : Stream + { + private readonly Stream _innerStream; + + public PassThroughStream(Stream innerStream) + { + _innerStream = innerStream; + } + + public override bool CanRead => _innerStream.CanRead; + + public override bool CanSeek => _innerStream.CanSeek; + + public override bool CanTimeout => _innerStream.CanTimeout; + + public override bool CanWrite => _innerStream.CanWrite; + + public override long Length => _innerStream.Length; + + public override long Position { get => _innerStream.Position; set => _innerStream.Position = value; } + + public override int ReadTimeout { get => _innerStream.ReadTimeout; set => _innerStream.ReadTimeout = value; } + + public override int WriteTimeout { get => _innerStream.WriteTimeout; set => _innerStream.WriteTimeout = value; } + + public override int Read(byte[] buffer, int offset, int count) + { + return _innerStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _innerStream.ReadByte(); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _innerStream.EndRead(asyncResult); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _innerStream.Write(buffer, offset, count); + } + + + public override void WriteByte(byte value) + { + _innerStream.WriteByte(value); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _innerStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _innerStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _innerStream.EndWrite(asyncResult); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _innerStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + + public override void Flush() + { + _innerStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _innerStream.FlushAsync(); + + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _innerStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _innerStream.SetLength(value); + } + + public override void Close() + { + _innerStream.Close(); + } + +#if NETCOREAPP2_1 + public override int Read(Span buffer) + { + return _innerStream.Read(buffer); + } + + public override ValueTask ReadAsync(Memory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.ReadAsync(buffer, cancellationToken); + } + + public override void Write(ReadOnlySpan buffer) + { + _innerStream.Write(buffer); + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + return _innerStream.WriteAsync(buffer, cancellationToken); + } + + public override void CopyTo(Stream destination, int bufferSize) + { + _innerStream.CopyTo(destination, bufferSize); + } +#endif + } } }