Improve HTTP/2 stream abort logic (#2819)

- Fix race where headers frame could be written after an abort was observed
- Fix Http2StreamTests to verify expected abort-related exceptions
This commit is contained in:
Stephen Halter 2018-08-13 11:45:17 -07:00 committed by GitHub
parent 2bbd890357
commit cd6de2fa18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 50 deletions

View File

@ -6,13 +6,11 @@ using System.Buffers;
using System.Diagnostics;
using System.Globalization;
using System.IO.Pipelines;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Connections.Abstractions;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
@ -28,6 +26,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
protected readonly long _keepAliveTicks;
private readonly long _requestHeadersTimeoutTicks;
private int _requestAborted;
private volatile bool _requestTimedOut;
private uint _requestCount;
@ -61,6 +60,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override bool IsUpgradableRequest => _upgradeAvailable;
/// <summary>
/// Immediately kill the connection and poison the request body stream with an error.
/// </summary>
public void Abort(ConnectionAbortedException abortReason)
{
if (Interlocked.Exchange(ref _requestAborted, 1) != 0)
{
return;
}
// Abort output prior to calling OnIOCompleted() to give the transport the chance to complete the input
// with the correct error and message.
Output.Abort(abortReason);
OnInputOrOutputCompleted();
PoisonRequestBodyStream(abortReason);
}
protected override void ApplicationAbort()
{
Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier);
Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication));
}
/// <summary>
/// Stops the request processing loop between requests.
/// Called on all active connections when the server wants to initiate a shutdown

View File

@ -6,7 +6,6 @@ using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
@ -230,10 +229,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ApplicationAbort();
}
protected virtual void ApplicationAbort()
{
Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier);
Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication));
}
protected abstract void ApplicationAbort();
}
}

View File

@ -18,7 +18,6 @@ using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
@ -42,7 +41,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private Stack<KeyValuePair<Func<object, Task>, object>> _onStarting;
private Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted;
private int _requestAborted;
private volatile int _ioCompleted;
private CancellationTokenSource _abortedCts;
private CancellationToken? _manuallySetRequestAbortToken;
@ -385,6 +383,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
}
protected virtual void OnErrorAfterResponseStarted()
{
}
protected virtual bool BeginRead(out ValueTask<ReadResult> awaitable)
{
awaitable = default;
@ -425,23 +427,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this);
}
/// <summary>
/// Immediately kill the connection and poison the request and response streams with an error if there is one.
/// </summary>
public virtual void Abort(ConnectionAbortedException abortReason)
protected void PoisonRequestBodyStream(Exception abortReason)
{
if (Interlocked.Exchange(ref _requestAborted, 1) != 0)
{
return;
}
_streams?.Abort(abortReason);
// Abort output prior to calling OnIOCompleted() to give the transport the chance to
// complete the input with the correct error and message.
Output.Abort(abortReason);
OnInputOrOutputCompleted();
}
public void OnHeader(Span<byte> name, Span<byte> value)
@ -1032,7 +1020,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
if (HasResponseStarted)
{
ErrorAfterResponseStarted();
// We can no longer change the response, so we simply close the connection.
_keepAlive = false;
OnErrorAfterResponseStarted();
return Task.CompletedTask;
}
@ -1057,12 +1047,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return WriteSuffix();
}
protected virtual void ErrorAfterResponseStarted()
{
// We can no longer change the response, so we simply close the connection.
_keepAlive = false;
}
[MethodImpl(MethodImplOptions.NoInlining)]
private async Task ProduceEndAwaited()
{

View File

@ -201,7 +201,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
catch (Http2StreamErrorException ex)
{
Log.Http2StreamError(ConnectionId, ex);
AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(ex.Message, ex));
AbortStream(_incomingFrame.StreamId, new IOException(ex.Message, ex));
await _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode);
}
finally
@ -269,7 +269,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
foreach (var stream in _streams.Values)
{
stream.Abort(connectionError);
stream.Abort(new IOException(CoreStrings.Http2StreamAborted, connectionError));
}
await _streamsCompleted.Task;
@ -583,7 +583,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
}
ThrowIfIncomingFrameSentToIdleStream();
AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(CoreStrings.Http2StreamResetByClient));
AbortStream(_incomingFrame.StreamId, new IOException(CoreStrings.Http2StreamResetByClient));
return Task.CompletedTask;
}
@ -885,7 +885,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
}
}
private void AbortStream(int streamId, ConnectionAbortedException error)
private void AbortStream(int streamId, IOException error)
{
if (_streams.TryGetValue(streamId, out var stream))
{

View File

@ -352,7 +352,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes);
}
public override void Abort(ConnectionAbortedException abortReason)
public void Abort(IOException abortReason)
{
if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted))
{
@ -362,10 +362,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
AbortCore(abortReason);
}
protected override void ErrorAfterResponseStarted()
protected override void OnErrorAfterResponseStarted()
{
// We can no longer change the response, send a Reset instead.
base.ErrorAfterResponseStarted();
var abortReason = new ConnectionAbortedException(CoreStrings.Http2StreamErrorAfterHeaders);
ResetAndAbort(abortReason, Http2ErrorCode.INTERNAL_ERROR);
}
@ -391,12 +390,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
AbortCore(abortReason);
}
private void AbortCore(ConnectionAbortedException abortReason)
private void AbortCore(Exception abortReason)
{
base.Abort(abortReason);
// Call OnIOCompleted() which closes the output prior to poisoning the request body stream or pipe to
// ensure that an app that completes early due to the abort doesn't result in header frames being sent.
OnInputOrOutputCompleted();
// Unblock the request body.
RequestBodyPipe.Writer.Complete(new IOException(CoreStrings.Http2StreamAborted, abortReason));
PoisonRequestBodyStream(abortReason);
RequestBodyPipe.Writer.Complete(abortReason);
_inputFlowControl.Abort();
}
@ -420,7 +422,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
var lastCompletionState = _completionState;
_completionState |= completionState;
if (ShoulStopTrackingStream(_completionState) && !ShoulStopTrackingStream(lastCompletionState))
if (ShouldStopTrackingStream(_completionState) && !ShouldStopTrackingStream(lastCompletionState))
{
_context.StreamLifetimeHandler.OnStreamCompleted(StreamId);
}
@ -429,7 +431,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
}
}
private static bool ShoulStopTrackingStream(StreamCompletionFlags completionState)
private static bool ShouldStopTrackingStream(StreamCompletionFlags completionState)
{
// This could be a single condition, but I think it reads better as two if's.
if ((completionState & StreamCompletionFlags.RequestProcessingEnded) == StreamCompletionFlags.RequestProcessingEnded)

View File

@ -937,6 +937,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
[Fact]
public async Task ContentLength_Received_SingleDataFrameOverSize_Reset()
{
IOException thrownEx = null;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -946,7 +948,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
};
await InitializeConnectionAsync(async context =>
{
await Assert.ThrowsAsync<ConnectionAbortedException>(async () =>
thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -959,11 +961,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
}
[Fact]
public async Task ContentLength_Received_SingleDataFrameUnderSize_Reset()
{
IOException thrownEx = null;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -973,7 +983,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
};
await InitializeConnectionAsync(async context =>
{
await Assert.ThrowsAsync<ConnectionAbortedException>(async () =>
thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -986,11 +996,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
}
[Fact]
public async Task ContentLength_Received_MultipleDataFramesOverSize_Reset()
{
IOException thrownEx = null;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -1000,7 +1018,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
};
await InitializeConnectionAsync(async context =>
{
await Assert.ThrowsAsync<ConnectionAbortedException>(async () =>
thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -1016,11 +1034,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
}
[Fact]
public async Task ContentLength_Received_MultipleDataFramesUnderSize_Reset()
{
IOException thrownEx = null;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -1030,7 +1056,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
};
await InitializeConnectionAsync(async context =>
{
await Assert.ThrowsAsync<ConnectionAbortedException>(async () =>
thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -1044,6 +1070,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
}
[Fact]
@ -1490,6 +1522,62 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
}
[Fact]
public async Task RequestAbort_ThrowsOperationCanceledExceptionFromSubsequentRequestBodyStreamRead()
{
OperationCanceledException thrownEx = null;
await InitializeConnectionAsync(async context =>
{
context.Abort();
var buffer = new byte[100];
var thrownExTask = Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length));
Assert.True(thrownExTask.IsCompleted);
thrownEx = await thrownExTask;
});
await StartStreamAsync(1, _browserRequestHeaders, endStream: false);
await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
Assert.NotNull(thrownEx);
Assert.IsType<ConnectionAbortedException>(thrownEx);
Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.Message);
}
[Fact]
public async Task RequestAbort_ThrowsOperationCanceledExceptionFromOngoingRequestBodyStreamRead()
{
OperationCanceledException thrownEx = null;
await InitializeConnectionAsync(async context =>
{
var buffer = new byte[100];
var thrownExTask = Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length));
Assert.False(thrownExTask.IsCompleted);
context.Abort();
thrownEx = await thrownExTask.DefaultTimeout();
});
await StartStreamAsync(1, _browserRequestHeaders, endStream: false);
await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
Assert.NotNull(thrownEx);
Assert.IsType<TaskCanceledException>(thrownEx);
Assert.Equal("The request was aborted", thrownEx.Message);
Assert.IsType<ConnectionAbortedException>(thrownEx.InnerException);
Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.InnerException.Message);
}
private async Task InitializeConnectionAsync(RequestDelegate application)
{
_connectionTask = _connection.ProcessRequestsAsync(new DummyApplication(application));