diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs index 235e724ffc..96ff0ab391 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -226,6 +226,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } void IHttpRequestLifetimeFeature.Abort() + { + ApplicationAbort(); + } + + protected virtual void ApplicationAbort() { Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier); Abort(new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication)); diff --git a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs index 03babbedeb..d7f1c886dd 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2OutputProducer.cs @@ -73,7 +73,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 public void Abort(ConnectionAbortedException abortReason) { - // TODO: RST_STREAM? Dispose(); } @@ -128,13 +127,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { lock (_dataWriterLock) { + // The HPACK header compressor is stateful, if we compress headers for an aborted stream we must send them. + // Optimize for not compressing or sending them. if (_completed) { return; } - // The HPACK header compressor is stateful, if we compress headers for an aborted stream we must send them. - // Optimize for not compressing or sending them. _frameWriter.WriteResponseHeaders(_streamId, statusCode, responseHeaders); } } @@ -181,6 +180,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } } + public Task WriteRstStreamAsync(Http2ErrorCode error) + { + lock (_dataWriterLock) + { + // Always send the reset even if the response body is _completed. The request body may not have completed yet. + + Dispose(); + + return _frameWriter.WriteRstStreamAsync(_streamId, error); + } + } + private async Task ProcessDataWrites() { try diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs index d47a78be13..5b52546f20 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs @@ -5,6 +5,7 @@ using System; using System.Buffers; using System.IO; using System.IO.Pipelines; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; @@ -16,15 +17,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 public partial class Http2Stream : HttpProtocol { private readonly Http2StreamContext _context; + private readonly Http2OutputProducer _http2Output; private readonly Http2StreamOutputFlowControl _outputFlowControl; + private int _requestAborted; public Http2Stream(Http2StreamContext context) : base(context) { _context = context; _outputFlowControl = new Http2StreamOutputFlowControl(context.ConnectionOutputFlowControl, context.ClientPeerSettings.InitialWindowSize); - - Output = new Http2OutputProducer(context.StreamId, context.FrameWriter, _outputFlowControl, context.TimeoutControl, context.MemoryPool); + _http2Output = new Http2OutputProducer(context.StreamId, context.FrameWriter, _outputFlowControl, context.TimeoutControl, context.MemoryPool); + Output = _http2Output; } public int StreamId => _context.StreamId; @@ -144,17 +147,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } } + public bool TryUpdateOutputWindow(int bytes) + { + return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes); + } + public override void Abort(ConnectionAbortedException abortReason) + { + if (Interlocked.Exchange(ref _requestAborted, 1) != 0) + { + return; + } + + AbortCore(abortReason); + } + + protected override void ApplicationAbort() + { + Log.ApplicationAbortedConnection(ConnectionId, TraceIdentifier); + var abortReason = new ConnectionAbortedException(CoreStrings.ConnectionAbortedByApplication); + ResetAndAbort(abortReason, Http2ErrorCode.CANCEL); + } + + private void ResetAndAbort(ConnectionAbortedException abortReason, Http2ErrorCode error) + { + if (Interlocked.Exchange(ref _requestAborted, 1) != 0) + { + return; + } + + // Don't block on IO. This never faults. + _ = _http2Output.WriteRstStreamAsync(error); + + AbortCore(abortReason); + } + + private void AbortCore(ConnectionAbortedException abortReason) { base.Abort(abortReason); // Unblock the request body. RequestBodyPipe.Writer.Complete(new IOException(CoreStrings.Http2StreamAborted, abortReason)); } - - public bool TryUpdateOutputWindow(int bytes) - { - return _context.FrameWriter.TryUpdateStreamWindow(_outputFlowControl, bytes); - } } } diff --git a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs index ce15b17a46..997285ef4e 100644 --- a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs +++ b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs @@ -1892,12 +1892,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForAllStreamsAsync(); Assert.Contains(1, _abortedStreamIds); - await SendGoAwayAsync(); - - // No data is received from the stream since it was aborted before writing anything - await WaitForConnectionStopAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); - - // TODO: Check logs + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } [Fact] @@ -1910,12 +1905,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForAllStreamsAsync(); Assert.Contains(1, _abortedStreamIds); - await SendGoAwayAsync(); - - // No END_STREAM HEADERS or DATA frame is received since the stream was aborted - await WaitForConnectionStopAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); - - // TODO: Check logs + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } [Fact] @@ -1928,12 +1918,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForAllStreamsAsync(); Assert.Contains(1, _abortedStreamIds); - await SendGoAwayAsync(); - - // No END_STREAM HEADERS or DATA frame is received since the stream was aborted - await WaitForConnectionStopAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); - - // TODO: Check logs + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } [Fact] @@ -2179,12 +2164,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForAllStreamsAsync(); Assert.Contains(1, _abortedStreamIds); - await SendGoAwayAsync(); - - // No data is received from the stream since it was aborted before writing anything - await WaitForConnectionStopAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); - - // TODO: Check logs + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } [Fact] @@ -2228,12 +2208,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await WaitForAllStreamsAsync(); Assert.Contains(1, _abortedStreamIds); - await SendGoAwayAsync(); - - // No data is received from the stream since it was aborted before writing anything - await WaitForConnectionStopAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); - - // TODO: Check logs + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } [Fact] @@ -2312,6 +2287,98 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests expectedErrorMessage: CoreStrings.FormatHttp2ErrorHeadersInterleaved(Http2FrameType.RST_STREAM, streamId: 1, headersStreamId: 1)); } + [Fact] + public async Task RequestAbort_SendsRstStream() + { + await InitializeConnectionAsync(async context => + { + var streamIdFeature = context.Features.Get(); + + try + { + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }); + + context.Abort(); + + // Not sent + await context.Response.Body.WriteAsync(new byte[10], 0, 10); + + await _runningStreams[streamIdFeature.StreamId].Task; + } + catch (Exception ex) + { + _runningStreams[streamIdFeature.StreamId].TrySetException(ex); + } + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.CANCEL, expectedErrorMessage: null); + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task RequestAbort_AfterDataSent_SendsRstStream() + { + await InitializeConnectionAsync(async context => + { + var streamIdFeature = context.Features.Get(); + + try + { + context.RequestAborted.Register(() => + { + lock (_abortedStreamIdsLock) + { + _abortedStreamIds.Add(streamIdFeature.StreamId); + } + + _runningStreams[streamIdFeature.StreamId].TrySetResult(null); + }); + + await context.Response.Body.WriteAsync(new byte[10], 0, 10); + + context.Abort(); + + // Not sent + await context.Response.Body.WriteAsync(new byte[11], 0, 11); + + await _runningStreams[streamIdFeature.StreamId].Task; + } + catch (Exception ex) + { + _runningStreams[streamIdFeature.StreamId].TrySetException(ex); + } + }); + + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 10, + withFlags: 0, + withStreamId: 1); + + await WaitForStreamErrorAsync(expectedStreamId: 1, Http2ErrorCode.CANCEL, expectedErrorMessage: null); + await WaitForAllStreamsAsync(); + Assert.Contains(1, _abortedStreamIds); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + [Fact] public async Task SETTINGS_Received_Sends_ACK() {