Allow CancellationToken in streaming hub methods (#2818)
This commit is contained in:
parent
4b378692a4
commit
6ba5e87b45
|
|
@ -188,11 +188,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
{
|
{
|
||||||
InitializeHub(hub, connection);
|
InitializeHub(hub, connection);
|
||||||
|
|
||||||
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
|
CancellationTokenSource cts = null;
|
||||||
|
var arguments = hubMethodInvocationMessage.Arguments;
|
||||||
|
if (descriptor.HasSyntheticArguments)
|
||||||
|
{
|
||||||
|
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
|
||||||
|
arguments = new object[descriptor.OriginalParameterTypes.Count];
|
||||||
|
|
||||||
|
var hubInvocationArgumentPointer = 0;
|
||||||
|
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
|
||||||
|
{
|
||||||
|
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
|
||||||
|
hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer])
|
||||||
|
{
|
||||||
|
// The types match so it isn't a synthetic argument, just copy it into the arguments array
|
||||||
|
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
|
||||||
|
hubInvocationArgumentPointer++;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// This is the only synthetic argument type we currently support
|
||||||
|
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
|
||||||
|
{
|
||||||
|
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||||
|
arguments[parameterPointer] = cts.Token;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// This should never happen
|
||||||
|
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
|
||||||
|
|
||||||
if (isStreamedInvocation)
|
if (isStreamedInvocation)
|
||||||
{
|
{
|
||||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
|
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
|
||||||
{
|
{
|
||||||
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
|
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
|
||||||
|
|
||||||
|
|
@ -204,7 +238,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
disposeScope = false;
|
disposeScope = false;
|
||||||
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
||||||
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
|
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
|
||||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
|
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
|
||||||
}
|
}
|
||||||
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
|
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
|
||||||
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
||||||
|
|
@ -375,29 +409,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
|
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
|
||||||
{
|
{
|
||||||
if (result != null)
|
if (result != null)
|
||||||
{
|
{
|
||||||
if (hubMethodDescriptor.IsChannel)
|
if (hubMethodDescriptor.IsChannel)
|
||||||
{
|
{
|
||||||
streamCts = CreateCancellation();
|
if (streamCts == null)
|
||||||
|
{
|
||||||
|
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||||
|
}
|
||||||
|
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
|
||||||
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
|
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
streamCts = null;
|
|
||||||
enumerator = null;
|
enumerator = null;
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
CancellationTokenSource CreateCancellation()
|
|
||||||
{
|
|
||||||
var userCts = new CancellationTokenSource();
|
|
||||||
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
|
|
||||||
|
|
||||||
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void DiscoverHubMethods()
|
private void DiscoverHubMethods()
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||||
{
|
{
|
||||||
MethodExecutor = methodExecutor;
|
MethodExecutor = methodExecutor;
|
||||||
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
|
||||||
Policies = policies.ToArray();
|
|
||||||
|
|
||||||
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
|
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
|
||||||
? MethodExecutor.AsyncResultType
|
? MethodExecutor.AsyncResultType
|
||||||
|
|
@ -34,6 +32,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
IsChannel = true;
|
IsChannel = true;
|
||||||
StreamReturnType = channelItemType;
|
StreamReturnType = channelItemType;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
|
||||||
|
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
|
||||||
|
{
|
||||||
|
// Only streams can take CancellationTokens currently
|
||||||
|
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
|
||||||
|
{
|
||||||
|
HasSyntheticArguments = true;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}).Select(p => p.ParameterType).ToArray();
|
||||||
|
|
||||||
|
if (HasSyntheticArguments)
|
||||||
|
{
|
||||||
|
OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||||
|
}
|
||||||
|
|
||||||
|
Policies = policies.ToArray();
|
||||||
}
|
}
|
||||||
|
|
||||||
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
|
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
|
||||||
|
|
@ -42,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
|
|
||||||
public IReadOnlyList<Type> ParameterTypes { get; }
|
public IReadOnlyList<Type> ParameterTypes { get; }
|
||||||
|
|
||||||
|
public IReadOnlyList<Type> OriginalParameterTypes { get; }
|
||||||
|
|
||||||
public Type NonAsyncReturnType { get; }
|
public Type NonAsyncReturnType { get; }
|
||||||
|
|
||||||
public bool IsChannel { get; }
|
public bool IsChannel { get; }
|
||||||
|
|
@ -52,6 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
|
|
||||||
public IList<IAuthorizeData> Policies { get; }
|
public IList<IAuthorizeData> Policies { get; }
|
||||||
|
|
||||||
|
public bool HasSyntheticArguments { get; private set; }
|
||||||
|
|
||||||
private static bool IsChannelType(Type type, out Type payloadType)
|
private static bool IsChannelType(Type type, out Type payloadType)
|
||||||
{
|
{
|
||||||
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System.Threading;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public static class CancellationTokenExtensions
|
||||||
|
{
|
||||||
|
public static Task WaitForCancellationAsync(this CancellationToken token)
|
||||||
|
{
|
||||||
|
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
|
token.Register((t) =>
|
||||||
|
{
|
||||||
|
((TaskCompletionSource<object>)t).SetResult(null);
|
||||||
|
}, tcs);
|
||||||
|
return tcs.Task;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
|
using System.Threading;
|
||||||
using System.Threading.Channels;
|
using System.Threading.Channels;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
using Microsoft.AspNetCore.Authorization;
|
using Microsoft.AspNetCore.Authorization;
|
||||||
|
|
@ -165,6 +166,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
|
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void InvalidArgument(CancellationToken token)
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
private class SelfRef
|
private class SelfRef
|
||||||
{
|
{
|
||||||
public SelfRef()
|
public SelfRef()
|
||||||
|
|
@ -547,6 +552,51 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
return Channel.CreateUnbounded<string>().Reader;
|
return Channel.CreateUnbounded<string>().Reader;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public ChannelReader<int> CancelableStream(CancellationToken token)
|
||||||
|
{
|
||||||
|
var channel = Channel.CreateBounded<int>(10);
|
||||||
|
|
||||||
|
Task.Run(async () =>
|
||||||
|
{
|
||||||
|
_tcsService.StartedMethod.SetResult(null);
|
||||||
|
await token.WaitForCancellationAsync();
|
||||||
|
channel.Writer.TryComplete();
|
||||||
|
_tcsService.EndMethod.SetResult(null);
|
||||||
|
});
|
||||||
|
|
||||||
|
return channel.Reader;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
|
||||||
|
{
|
||||||
|
var channel = Channel.CreateBounded<int>(10);
|
||||||
|
|
||||||
|
Task.Run(async () =>
|
||||||
|
{
|
||||||
|
_tcsService.StartedMethod.SetResult(null);
|
||||||
|
await token.WaitForCancellationAsync();
|
||||||
|
channel.Writer.TryComplete();
|
||||||
|
_tcsService.EndMethod.SetResult(null);
|
||||||
|
});
|
||||||
|
|
||||||
|
return channel.Reader;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
|
||||||
|
{
|
||||||
|
var channel = Channel.CreateBounded<int>(10);
|
||||||
|
|
||||||
|
Task.Run(async () =>
|
||||||
|
{
|
||||||
|
_tcsService.StartedMethod.SetResult(null);
|
||||||
|
await token.WaitForCancellationAsync();
|
||||||
|
channel.Writer.TryComplete();
|
||||||
|
_tcsService.EndMethod.SetResult(null);
|
||||||
|
});
|
||||||
|
|
||||||
|
return channel.Reader;
|
||||||
|
}
|
||||||
|
|
||||||
public int SimpleMethod()
|
public int SimpleMethod()
|
||||||
{
|
{
|
||||||
return 21;
|
return 21;
|
||||||
|
|
|
||||||
|
|
@ -2381,6 +2381,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[InlineData(nameof(LongRunningHub.CancelableStream))]
|
||||||
|
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
|
||||||
|
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
|
||||||
|
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
|
||||||
|
{
|
||||||
|
builder.AddSingleton(tcsService);
|
||||||
|
});
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||||
|
|
||||||
|
var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout();
|
||||||
|
// Wait for the stream method to start
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
// Cancel the stream which should trigger the CancellationToken in the hub method
|
||||||
|
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
|
||||||
|
|
||||||
|
var result = await client.ReadAsync().OrTimeout();
|
||||||
|
|
||||||
|
var simpleCompletion = Assert.IsType<CompletionMessage>(result);
|
||||||
|
Assert.Null(simpleCompletion.Result);
|
||||||
|
|
||||||
|
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
|
||||||
|
await tcsService.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
// Shut down
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted()
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
|
||||||
|
{
|
||||||
|
builder.AddSingleton(tcsService);
|
||||||
|
});
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||||
|
|
||||||
|
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
|
||||||
|
// Wait for the stream method to start
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
// Shut down the client which should trigger the CancellationToken in the hub method
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
|
||||||
|
await tcsService.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
|
||||||
|
{
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||||
|
|
||||||
|
var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout();
|
||||||
|
|
||||||
|
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
|
||||||
|
|
||||||
|
Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
|
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
|
||||||
{
|
{
|
||||||
public int ReleaseCount;
|
public int ReleaseCount;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue