diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets/SocketConnection.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets/SocketConnection.cs index d8378a43c3..7a1dab14c5 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets/SocketConnection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets/SocketConnection.cs @@ -4,8 +4,10 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.IO; using System.Net; using System.Net.Sockets; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Internal.System.Buffers; using Microsoft.AspNetCore.Server.Kestrel.Internal.System.IO.Pipelines; @@ -23,8 +25,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets private IPipeWriter _input; private IPipeReader _output; private IList> _sendBufferList; - - private const int MinAllocBufferSize = 2048; // from libuv transport + private const int MinAllocBufferSize = 2048; internal SocketConnection(Socket socket, SocketTransport transport) { @@ -51,17 +52,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets Task receiveTask = DoReceive(); Task sendTask = DoSend(); - // Wait for eiher of them to complete (note they won't throw exceptions) - await Task.WhenAny(receiveTask, sendTask); - - // Shut the socket down and wait for both sides to end - _socket.Shutdown(SocketShutdown.Both); + // If the sending task completes then close the receive + // We don't need to do this in the other direction because the kestrel + // will trigger the output closing once the input is complete. + if (await Task.WhenAny(receiveTask, sendTask) == sendTask) + { + // Tell the reader it's being aborted + _socket.Dispose(); + } // Now wait for both to complete await receiveTask; await sendTask; - // Dispose the socket + // Dispose the socket(should noop if already called) _socket.Dispose(); } catch (Exception) @@ -90,9 +94,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets if (bytesReceived == 0) { - // We receive a FIN so throw an exception so that we cancel the input - // with an error - throw new TaskCanceledException("The request was aborted"); + // FIN + break; } buffer.Advance(bytesReceived); @@ -106,17 +109,45 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets if (result.IsCompleted) { // Pipe consumer is shut down, do we stop writing - _socket.Shutdown(SocketShutdown.Receive); break; } } + _connectionContext.Abort(ex: null); _input.Complete(); } catch (Exception ex) { - _connectionContext.Abort(ex); - _input.Complete(ex); + Exception error = null; + + if (ex is SocketException se) + { + if (se.SocketErrorCode == SocketError.ConnectionReset) + { + // Connection reset + error = new ConnectionResetException(ex.Message, ex); + } + else if (se.SocketErrorCode == SocketError.OperationAborted) + { + error = new TaskCanceledException("The request was aborted"); + } + } + + if (ex is ObjectDisposedException) + { + error = new TaskCanceledException("The request was aborted"); + } + else if (ex is IOException ioe) + { + error = ioe; + } + else if (error == null) + { + error = new IOException(ex.Message, ex); + } + + _connectionContext.Abort(error); + _input.Complete(error); } } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs index 11c494c242..517b6a0cca 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions; using Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Internal.Networking; using Microsoft.AspNetCore.Testing; using Microsoft.AspNetCore.Testing.xunit; @@ -462,9 +463,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { await context.Request.Body.ReadAsync(new byte[1], 0, 1); } - catch (IOException ex) + catch (ConnectionResetException) { - expectedExceptionThrown = ex.InnerException is UvException && ex.InnerException.Message.Contains("ECONNRESET"); + expectedExceptionThrown = true; } appDone.Release();