diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index eb2f6376c0..b8a6211863 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -116,8 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal public override IReadOnlyList GetParameterTypes(string methodName) { - HubMethodDescriptor descriptor; - if (!_methods.TryGetValue(methodName, out descriptor)) + if (!_methods.TryGetValue(methodName, out var descriptor)) { return Type.EmptyTypes; } @@ -165,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal return; } - if (!await ValidateInvocationMode(methodExecutor.MethodReturnType, isStreamedInvocation, hubMethodInvocationMessage, connection)) + if (!await ValidateInvocationMode(methodExecutor, isStreamedInvocation, hubMethodInvocationMessage, connection)) { return; } @@ -188,7 +187,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal if (isStreamedInvocation) { - if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator)) + if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, out var enumerator)) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); @@ -341,10 +340,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal return authorizationResult.Succeeded; } - private async Task ValidateInvocationMode(Type resultType, bool isStreamedInvocation, + private async Task ValidateInvocationMode(ObjectMethodExecutor methodExecutor, bool isStreamedInvocation, HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection) { - var isStreamedResult = IsStreamed(resultType); + var isStreamedResult = IsStreamed(methodExecutor); if (isStreamedResult && !isStreamedInvocation) { // Non-null/empty InvocationId? Blocking @@ -370,11 +369,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } - private static bool IsStreamed(Type resultType) + private static bool IsStreamed(ObjectMethodExecutor methodExecutor) { - // TODO: cache reflection for performance, on HubMethodDescriptor maybe? - resultType = UnwrapTask(resultType); + var resultType = (methodExecutor.IsMethodAsync) + ? methodExecutor.AsyncResultType + : methodExecutor.MethodReturnType; + // TODO: cache reflection for performance, on HubMethodDescriptor maybe? var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); @@ -392,13 +393,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal return false; } - private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator enumerator) + private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, out IAsyncEnumerator enumerator) { if (result != null) { - // TODO: cache reflection for performance, on HubMethodDescriptor maybe? - resultType = UnwrapTask(resultType); + var resultType = (methodExecutor.IsMethodAsync) + ? methodExecutor.AsyncResultType + : methodExecutor.MethodReturnType; + // TODO: cache reflection for performance, on HubMethodDescriptor maybe? var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); @@ -426,16 +429,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal } } - private static Type UnwrapTask(Type type) - { - if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>)) - { - return type.GetGenericArguments()[0]; - } - - return type; - } - private static bool IsIObservable(Type iface) { return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs index c7dc5d0539..981497ee29 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs @@ -519,6 +519,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils return CounterObservable(count); } + public async ValueTask> CounterObservableValueTaskAsync(int count) + { + await Task.Yield(); + return CounterObservable(count); + } + public ChannelReader CounterChannel(int count) { var channel = Channel.CreateUnbounded(); @@ -541,6 +547,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils return CounterChannel(count); } + public async ValueTask> CounterChannelValueTaskAsync(int count) + { + await Task.Yield(); + return CounterChannel(count); + } + public ChannelReader BlockingStream() { return Channel.CreateUnbounded().Reader; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index e1b1bca8f5..2ef56dbc0d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -1453,7 +1453,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { get { - foreach (var method in new[] { nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync) }) + foreach (var method in new[] + { + nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync), + nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync), nameof(StreamingHub.CounterObservableValueTaskAsync) + }) { foreach (var protocol in new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() }) {