diff --git a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs index 6acbc0e378..b51d3ce6b8 100644 --- a/src/Kestrel.Core/Internal/Http/HttpProtocol.cs +++ b/src/Kestrel.Core/Internal/Http/HttpProtocol.cs @@ -1131,7 +1131,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http // chunked is applied to a response payload body, the sender MUST either // apply chunked as the final transfer coding or terminate the message // by closing the connection. - if (hasTransferEncoding && + if (hasTransferEncoding && HttpHeaders.GetFinalTransferCoding(responseHeaders.HeaderTransferEncoding) != TransferCoding.Chunked) { _keepAlive = false; diff --git a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs index 7ffa068d88..0ef973b98c 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2Connection.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2Connection.cs @@ -79,6 +79,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 private readonly object _stateLock = new object(); private int _highestOpenedStreamId; private Http2ConnectionState _state = Http2ConnectionState.Open; + private readonly TaskCompletionSource _streamsCompleted = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly ConcurrentDictionary _streams = new ConcurrentDictionary(); @@ -256,6 +257,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 _frameWriter.WriteGoAwayAsync(_highestOpenedStreamId, errorCode); UpdateState(Http2ConnectionState.Closed); } + + if (_streams.IsEmpty) + { + _streamsCompleted.TrySetResult(null); + } } // Ensure aborting each stream doesn't result in unnecessary WINDOW_UPDATE frames being sent. @@ -266,6 +272,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 stream.Abort(connectionError); } + await _streamsCompleted.Task; + _frameWriter.Complete(); } catch @@ -891,13 +899,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 { _streams.TryRemove(streamId, out _); - if (_state == Http2ConnectionState.Closing && _streams.IsEmpty) + if (_streams.IsEmpty) { - _frameWriter.WriteGoAwayAsync(_highestOpenedStreamId, Http2ErrorCode.NO_ERROR); - UpdateState(Http2ConnectionState.Closed); + if (_state == Http2ConnectionState.Closing) + { + _frameWriter.WriteGoAwayAsync(_highestOpenedStreamId, Http2ErrorCode.NO_ERROR); + UpdateState(Http2ConnectionState.Closed); - // Wake up request processing loop so the connection can complete if there are no pending requests - Input.CancelPendingRead(); + // Wake up request processing loop so the connection can complete if there are no pending requests + Input.CancelPendingRead(); + } + + + if (_state != Http2ConnectionState.Open) + { + // Complete the task waiting on all streams to finish + _streamsCompleted.TrySetResult(null); + } } } } diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs b/src/Kestrel.Transport.Sockets/Internal/SocketAwaitableEventArgs.cs similarity index 71% rename from src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs rename to src/Kestrel.Transport.Sockets/Internal/SocketAwaitableEventArgs.cs index 6c4de75c45..d757316668 100644 --- a/src/Kestrel.Transport.Sockets/Internal/SocketAwaitable.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketAwaitableEventArgs.cs @@ -11,22 +11,20 @@ using System.Threading.Tasks; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal { - public class SocketAwaitable : ICriticalNotifyCompletion + public class SocketAwaitableEventArgs : SocketAsyncEventArgs, ICriticalNotifyCompletion { private static readonly Action _callbackCompleted = () => { }; private readonly PipeScheduler _ioScheduler; private Action _callback; - private int _bytesTransferred; - private SocketError _error; - public SocketAwaitable(PipeScheduler ioScheduler) + public SocketAwaitableEventArgs(PipeScheduler ioScheduler) { _ioScheduler = ioScheduler; } - public SocketAwaitable GetAwaiter() => this; + public SocketAwaitableEventArgs GetAwaiter() => this; public bool IsCompleted => ReferenceEquals(_callback, _callbackCompleted); public int GetResult() @@ -35,12 +33,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal _callback = null; - if (_error != SocketError.Success) + if (SocketError != SocketError.Success) { - throw new SocketException((int)_error); + ThrowSocketException(SocketError); } - return _bytesTransferred; + return BytesTransferred; + + void ThrowSocketException(SocketError e) + { + throw new SocketException((int)e); + } } public void OnCompleted(Action continuation) @@ -57,10 +60,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal OnCompleted(continuation); } - public void Complete(int bytesTransferred, SocketError socketError) + public void Complete() + { + OnCompleted(this); + } + + protected override void OnCompleted(SocketAsyncEventArgs _) { - _error = socketError; - _bytesTransferred = bytesTransferred; var continuation = Interlocked.Exchange(ref _callback, _callbackCompleted); if (continuation != null) diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs b/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs index 223d5e9b70..d84b1dfa9d 100644 --- a/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketReceiver.cs @@ -7,42 +7,29 @@ using System.Net.Sockets; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal { - public class SocketReceiver : IDisposable + public sealed class SocketReceiver : SocketSenderReceiverBase { - private readonly Socket _socket; - private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); - private readonly SocketAwaitable _awaitable; - - public SocketReceiver(Socket socket, PipeScheduler scheduler) + public SocketReceiver(Socket socket, PipeScheduler scheduler) : base(socket, scheduler) { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); } - public SocketAwaitable ReceiveAsync(Memory buffer) + public SocketAwaitableEventArgs ReceiveAsync(Memory buffer) { #if NETCOREAPP2_1 - _eventArgs.SetBuffer(buffer); + _awaitableEventArgs.SetBuffer(buffer); #elif NETSTANDARD2_0 var segment = buffer.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + _awaitableEventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #else #error TFMs need to be updated #endif - if (!_socket.ReceiveAsync(_eventArgs)) + if (!_socket.ReceiveAsync(_awaitableEventArgs)) { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + _awaitableEventArgs.Complete(); } - return _awaitable; - } - - public void Dispose() - { - _eventArgs.Dispose(); + return _awaitableEventArgs; } } } diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs b/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs index 0684560d87..4dba6aedb4 100644 --- a/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs +++ b/src/Kestrel.Transport.Sockets/Internal/SocketSender.cs @@ -11,23 +11,15 @@ using System.Runtime.InteropServices; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal { - public class SocketSender : IDisposable + public sealed class SocketSender : SocketSenderReceiverBase { - private readonly Socket _socket; - private readonly SocketAsyncEventArgs _eventArgs = new SocketAsyncEventArgs(); - private readonly SocketAwaitable _awaitable; - private List> _bufferList; - public SocketSender(Socket socket, PipeScheduler scheduler) + public SocketSender(Socket socket, PipeScheduler scheduler) : base(socket, scheduler) { - _socket = socket; - _awaitable = new SocketAwaitable(scheduler); - _eventArgs.UserToken = _awaitable; - _eventArgs.Completed += (_, e) => ((SocketAwaitable)e.UserToken).Complete(e.BytesTransferred, e.SocketError); } - public SocketAwaitable SendAsync(ReadOnlySequence buffers) + public SocketAwaitableEventArgs SendAsync(ReadOnlySequence buffers) { if (buffers.IsSingleSegment) { @@ -35,49 +27,49 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal } #if NETCOREAPP2_1 - if (!_eventArgs.MemoryBuffer.Equals(Memory.Empty)) + if (!_awaitableEventArgs.MemoryBuffer.Equals(Memory.Empty)) #elif NETSTANDARD2_0 - if (_eventArgs.Buffer != null) + if (_awaitableEventArgs.Buffer != null) #else #error TFMs need to be updated #endif { - _eventArgs.SetBuffer(null, 0, 0); + _awaitableEventArgs.SetBuffer(null, 0, 0); } - _eventArgs.BufferList = GetBufferList(buffers); + _awaitableEventArgs.BufferList = GetBufferList(buffers); - if (!_socket.SendAsync(_eventArgs)) + if (!_socket.SendAsync(_awaitableEventArgs)) { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + _awaitableEventArgs.Complete(); } - return _awaitable; + return _awaitableEventArgs; } - private SocketAwaitable SendAsync(ReadOnlyMemory memory) + private SocketAwaitableEventArgs SendAsync(ReadOnlyMemory memory) { // The BufferList getter is much less expensive then the setter. - if (_eventArgs.BufferList != null) + if (_awaitableEventArgs.BufferList != null) { - _eventArgs.BufferList = null; + _awaitableEventArgs.BufferList = null; } #if NETCOREAPP2_1 - _eventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); + _awaitableEventArgs.SetBuffer(MemoryMarshal.AsMemory(memory)); #elif NETSTANDARD2_0 var segment = memory.GetArray(); - _eventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); + _awaitableEventArgs.SetBuffer(segment.Array, segment.Offset, segment.Count); #else #error TFMs need to be updated #endif - if (!_socket.SendAsync(_eventArgs)) + if (!_socket.SendAsync(_awaitableEventArgs)) { - _awaitable.Complete(_eventArgs.BytesTransferred, _eventArgs.SocketError); + _awaitableEventArgs.Complete(); } - return _awaitable; + return _awaitableEventArgs; } private List> GetBufferList(ReadOnlySequence buffer) @@ -102,10 +94,5 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal return _bufferList; } - - public void Dispose() - { - _eventArgs.Dispose(); - } } } diff --git a/src/Kestrel.Transport.Sockets/Internal/SocketSenderReceiverBase.cs b/src/Kestrel.Transport.Sockets/Internal/SocketSenderReceiverBase.cs new file mode 100644 index 0000000000..3258b31c58 --- /dev/null +++ b/src/Kestrel.Transport.Sockets/Internal/SocketSenderReceiverBase.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Net.Sockets; + +namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Sockets.Internal +{ + public abstract class SocketSenderReceiverBase : IDisposable + { + protected readonly Socket _socket; + protected readonly SocketAwaitableEventArgs _awaitableEventArgs; + + protected SocketSenderReceiverBase(Socket socket, PipeScheduler scheduler) + { + _socket = socket; + _awaitableEventArgs = new SocketAwaitableEventArgs(scheduler); + } + + public void Dispose() => _awaitableEventArgs.Dispose(); + } +} diff --git a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs index e183f10523..a41526e333 100644 --- a/test/Kestrel.Core.Tests/Http2ConnectionTests.cs +++ b/test/Kestrel.Core.Tests/Http2ConnectionTests.cs @@ -1065,12 +1065,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task DATA_Received_NoStreamWindowSpace_ConnectionError() { - // I hate doing this, but it avoids exceptions from MemoryPool.Dipose() in debug mode. The problem is since - // the stream's ProcessRequestsAsync loop is never awaited by the connection, it's not really possible to - // observe when all the blocks are returned. This can be removed after we implement graceful shutdown. - Dispose(); - InitializeConnectionFields(new DiagnosticMemoryPool(KestrelMemoryPool.CreateSlabMemoryPool(), allowLateReturn: true)); - // _maxData should be 1/4th of the default initial window size + 1. Assert.Equal(Http2PeerSettings.DefaultInitialWindowSize + 1, (uint)_maxData.Length * 4); @@ -1093,12 +1087,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task DATA_Received_NoConnectionWindowSpace_ConnectionError() { - // I hate doing this, but it avoids exceptions from MemoryPool.Dipose() in debug mode. The problem is since - // the stream's ProcessRequestsAsync loop is never awaited by the connection, it's not really possible to - // observe when all the blocks are returned. This can be removed after we implement graceful shutdown. - Dispose(); - InitializeConnectionFields(new DiagnosticMemoryPool(KestrelMemoryPool.CreateSlabMemoryPool(), allowLateReturn: true)); - // _maxData should be 1/4th of the default initial window size + 1. Assert.Equal(Http2PeerSettings.DefaultInitialWindowSize + 1, (uint)_maxData.Length * 4); @@ -3286,8 +3274,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests _pair.Application.Output.Complete(new ConnectionResetException(string.Empty)); - var result = await _pair.Application.Input.ReadAsync(); - Assert.True(result.IsCompleted); + await StopConnectionAsync(1, ignoreNonGoAwayFrames: false); Assert.Single(_logger.Messages, m => m.Exception is ConnectionResetException); } @@ -3337,6 +3324,54 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests VerifyGoAway(await ReceiveFrameAsync(), 0, Http2ErrorCode.NO_ERROR); } + [Fact] + public async Task StopProcessingNextRequestSendsGracefulGOAWAYAndWaitsForStreamsToComplete() + { + var task = Task.CompletedTask; + await InitializeConnectionAsync(context => task); + + // Send and receive an unblocked request + await StartStreamAsync(1, _browserRequestHeaders, endStream: true); + + await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + // Send a blocked request + var tcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + task = tcs.Task; + await StartStreamAsync(3, _browserRequestHeaders, endStream: false); + + // Close pipe + _pair.Application.Output.Complete(); + + // Assert connection closed + await _closedStateReached.Task.DefaultTimeout(); + VerifyGoAway(await ReceiveFrameAsync(), 3, Http2ErrorCode.NO_ERROR); + + // Assert connection shutdown is still blocked + // ProcessRequestsAsync completes the connection's Input pipe + var readTask = _pair.Application.Input.ReadAsync(); + _pair.Application.Input.CancelPendingRead(); + var result = await readTask; + Assert.False(result.IsCompleted); + + // Unblock the request and ProcessRequestsAsync + tcs.TrySetResult(null); + await _connectionTask; + + // Assert connection's Input pipe is completed + readTask = _pair.Application.Input.ReadAsync(); + _pair.Application.Input.CancelPendingRead(); + result = await readTask; + Assert.True(result.IsCompleted); + } + [Fact] public async Task StopProcessingNextRequestSendsGracefulGOAWAYThenFinalGOAWAYWhenAllStreamsComplete() { diff --git a/test/Kestrel.Core.Tests/Http2StreamTests.cs b/test/Kestrel.Core.Tests/Http2StreamTests.cs index 4e2a09c449..670aeb1c53 100644 --- a/test/Kestrel.Core.Tests/Http2StreamTests.cs +++ b/test/Kestrel.Core.Tests/Http2StreamTests.cs @@ -963,12 +963,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task ContentLength_Received_SingleDataFrameUnderSize_Reset() { - // I hate doing this, but it avoids exceptions from MemoryPool.Dipose() in debug mode. The problem is since - // the stream's ProcessRequestsAsync loop is never awaited by the connection, it's not really possible to - // observe when all the blocks are returned. This can be removed after we implement graceful shutdown. - Dispose(); - InitializeConnectionFields(new DiagnosticMemoryPool(KestrelMemoryPool.CreateSlabMemoryPool(), allowLateReturn: true)); - var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -996,12 +990,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task ContentLength_Received_MultipleDataFramesOverSize_Reset() { - // I hate doing this, but it avoids exceptions from MemoryPool.Dipose() in debug mode. The problem is since - // the stream's ProcessRequestsAsync loop is never awaited by the connection, it's not really possible to - // observe when all the blocks are returned. This can be removed after we implement graceful shutdown. - Dispose(); - InitializeConnectionFields(new DiagnosticMemoryPool(KestrelMemoryPool.CreateSlabMemoryPool(), allowLateReturn: true)); - var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), @@ -1032,12 +1020,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Fact] public async Task ContentLength_Received_MultipleDataFramesUnderSize_Reset() { - // I hate doing this, but it avoids exceptions from MemoryPool.Dipose() in debug mode. The problem is since - // the stream's ProcessRequestsAsync loop is never awaited by the connection, it's not really possible to - // observe when all the blocks are returned. This can be removed after we implement graceful shutdown. - Dispose(); - InitializeConnectionFields(new DiagnosticMemoryPool(KestrelMemoryPool.CreateSlabMemoryPool(), allowLateReturn: true)); - var headers = new[] { new KeyValuePair(HeaderNames.Method, "POST"), diff --git a/test/Kestrel.Transport.FunctionalTests/Http2/H2SpecTests.cs b/test/Kestrel.Transport.FunctionalTests/Http2/H2SpecTests.cs index b5a196c5f2..baca985366 100644 --- a/test/Kestrel.Transport.FunctionalTests/Http2/H2SpecTests.cs +++ b/test/Kestrel.Transport.FunctionalTests/Http2/H2SpecTests.cs @@ -27,9 +27,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.Http2 [MemberData(nameof(H2SpecTestCases))] public async Task RunIndividualTestCase(H2SpecTestCase testCase) { - var memoryPoolFactory = new DiagnosticMemoryPoolFactory(allowLateReturn: true); - - var hostBuilder = TransportSelector.GetWebHostBuilder(memoryPoolFactory.Create) + var hostBuilder = TransportSelector.GetWebHostBuilder() .UseKestrel(options => { options.Listen(IPAddress.Loopback, 0, listenOptions => @@ -66,7 +64,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.Http2 { skip = "https://github.com/aspnet/KestrelHttpServer/issues/2154"; } - + dataset.Add(new H2SpecTestCase() { Id = testcase.Item1, @@ -74,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests.Http2 Https = false, Skip = skip, }); - + dataset.Add(new H2SpecTestCase() { Id = testcase.Item1,