Allow CancellationToken in streaming hub methods (#2818)

This commit is contained in:
BrennanConroy 2018-09-19 15:21:07 -07:00 committed by GitHub
parent 4b378692a4
commit 6ba5e87b45
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 226 additions and 16 deletions

View File

@ -188,11 +188,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
InitializeHub(hub, connection);
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
CancellationTokenSource cts = null;
var arguments = hubMethodInvocationMessage.Arguments;
if (descriptor.HasSyntheticArguments)
{
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object[descriptor.OriginalParameterTypes.Count];
var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer])
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
// This is the only synthetic argument type we currently support
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
}
}
}
}
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
if (isStreamedInvocation)
{
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
@ -204,7 +238,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
disposeScope = false;
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
}
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@ -375,29 +409,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsChannel)
{
streamCts = CreateCancellation();
if (streamCts == null)
{
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
}
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}
streamCts = null;
enumerator = null;
return false;
CancellationTokenSource CreateCancellation()
{
var userCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
}
}
private void DiscoverHubMethods()

View File

@ -22,8 +22,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
Policies = policies.ToArray();
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
? MethodExecutor.AsyncResultType
@ -34,6 +32,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal
IsChannel = true;
StreamReturnType = channelItemType;
}
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
{
// Only streams can take CancellationTokens currently
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
{
HasSyntheticArguments = true;
return false;
}
return true;
}).Select(p => p.ParameterType).ToArray();
if (HasSyntheticArguments)
{
OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
}
Policies = policies.ToArray();
}
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
@ -42,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IReadOnlyList<Type> ParameterTypes { get; }
public IReadOnlyList<Type> OriginalParameterTypes { get; }
public Type NonAsyncReturnType { get; }
public bool IsChannel { get; }
@ -52,6 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IList<IAuthorizeData> Policies { get; }
public bool HasSyntheticArguments { get; private set; }
private static bool IsChannelType(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));

View File

@ -0,0 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public static class CancellationTokenExtensions
{
public static Task WaitForCancellationAsync(this CancellationToken token)
{
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
token.Register((t) =>
{
((TaskCompletionSource<object>)t).SetResult(null);
}, tcs);
return tcs.Task;
}
}
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
@ -165,6 +166,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
}
public void InvalidArgument(CancellationToken token)
{
}
private class SelfRef
{
public SelfRef()
@ -547,6 +552,51 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Channel.CreateUnbounded<string>().Reader;
}
public ChannelReader<int> CancelableStream(CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public int SimpleMethod()
{
return 21;

View File

@ -2381,6 +2381,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Theory]
[InlineData(nameof(LongRunningHub.CancelableStream))]
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Cancel the stream which should trigger the CancellationToken in the hub method
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
var result = await client.ReadAsync().OrTimeout();
var simpleCompletion = Assert.IsType<CompletionMessage>(result);
Assert.Null(simpleCompletion.Result);
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();
// Shut down
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
[Fact]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted()
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Shut down the client which should trigger the CancellationToken in the hub method
client.Dispose();
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();
await connectionHandlerTask.OrTimeout();
}
}
[Fact]
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout();
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{
public int ReleaseCount;