From 1c44e8febfd57a1038715d6c74f030cebdcd14f2 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Wed, 7 Mar 2018 20:07:06 +1300 Subject: [PATCH] Fix streaming hub methods combined with async (#1544) --- .../Internal/DefaultHubDispatcher.cs | 16 ++++++++++++++++ .../HubEndpointTestUtils/Hubs.cs | 12 ++++++++++++ .../HubEndpointTests.cs | 2 +- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index 2c82b3a75e..eb2f6376c0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -372,6 +372,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal private static bool IsStreamed(Type resultType) { + // TODO: cache reflection for performance, on HubMethodDescriptor maybe? + resultType = UnwrapTask(resultType); + var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); @@ -393,6 +396,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal { if (result != null) { + // TODO: cache reflection for performance, on HubMethodDescriptor maybe? + resultType = UnwrapTask(resultType); + var observableInterface = IsIObservable(resultType) ? resultType : resultType.GetInterfaces().FirstOrDefault(IsIObservable); @@ -420,6 +426,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal } } + private static Type UnwrapTask(Type type) + { + if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(Task<>)) + { + return type.GetGenericArguments()[0]; + } + + return type; + } + private static bool IsIObservable(Type iface) { return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs index 4dd0378bc3..c7dc5d0539 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs @@ -513,6 +513,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils return new CountingObservable(count); } + public async Task> CounterObservableAsync(int count) + { + await Task.Yield(); + return CounterObservable(count); + } + public ChannelReader CounterChannel(int count) { var channel = Channel.CreateUnbounded(); @@ -529,6 +535,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests.HubEndpointTestUtils return channel.Reader; } + public async Task> CounterChannelAsync(int count) + { + await Task.Yield(); + return CounterChannel(count); + } + public ChannelReader BlockingStream() { return Channel.CreateUnbounded().Reader; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 493670b0a9..e1b1bca8f5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -1453,7 +1453,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { get { - foreach (var method in new[] { nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterObservable) }) + foreach (var method in new[] { nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync) }) { foreach (var protocol in new IHubProtocol[] { new JsonHubProtocol(), new MessagePackHubProtocol() }) {