From 6f7a841374a6bbcc8375da71b26ee3edb759cb80 Mon Sep 17 00:00:00 2001 From: John Luo Date: Fri, 17 Aug 2018 17:45:47 -0700 Subject: [PATCH] Fire OnStreamCompleted after all pipes are closed --- .../Internal/Http2/Http2Stream.cs | 85 +++++++++++++------ 1 file changed, 60 insertions(+), 25 deletions(-) diff --git a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs index 9ae576a1b8..82935a343b 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Stream.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Stream.cs @@ -63,13 +63,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 protected override void OnRequestProcessingEnded() { - TryApplyCompletionFlag(StreamCompletionFlags.RequestProcessingEnded); + var states = ApplyCompletionFlag(StreamCompletionFlags.RequestProcessingEnded); - RequestBodyPipe.Reader.Complete(); + try + { + RequestBodyPipe.Reader.Complete(); - // The app can no longer read any more of the request body, so return any bytes that weren't read to the - // connection's flow-control window. - _inputFlowControl.Abort(); + // The app can no longer read any more of the request body, so return any bytes that weren't read to the + // connection's flow-control window. + _inputFlowControl.Abort(); + } + finally + { + TryFireOnStreamCompleted(states); + } } protected override string CreateRequestId() @@ -335,11 +342,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 } } - TryApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived); + var states = ApplyCompletionFlag(StreamCompletionFlags.EndStreamReceived); - RequestBodyPipe.Writer.Complete(); + try + { + RequestBodyPipe.Writer.Complete(); - _inputFlowControl.StopWindowUpdates(); + _inputFlowControl.StopWindowUpdates(); + } + finally + { + TryFireOnStreamCompleted(states); + } } public void OnDataRead(int bytesRead) @@ -354,12 +368,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 public void Abort(IOException abortReason) { - if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted)) - { - return; - } + var states = ApplyCompletionFlag(StreamCompletionFlags.Aborted); - AbortCore(abortReason); + try + { + if (states.OldState == states.NewState) + { + return; + } + + AbortCore(abortReason); + } + finally + { + TryFireOnStreamCompleted(states); + } } protected override void OnErrorAfterResponseStarted() @@ -377,17 +400,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 private void ResetAndAbort(ConnectionAbortedException abortReason, Http2ErrorCode error) { - if (!TryApplyCompletionFlag(StreamCompletionFlags.Aborted)) + var states = ApplyCompletionFlag(StreamCompletionFlags.Aborted); + + if (states.OldState == states.NewState) { return; } - Log.Http2StreamResetAbort(TraceIdentifier, error, abortReason); + try + { + Log.Http2StreamResetAbort(TraceIdentifier, error, abortReason); - // Don't block on IO. This never faults. - _ = _http2Output.WriteRstStreamAsync(error); + // Don't block on IO. This never faults. + _ = _http2Output.WriteRstStreamAsync(error); - AbortCore(abortReason); + AbortCore(abortReason); + } + finally + { + TryFireOnStreamCompleted(states); + } } private void AbortCore(Exception abortReason) @@ -415,19 +447,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 minimumSegmentSize: KestrelMemoryPool.MinimumSegmentSize )); - private bool TryApplyCompletionFlag(StreamCompletionFlags completionState) + private (StreamCompletionFlags OldState, StreamCompletionFlags NewState) ApplyCompletionFlag(StreamCompletionFlags completionState) { lock (_completionLock) { - var lastCompletionState = _completionState; + var oldCompletionState = _completionState; _completionState |= completionState; - if (ShouldStopTrackingStream(_completionState) && !ShouldStopTrackingStream(lastCompletionState)) - { - _context.StreamLifetimeHandler.OnStreamCompleted(StreamId); - } + return (oldCompletionState, _completionState); + } + } - return _completionState != lastCompletionState; + private void TryFireOnStreamCompleted((StreamCompletionFlags OldState, StreamCompletionFlags NewState) states) + { + if (!ShouldStopTrackingStream(states.OldState) && ShouldStopTrackingStream(states.NewState)) + { + _context.StreamLifetimeHandler.OnStreamCompleted(StreamId); } }