diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs index d48bc08944..f92f419bf8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs @@ -154,15 +154,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // Called on Libuv thread public virtual void OnSocketClosed() { - if (_filteredStreamAdapter != null) + _frame.FrameStartedTask.ContinueWith((task, state) => { - Task.WhenAll(_readInputTask, _frame.StopAsync()).ContinueWith((task, state) => + var connection = (Connection)state; + + if (_filteredStreamAdapter != null) { - var connection = (Connection)state; - connection._filterContext.Connection.Dispose(); - connection._filteredStreamAdapter.Dispose(); - }, this); - } + Task.WhenAll(_readInputTask, _frame.StopAsync()).ContinueWith((task2, state2) => + { + var connection2 = (Connection)state2; + connection2._filterContext.Connection.Dispose(); + connection2._filteredStreamAdapter.Dispose(); + }, connection); + } + }, this); SocketInput.Dispose(); _socketClosedTcs.TrySetResult(null); @@ -199,7 +204,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } _frame.PrepareRequest = _filterContext.PrepareRequest; - _frame.Start(); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index c4a0c50d85..eabcf12081 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -51,6 +51,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http protected Stack, object>> _onStarting; protected Stack, object>> _onCompleted; + private TaskCompletionSource _frameStartedTcs = new TaskCompletionSource(); private Task _requestProcessingTask; protected volatile bool _requestProcessingStopping; // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx protected int _requestAborted; @@ -205,6 +206,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public Stream DuplexStream { get; set; } + public Task FrameStartedTask => _frameStartedTcs.Task; + public CancellationToken RequestAborted { get @@ -366,6 +369,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http default(CancellationToken), TaskCreationOptions.DenyChildAttach, TaskScheduler.Default).Unwrap(); + _frameStartedTcs.SetResult(null); } /// diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs index 8f3a624406..f6badb31cf 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Net; using System.Net.Security; using System.Net.Sockets; using System.Security.Authentication; @@ -184,6 +185,35 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests await tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); } + // Regression test for https://github.com/aspnet/KestrelHttpServer/pull/1197 + [Fact] + public void ConnectionFilterDoesNotLeakBlock() + { + var loggerFactory = new HandshakeErrorLoggerFactory(); + + var hostBuilder = new WebHostBuilder() + .UseKestrel(options => + { + options.UseHttps(@"TestResources/testCert.pfx", "testPassword"); + }) + .UseUrls("https://127.0.0.1:0/") + .UseLoggerFactory(loggerFactory) + .Configure(app => { }); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp)) + { + socket.Connect(new IPEndPoint(IPAddress.Loopback, host.GetPort())); + + // Close socket immediately + socket.LingerState = new LingerOption(true, 0); + } + } + } + private class HandshakeErrorLoggerFactory : ILoggerFactory { public HttpsConnectionFilterLogger FilterLogger { get; } = new HttpsConnectionFilterLogger();