From 8e23200fd2d3f6e22e74f2a1a6f8ba0f682c6eff Mon Sep 17 00:00:00 2001 From: Chris Ross Date: Mon, 28 Apr 2014 14:54:33 -0700 Subject: [PATCH] #3 - Implement IHttpRequestLifetime --- .../FeatureContext.cs | 13 ++++++- .../RequestProcessing/RequestContext.cs | 38 ++++++++++++------- .../ServerTests.cs | 36 ++++++++++++++++-- 3 files changed, 69 insertions(+), 18 deletions(-) diff --git a/src/Microsoft.AspNet.Server.WebListener/FeatureContext.cs b/src/Microsoft.AspNet.Server.WebListener/FeatureContext.cs index 727fff3423..361a26face 100644 --- a/src/Microsoft.AspNet.Server.WebListener/FeatureContext.cs +++ b/src/Microsoft.AspNet.Server.WebListener/FeatureContext.cs @@ -11,7 +11,7 @@ using Microsoft.Net.Server; namespace Microsoft.AspNet.Server.WebListener { - internal class FeatureContext : IHttpRequestInformation, IHttpConnection, IHttpResponseInformation, IHttpSendFile, IHttpTransportLayerSecurity + internal class FeatureContext : IHttpRequestInformation, IHttpConnection, IHttpResponseInformation, IHttpSendFile, IHttpTransportLayerSecurity, IHttpRequestLifetime { private RequestContext _requestContext; private FeatureCollection _features; @@ -66,6 +66,7 @@ namespace Microsoft.AspNet.Server.WebListener } _features.Add(typeof(IHttpResponseInformation), this); _features.Add(typeof(IHttpSendFile), this); + _features.Add(typeof(IHttpRequestLifetime), this); // TODO: // _environment.CallCancelled = _cts.Token; @@ -334,5 +335,15 @@ namespace Microsoft.AspNet.Server.WebListener { return Response.SendFileAsync(path, offset, length, cancellation); } + + public CancellationToken OnRequestAborted + { + get { return _requestContext.DisconnectToken; } + } + + public void Abort() + { + _requestContext.Abort(); + } } } diff --git a/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs b/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs index 0546cf349a..16ac300615 100644 --- a/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs +++ b/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs @@ -25,7 +25,7 @@ namespace Microsoft.Net.Server private NativeRequestContext _memoryBlob; private OpaqueFunc _opaqueCallback; private bool _disposed; - private CancellationTokenRegistration? _disconnectRegistration; + private CancellationTokenSource _requestAbortSource; private CancellationToken? _disconnectToken; internal RequestContext(WebListener httpListener, NativeRequestContext memoryBlob) @@ -63,24 +63,26 @@ namespace Microsoft.Net.Server { get { + // Create a new token per request, but link it to a single connection token. + // We need to be able to dispose of the registrations each request to prevent leaks. if (!_disconnectToken.HasValue) { - _disconnectToken = _server.RegisterForDisconnectNotification(this); - if (_disconnectToken.Value.CanBeCanceled) + var connectionDisconnectToken = _server.RegisterForDisconnectNotification(this); + + if (connectionDisconnectToken.CanBeCanceled) { - _disconnectRegistration = _disconnectToken.Value.Register(Cancel, this); + _requestAbortSource = CancellationTokenSource.CreateLinkedTokenSource(connectionDisconnectToken); + _disconnectToken = _requestAbortSource.Token; + } + else + { + _disconnectToken = CancellationToken.None; } } return _disconnectToken.Value; } } - private static void Cancel(object obj) - { - RequestContext context = (RequestContext)obj; - context.Abort(); - } - internal WebListener Server { get @@ -138,9 +140,9 @@ namespace Microsoft.Net.Server // TODO: Verbose log try { - if (_disconnectRegistration.HasValue) + if (_requestAbortSource != null) { - _disconnectRegistration.Value.Dispose(); + _requestAbortSource.Dispose(); } _response.Dispose(); } @@ -155,9 +157,17 @@ namespace Microsoft.Net.Server // May be called from Dispose() code path, don't check _disposed. // TODO: Verbose log _disposed = true; - if (_disconnectRegistration.HasValue) + if (_requestAbortSource != null) { - _disconnectRegistration.Value.Dispose(); + try + { + _requestAbortSource.Cancel(); + } + catch (Exception ex) + { + LogHelper.LogException(Logger, "Abort", ex); + } + _requestAbortSource.Dispose(); } ForceCancelRequest(RequestQueueHandle, _request.RequestId); _request.Dispose(); diff --git a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs index d6233e5c6d..9c7511aba3 100644 --- a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs +++ b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ServerTests.cs @@ -147,7 +147,7 @@ namespace Microsoft.AspNet.Server.WebListener Assert.True(Task.WaitAll(requestTasks.ToArray(), TimeSpan.FromSeconds(2)), "Timed out"); } } - /* TODO: + [Fact] public async Task Server_ClientDisconnects_CallCancelled() { @@ -158,7 +158,8 @@ namespace Microsoft.AspNet.Server.WebListener using (Utilities.CreateHttpServer(env => { - CancellationToken ct = env.Get("owin.CallCancelled"); + var httpContext = new DefaultHttpContext((IFeatureCollection)env); + CancellationToken ct = httpContext.OnRequestAborted; Assert.True(ct.CanBeCanceled, "CanBeCanceled"); Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); ct.Register(() => canceled.Set()); @@ -180,7 +181,36 @@ namespace Microsoft.AspNet.Server.WebListener Assert.True(canceled.WaitOne(interval), "canceled"); } } - */ + + [Fact] + public async Task Server_Abort_CallCancelled() + { + TimeSpan interval = TimeSpan.FromSeconds(100); + ManualResetEvent received = new ManualResetEvent(false); + ManualResetEvent aborted = new ManualResetEvent(false); + ManualResetEvent canceled = new ManualResetEvent(false); + + using (Utilities.CreateHttpServer(env => + { + var httpContext = new DefaultHttpContext((IFeatureCollection)env); + CancellationToken ct = httpContext.OnRequestAborted; + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + ct.Register(() => canceled.Set()); + received.Set(); + httpContext.Abort(); + Assert.True(canceled.WaitOne(interval), "Aborted"); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + return Task.FromResult(0); + })) + { + using (Socket socket = await SendHungRequestAsync("GET", Address)) + { + Assert.True(received.WaitOne(interval), "Receive Timeout"); + Assert.Throws(() => socket.Receive(new byte[10])); + } + } + } [Fact] public async Task Server_SetQueueLimit_Success()