Improve HTTP/2 stream abort logic (#2819)

- Fix race where headers frame could be written after an abort was observed
- Fix Http2StreamTests to verify expected abort-related exceptions
This commit is contained in:
Stephen Halter 2018-08-13 11:45:17 -07:00 committed by GitHub
parent 2bbd890357
commit cd6de2fa18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 143 additions and 50 deletions

View File

@ -6,13 +6,11 @@ using System.Buffers;
using System.Diagnostics; using System.Diagnostics;
using System.Globalization; using System.Globalization;
using System.IO.Pipelines; using System.IO.Pipelines;
using System.Runtime.InteropServices; using System.Threading;
using System.Text;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Connections.Abstractions;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
@ -28,6 +26,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
protected readonly long _keepAliveTicks; protected readonly long _keepAliveTicks;
private readonly long _requestHeadersTimeoutTicks; private readonly long _requestHeadersTimeoutTicks;
private int _requestAborted;
private volatile bool _requestTimedOut; private volatile bool _requestTimedOut;
private uint _requestCount; private uint _requestCount;
@ -61,6 +60,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
public override bool IsUpgradableRequest => _upgradeAvailable; public override bool IsUpgradableRequest => _upgradeAvailable;
/// <summary>
/// Immediately kill the connection and poison the request body stream with an error.
/// </summary>
public void Abort(ConnectionAbortedException abortReason)
{
if (Interlocked.Exchange(ref _requestAborted, 1) != 0)
{
return;
}
// Abort output prior to calling OnIOCompleted() to give the transport the chance to complete the input
// with the correct error and message.
Output.Abort(abortReason);
OnInputOrOutputCompleted();
PoisonRequestBodyStream(abortReason);
}
protected override void ApplicationAbort()
{
Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier);
Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication));
}
/// <summary> /// <summary>
/// Stops the request processing loop between requests. /// Stops the request processing loop between requests.
/// Called on all active connections when the server wants to initiate a shutdown /// Called on all active connections when the server wants to initiate a shutdown

View File

@ -6,7 +6,6 @@ using System.IO;
using System.Net; using System.Net;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
@ -230,10 +229,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ApplicationAbort(); ApplicationAbort();
} }
protected virtual void ApplicationAbort() protected abstract void ApplicationAbort();
{
Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier);
Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication));
}
} }
} }

View File

@ -18,7 +18,6 @@ using Microsoft.AspNetCore.Hosting.Server;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives; using Microsoft.Extensions.Primitives;
@ -42,7 +41,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private Stack<KeyValuePair<Func<object, Task>, object>> _onStarting; private Stack<KeyValuePair<Func<object, Task>, object>> _onStarting;
private Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted; private Stack<KeyValuePair<Func<object, Task>, object>> _onCompleted;
private int _requestAborted;
private volatile int _ioCompleted; private volatile int _ioCompleted;
private CancellationTokenSource _abortedCts; private CancellationTokenSource _abortedCts;
private CancellationToken? _manuallySetRequestAbortToken; private CancellationToken? _manuallySetRequestAbortToken;
@ -385,6 +383,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{ {
} }
protected virtual void OnErrorAfterResponseStarted()
{
}
protected virtual bool BeginRead(out ValueTask<ReadResult> awaitable) protected virtual bool BeginRead(out ValueTask<ReadResult> awaitable)
{ {
awaitable = default; awaitable = default;
@ -425,23 +427,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this); ServiceContext.Scheduler.Schedule(state => ((HttpProtocol)state).CancelRequestAbortedToken(), this);
} }
/// <summary> protected void PoisonRequestBodyStream(Exception abortReason)
/// Immediately kill the connection and poison the request and response streams with an error if there is one.
/// </summary>
public virtual void Abort(ConnectionAbortedException abortReason)
{ {
if (Interlocked.Exchange(ref _requestAborted, 1) != 0)
{
return;
}
_streams?.Abort(abortReason); _streams?.Abort(abortReason);
// Abort output prior to calling OnIOCompleted() to give the transport the chance to
// complete the input with the correct error and message.
Output.Abort(abortReason);
OnInputOrOutputCompleted();
} }
public void OnHeader(Span<byte> name, Span<byte> value) public void OnHeader(Span<byte> name, Span<byte> value)
@ -1032,7 +1020,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{ {
if (HasResponseStarted) if (HasResponseStarted)
{ {
ErrorAfterResponseStarted(); // We can no longer change the response, so we simply close the connection.
_keepAlive = false;
OnErrorAfterResponseStarted();
return Task.CompletedTask; return Task.CompletedTask;
} }
@ -1057,12 +1047,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return WriteSuffix(); return WriteSuffix();
} }
protected virtual void ErrorAfterResponseStarted()
{
// We can no longer change the response, so we simply close the connection.
_keepAlive = false;
}
[MethodImpl(MethodImplOptions.NoInlining)] [MethodImpl(MethodImplOptions.NoInlining)]
private async Task ProduceEndAwaited() private async Task ProduceEndAwaited()
{ {

View File

@ -201,7 +201,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
catch (Http2StreamErrorException ex) catch (Http2StreamErrorException ex)
{ {
Log.Http2StreamError(ConnectionId, ex); Log.Http2StreamError(ConnectionId, ex);
AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(ex.Message, ex)); AbortStream(_incomingFrame.StreamId, new IOException(ex.Message, ex));
await _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode); await _frameWriter.WriteRstStreamAsync(ex.StreamId, ex.ErrorCode);
} }
finally finally
@ -269,7 +269,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
foreach (var stream in _streams.Values) foreach (var stream in _streams.Values)
{ {
stream.Abort(connectionError); stream.Abort(new IOException(CoreStrings.Http2StreamAborted, connectionError));
} }
await _streamsCompleted.Task; await _streamsCompleted.Task;
@ -583,7 +583,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
} }
ThrowIfIncomingFrameSentToIdleStream(); ThrowIfIncomingFrameSentToIdleStream();
AbortStream(_incomingFrame.StreamId, new ConnectionAbortedException(CoreStrings.Http2StreamResetByClient)); AbortStream(_incomingFrame.StreamId, new IOException(CoreStrings.Http2StreamResetByClient));
return Task.CompletedTask; return Task.CompletedTask;
} }
@ -885,7 +885,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
} }
} }
private void AbortStream(int streamId, ConnectionAbortedException error) private void AbortStream(int streamId, IOException error)
{ {
if (_streams.TryGetValue(streamId, out var stream)) if (_streams.TryGetValue(streamId, out var stream))
{ {

View File

@ -352,7 +352,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes); return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes);
} }
public override void Abort(ConnectionAbortedException abortReason) public void Abort(IOException abortReason)
{ {
if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted)) if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted))
{ {
@ -362,10 +362,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
AbortCore(abortReason); AbortCore(abortReason);
} }
protected override void ErrorAfterResponseStarted() protected override void OnErrorAfterResponseStarted()
{ {
// We can no longer change the response, send a Reset instead. // We can no longer change the response, send a Reset instead.
base.ErrorAfterResponseStarted();
var abortReason = new ConnectionAbortedException(CoreStrings.Http2StreamErrorAfterHeaders); var abortReason = new ConnectionAbortedException(CoreStrings.Http2StreamErrorAfterHeaders);
ResetAndAbort(abortReason, Http2ErrorCode.INTERNAL_ERROR); ResetAndAbort(abortReason, Http2ErrorCode.INTERNAL_ERROR);
} }
@ -391,12 +390,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
AbortCore(abortReason); AbortCore(abortReason);
} }
private void AbortCore(ConnectionAbortedException abortReason) private void AbortCore(Exception abortReason)
{ {
base.Abort(abortReason); // Call OnIOCompleted() which closes the output prior to poisoning the request body stream or pipe to
// ensure that an app that completes early due to the abort doesn't result in header frames being sent.
OnInputOrOutputCompleted();
// Unblock the request body. // Unblock the request body.
RequestBodyPipe.Writer.Complete(new IOException(CoreStrings.Http2StreamAborted, abortReason)); PoisonRequestBodyStream(abortReason);
RequestBodyPipe.Writer.Complete(abortReason);
_inputFlowControl.Abort(); _inputFlowControl.Abort();
} }
@ -420,7 +422,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
var lastCompletionState = _completionState; var lastCompletionState = _completionState;
_completionState |= completionState; _completionState |= completionState;
if (ShoulStopTrackingStream(_completionState) && !ShoulStopTrackingStream(lastCompletionState)) if (ShouldStopTrackingStream(_completionState) && !ShouldStopTrackingStream(lastCompletionState))
{ {
_context.StreamLifetimeHandler.OnStreamCompleted(StreamId); _context.StreamLifetimeHandler.OnStreamCompleted(StreamId);
} }
@ -429,7 +431,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
} }
} }
private static bool ShoulStopTrackingStream(StreamCompletionFlags completionState) private static bool ShouldStopTrackingStream(StreamCompletionFlags completionState)
{ {
// This could be a single condition, but I think it reads better as two if's. // This could be a single condition, but I think it reads better as two if's.
if ((completionState & StreamCompletionFlags.RequestProcessingEnded) == StreamCompletionFlags.RequestProcessingEnded) if ((completionState & StreamCompletionFlags.RequestProcessingEnded) == StreamCompletionFlags.RequestProcessingEnded)

View File

@ -937,6 +937,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
[Fact] [Fact]
public async Task ContentLength_Received_SingleDataFrameOverSize_Reset() public async Task ContentLength_Received_SingleDataFrameOverSize_Reset()
{ {
IOException thrownEx = null;
var headers = new[] var headers = new[]
{ {
new KeyValuePair<string, string>(HeaderNames.Method, "POST"), new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -946,7 +948,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}; };
await InitializeConnectionAsync(async context => await InitializeConnectionAsync(async context =>
{ {
await Assert.ThrowsAsync<ConnectionAbortedException>(async () => thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{ {
var buffer = new byte[100]; var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -959,11 +961,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength); await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
} }
[Fact] [Fact]
public async Task ContentLength_Received_SingleDataFrameUnderSize_Reset() public async Task ContentLength_Received_SingleDataFrameUnderSize_Reset()
{ {
IOException thrownEx = null;
var headers = new[] var headers = new[]
{ {
new KeyValuePair<string, string>(HeaderNames.Method, "POST"), new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -973,7 +983,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}; };
await InitializeConnectionAsync(async context => await InitializeConnectionAsync(async context =>
{ {
await Assert.ThrowsAsync<ConnectionAbortedException>(async () => thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{ {
var buffer = new byte[100]; var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -986,11 +996,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength); await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
} }
[Fact] [Fact]
public async Task ContentLength_Received_MultipleDataFramesOverSize_Reset() public async Task ContentLength_Received_MultipleDataFramesOverSize_Reset()
{ {
IOException thrownEx = null;
var headers = new[] var headers = new[]
{ {
new KeyValuePair<string, string>(HeaderNames.Method, "POST"), new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -1000,7 +1018,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}; };
await InitializeConnectionAsync(async context => await InitializeConnectionAsync(async context =>
{ {
await Assert.ThrowsAsync<ConnectionAbortedException>(async () => thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{ {
var buffer = new byte[100]; var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -1016,11 +1034,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength); await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorMoreDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorMoreDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
} }
[Fact] [Fact]
public async Task ContentLength_Received_MultipleDataFramesUnderSize_Reset() public async Task ContentLength_Received_MultipleDataFramesUnderSize_Reset()
{ {
IOException thrownEx = null;
var headers = new[] var headers = new[]
{ {
new KeyValuePair<string, string>(HeaderNames.Method, "POST"), new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
@ -1030,7 +1056,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
}; };
await InitializeConnectionAsync(async context => await InitializeConnectionAsync(async context =>
{ {
await Assert.ThrowsAsync<ConnectionAbortedException>(async () => thrownEx = await Assert.ThrowsAsync<IOException>(async () =>
{ {
var buffer = new byte[100]; var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
@ -1044,6 +1070,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength); await WaitForStreamErrorAsync(1, Http2ErrorCode.PROTOCOL_ERROR, CoreStrings.Http2StreamErrorLessDataThanLength);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
var expectedError = new Http2StreamErrorException(1, CoreStrings.Http2StreamErrorLessDataThanLength, Http2ErrorCode.PROTOCOL_ERROR);
Assert.NotNull(thrownEx);
Assert.Equal(expectedError.Message, thrownEx.Message);
Assert.IsType<Http2StreamErrorException>(thrownEx.InnerException);
} }
[Fact] [Fact]
@ -1490,6 +1522,62 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
} }
[Fact]
public async Task RequestAbort_ThrowsOperationCanceledExceptionFromSubsequentRequestBodyStreamRead()
{
OperationCanceledException thrownEx = null;
await InitializeConnectionAsync(async context =>
{
context.Abort();
var buffer = new byte[100];
var thrownExTask = Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length));
Assert.True(thrownExTask.IsCompleted);
thrownEx = await thrownExTask;
});
await StartStreamAsync(1, _browserRequestHeaders, endStream: false);
await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
Assert.NotNull(thrownEx);
Assert.IsType<ConnectionAbortedException>(thrownEx);
Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.Message);
}
[Fact]
public async Task RequestAbort_ThrowsOperationCanceledExceptionFromOngoingRequestBodyStreamRead()
{
OperationCanceledException thrownEx = null;
await InitializeConnectionAsync(async context =>
{
var buffer = new byte[100];
var thrownExTask = Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Request.Body.ReadAsync(buffer, 0, buffer.Length));
Assert.False(thrownExTask.IsCompleted);
context.Abort();
thrownEx = await thrownExTask.DefaultTimeout();
});
await StartStreamAsync(1, _browserRequestHeaders, endStream: false);
await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.INTERNAL_ERROR, CoreStrings.ConnectionAbortedByApplication);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
Assert.NotNull(thrownEx);
Assert.IsType<TaskCanceledException>(thrownEx);
Assert.Equal("The request was aborted", thrownEx.Message);
Assert.IsType<ConnectionAbortedException>(thrownEx.InnerException);
Assert.Equal(CoreStrings.ConnectionAbortedByApplication, thrownEx.InnerException.Message);
}
private async Task InitializeConnectionAsync(RequestDelegate application) private async Task InitializeConnectionAsync(RequestDelegate application)
{ {
_connectionTask = _connection.ProcessRequestsAsync(new DummyApplication(application)); _connectionTask = _connection.ProcessRequestsAsync(new DummyApplication(application));