Added Cancellation support (#897)

* Added Cancellation support
- Added ConnectionAbortedToken to the HubConnectionContext. This allows
arbitrary code to access a handle that represents the connection lifetime
without handling OnDisconnectedAsync on the hub itself.
- Expose Abort on HubConnectionContext to allow server side methods to
abort the connection.
- Use the Abort to stop the main loop when unexpected invocation errors happen.
- Use the connection aborted token as unsubscribe from the IObservable and to complete
the IAsyncEnumerator for streaming results.
This commit is contained in:
David Fowler 2017-09-18 12:47:38 -07:00 committed by GitHub
parent 20b07a0dff
commit a93e4be82f
4 changed files with 383 additions and 35 deletions

View File

@ -1,8 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Features;
@ -15,13 +19,18 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionContext
{
private static Action<object> _abortedCallback = AbortConnection;
private readonly WritableChannel<HubMessage> _output;
private readonly ConnectionContext _connectionContext;
private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource();
private readonly TaskCompletionSource<object> _abortCompletedTcs = new TaskCompletionSource<object>();
public HubConnectionContext(WritableChannel<HubMessage> output, ConnectionContext connectionContext)
{
_output = output;
_connectionContext = connectionContext;
ConnectionAbortedToken = _connectionAbortedTokenSource.Token;
}
private IHubFeature HubFeature => Features.Get<IHubFeature>();
@ -29,6 +38,10 @@ namespace Microsoft.AspNetCore.SignalR
// Used by the HubEndPoint only
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
internal ExceptionDispatchInfo AbortException { get; private set; }
public virtual CancellationToken ConnectionAbortedToken { get; }
public virtual string ConnectionId => _connectionContext.ConnectionId;
public virtual ClaimsPrincipal User => Features.Get<IConnectionUserFeature>()?.User;
@ -40,5 +53,49 @@ namespace Microsoft.AspNetCore.SignalR
public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
public virtual WritableChannel<HubMessage> Output => _output;
public virtual void Abort()
{
// If we already triggered the token then noop, this isn't thread safe but it's good enough
// to avoid spawning a new task in the most common cases
if (_connectionAbortedTokenSource.IsCancellationRequested)
{
return;
}
// We fire and forget since this can trigger user code to run
Task.Factory.StartNew(_abortedCallback, this);
}
internal void Abort(Exception exception)
{
AbortException = ExceptionDispatchInfo.Capture(exception);
Abort();
}
// Used by the HubEndPoint only
internal Task AbortAsync()
{
Abort();
return _abortCompletedTcs.Task;
}
private static void AbortConnection(object state)
{
var connection = (HubConnectionContext)state;
try
{
connection._connectionAbortedTokenSource.Cancel();
// Communicate the fact that we're finished triggering abort callbacks
connection._abortCompletedTcs.TrySetResult(null);
}
catch (Exception ex)
{
// TODO: Should we log if the cancellation callback fails? This is more preventative to make sure
// we don't end up with an unobserved task
connection._abortCompletedTcs.TrySetException(ex);
}
}
}
}

View File

@ -207,6 +207,19 @@ namespace Microsoft.AspNetCore.SignalR
{
try
{
// We wait on abort to complete, this is so that we can guarantee that all callbacks have fired
// before OnDisconnectedAsync
try
{
// Ensure the connection is aborted before firing disconnect
await connection.AbortAsync();
}
catch (Exception ex)
{
_logger.LogTrace(0, ex, "Abort callback failed");
}
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
@ -231,16 +244,12 @@ namespace Microsoft.AspNetCore.SignalR
private async Task DispatchMessagesAsync(HubConnectionContext connection)
{
// We use these for error handling. Since we dispatch multiple hub invocations
// in parallel, we need a way to communicate failure back to the main processing loop. The
// cancellation token is used to stop reading from the channel, the tcs
// is used to get the exception so we can bubble it up the stack
var cts = new CancellationTokenSource();
var completion = new TaskCompletionSource<object>();
// Since we dispatch multiple hub invocations in parallel, we need a way to communicate failure back to the main processing loop.
// This is done by aborting the connection.
try
{
while (await connection.Input.WaitToReadAsync(cts.Token))
while (await connection.Input.WaitToReadAsync(connection.ConnectionAbortedToken))
{
while (connection.Input.TryRead(out var buffer))
{
@ -258,7 +267,7 @@ namespace Microsoft.AspNetCore.SignalR
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
var ignore = ProcessInvocation(connection, invocationMessage, cts, completion);
_ = ProcessInvocation(connection, invocationMessage);
break;
// Other kind of message we weren't expecting
@ -273,15 +282,12 @@ namespace Microsoft.AspNetCore.SignalR
}
catch (OperationCanceledException)
{
// Await the task so the exception bubbles up to the caller
await completion.Task;
// If there's an exception, bubble it to the caller
connection.AbortException?.Throw();
}
}
private async Task ProcessInvocation(HubConnectionContext connection,
InvocationMessage invocationMessage,
CancellationTokenSource dispatcherCancellation,
TaskCompletionSource<object> dispatcherCompletion)
private async Task ProcessInvocation(HubConnectionContext connection, InvocationMessage invocationMessage)
{
try
{
@ -291,11 +297,8 @@ namespace Microsoft.AspNetCore.SignalR
}
catch (Exception ex)
{
// Set the exception on the task completion source
dispatcherCompletion.TrySetException(ex);
// Cancel reading operation
dispatcherCancellation.Cancel();
// Abort the entire connection if the invocation fails in an unexpected way
connection.Abort(ex);
}
}
@ -370,7 +373,7 @@ namespace Microsoft.AspNetCore.SignalR
result = methodExecutor.Execute(hub, invocationMessage.Arguments);
}
if (IsStreamed(methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator))
if (IsStreamed(connection, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator))
{
_logger.LogTrace("[{connectionId}/{invocationId}] Streaming result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await StreamResultsAsync(invocationMessage.InvocationId, connection, enumerator);
@ -426,9 +429,8 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection,IAsyncEnumerator<object> enumerator)
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator)
{
// TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481
try
{
while (await enumerator.MoveNextAsync())
@ -445,7 +447,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private bool IsStreamed(ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator<object> enumerator)
private bool IsStreamed(HubConnectionContext connection, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator<object> enumerator)
{
if (result == null)
{
@ -453,17 +455,20 @@ namespace Microsoft.AspNetCore.SignalR
return false;
}
// TODO: We need to support cancelling the stream without a client disconnect as well.
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface != null)
{
enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface);
enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, connection.ConnectionAbortedToken);
return true;
}
else if (IsChannel(resultType, out var payloadType))
{
enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType);
enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, connection.ConnectionAbortedToken);
return true;
}
else

View File

@ -21,32 +21,33 @@ namespace Microsoft.AspNetCore.SignalR.Internal
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod);
private static readonly object[] _getAsyncEnumeratorArgs = new object[] { CancellationToken.None };
public static IAsyncEnumerator<object> FromObservable(object observable, Type observableInterface)
public static IAsyncEnumerator<object> FromObservable(object observable, Type observableInterface, CancellationToken cancellationToken)
{
// TODO: Cache expressions by observable.GetType()?
return (IAsyncEnumerator<object>)_fromObservableMethod
.MakeGenericMethod(observableInterface.GetGenericArguments())
.Invoke(null, new[] { observable });
.Invoke(null, new[] { observable, cancellationToken });
}
public static IAsyncEnumerator<object> FromObservable<T>(IObservable<T> observable)
public static IAsyncEnumerator<object> FromObservable<T>(IObservable<T> observable, CancellationToken cancellationToken)
{
// TODO: Allow bounding and optimizations?
var channel = Channel.CreateUnbounded<object>();
var subscription = observable.Subscribe(new ChannelObserver<T>(channel.Out, CancellationToken.None));
var subscription = observable.Subscribe(new ChannelObserver<T>(channel.Out, cancellationToken));
return channel.In.GetAsyncEnumerator();
// Dispose the subscription when the token is cancelled
cancellationToken.Register(state => ((IDisposable)state).Dispose(), subscription);
return channel.In.GetAsyncEnumerator(cancellationToken);
}
public static IAsyncEnumerator<object> FromChannel(object readableChannelOfT, Type payloadType)
public static IAsyncEnumerator<object> FromChannel(object readableChannelOfT, Type payloadType, CancellationToken cancellationToken)
{
var enumerator = readableChannelOfT
.GetType()
.GetRuntimeMethod("GetAsyncEnumerator", new[] { typeof(CancellationToken) })
.Invoke(readableChannelOfT, _getAsyncEnumeratorArgs);
.Invoke(readableChannelOfT, new object[] { cancellationToken });
if (payloadType.IsValueType)
{

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
@ -42,6 +43,167 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task ConnectionAbortedTokenTriggers()
{
var state = new ConnectionLifetimeState();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(state));
var endPoint = serviceProvider.GetService<HubEndPoint<ConnectionLifetimeHub>>();
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
// kill the connection
client.Dispose();
await endPointTask.OrTimeout();
Assert.True(state.TokenCallbackTriggered);
Assert.False(state.TokenStateInConnected);
Assert.True(state.TokenStateInDisconnected);
}
}
[Fact]
public async Task AbortFromHubMethodForcesClientDisconnect()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<AbortHub>>();
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
await client.InvokeAsync(nameof(AbortHub.Kill));
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task ObservableHubRemovesSubscriptionsWithInfiniteStreams()
{
var observable = new Observable<int>();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(observable));
var endPoint = serviceProvider.GetService<HubEndPoint<ObservableHub>>();
var waitForSubscribe = new TaskCompletionSource<object>();
observable.OnSubscribe = o =>
{
waitForSubscribe.TrySetResult(null);
};
var waitForDispose = new TaskCompletionSource<object>();
observable.OnDispose = o =>
{
waitForDispose.TrySetResult(null);
};
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
async Task Produce()
{
int i = 0;
while (true)
{
observable.OnNext(i++);
await Task.Delay(100);
}
}
_ = Produce();
Assert.Empty(observable.Observers);
var subscribeTask = client.StreamAsync(nameof(ObservableHub.Subscribe));
await waitForSubscribe.Task.OrTimeout();
Assert.Single(observable.Observers);
client.Dispose();
// We don't care if this throws, we just expect it to complete
try
{
await subscribeTask.OrTimeout();
}
catch
{
}
await waitForDispose.Task.OrTimeout();
Assert.Empty(observable.Observers);
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task ObservableHubRemovesSubscriptions()
{
var observable = new Observable<int>();
var serviceProvider = CreateServiceProvider(s => s.AddSingleton(observable));
var endPoint = serviceProvider.GetService<HubEndPoint<ObservableHub>>();
var waitForSubscribe = new TaskCompletionSource<object>();
observable.OnSubscribe = o =>
{
waitForSubscribe.TrySetResult(null);
};
var waitForDispose = new TaskCompletionSource<object>();
observable.OnDispose = o =>
{
waitForDispose.TrySetResult(null);
};
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
async Task Subscribe()
{
var results = await client.StreamAsync(nameof(ObservableHub.Subscribe));
var items = results.OfType<StreamItemMessage>().ToList();
Assert.Single(items);
Assert.Equal(2, (long)items[0].Item);
}
observable.OnNext(1);
Assert.Empty(observable.Observers);
var subscribeTask = Subscribe();
await waitForSubscribe.Task.OrTimeout();
Assert.Single(observable.Observers);
observable.OnNext(2);
observable.Complete();
await subscribeTask.OrTimeout();
client.Dispose();
await waitForDispose.Task.OrTimeout();
Assert.Empty(observable.Observers);
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task MissingNegotiateAndMessageSentFromHubConnectionCanBeDisposedCleanly()
{
@ -534,7 +696,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var excludeSecondClientId = new HashSet<string>();
excludeSecondClientId.Add(secondClient.Connection.ConnectionId);
var excludeThirdClientId = new HashSet<string>();
var excludeThirdClientId = new HashSet<string>();
excludeThirdClientId.Add(thirdClient.Connection.ConnectionId);
await firstClient.SendInvocationAsync("SendToAllExcept", "To second", excludeThirdClientId).OrTimeout();
@ -1007,6 +1169,129 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Task Broadcast(string message);
}
public class Observable<T> : IObservable<T>
{
public List<IObserver<T>> Observers = new List<IObserver<T>>();
public Action<IObserver<T>> OnSubscribe;
public Action<IObserver<T>> OnDispose;
public IDisposable Subscribe(IObserver<T> observer)
{
lock (Observers)
{
Observers.Add(observer);
}
OnSubscribe?.Invoke(observer);
return new DisposableAction(() =>
{
lock (Observers)
{
Observers.Remove(observer);
}
OnDispose?.Invoke(observer);
});
}
public void OnNext(T value)
{
lock (Observers)
{
foreach (var observer in Observers)
{
observer.OnNext(value);
}
}
}
public void Complete()
{
lock (Observers)
{
foreach (var observer in Observers)
{
observer.OnCompleted();
}
}
}
private class DisposableAction : IDisposable
{
private readonly Action _action;
public DisposableAction(Action action)
{
_action = action;
}
public void Dispose()
{
_action();
}
}
}
public class ObservableHub : Hub
{
private readonly Observable<int> _numbers;
public ObservableHub(Observable<int> numbers)
{
_numbers = numbers;
}
public IObservable<int> Subscribe() => _numbers;
}
public class AbortHub : Hub
{
public void Kill()
{
Context.Connection.Abort();
}
}
public class ConnectionLifetimeState
{
public bool TokenCallbackTriggered { get; set; }
public bool TokenStateInConnected { get; set; }
public bool TokenStateInDisconnected { get; set; }
}
public class ConnectionLifetimeHub : Hub
{
private ConnectionLifetimeState _state;
public ConnectionLifetimeHub(ConnectionLifetimeState state)
{
_state = state;
}
public override Task OnConnectedAsync()
{
_state.TokenStateInConnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested;
Context.Connection.ConnectionAbortedToken.Register(() =>
{
_state.TokenCallbackTriggered = true;
});
return base.OnConnectedAsync();
}
public override Task OnDisconnectedAsync(Exception exception)
{
_state.TokenStateInDisconnected = Context.Connection.ConnectionAbortedToken.IsCancellationRequested;
return base.OnDisconnectedAsync(exception);
}
}
public class HubT : Hub<Test>
{
public override Task OnConnectedAsync()