diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 8ed6e18184..9a013291a2 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -357,6 +357,52 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [InlineData("json")] + [InlineData("messagepack")] + public async Task CanStreamToHubWithIAsyncEnumerableMethodArg(string protocolName) + { + var protocol = HubProtocols[protocolName]; + using (StartServer(out var server)) + { + var connection = CreateHubConnection(server.Url, "/default", HttpTransportType.WebSockets, protocol, LoggerFactory); + try + { + async IAsyncEnumerable ClientStreamData(int value) + { + for (var i = 0; i < value; i++) + { + yield return i; + await Task.Delay(10); + } + } + + var streamTo = 5; + var stream = ClientStreamData(streamTo); + + await connection.StartAsync().OrTimeout(); + var expectedValue = 0; + var asyncEnumerable = connection.StreamAsync("StreamIAsyncConsumer", stream); + await foreach (var streamValue in asyncEnumerable) + { + Assert.Equal(expectedValue, streamValue); + expectedValue++; + } + + Assert.Equal(streamTo, expectedValue); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [LogLevel(LogLevel.Trace)] diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs index 14f102da8f..6843a4cb64 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/Hubs.cs @@ -45,6 +45,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public ChannelReader StreamEchoInt(ChannelReader source) => TestHubMethodsImpl.StreamEchoInt(source); + public IAsyncEnumerable StreamIAsyncConsumer(IAsyncEnumerable source) => source; + public string GetUserIdentifier() { return Context.UserIdentifier; @@ -125,6 +127,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public ChannelReader StreamEcho(ChannelReader source) => TestHubMethodsImpl.StreamEcho(source); public ChannelReader StreamEchoInt(ChannelReader source) => TestHubMethodsImpl.StreamEchoInt(source); + + public IAsyncEnumerable StreamIAsyncConsumer(IAsyncEnumerable source) => source; } public class TestHubT : Hub @@ -157,6 +161,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public ChannelReader StreamEcho(ChannelReader source) => TestHubMethodsImpl.StreamEcho(source); public ChannelReader StreamEchoInt(ChannelReader source) => TestHubMethodsImpl.StreamEchoInt(source); + + public IAsyncEnumerable StreamIAsyncConsumer(IAsyncEnumerable source) => source; } internal static class TestHubMethodsImpl diff --git a/src/SignalR/common/Shared/ReflectionHelper.cs b/src/SignalR/common/Shared/ReflectionHelper.cs index 890b9e9adc..9a7b3371a6 100644 --- a/src/SignalR/common/Shared/ReflectionHelper.cs +++ b/src/SignalR/common/Shared/ReflectionHelper.cs @@ -38,6 +38,11 @@ namespace Microsoft.AspNetCore.SignalR #if NETCOREAPP3_0 public static bool IsIAsyncEnumerable(Type type) { + if (type.IsGenericType) + { + return type.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>); + } + return type.GetInterfaces().Any(t => { if (t.IsGenericType) diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index c08faa6456..ee39dcad9d 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -277,7 +277,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal { Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]); var itemType = descriptor.StreamingParameters[streamPointer]; - arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], itemType); + arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], + itemType, descriptor.OriginalParameterTypes[parameterPointer]); + streamPointer++; } else diff --git a/src/SignalR/server/Core/src/Properties/AssemblyInfo.cs b/src/SignalR/server/Core/src/Properties/AssemblyInfo.cs index ce43e9c728..4d870535a2 100644 --- a/src/SignalR/server/Core/src/Properties/AssemblyInfo.cs +++ b/src/SignalR/server/Core/src/Properties/AssemblyInfo.cs @@ -5,4 +5,4 @@ using System.Runtime.CompilerServices; [assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests.Utils, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] [assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Microbenchmarks, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] -[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] \ No newline at end of file +[assembly: InternalsVisibleTo("Microsoft.AspNetCore.SignalR.Tests, PublicKey=0024000004800000940000000602000000240000525341310004000001000100f33a29044fa9d740c9b3213a93e57c84b472c84e0b8a0e1ae48e67a9f8f6de9d5f7f3d52ac23e48ac51801f1dc950abe901da34d2a9e3baadb141a17c77ef3c565dd5ee5054b91cf63bb3c6ab83f72ab3aafe93d0fc3c2348b764fafb0b1c0733de51459aeab46580384bf9d74c4e28164b7cde247f891ba07891c9d872ad2bb")] diff --git a/src/SignalR/server/Core/src/StreamTracker.cs b/src/SignalR/server/Core/src/StreamTracker.cs index 1445771046..1c5efc2871 100644 --- a/src/SignalR/server/Core/src/StreamTracker.cs +++ b/src/SignalR/server/Core/src/StreamTracker.cs @@ -21,11 +21,11 @@ namespace Microsoft.AspNetCore.SignalR /// /// Creates a new stream and returns the ChannelReader for it as an object. /// - public object AddStream(string streamId, Type itemType) + public object AddStream(string streamId, Type itemType, Type targetType) { var newConverter = (IStreamConverter)_buildConverterMethod.MakeGenericMethod(itemType).Invoke(null, Array.Empty()); _lookup[streamId] = newConverter; - return newConverter.GetReaderAsObject(); + return newConverter.GetReaderAsObject(targetType); } private bool TryGetConverter(string streamId, out IStreamConverter converter) @@ -79,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR private interface IStreamConverter { Type GetItemType(); - object GetReaderAsObject(); + object GetReaderAsObject(Type type); Task WriteToStream(object item); void TryComplete(Exception ex); } @@ -100,9 +100,16 @@ namespace Microsoft.AspNetCore.SignalR return typeof(T); } - public object GetReaderAsObject() + public object GetReaderAsObject(Type type) { - return _channel.Reader; + if (ReflectionHelper.IsIAsyncEnumerable(type)) + { + return _channel.Reader.ReadAllAsync(); + } + else + { + return _channel.Reader; + } } public Task WriteToStream(object o) diff --git a/src/SignalR/server/SignalR/test/Internal/ReflectionHelperTests.cs b/src/SignalR/server/SignalR/test/Internal/ReflectionHelperTests.cs new file mode 100644 index 0000000000..3602555539 --- /dev/null +++ b/src/SignalR/server/SignalR/test/Internal/ReflectionHelperTests.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests.Internal +{ + public class ReflectionHelperTests + { + [Theory] + [MemberData(nameof(TypesToCheck))] + public void IsIAsyncEnumerableTests(Type type, bool expectedOutcome) + { + Assert.Equal(expectedOutcome, ReflectionHelper.IsIAsyncEnumerable(type)); + } + + public static IEnumerable TypesToCheck() + { + yield return new object[] + { + typeof(IAsyncEnumerable), + true + }; + + yield return new object[] + { + typeof(ChannelReader), + false + }; + + async IAsyncEnumerable Stream() + { + await Task.Delay(10); + yield return 1; + } + + object streamAsObject = Stream(); + yield return new object[] + { + streamAsObject.GetType(), + true + }; + + yield return new object[] + { + typeof(string), + false + }; + + yield return new object[] + { + typeof(CustomAsyncEnumerable), + true + }; + } + + private class CustomAsyncEnumerable : IAsyncEnumerable + { + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + throw new NotImplementedException(); + } + } + } +}