From cc15b1bb43d67b794b70cd3e17c33f3fc65a36a8 Mon Sep 17 00:00:00 2001 From: Brennan Date: Tue, 23 Jun 2020 22:14:12 -0700 Subject: [PATCH] [SignalR] Implement IConnectionLifetimeFeature (#20604) --- .../src/DefaultConnectionContext.cs | 3 +- .../src/Internal/HttpConnectionContext.cs | 23 ++- .../test/HttpConnectionDispatcherTests.cs | 169 +++++++++++++++++- .../server/Core/src/HubConnectionContext.cs | 19 +- .../server/Core/src/HubConnectionHandler.cs | 2 + .../HubConnectionHandlerTestUtils/Hubs.cs | 9 +- .../SignalR/test/HubConnectionHandlerTests.cs | 24 +++ 7 files changed, 239 insertions(+), 10 deletions(-) diff --git a/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs b/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs index ce3d115e2a..9ad2c62e8a 100644 --- a/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs +++ b/src/Servers/Connections.Abstractions/src/DefaultConnectionContext.cs @@ -26,7 +26,6 @@ namespace Microsoft.AspNetCore.Connections public DefaultConnectionContext() : this(Guid.NewGuid().ToString()) { - ConnectionClosed = _connectionClosedTokenSource.Token; } /// @@ -45,6 +44,8 @@ namespace Microsoft.AspNetCore.Connections Features.Set(this); Features.Set(this); Features.Set(this); + + ConnectionClosed = _connectionClosedTokenSource.Token; } public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index 783bcf07bf..ec7b124f8f 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -29,7 +29,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal ITransferFormatFeature, IHttpContextFeature, IHttpTransportFeature, - IConnectionInherentKeepAliveFeature + IConnectionInherentKeepAliveFeature, + IConnectionLifetimeFeature { private static long _tenSeconds = TimeSpan.FromSeconds(10).Ticks; @@ -41,6 +42,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private PipeWriterStream _applicationStream; private IDuplexPipe _application; private IDictionary _items; + private CancellationTokenSource _connectionClosedTokenSource; private CancellationTokenSource _sendCts; private bool _activeSend; @@ -82,6 +84,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Features.Set(this); Features.Set(this); Features.Set(this); + Features.Set(this); + + _connectionClosedTokenSource = new CancellationTokenSource(); + ConnectionClosed = _connectionClosedTokenSource.Token; } public CancellationTokenSource Cancellation { get; set; } @@ -170,6 +176,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal public HttpContext HttpContext { get; set; } + public override CancellationToken ConnectionClosed { get; set; } + + public override void Abort() + { + ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource); + + HttpContext?.Abort(); + } + public void OnHeartbeat(Action action, object state) { lock (_heartbeatLock) @@ -305,6 +320,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // Now complete the application Application?.Output.Complete(); Application?.Input.Complete(); + + // Trigger ConnectionClosed + ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource); } } else @@ -313,6 +331,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Application?.Output.Complete(transportTask.Exception?.InnerException); Application?.Input.Complete(); + // Trigger ConnectionClosed + ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource); + try { // A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index 22e5d5d6a2..bed8bba4db 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -961,7 +961,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } [Fact] - public async Task SynchronusExceptionEndsConnection() + public async Task SynchronousExceptionEndsConnection() { bool ExpectedErrors(WriteContext writeContext) { @@ -2269,6 +2269,173 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task LongPollingConnectionClosingTriggersConnectionClosedToken() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 2, resumeWriterThreshold: 1); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var pollTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(pollTask.IsCompleted); + + // Now send the second poll + pollTask = dispatcher.ExecuteAsync(context, options, app); + + // Issue the delete request and make sure the poll completes + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + + Assert.False(pollTask.IsCompleted); + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + await pollTask.OrTimeout(); + + // Verify that transport shuts down + await connection.TransportTask.OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); + + // Verify the connection not removed because application is hanging + Assert.True(manager.TryGetConnection(connection.ConnectionId, out _)); + + Assert.True(connection.ConnectionClosed.IsCancellationRequested); + } + } + + [Fact] + public async Task SSEConnectionClosingTriggersConnectionClosedToken() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = MakeRequest("/foo", connection); + SetTransport(context, connection.TransportType); + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + _ = dispatcher.ExecuteAsync(context, options, app); + + // Close the SSE connection + connection.Transport.Output.Complete(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection.ConnectionClosed.Register(() => tcs.SetResult(null)); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task WebSocketConnectionClosingTriggersConnectionClosedToken() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.WebSockets; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var context = MakeRequest("/foo", connection); + SetTransport(context, HttpTransportType.WebSockets); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1); + + _ = dispatcher.ExecuteAsync(context, options, app); + + var websocket = (TestWebSocketConnectionFeature)context.Features.Get(); + await websocket.Accepted.OrTimeout(); + await websocket.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken: default).OrTimeout(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection.ConnectionClosed.Register(() => tcs.SetResult(null)); + await tcs.Task.OrTimeout(); + } + } + + public class CustomHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature + { + public CancellationToken RequestAborted { get; set; } + + private CancellationTokenSource _cts; + public CustomHttpRequestLifetimeFeature() + { + _cts = new CancellationTokenSource(); + RequestAborted = _cts.Token; + } + + public void Abort() + { + _cts.Cancel(); + } + } + + [Fact] + public async Task AbortingConnectionAbortsHttpContextAndTriggersConnectionClosedToken() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = MakeRequest("/foo", connection); + var lifetimeFeature = new CustomHttpRequestLifetimeFeature(); + context.Features.Set(lifetimeFeature); + SetTransport(context, connection.TransportType); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + _ = dispatcher.ExecuteAsync(context, options, app); + + connection.Abort(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection.ConnectionClosed.Register(() => tcs.SetResult(null)); + await tcs.Task.OrTimeout(); + + tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + lifetimeFeature.RequestAborted.Register(() => tcs.SetResult(null)); + await tcs.Task.OrTimeout(); + } + } + private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory) { var manager = CreateConnectionManager(loggerFactory); diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 8611e54bae..01d07d2e8a 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -37,6 +37,7 @@ namespace Microsoft.AspNetCore.SignalR private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); private readonly object _receiveMessageTimeoutLock = new object(); private readonly ISystemClock _systemClock; + private readonly CancellationTokenRegistration _closedRegistration; private StreamTracker _streamTracker; private long _lastSendTimeStamp; @@ -66,6 +67,7 @@ namespace Microsoft.AspNetCore.SignalR _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); ConnectionAborted = _connectionAbortedTokenSource.Token; + _closedRegistration = connectionContext.ConnectionClosed.Register((state) => ((HubConnectionContext)state).Abort(), this); HubCallerContext = new DefaultHubCallerContext(this); @@ -624,12 +626,6 @@ namespace Microsoft.AspNetCore.SignalR finally { _ = InnerAbortConnection(connection); - - // Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist - if (connection._streamTracker != null) - { - connection._streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed.")); - } } static async Task InnerAbortConnection(HubConnectionContext connection) @@ -670,6 +666,17 @@ namespace Microsoft.AspNetCore.SignalR } } + internal void Cleanup() + { + _closedRegistration.Dispose(); + + // Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist + if (_streamTracker != null) + { + _streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed.")); + } + } + private static class Log { // Category: HubConnectionContext diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 7aea15a5d9..403a03e8ae 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -139,6 +139,8 @@ namespace Microsoft.AspNetCore.SignalR } finally { + connectionContext.Cleanup(); + Log.ConnectedEnding(_logger); await _lifetimeManager.OnDisconnectedAsync(connectionContext); } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index c9e790c11a..3031ee9ffd 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -221,7 +221,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - public async Task StreamingSum(ChannelReader source) { var total = 0; @@ -322,6 +321,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests tcs.TrySetResult(42); } } + + public async Task BlockingMethod() + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + Context.ConnectionAborted.Register(state => ((TaskCompletionSource)state).SetResult(null), tcs); + + await tcs.Task; + } } public abstract class TestHub : Hub diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 962d32bc7c..ceab4322be 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -948,6 +948,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.True(hasErrorLog); } + [Fact] + public async Task HubMethodListeningToConnectionAbortedClosesOnConnectionContextAbort() + { + using (StartVerifiableLog()) + { + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(MethodHub), loggerFactory: LoggerFactory); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + var invokeTask = client.InvokeAsync(nameof(MethodHub.BlockingMethod)); + + client.Connection.Abort(); + + // If this completes then the server has completed the connection + await connectionHandlerTask.OrTimeout(); + + // Nothing written to connection because it was closed + Assert.False(invokeTask.IsCompleted); + } + } + } + [Fact] public async Task DetailedExceptionEvenWhenNotExplicitlySet() {