Fixed streaming hub method with ValueTask (#1572)

This commit is contained in:
James Newton-King 2018-03-10 12:31:38 +13:00 committed by GitHub
parent d941a4be09
commit d6178f2482
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 23 deletions

View File

@ -116,8 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public override IReadOnlyList<Type> GetParameterTypes(string methodName)
{
HubMethodDescriptor descriptor;
if (!_methods.TryGetValue(methodName, out descriptor))
if (!_methods.TryGetValue(methodName, out var descriptor))
{
return Type.EmptyTypes;
}
@ -165,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return;
}
if (!await ValidateInvocationMode(methodExecutor.MethodReturnType, isStreamedInvocation, hubMethodInvocationMessage, connection))
if (!await ValidateInvocationMode(methodExecutor, isStreamedInvocation, hubMethodInvocationMessage, connection))
{
return;
}
@ -188,7 +187,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (isStreamedInvocation)
{
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator))
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, out var enumerator))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
@ -341,10 +340,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return authorizationResult.Succeeded;
}
private async Task<bool> ValidateInvocationMode(Type resultType, bool isStreamedInvocation,
private async Task<bool> ValidateInvocationMode(ObjectMethodExecutor methodExecutor, bool isStreamedInvocation,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
var isStreamedResult = IsStreamed(resultType);
var isStreamedResult = IsStreamed(methodExecutor);
if (isStreamedResult && !isStreamedInvocation)
{
// Non-null/empty InvocationId? Blocking
@ -370,11 +369,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private static bool IsStreamed(Type resultType)
private static bool IsStreamed(ObjectMethodExecutor methodExecutor)
{
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
resultType = UnwrapTask(resultType);
var resultType = (methodExecutor.IsMethodAsync)
? methodExecutor.AsyncResultType
: methodExecutor.MethodReturnType;
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
@ -392,13 +393,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return false;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator<object> enumerator)
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, out IAsyncEnumerator<object> enumerator)
{
if (result != null)
{
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
resultType = UnwrapTask(resultType);
var resultType = (methodExecutor.IsMethodAsync)
? methodExecutor.AsyncResultType
: methodExecutor.MethodReturnType;
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
@ -426,16 +429,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private static Type UnwrapTask(Type type)
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>))
{
return type.GetGenericArguments()[0];
}
return type;
}
private static bool IsIObservable(Type iface)
{
return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>);

View File

@ -519,6 +519,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils
return CounterObservable(count);
}
public async ValueTask<IObservable<string>> CounterObservableValueTaskAsync(int count)
{
await Task.Yield();
return CounterObservable(count);
}
public ChannelReader<string> CounterChannel(int count)
{
var channel = Channel.CreateUnbounded<string>();
@ -541,6 +547,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils
return CounterChannel(count);
}
public async ValueTask<ChannelReader<string>> CounterChannelValueTaskAsync(int count)
{
await Task.Yield();
return CounterChannel(count);
}
public ChannelReader<string> BlockingStream()
{
return Channel.CreateUnbounded<string>().Reader;

View File

@ -1453,7 +1453,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
get
{
foreach (var method in new[] { nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync) })
foreach (var method in new[]
{
nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync),
nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync), nameof(StreamingHub.CounterObservableValueTaskAsync)
})
{
foreach (var protocol in new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() })
{