From 3177ba0aaeb775e8347ebdc103b12c37f5e6cf44 Mon Sep 17 00:00:00 2001 From: Cesar Blum Silveira Date: Mon, 10 Oct 2016 10:16:03 -0700 Subject: [PATCH] Wait for frame loop completion to dispose connection stream (#1156). --- .../Internal/Http/Connection.cs | 4 +- .../Internal/Http/Frame.cs | 4 +- .../HttpsTests.cs | 50 ++++++++++++++++++- .../FrameTests.cs | 42 +++++++++++++++- 4 files changed, 94 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs index 2ba693ef73..9c9e225a9f 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Connection.cs @@ -135,7 +135,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public Task StopAsync() { - _frame.Stop(); + _frame.StopAsync(); _frame.SocketInput.CompleteAwaiting(); return _socketClosedTcs.Task; @@ -156,7 +156,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { if (_filteredStreamAdapter != null) { - _readInputTask.ContinueWith((task, state) => + Task.WhenAll(_readInputTask, _frame.StopAsync()).ContinueWith((task, state) => { var connection = (Connection)state; connection._filterContext.Connection.Dispose(); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs index 9a13ab2cb7..f85940605c 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/Frame.cs @@ -365,7 +365,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http this, default(CancellationToken), TaskCreationOptions.DenyChildAttach, - TaskScheduler.Default); + TaskScheduler.Default).Unwrap(); } /// @@ -374,7 +374,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http /// Stop will be called on all active connections, and Task.WaitAll() will be called on every /// return value. /// - public Task Stop() + public Task StopAsync() { if (!_requestProcessingStopping) { diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs index 716c71e7b0..8f3a624406 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/HttpsTests.cs @@ -105,7 +105,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { try { - await httpContext.Response.WriteAsync($"hello, world\r\r", ct); + await httpContext.Response.WriteAsync($"hello, world", ct); await Task.Delay(1000, ct); } catch (TaskCanceledException) @@ -136,6 +136,54 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests Assert.False(loggerFactory.ErrorLogger.ObjectDisposedExceptionLogged); } + [Fact] + public async Task DoesNotThrowObjectDisposedExceptionFromWriteAsyncAfterConnectionIsAborted() + { + var tcs = new TaskCompletionSource(); + var x509Certificate2 = new X509Certificate2(@"TestResources/testCert.pfx", "testPassword"); + var loggerFactory = new HandshakeErrorLoggerFactory(); + var hostBuilder = new WebHostBuilder() + .UseKestrel(options => + { + options.UseHttps(@"TestResources/testCert.pfx", "testPassword"); + }) + .UseUrls("https://127.0.0.1:0/") + .UseLoggerFactory(loggerFactory) + .Configure(app => app.Run(async httpContext => + { + httpContext.Abort(); + try + { + await httpContext.Response.WriteAsync($"hello, world"); + tcs.SetResult(null); + } + catch (Exception ex) + { + tcs.SetException(ex); + } + })); + + using (var host = hostBuilder.Build()) + { + host.Start(); + + using (var socket = await HttpClientSlim.GetSocket(new Uri($"https://127.0.0.1:{host.GetPort()}/"))) + using (var stream = new NetworkStream(socket, ownsSocket: false)) + using (var sslStream = new SslStream(stream, true, (sender, certificate, chain, errors) => true)) + { + await sslStream.AuthenticateAsClientAsync("127.0.0.1", clientCertificates: null, + enabledSslProtocols: SslProtocols.Tls11 | SslProtocols.Tls12, + checkCertificateRevocation: false); + + var request = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n"); + await sslStream.WriteAsync(request, 0, request.Length); + await sslStream.ReadAsync(new byte[32], 0, 32); + } + } + + await tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10)); + } + private class HandshakeErrorLoggerFactory : ILoggerFactory { public HttpsConnectionFilterLogger FilterLogger { get; } = new HttpsConnectionFilterLogger(); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs index b411dc6b8f..be5c3e9348 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs @@ -6,6 +6,8 @@ using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting.Server; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel; using Microsoft.AspNetCore.Server.Kestrel.Internal; @@ -1290,7 +1292,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var expectedKeepAliveTimeout = (long)serviceContext.ServerOptions.Limits.KeepAliveTimeout.TotalMilliseconds; connectionControl.Verify(cc => cc.SetTimeout(expectedKeepAliveTimeout, TimeoutAction.CloseConnection)); - frame.Stop(); + frame.StopAsync(); socketInput.IncomingFin(); requestProcessingTask.Wait(); @@ -1464,5 +1466,43 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // Assert Assert.Throws(() => frame.Flush()); } + + [Fact] + public async Task RequestProcessingTaskIsUnwrapped() + { + var trace = new KestrelTrace(new TestKestrelTrace()); + var ltp = new LoggingThreadPool(trace); + using (var pool = new MemoryPool()) + using (var socketInput = new SocketInput(pool, ltp)) + { + var serviceContext = new ServiceContext + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerOptions = new KestrelServerOptions(), + Log = trace + }; + var listenerContext = new ListenerContext(serviceContext) + { + ServerAddress = ServerAddress.FromUrl("http://localhost:5000") + }; + var connectionContext = new ConnectionContext(listenerContext) + { + ConnectionControl = Mock.Of(), + SocketInput = socketInput + }; + + var frame = new Frame(application: null, context: connectionContext); + frame.Start(); + + var data = Encoding.ASCII.GetBytes("GET / HTTP/1.1\r\n\r\n"); + socketInput.IncomingData(data, 0, data.Length); + + var requestProcessingTask = frame.StopAsync(); + Assert.IsNotType(typeof(Task), requestProcessingTask); + + await requestProcessingTask.TimeoutAfter(TimeSpan.FromSeconds(10)); + socketInput.IncomingFin(); + } + } } }