From cd6de2fa187b247e552954ae62df03a0b2198f87 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Mon, 13 Aug 2018 11:45:17 -0700 Subject: [PATCH] Improve HTTP/2 stream abort logic (#2819) - Fix race where headers frame could be written after an abort was observed - Fix Http2StreamTests to verify expected abort-related exceptions --- .../Internal/Http/Http1Connection.cs | 32 ++++++- .../Http/HttpProtocol.FeatureCollection.cs | 7 +- .../Internal/Http/HttpProtocol.cs | 32 ++----- .../Internal/Http2/Http2Connection.cs | 8 +- .../Internal/Http2/Http2Stream.cs | 18 ++-- test/Kestrel.Core.Tests/Http2StreamTests.cs | 96 ++++++++++++++++++- 6 files changed, 143 insertions(+), 50 deletions(-) diff --git a/src/Kestrel.Core/Internal/Http/Http1Connection.cs b/src/Kestrel.Core/Internal/Http/Http1Connection.cs index 68c6187c43..5a4e1d68bb 100644 --- a/src/Kestrel.Core/Internal/Http/Http1Connection.cs +++ b/src/Kestrel.Core/Internal/Http/Http1Connection.cs @@ -6,13 +6,11 @@ using System.Buffers; using System.Diagnostics; using System.Globalization; using System.IO.Pipelines; -using System.Runtime.InteropServices; -using System.Text; +using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; -using Microsoft.AspNetCore.Connections.Abstractions; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http @@ -28,6 +26,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http protected readonly long _keepAliveTicks; private readonly long _requestHeadersTimeoutTicks; + private int _requestAborted; private volatile bool _requestTimedOut; private uint _requestCount; @@ -61,6 +60,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public override bool IsUpgradableRequest => _upgradeAvailable; + /// + /// Immediately kill the connection and poison the request body stream with an error. + /// + public void Abort(ConnectionAbortedException abortReason) + { + if (Interlocked.Exchange(ref _requestAborted, 1) != 0) + { + return; + } + + // Abort output prior to calling OnIOCompleted() to give the transport the chance to complete the input + // with the correct error and message. + Output.Abort(abortReason); + + OnInputOrOutputCompleted(); + + PoisonRequestBodyStream(abortReason); + } + + protected override void ApplicationAbort() + { + Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier); + Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication)); + } + /// /// Stops the request processing loop between requests. /// Called on all active connections when the server wants to initiate a shutdown diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs index 96ff0ab391..7f3b047d70 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -6,7 +6,6 @@ using System.IO; using System.Net; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; @@ -230,10 +229,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http ApplicationAbort(); } - protected virtual void ApplicationAbort() - { - Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier); - Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication)); - } + protected abstract void ApplicationAbort(); } } diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs index b51d3ce6b8..f272f0ba29 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs @@ -18,7 +18,6 @@ using Microsoft.AspNetCore.Hosting.Server; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; -using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; @@ -42,7 +41,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private Stack, object>> _onStarting; private Stack, object>> _onCompleted; - private int _requestAborted; private volatile int _ioCompleted; private CancellationTokenSource _abortedCts; private CancellationToken? _manuallySetRequestAbortToken; @@ -385,6 +383,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { } + protected virtual void OnErrorAfterResponseStarted() + { + } + protected virtual bool BeginRead(out ValueTask awaitable) { awaitable = default; @@ -425,23 +427,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this); } - /// - /// Immediately kill the connection and poison the request and response streams with an error if there is one. - /// - public virtual void Abort(ConnectionAbortedException abortReason) + protected void PoisonRequestBodyStream(Exception abortReason) { - if (Interlocked.Exchange(ref _requestAborted, 1) != 0) - { - return; - } - _streams?.Abort(abortReason); - - // Abort output prior to calling OnIOCompleted() to give the transport the chance to - // complete the input with the correct error and message. - Output.Abort(abortReason); - - OnInputOrOutputCompleted(); } public void OnHeader(Span name, Span value) @@ -1032,7 +1020,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { if (HasResponseStarted) { - ErrorAfterResponseStarted(); + // We can no longer change the response, so we simply close the connection. + _keepAlive = false; + OnErrorAfterResponseStarted(); return Task.CompletedTask; } @@ -1057,12 +1047,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return WriteSuffix(); } - protected virtual void ErrorAfterResponseStarted() - { - // We can no longer change the response, so we simply close the connection. - _keepAlive = false; - } - [MethodImpl(MethodImplOptions.NoInlining)] private async Task ProduceEndAwaited() { diff --git a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs index 0ef973b98c..47e2605578 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs @@ -201,7 +201,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 catch (Http2StreamErrorException ex) { Log.Http2StreamError(ConnectionId, ex); - AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(ex.Message, ex)); + AbortStream(_incomingFrame.StreamId, new IOException(ex.Message, ex)); await _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode); } finally @@ -269,7 +269,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 foreach (var stream in _streams.Values) { - stream.Abort(connectionError); + stream.Abort(new IOException(CoreStrings.Http2StreamAborted, connectionError)); } await _streamsCompleted.Task; @@ -583,7 +583,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } ThrowIfIncomingFrameSentToIdleStream(); - AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(CoreStrings.Http2StreamResetByClient)); + AbortStream(_incomingFrame.StreamId, new IOException(CoreStrings.Http2StreamResetByClient)); return Task.CompletedTask; } @@ -885,7 +885,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } } - private void AbortStream(int streamId, ConnectionAbortedException error) + private void AbortStream(int streamId, IOException error) { if (_streams.TryGetValue(streamId, out var stream)) { diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs index 464075cbd2..940fda25a6 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs @@ -352,7 +352,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes); } - public override void Abort(ConnectionAbortedException abortReason) + public void Abort(IOException abortReason) { if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted)) { @@ -362,10 +362,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 AbortCore(abortReason); } - protected override void ErrorAfterResponseStarted() + protected override void OnErrorAfterResponseStarted() { // We can no longer change the response, send a Reset instead. - base.ErrorAfterResponseStarted(); var abortReason = new ConnectionAbortedException(CoreStrings.Http2StreamErrorAfterHeaders); ResetAndAbort(abortReason, Http2ErrorCode.INTERNAL_ERROR); } @@ -391,12 +390,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 AbortCore(abortReason); } - private void AbortCore(ConnectionAbortedException abortReason) + private void AbortCore(Exception abortReason) { - base.Abort(abortReason); + // Call OnIOCompleted() which closes the output prior to poisoning the request body stream or pipe to + // ensure that an app that completes early due to the abort doesn't result in header frames being sent. + OnInputOrOutputCompleted(); // Unblock the request body. - RequestBodyPipe.Writer.Complete(new IOException(CoreStrings.Http2StreamAborted, abortReason)); + PoisonRequestBodyStream(abortReason); + RequestBodyPipe.Writer.Complete(abortReason); _inputFlowControl.Abort(); } @@ -420,7 +422,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 var lastCompletionState = _completionState; _completionState |= completionState; - if (ShoulStopTrackingStream(_completionState) && !ShoulStopTrackingStream(lastCompletionState)) + if (ShouldStopTrackingStream(_completionState) && !ShouldStopTrackingStream(lastCompletionState)) { _context.StreamLifetimeHandler.OnStreamCompleted(StreamId); } @@ -429,7 +431,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } } - private static bool ShoulStopTrackingStream(StreamCompletionFlags completionState) + private static bool ShouldStopTrackingStream(StreamCompletionFlags completionState) { // This could be a single condition, but I think it reads better as two if's. if ((completionState & StreamCompletionFlags.RequestProcessingEnded) == StreamCompletionFlags.RequestProcessingEnded) diff --git a/test/Kestrel.Core.Tests/Http2StreamTests.cs b/test/Kestrel.Core.Tests/Http2StreamTests.cs index 1eb2c10ad8..f5e0665e74 100644 --- a/test/Kestrel.Core.Tests/Http2StreamTests.cs +++ b/test/Kestrel.Core.Tests/Http2StreamTests.cs @@ -937,6 +937,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task ContentLength_Received_SingleDataFrameOverSize_Reset() { + IOException thrownEx = null; + var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -946,7 +948,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; await InitializeConnectionAsync(async context => { - await Assert.ThrowsAsync(async () => + thrownEx = await Assert.ThrowsAsync(async () => { var buffer = new byte[100]; while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } @@ -959,11 +961,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR); + + Assert.NotNull(thrownEx); + Assert.Equal(expectedError.Message, thrownEx.Message); + Assert.IsType(thrownEx.InnerException); } [Fact] public async Task ContentLength_Received_SingleDataFrameUnderSize_Reset() { + IOException thrownEx = null; + var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -973,7 +983,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; await InitializeConnectionAsync(async context => { - await Assert.ThrowsAsync(async () => + thrownEx = await Assert.ThrowsAsync(async () => { var buffer = new byte[100]; while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } @@ -986,11 +996,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR); + + Assert.NotNull(thrownEx); + Assert.Equal(expectedError.Message, thrownEx.Message); + Assert.IsType(thrownEx.InnerException); } [Fact] public async Task ContentLength_Received_MultipleDataFramesOverSize_Reset() { + IOException thrownEx = null; + var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -1000,7 +1018,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; await InitializeConnectionAsync(async context => { - await Assert.ThrowsAsync(async () => + thrownEx = await Assert.ThrowsAsync(async () => { var buffer = new byte[100]; while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } @@ -1016,11 +1034,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR); + + Assert.NotNull(thrownEx); + Assert.Equal(expectedError.Message, thrownEx.Message); + Assert.IsType(thrownEx.InnerException); } [Fact] public async Task ContentLength_Received_MultipleDataFramesUnderSize_Reset() { + IOException thrownEx = null; + var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -1030,7 +1056,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; await InitializeConnectionAsync(async context => { - await Assert.ThrowsAsync(async () => + thrownEx = await Assert.ThrowsAsync(async () => { var buffer = new byte[100]; while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } @@ -1044,6 +1070,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR); + + Assert.NotNull(thrownEx); + Assert.Equal(expectedError.Message, thrownEx.Message); + Assert.IsType(thrownEx.InnerException); } [Fact] @@ -1490,6 +1522,62 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } + [Fact] + public async Task RequestAbort_ThrowsOperationCanceledExceptionFromSubsequentRequestBodyStreamRead() + { + OperationCanceledException thrownEx = null; + + await InitializeConnectionAsync(async context => + { + context.Abort(); + + var buffer = new byte[100]; + var thrownExTask = Assert.ThrowsAnyAsync(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length)); + + Assert.True(thrownExTask.IsCompleted); + + thrownEx = await thrownExTask; + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.NotNull(thrownEx); + Assert.IsType(thrownEx); + Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.Message); + } + + [Fact] + public async Task RequestAbort_ThrowsOperationCanceledExceptionFromOngoingRequestBodyStreamRead() + { + OperationCanceledException thrownEx = null; + + await InitializeConnectionAsync(async context => + { + var buffer = new byte[100]; + var thrownExTask = Assert.ThrowsAnyAsync(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length)); + + Assert.False(thrownExTask.IsCompleted); + + context.Abort(); + + thrownEx = await thrownExTask.DefaultTimeout(); + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: false); + await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + Assert.NotNull(thrownEx); + Assert.IsType(thrownEx); + Assert.Equal("The request was aborted", thrownEx.Message); + Assert.IsType(thrownEx.InnerException); + Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.InnerException.Message); + } + private async Task InitializeConnectionAsync(RequestDelegate application) { _connectionTask = _connection.ProcessRequestsAsync(new DummyApplication(application));