diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs index 5ad6bbf78d..34b68d7242 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/RequestContext.cs @@ -69,16 +69,24 @@ namespace Microsoft.Net.Http.Server // We need to be able to dispose of the registrations each request to prevent leaks. if (!_disconnectToken.HasValue) { - var connectionDisconnectToken = Server.DisconnectListener.GetTokenForConnection(Request.UConnectionId); - - if (connectionDisconnectToken.CanBeCanceled) + if (_disposed || Response.BodyIsFinished) { - _requestAbortSource = CancellationTokenSource.CreateLinkedTokenSource(connectionDisconnectToken); - _disconnectToken = _requestAbortSource.Token; + // We cannot register for disconnect notifications after the response has finished sending. + _disconnectToken = CancellationToken.None; } else { - _disconnectToken = CancellationToken.None; + var connectionDisconnectToken = Server.DisconnectListener.GetTokenForConnection(Request.UConnectionId); + + if (connectionDisconnectToken.CanBeCanceled) + { + _requestAbortSource = CancellationTokenSource.CreateLinkedTokenSource(connectionDisconnectToken); + _disconnectToken = _requestAbortSource.Token; + } + else + { + _disconnectToken = CancellationToken.None; + } } } return _disconnectToken.Value; diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs index 3ff9af71cd..1c6dc9c2c9 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs @@ -115,6 +115,8 @@ namespace Microsoft.Net.Http.Server } } + internal bool BodyIsFinished => _nativeStream?.IsDisposed ?? _responseState >= ResponseState.Closed; + /// /// The authentication challenges that will be added to the response if the status code is 401. /// This must be a subset of the AuthenticationSchemes enabled on the server. diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs index cacb354851..b5bfca3915 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/ResponseStream.cs @@ -63,6 +63,8 @@ namespace Microsoft.Net.Http.Server internal bool ThrowWriteExceptions => RequestContext.Server.Settings.ThrowWriteExceptions; + internal bool IsDisposed => _disposed; + public override bool CanSeek { get diff --git a/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs b/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs index 332774d211..21e60cb3aa 100644 --- a/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs +++ b/test/Microsoft.Net.Http.Server.FunctionalTests/ServerTests.cs @@ -104,6 +104,68 @@ namespace Microsoft.Net.Http.Server } } + [Fact] + public async Task Server_TokenRegisteredAfterClientDisconnects_CallCanceled() + { + var interval = TimeSpan.FromSeconds(1); + var canceled = new ManualResetEvent(false); + + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + using (var client = new HttpClient()) + { + var responseTask = client.GetAsync(address); + + var context = await server.AcceptAsync(); + + client.CancelPendingRequests(); + await Assert.ThrowsAsync(() => responseTask); + + var ct = context.DisconnectToken; + Assert.True(ct.CanBeCanceled, "CanBeCanceled"); + ct.Register(() => canceled.Set()); + Assert.True(ct.WaitHandle.WaitOne(interval)); + Assert.True(ct.IsCancellationRequested, "IsCancellationRequested"); + + Assert.True(canceled.WaitOne(interval), "canceled"); + + context.Dispose(); + } + } + } + + [Fact] + public async Task Server_TokenRegisteredAfterResponseSent_Success() + { + var interval = TimeSpan.FromSeconds(1); + var canceled = new ManualResetEvent(false); + + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + using (var client = new HttpClient()) + { + var responseTask = client.GetAsync(address); + + var context = await server.AcceptAsync(); + context.Dispose(); + + var response = await responseTask; + response.EnsureSuccessStatusCode(); + Assert.Equal(string.Empty, await response.Content.ReadAsStringAsync()); + + var ct = context.DisconnectToken; + Assert.False(ct.CanBeCanceled, "CanBeCanceled"); + ct.Register(() => canceled.Set()); + Assert.False(ct.WaitHandle.WaitOne(interval)); + Assert.False(ct.IsCancellationRequested, "IsCancellationRequested"); + + Assert.False(canceled.WaitOne(interval), "canceled"); + } + } + } + [Fact] public async Task Server_Abort_CallCanceled() {