diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs index 4061f1e80a..182bc98654 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Connection.cs @@ -152,7 +152,13 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { handle.Libuv.Check(status, out error); } + _rawSocketInput.IncomingComplete(readCount, error); + + if (errorDone) + { + _frame.Abort(); + } } private Frame CreateFrame() diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs index 8cac73a946..b14f968f16 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.FeatureCollection.cs @@ -7,9 +7,11 @@ using System.Collections.Generic; using System.IO; using System.Linq; using System.Net; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Http; using Microsoft.AspNet.Http.Features; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNet.Server.Kestrel.Http @@ -18,7 +20,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http IHttpRequestFeature, IHttpResponseFeature, IHttpUpgradeFeature, - IHttpConnectionFeature + IHttpConnectionFeature, + IHttpRequestLifetimeFeature { // NOTE: When feature interfaces are added to or removed from this Frame class implementation, // then the list of `implementedFeatures` in the generated code project MUST also be updated. @@ -260,6 +263,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http bool IHttpConnectionFeature.IsLocal { get; set; } + CancellationToken IHttpRequestLifetimeFeature.RequestAborted { get; set; } + object IFeatureCollection.this[Type key] { get { return FastFeatureGet(key); } @@ -298,5 +303,10 @@ namespace Microsoft.AspNet.Server.Kestrel.Http IEnumerator> IEnumerable>.GetEnumerator() => FastEnumerable().GetEnumerator(); IEnumerator IEnumerable.GetEnumerator() => FastEnumerable().GetEnumerator(); + + void IHttpRequestLifetimeFeature.Abort() + { + Abort(); + } } } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.Generated.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.Generated.cs index 985267a57f..e449f94d83 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.Generated.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.Generated.cs @@ -45,11 +45,11 @@ namespace Microsoft.AspNet.Server.Kestrel.Http _currentIHttpRequestFeature = this; _currentIHttpResponseFeature = this; _currentIHttpUpgradeFeature = this; + _currentIHttpRequestLifetimeFeature = this; _currentIHttpConnectionFeature = this; _currentIHttpRequestIdentifierFeature = null; _currentIServiceProvidersFeature = null; - _currentIHttpRequestLifetimeFeature = null; _currentIHttpAuthenticationFeature = null; _currentIQueryFeature = null; _currentIFormFeature = null; diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index 104fc6c1d9..39f4e428ec 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -41,6 +41,12 @@ namespace Microsoft.AspNet.Server.Kestrel.Http private bool _requestProcessingStarted; private Task _requestProcessingTask; private volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx + private volatile bool _requestAborted; + private CancellationTokenSource _disconnectCts = new CancellationTokenSource(); + private CancellationTokenSource _requestAbortCts; + + private FrameRequestStream _requestBody; + private FrameResponseStream _responseBody; private bool _responseStarted; private bool _keepAlive; @@ -74,7 +80,6 @@ namespace Microsoft.AspNet.Server.Kestrel.Http public string QueryString { get; set; } public string HttpVersion { get; set; } public IHeaderDictionary RequestHeaders { get; set; } - public MessageBody MessageBody { get; set; } public Stream RequestBody { get; set; } public int StatusCode { get; set; } @@ -110,7 +115,6 @@ namespace Microsoft.AspNet.Server.Kestrel.Http QueryString = null; HttpVersion = null; RequestHeaders = _requestHeaders; - MessageBody = null; RequestBody = null; StatusCode = 200; ReasonPhrase = null; @@ -133,6 +137,8 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { httpConnectionFeature.IsLocal = false; } + + _requestAbortCts?.Dispose(); } public void ResetResponseHeaders() @@ -169,6 +175,30 @@ namespace Microsoft.AspNet.Server.Kestrel.Http return _requestProcessingTask ?? TaskUtilities.CompletedTask; } + /// + /// Immediate kill the connection and poison the request and response streams. + /// + public void Abort() + { + _requestProcessingStopping = true; + _requestAborted = true; + + _requestBody?.StopAcceptingReads(); + _responseBody?.StopAcceptingWrites(); + + try + { + ConnectionControl.End(ProduceEndType.SocketDisconnect); + SocketInput.AbortAwaiting(); + + _disconnectCts.Cancel(); + } + catch (Exception ex) + { + Log.LogError("Abort", ex); + } + } + /// /// Primary loop which consumes socket input, parses it for protocol framing, and invokes the /// application delegate for as long as the socket is intended to remain open. @@ -202,14 +232,17 @@ namespace Microsoft.AspNet.Server.Kestrel.Http if (!terminated && !_requestProcessingStopping) { - MessageBody = MessageBody.For(HttpVersion, _requestHeaders, this); - _keepAlive = MessageBody.RequestKeepAlive; - var requestBody = new FrameRequestStream(MessageBody); - RequestBody = requestBody; - var responseBody = new FrameResponseStream(this); - ResponseBody = responseBody; + var messageBody = MessageBody.For(HttpVersion, _requestHeaders, this); + _keepAlive = messageBody.RequestKeepAlive; + _requestBody = new FrameRequestStream(messageBody); + RequestBody = _requestBody; + _responseBody = new FrameResponseStream(this); + ResponseBody = _responseBody; DuplexStream = new FrameDuplexStream(RequestBody, ResponseBody); + _requestAbortCts = CancellationTokenSource.CreateLinkedTokenSource(_disconnectCts.Token); + ((IHttpRequestLifetimeFeature)this).RequestAborted = _requestAbortCts.Token; + var httpContext = HttpContextFactory.Create(this); try { @@ -234,13 +267,17 @@ namespace Microsoft.AspNet.Server.Kestrel.Http HttpContextFactory.Dispose(httpContext); - await ProduceEnd(); + // If _requestAbort is set, the connection has already been closed. + if (!_requestAborted) + { + await ProduceEnd(); - // Finish reading the request body in case the app did not. - await MessageBody.Consume(); + // Finish reading the request body in case the app did not. + await messageBody.Consume(); + } - requestBody.StopAcceptingReads(); - responseBody.StopAcceptingWrites(); + _requestBody.StopAcceptingReads(); + _responseBody.StopAcceptingWrites(); } terminated = !_keepAlive; @@ -257,14 +294,20 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { try { - // Inform client no more data will ever arrive - ConnectionControl.End(ProduceEndType.SocketShutdownSend); + _disconnectCts.Dispose(); - // Wait for client to either disconnect or send unexpected data - await SocketInput; + // If _requestAborted is set, the connection has already been closed. + if (!_requestAborted) + { + // Inform client no more data will ever arrive + ConnectionControl.End(ProduceEndType.SocketShutdownSend); - // Dispose socket - ConnectionControl.End(ProduceEndType.SocketDisconnect); + // Wait for client to either disconnect or send unexpected data + await SocketInput; + + // Dispose socket + ConnectionControl.End(ProduceEndType.SocketDisconnect); + } } catch (Exception ex) { diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs index 98910340ab..ae74bffb24 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInput.cs @@ -176,6 +176,22 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } + public void AbortAwaiting() + { + var awaitableState = Interlocked.Exchange( + ref _awaitableState, + _awaitableIsCompleted); + + _awaitableError = new ObjectDisposedException(nameof(SocketInput), "The request was aborted"); + _manualResetEvent.Set(); + + if (awaitableState != _awaitableIsCompleted && + awaitableState != _awaitableIsNotCompleted) + { + Task.Run(awaitableState); + } + } + public SocketInput GetAwaiter() { return this; @@ -199,6 +215,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http else { // THIS IS AN ERROR STATE - ONLY ONE WAITER CAN WAIT + throw new InvalidOperationException("Concurrent reads are not supported."); } } diff --git a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs index d5893e214b..48db5b9f21 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/EngineTests.cs @@ -6,8 +6,10 @@ using System.IO; using System.Net; using System.Net.Sockets; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNet.Http; +using Microsoft.AspNet.Http.Features; using Microsoft.AspNet.Server.Kestrel; using Microsoft.AspNet.Server.Kestrel.Filter; using Microsoft.AspNet.Testing.xunit; @@ -925,6 +927,76 @@ namespace Microsoft.AspNet.Server.KestrelTests } } + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task RequestsCanBeAbortedMidRead(ServiceContext testContext) + { + var readTcs = new TaskCompletionSource(); + var registrationTcs = new TaskCompletionSource(); + var requestId = 0; + + using (var server = new TestServer(async httpContext => + { + requestId++; + + var response = httpContext.Response; + var request = httpContext.Request; + var lifetime = httpContext.Features.Get(); + + lifetime.RequestAborted.Register(() => registrationTcs.TrySetResult(requestId)); + + if (requestId == 1) + { + response.Headers.Clear(); + response.Headers["Content-Length"] = new[] { "5" }; + + await response.WriteAsync("World"); + } + else + { + var readTask = request.Body.CopyToAsync(Stream.Null); + + lifetime.Abort(); + + try + { + await readTask; + } + catch (Exception ex) + { + readTcs.SetException(ex); + throw; + } + } + }, testContext)) + { + using (var connection = new TestConnection()) + { + // Never send the body so CopyToAsync always fails. + await connection.Send( + "POST / HTTP/1.1", + "Content-Length: 5", + "", + "HelloPOST / HTTP/1.1", + "Content-Length: 5", + "", + ""); + + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + "Content-Length: 5", + "", + "World"); + } + } + + await Assert.ThrowsAsync(async () => await readTcs.Task); + + // The cancellation token for only the last request should be triggered. + var abortedRequestId = await registrationTcs.Task; + Assert.Equal(2, abortedRequestId); + } + private class TestApplicationErrorLogger : ILogger { public int ApplicationErrorsLogged { get; set; } diff --git a/tools/Microsoft.AspNet.Server.Kestrel.GeneratedCode/FrameFeatureCollection.cs b/tools/Microsoft.AspNet.Server.Kestrel.GeneratedCode/FrameFeatureCollection.cs index 62b8e19245..e32e54c133 100644 --- a/tools/Microsoft.AspNet.Server.Kestrel.GeneratedCode/FrameFeatureCollection.cs +++ b/tools/Microsoft.AspNet.Server.Kestrel.GeneratedCode/FrameFeatureCollection.cs @@ -66,6 +66,7 @@ namespace Microsoft.AspNet.Server.Kestrel.GeneratedCode typeof(IHttpRequestFeature), typeof(IHttpResponseFeature), typeof(IHttpUpgradeFeature), + typeof(IHttpRequestLifetimeFeature), typeof(IHttpConnectionFeature) };