From a93e4be82fe4185819c468c1df08edc5ab6977a0 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Mon, 18 Sep 2017 12:47:38 -0700 Subject: [PATCH] Added Cancellation support (#897) * Added Cancellation support - Added ConnectionAbortedToken to the HubConnectionContext. This allows arbitrary code to access a handle that represents the connection lifetime without handling OnDisconnectedAsync on the hub itself. - Expose Abort on HubConnectionContext to allow server side methods to abort the connection. - Use the Abort to stop the main loop when unexpected invocation errors happen. - Use the connection aborted token as unsubscribe from the IObservable and to complete the IAsyncEnumerator for streaming results. --- .../HubConnectionContext.cs | 57 ++++ .../HubEndPoint.cs | 55 ++-- .../Internal/AsyncEnumeratorAdapters.cs | 19 +- .../HubEndpointTests.cs | 287 +++++++++++++++++- 4 files changed, 383 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 5441855336..c95660c199 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -1,8 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; +using System.Runtime.ExceptionServices; using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Features; @@ -15,13 +19,18 @@ namespace Microsoft.AspNetCore.SignalR { public class HubConnectionContext { + private static Action _abortedCallback = AbortConnection; + private readonly WritableChannel _output; private readonly ConnectionContext _connectionContext; + private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); + private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); public HubConnectionContext(WritableChannel output, ConnectionContext connectionContext) { _output = output; _connectionContext = connectionContext; + ConnectionAbortedToken = _connectionAbortedTokenSource.Token; } private IHubFeature HubFeature => Features.Get(); @@ -29,6 +38,10 @@ namespace Microsoft.AspNetCore.SignalR // Used by the HubEndPoint only internal ReadableChannel Input => _connectionContext.Transport; + internal ExceptionDispatchInfo AbortException { get; private set; } + + public virtual CancellationToken ConnectionAbortedToken { get; } + public virtual string ConnectionId => _connectionContext.ConnectionId; public virtual ClaimsPrincipal User => Features.Get()?.User; @@ -40,5 +53,49 @@ namespace Microsoft.AspNetCore.SignalR public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; } public virtual WritableChannel Output => _output; + + public virtual void Abort() + { + // If we already triggered the token then noop, this isn't thread safe but it's good enough + // to avoid spawning a new task in the most common cases + if (_connectionAbortedTokenSource.IsCancellationRequested) + { + return; + } + + // We fire and forget since this can trigger user code to run + Task.Factory.StartNew(_abortedCallback, this); + } + + internal void Abort(Exception exception) + { + AbortException = ExceptionDispatchInfo.Capture(exception); + Abort(); + } + + // Used by the HubEndPoint only + internal Task AbortAsync() + { + Abort(); + return _abortCompletedTcs.Task; + } + + private static void AbortConnection(object state) + { + var connection = (HubConnectionContext)state; + try + { + connection._connectionAbortedTokenSource.Cancel(); + + // Communicate the fact that we're finished triggering abort callbacks + connection._abortCompletedTcs.TrySetResult(null); + } + catch (Exception ex) + { + // TODO: Should we log if the cancellation callback fails? This is more preventative to make sure + // we don't end up with an unobserved task + connection._abortCompletedTcs.TrySetException(ex); + } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 29066f3c99..01e95ea142 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -207,6 +207,19 @@ namespace Microsoft.AspNetCore.SignalR { try { + // We wait on abort to complete, this is so that we can guarantee that all callbacks have fired + // before OnDisconnectedAsync + + try + { + // Ensure the connection is aborted before firing disconnect + await connection.AbortAsync(); + } + catch (Exception ex) + { + _logger.LogTrace(0, ex, "Abort callback failed"); + } + using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); @@ -231,16 +244,12 @@ namespace Microsoft.AspNetCore.SignalR private async Task DispatchMessagesAsync(HubConnectionContext connection) { - // We use these for error handling. Since we dispatch multiple hub invocations - // in parallel, we need a way to communicate failure back to the main processing loop. The - // cancellation token is used to stop reading from the channel, the tcs - // is used to get the exception so we can bubble it up the stack - var cts = new CancellationTokenSource(); - var completion = new TaskCompletionSource(); + // Since we dispatch multiple hub invocations in parallel, we need a way to communicate failure back to the main processing loop. + // This is done by aborting the connection. try { - while (await connection.Input.WaitToReadAsync(cts.Token)) + while (await connection.Input.WaitToReadAsync(connection.ConnectionAbortedToken)) { while (connection.Input.TryRead(out var buffer)) { @@ -258,7 +267,7 @@ namespace Microsoft.AspNetCore.SignalR // Don't wait on the result of execution, continue processing other // incoming messages on this connection. - var ignore = ProcessInvocation(connection, invocationMessage, cts, completion); + _ = ProcessInvocation(connection, invocationMessage); break; // Other kind of message we weren't expecting @@ -273,15 +282,12 @@ namespace Microsoft.AspNetCore.SignalR } catch (OperationCanceledException) { - // Await the task so the exception bubbles up to the caller - await completion.Task; + // If there's an exception, bubble it to the caller + connection.AbortException?.Throw(); } } - private async Task ProcessInvocation(HubConnectionContext connection, - InvocationMessage invocationMessage, - CancellationTokenSource dispatcherCancellation, - TaskCompletionSource dispatcherCompletion) + private async Task ProcessInvocation(HubConnectionContext connection, InvocationMessage invocationMessage) { try { @@ -291,11 +297,8 @@ namespace Microsoft.AspNetCore.SignalR } catch (Exception ex) { - // Set the exception on the task completion source - dispatcherCompletion.TrySetException(ex); - - // Cancel reading operation - dispatcherCancellation.Cancel(); + // Abort the entire connection if the invocation fails in an unexpected way + connection.Abort(ex); } } @@ -370,7 +373,7 @@ namespace Microsoft.AspNetCore.SignalR result = methodExecutor.Execute(hub, invocationMessage.Arguments); } - if (IsStreamed(methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) + if (IsStreamed(connection, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) { _logger.LogTrace("[{connectionId}/{invocationId}] Streaming result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); await StreamResultsAsync(invocationMessage.InvocationId, connection, enumerator); @@ -426,9 +429,8 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection,IAsyncEnumerator enumerator) + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator) { - // TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481 try { while (await enumerator.MoveNextAsync()) @@ -445,7 +447,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private bool IsStreamed(ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) + private bool IsStreamed(HubConnectionContext connection, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) { if (result == null) { @@ -453,17 +455,20 @@ namespace Microsoft.AspNetCore.SignalR return false; } + + // TODO: We need to support cancelling the stream without a client disconnect as well. + var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); if (observableInterface != null) { - enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface); + enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, connection.ConnectionAbortedToken); return true; } else if (IsChannel(resultType, out var payloadType)) { - enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType); + enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, connection.ConnectionAbortedToken); return true; } else diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs index 58411afedc..f387b22329 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs @@ -21,32 +21,33 @@ namespace Microsoft.AspNetCore.SignalR.Internal .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod); - private static readonly object[] _getAsyncEnumeratorArgs = new object[] { CancellationToken.None }; - - public static IAsyncEnumerator FromObservable(object observable, Type observableInterface) + public static IAsyncEnumerator FromObservable(object observable, Type observableInterface, CancellationToken cancellationToken) { // TODO: Cache expressions by observable.GetType()? return (IAsyncEnumerator)_fromObservableMethod .MakeGenericMethod(observableInterface.GetGenericArguments()) - .Invoke(null, new[] { observable }); + .Invoke(null, new[] { observable, cancellationToken }); } - public static IAsyncEnumerator FromObservable(IObservable observable) + public static IAsyncEnumerator FromObservable(IObservable observable, CancellationToken cancellationToken) { // TODO: Allow bounding and optimizations? var channel = Channel.CreateUnbounded(); - var subscription = observable.Subscribe(new ChannelObserver(channel.Out, CancellationToken.None)); + var subscription = observable.Subscribe(new ChannelObserver(channel.Out, cancellationToken)); - return channel.In.GetAsyncEnumerator(); + // Dispose the subscription when the token is cancelled + cancellationToken.Register(state => ((IDisposable)state).Dispose(), subscription); + + return channel.In.GetAsyncEnumerator(cancellationToken); } - public static IAsyncEnumerator FromChannel(object readableChannelOfT, Type payloadType) + public static IAsyncEnumerator FromChannel(object readableChannelOfT, Type payloadType, CancellationToken cancellationToken) { var enumerator = readableChannelOfT .GetType() .GetRuntimeMethod("GetAsyncEnumerator", new[] { typeof(CancellationToken) }) - .Invoke(readableChannelOfT, _getAsyncEnumeratorArgs); + .Invoke(readableChannelOfT, new object[] { cancellationToken }); if (payloadType.IsValueType) { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 59956f2a46..012d49b7e1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; @@ -42,6 +43,167 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ConnectionAbortedTokenTriggers() + { + var state = new ConnectionLifetimeState(); + var serviceProvider = CreateServiceProvider(s => s.AddSingleton(state)); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + + Assert.True(state.TokenCallbackTriggered); + Assert.False(state.TokenStateInConnected); + Assert.True(state.TokenStateInDisconnected); + } + } + + [Fact] + public async Task AbortFromHubMethodForcesClientDisconnect() + { + var serviceProvider = CreateServiceProvider(); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + await client.InvokeAsync(nameof(AbortHub.Kill)); + + await endPointTask.OrTimeout(); + } + } + + [Fact] + public async Task ObservableHubRemovesSubscriptionsWithInfiniteStreams() + { + var observable = new Observable(); + var serviceProvider = CreateServiceProvider(s => s.AddSingleton(observable)); + var endPoint = serviceProvider.GetService>(); + + var waitForSubscribe = new TaskCompletionSource(); + observable.OnSubscribe = o => + { + waitForSubscribe.TrySetResult(null); + }; + + var waitForDispose = new TaskCompletionSource(); + observable.OnDispose = o => + { + waitForDispose.TrySetResult(null); + }; + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + async Task Produce() + { + int i = 0; + while (true) + { + observable.OnNext(i++); + await Task.Delay(100); + } + } + + _ = Produce(); + + Assert.Empty(observable.Observers); + + var subscribeTask = client.StreamAsync(nameof(ObservableHub.Subscribe)); + + await waitForSubscribe.Task.OrTimeout(); + + Assert.Single(observable.Observers); + + client.Dispose(); + + + // We don't care if this throws, we just expect it to complete + try + { + await subscribeTask.OrTimeout(); + } + catch + { + + } + + await waitForDispose.Task.OrTimeout(); + + Assert.Empty(observable.Observers); + + await endPointTask.OrTimeout(); + } + } + + [Fact] + public async Task ObservableHubRemovesSubscriptions() + { + var observable = new Observable(); + var serviceProvider = CreateServiceProvider(s => s.AddSingleton(observable)); + var endPoint = serviceProvider.GetService>(); + + var waitForSubscribe = new TaskCompletionSource(); + observable.OnSubscribe = o => + { + waitForSubscribe.TrySetResult(null); + }; + + var waitForDispose = new TaskCompletionSource(); + observable.OnDispose = o => + { + waitForDispose.TrySetResult(null); + }; + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + async Task Subscribe() + { + var results = await client.StreamAsync(nameof(ObservableHub.Subscribe)); + + var items = results.OfType().ToList(); + + Assert.Single(items); + Assert.Equal(2, (long)items[0].Item); + } + + observable.OnNext(1); + + Assert.Empty(observable.Observers); + + var subscribeTask = Subscribe(); + + await waitForSubscribe.Task.OrTimeout(); + + Assert.Single(observable.Observers); + + observable.OnNext(2); + + observable.Complete(); + + await subscribeTask.OrTimeout(); + + client.Dispose(); + + await waitForDispose.Task.OrTimeout(); + + Assert.Empty(observable.Observers); + + await endPointTask.OrTimeout(); + } + } + [Fact] public async Task MissingNegotiateAndMessageSentFromHubConnectionCanBeDisposedCleanly() { @@ -534,7 +696,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var excludeSecondClientId = new HashSet(); excludeSecondClientId.Add(secondClient.Connection.ConnectionId); - var excludeThirdClientId = new HashSet(); + var excludeThirdClientId = new HashSet(); excludeThirdClientId.Add(thirdClient.Connection.ConnectionId); await firstClient.SendInvocationAsync("SendToAllExcept", "To second", excludeThirdClientId).OrTimeout(); @@ -1007,6 +1169,129 @@ namespace Microsoft.AspNetCore.SignalR.Tests Task Broadcast(string message); } + public class Observable : IObservable + { + public List> Observers = new List>(); + + public Action> OnSubscribe; + + public Action> OnDispose; + + public IDisposable Subscribe(IObserver observer) + { + lock (Observers) + { + Observers.Add(observer); + } + + OnSubscribe?.Invoke(observer); + + return new DisposableAction(() => + { + lock (Observers) + { + Observers.Remove(observer); + } + + OnDispose?.Invoke(observer); + }); + } + + public void OnNext(T value) + { + lock (Observers) + { + foreach (var observer in Observers) + { + observer.OnNext(value); + } + } + } + + public void Complete() + { + lock (Observers) + { + foreach (var observer in Observers) + { + observer.OnCompleted(); + } + } + } + + private class DisposableAction : IDisposable + { + private readonly Action _action; + public DisposableAction(Action action) + { + _action = action; + } + + public void Dispose() + { + _action(); + } + } + } + + public class ObservableHub : Hub + { + private readonly Observable _numbers; + + public ObservableHub(Observable numbers) + { + _numbers = numbers; + } + + public IObservable Subscribe() => _numbers; + } + + public class AbortHub : Hub + { + public void Kill() + { + Context.Connection.Abort(); + } + } + + public class ConnectionLifetimeState + { + public bool TokenCallbackTriggered { get; set; } + + public bool TokenStateInConnected { get; set; } + + public bool TokenStateInDisconnected { get; set; } + } + + public class ConnectionLifetimeHub : Hub + { + private ConnectionLifetimeState _state; + + public ConnectionLifetimeHub(ConnectionLifetimeState state) + { + _state = state; + } + + public override Task OnConnectedAsync() + { + _state.TokenStateInConnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested; + + Context.Connection.ConnectionAbortedToken.Register(() => + { + _state.TokenCallbackTriggered = true; + }); + + return base.OnConnectedAsync(); + } + + public override Task OnDisconnectedAsync(Exception exception) + { + _state.TokenStateInDisconnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested; + + return base.OnDisconnectedAsync(exception); + } + } + public class HubT : Hub { public override Task OnConnectedAsync()