Allow CancellationToken in streaming hub methods (#2818)
This commit is contained in:
parent
4b378692a4
commit
6ba5e87b45
|
|
@ -188,11 +188,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
InitializeHub(hub, connection);
|
||||
|
||||
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
|
||||
CancellationTokenSource cts = null;
|
||||
var arguments = hubMethodInvocationMessage.Arguments;
|
||||
if (descriptor.HasSyntheticArguments)
|
||||
{
|
||||
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
|
||||
arguments = new object[descriptor.OriginalParameterTypes.Count];
|
||||
|
||||
var hubInvocationArgumentPointer = 0;
|
||||
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
|
||||
{
|
||||
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
|
||||
hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer])
|
||||
{
|
||||
// The types match so it isn't a synthetic argument, just copy it into the arguments array
|
||||
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
|
||||
hubInvocationArgumentPointer++;
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is the only synthetic argument type we currently support
|
||||
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
|
||||
{
|
||||
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||
arguments[parameterPointer] = cts.Token;
|
||||
}
|
||||
else
|
||||
{
|
||||
// This should never happen
|
||||
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
|
||||
|
||||
if (isStreamedInvocation)
|
||||
{
|
||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
|
||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
|
||||
{
|
||||
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
|
||||
|
||||
|
|
@ -204,7 +238,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
disposeScope = false;
|
||||
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
||||
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
|
||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
|
||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
|
||||
}
|
||||
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
|
||||
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
||||
|
|
@ -375,29 +409,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return true;
|
||||
}
|
||||
|
||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
|
||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
|
||||
{
|
||||
if (result != null)
|
||||
{
|
||||
if (hubMethodDescriptor.IsChannel)
|
||||
{
|
||||
streamCts = CreateCancellation();
|
||||
if (streamCts == null)
|
||||
{
|
||||
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||
}
|
||||
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
|
||||
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
streamCts = null;
|
||||
enumerator = null;
|
||||
return false;
|
||||
|
||||
CancellationTokenSource CreateCancellation()
|
||||
{
|
||||
var userCts = new CancellationTokenSource();
|
||||
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
|
||||
|
||||
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
|
||||
}
|
||||
}
|
||||
|
||||
private void DiscoverHubMethods()
|
||||
|
|
|
|||
|
|
@ -22,8 +22,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||
{
|
||||
MethodExecutor = methodExecutor;
|
||||
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||
Policies = policies.ToArray();
|
||||
|
||||
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
|
||||
? MethodExecutor.AsyncResultType
|
||||
|
|
@ -34,6 +32,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
IsChannel = true;
|
||||
StreamReturnType = channelItemType;
|
||||
}
|
||||
|
||||
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
|
||||
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
|
||||
{
|
||||
// Only streams can take CancellationTokens currently
|
||||
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
|
||||
{
|
||||
HasSyntheticArguments = true;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}).Select(p => p.ParameterType).ToArray();
|
||||
|
||||
if (HasSyntheticArguments)
|
||||
{
|
||||
OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||
}
|
||||
|
||||
Policies = policies.ToArray();
|
||||
}
|
||||
|
||||
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
|
||||
|
|
@ -42,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public IReadOnlyList<Type> ParameterTypes { get; }
|
||||
|
||||
public IReadOnlyList<Type> OriginalParameterTypes { get; }
|
||||
|
||||
public Type NonAsyncReturnType { get; }
|
||||
|
||||
public bool IsChannel { get; }
|
||||
|
|
@ -52,6 +71,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public IList<IAuthorizeData> Policies { get; }
|
||||
|
||||
public bool HasSyntheticArguments { get; private set; }
|
||||
|
||||
private static bool IsChannelType(Type type, out Type payloadType)
|
||||
{
|
||||
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public static class CancellationTokenExtensions
|
||||
{
|
||||
public static Task WaitForCancellationAsync(this CancellationToken token)
|
||||
{
|
||||
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
token.Register((t) =>
|
||||
{
|
||||
((TaskCompletionSource<object>)t).SetResult(null);
|
||||
}, tcs);
|
||||
return tcs.Task;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
|
|
@ -165,6 +166,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
|
||||
}
|
||||
|
||||
public void InvalidArgument(CancellationToken token)
|
||||
{
|
||||
}
|
||||
|
||||
private class SelfRef
|
||||
{
|
||||
public SelfRef()
|
||||
|
|
@ -547,6 +552,51 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return Channel.CreateUnbounded<string>().Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStream(CancellationToken token)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
Task.Run(async () =>
|
||||
{
|
||||
_tcsService.StartedMethod.SetResult(null);
|
||||
await token.WaitForCancellationAsync();
|
||||
channel.Writer.TryComplete();
|
||||
_tcsService.EndMethod.SetResult(null);
|
||||
});
|
||||
|
||||
return channel.Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
Task.Run(async () =>
|
||||
{
|
||||
_tcsService.StartedMethod.SetResult(null);
|
||||
await token.WaitForCancellationAsync();
|
||||
channel.Writer.TryComplete();
|
||||
_tcsService.EndMethod.SetResult(null);
|
||||
});
|
||||
|
||||
return channel.Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
Task.Run(async () =>
|
||||
{
|
||||
_tcsService.StartedMethod.SetResult(null);
|
||||
await token.WaitForCancellationAsync();
|
||||
channel.Writer.TryComplete();
|
||||
_tcsService.EndMethod.SetResult(null);
|
||||
});
|
||||
|
||||
return channel.Reader;
|
||||
}
|
||||
|
||||
public int SimpleMethod()
|
||||
{
|
||||
return 21;
|
||||
|
|
|
|||
|
|
@ -2381,6 +2381,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStream))]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
|
||||
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
|
||||
{
|
||||
var tcsService = new TcsService();
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
|
||||
{
|
||||
builder.AddSingleton(tcsService);
|
||||
});
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||
|
||||
var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout();
|
||||
// Wait for the stream method to start
|
||||
await tcsService.StartedMethod.Task.OrTimeout();
|
||||
|
||||
// Cancel the stream which should trigger the CancellationToken in the hub method
|
||||
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
|
||||
|
||||
var result = await client.ReadAsync().OrTimeout();
|
||||
|
||||
var simpleCompletion = Assert.IsType<CompletionMessage>(result);
|
||||
Assert.Null(simpleCompletion.Result);
|
||||
|
||||
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
|
||||
await tcsService.EndMethod.Task.OrTimeout();
|
||||
|
||||
// Shut down
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted()
|
||||
{
|
||||
var tcsService = new TcsService();
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
|
||||
{
|
||||
builder.AddSingleton(tcsService);
|
||||
});
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||
|
||||
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
|
||||
// Wait for the stream method to start
|
||||
await tcsService.StartedMethod.Task.OrTimeout();
|
||||
|
||||
// Shut down the client which should trigger the CancellationToken in the hub method
|
||||
client.Dispose();
|
||||
|
||||
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
|
||||
await tcsService.EndMethod.Task.OrTimeout();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
|
||||
{
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||
|
||||
var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout();
|
||||
|
||||
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
|
||||
|
||||
Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
|
||||
{
|
||||
public int ReleaseCount;
|
||||
|
|
|
|||
Loading…
Reference in New Issue