diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index f351e66970..378f7bc7e0 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -822,7 +822,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await connection.StartAsync().OrTimeout(); var channel = await connection.StreamAsChannelAsync("StreamBroken").OrTimeout(); var ex = await Assert.ThrowsAsync(() => channel.ReadAndCollectAllAsync()).OrTimeout(); - Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message); + Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message); } catch (Exception ex) { diff --git a/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs b/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs index ee5401b7c1..f0d1dc4baa 100644 --- a/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs +++ b/src/SignalR/samples/SignalRSamples/Hubs/Streaming.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Reactive.Linq; using System.Threading.Channels; using System.Threading.Tasks; @@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs { public class Streaming : Hub { + public async IAsyncEnumerable AsyncEnumerableCounter(int count, int delay) + { + for (var i = 0; i < count; i++) + { + yield return i; + await Task.Delay(delay); + } + } + public ChannelReader ObservableCounter(int count, int delay) { var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay)) diff --git a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj index 25201e0f9a..6f0b3379d8 100644 --- a/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj +++ b/src/SignalR/samples/SignalRSamples/SignalRSamples.csproj @@ -1,7 +1,8 @@ - + netcoreapp3.0 + 8.0 diff --git a/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html b/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html index 44becb3cb2..5cc0b3fd67 100644 --- a/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html +++ b/src/SignalR/samples/SignalRSamples/wwwroot/streaming.html @@ -17,6 +17,7 @@
+
@@ -32,7 +33,7 @@ let resultsList = document.getElementById('resultsList'); let channelButton = document.getElementById('channelButton'); let observableButton = document.getElementById('observableButton'); - let clearButton = document.getElementById('clearButton'); + let asyncEnumerableButton = document.getElementById('asyncEnumerableButton'); let connectButton = document.getElementById('connectButton'); let disconnectButton = document.getElementById('disconnectButton'); @@ -61,6 +62,7 @@ connection.onclose(function () { channelButton.disabled = true; observableButton.disabled = true; + asyncEnumerableButton.disabled = true; connectButton.disabled = false; disconnectButton.disabled = true; @@ -71,12 +73,17 @@ .then(function () { channelButton.disabled = false; observableButton.disabled = false; + asyncEnumerableButton.disabled = false; connectButton.disabled = true; disconnectButton.disabled = false; addLine('resultsList', 'connected', 'green'); }); }); + click('asyncEnumerableButton', function () { + run('AsyncEnumerableCounter'); + }) + click('observableButton', function () { run('ObservableCounter'); }); diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs new file mode 100644 index 0000000000..c0bef6ad78 --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/AsyncEnumerableAdapters.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + // True-internal because this is a weird and tricky class to use :) + internal static class AsyncEnumerableAdapters + { + public static IAsyncEnumerable MakeCancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken = default) + { + return new CancelableAsyncEnumerable(asyncEnumerable, cancellationToken); + } + + public static IAsyncEnumerable MakeCancelableAsyncEnumerableFromChannel(ChannelReader channel, CancellationToken cancellationToken = default) + { + return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken); + } + + /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. + private class CancelableAsyncEnumerable : IAsyncEnumerable + { + private readonly IAsyncEnumerable _asyncEnumerable; + private readonly CancellationToken _cancellationToken; + + public CancelableAsyncEnumerable(IAsyncEnumerable asyncEnumerable, CancellationToken cancellationToken) + { + _asyncEnumerable = asyncEnumerable; + _cancellationToken = cancellationToken; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + // Assume that this will be iterated through with await foreach which always passes a default token. + // Instead use the token from the ctor. + Debug.Assert(cancellationToken == default); + + var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken); + return enumeratorOfT as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumeratorOfT); + } + + private class BoxedAsyncEnumerator : IAsyncEnumerator + { + private IAsyncEnumerator _asyncEnumerator; + + public BoxedAsyncEnumerator(IAsyncEnumerator asyncEnumerator) + { + _asyncEnumerator = asyncEnumerator; + } + + public object Current => _asyncEnumerator.Current; + + public ValueTask MoveNextAsync() + { + return _asyncEnumerator.MoveNextAsync(); + } + + public ValueTask DisposeAsync() + { + return _asyncEnumerator.DisposeAsync(); + } + } + } + } +} diff --git a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs b/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs deleted file mode 100644 index 0e50b5c433..0000000000 --- a/src/SignalR/server/Core/src/Internal/AsyncEnumeratorAdapters.cs +++ /dev/null @@ -1,84 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading; -using System.Threading.Channels; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.SignalR.Internal -{ - // True-internal because this is a weird and tricky class to use :) - internal static class AsyncEnumeratorAdapters - { - public static IAsyncEnumerator GetAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) - { - // Nothing to dispose when we finish enumerating in this case. - return new AsyncEnumerator(channel, cancellationToken, disposable: null); - } - - /// Provides an async enumerator for the data in a channel. - internal class AsyncEnumerator : IAsyncEnumerator, IDisposable - { - /// The channel being enumerated. - private readonly ChannelReader _channel; - /// Cancellation token used to cancel the enumeration. - private readonly CancellationToken _cancellationToken; - /// The current element of the enumeration. - private object _current; - - private readonly IDisposable _disposable; - - internal AsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken, IDisposable disposable) - { - _channel = channel; - _cancellationToken = cancellationToken; - _disposable = disposable; - } - - public object Current => _current; - - public Task MoveNextAsync() - { - var result = _channel.ReadAsync(_cancellationToken); - - if (result.IsCompletedSuccessfully) - { - _current = result.Result; - return Task.FromResult(true); - } - - return result.AsTask().ContinueWith((t, s) => - { - var thisRef = (AsyncEnumerator)s; - if (t.IsFaulted && t.Exception.InnerException is ChannelClosedException cce && cce.InnerException == null) - { - return false; - } - thisRef._current = t.GetAwaiter().GetResult(); - return true; - }, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); - } - - public void Dispose() - { - _disposable?.Dispose(); - } - } - } - - /// Represents an enumerator accessed asynchronously. - /// Specifies the type of the data enumerated. - internal interface IAsyncEnumerator - { - /// Asynchronously move the enumerator to the next element. - /// - /// A task that returns true if the enumerator was successfully advanced to the next item, - /// or false if no more data was available in the collection. - /// - Task MoveNextAsync(); - - /// Gets the current element being enumerated. - T Current { get; } - } -} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 01f470b0c7..c08faa6456 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -293,16 +293,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal { var result = await ExecuteHubMethod(methodExecutor, hub, arguments); - if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts)) + if (result == null) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, - $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>."); + $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>."); return; } + cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts); + var enumerable = descriptor.FromReturnedStream(result, cts.Token); + Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage); + _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage); } else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -393,17 +397,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal return scope.DisposeAsync(); } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, IServiceScope scope, + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable enumerable, IServiceScope scope, IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; try { - while (await enumerator.MoveNextAsync()) + await foreach (var streamItem in enumerable) { // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); + await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem)); } } catch (ChannelClosedException ex) @@ -422,8 +426,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal } finally { - (enumerator as IDisposable)?.Dispose(); - await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); // Dispose the linked CTS for the stream. @@ -502,10 +504,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal return authorizationResult.Succeeded; } - private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation, + private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse, HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection) { - if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation) + if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse) { // Non-null/empty InvocationId? Blocking if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -518,7 +520,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal return false; } - if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation) + if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, @@ -530,26 +532,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } - private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, ref CancellationTokenSource streamCts) - { - if (result != null) - { - if (hubMethodDescriptor.IsChannel) - { - if (streamCts == null) - { - streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - } - connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); - enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); - return true; - } - } - - enumerator = null; - return false; - } - private void DiscoverHubMethods() { var hubType = typeof(THub); diff --git a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs index dec2e67aaf..205c1ced72 100644 --- a/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/server/Core/src/Internal/HubMethodDescriptor.cs @@ -15,9 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal { internal class HubMethodDescriptor { - private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters) + private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters) .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod); + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod); + + private static readonly MethodInfo MakeCancelableAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters) + .GetRuntimeMethods() + .Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel)) && m.IsGenericMethod); + + private readonly MethodInfo _makeCancelableEnumerableMethodInfo; + private Func> _makeCancelableEnumerable; public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { @@ -27,17 +34,35 @@ namespace Microsoft.AspNetCore.SignalR.Internal ? MethodExecutor.AsyncResultType : MethodExecutor.MethodReturnType; - if (IsChannelType(NonAsyncReturnType, out var channelItemType)) + foreach (var returnType in NonAsyncReturnType.GetInterfaces().Concat(NonAsyncReturnType.AllBaseTypes())) { - IsChannel = true; - StreamReturnType = channelItemType; + if (!returnType.IsGenericType) + { + continue; + } + + var openReturnType = returnType.GetGenericTypeDefinition(); + + if (openReturnType == typeof(IAsyncEnumerable<>)) + { + StreamReturnType = returnType.GetGenericArguments()[0]; + _makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod; + break; + } + + if (openReturnType == typeof(ChannelReader<>)) + { + StreamReturnType = returnType.GetGenericArguments()[0]; + _makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableFromChannelMethod; + break; + } } // Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers ParameterTypes = methodExecutor.MethodParameters.Where(p => { // Only streams can take CancellationTokens currently - if (IsStreamable && p.ParameterType == typeof(CancellationToken)) + if (IsStreamResponse && p.ParameterType == typeof(CancellationToken)) { HasSyntheticArguments = true; return false; @@ -66,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal public List StreamingParameters { get; private set; } - private Func> _convertToEnumerator; - public ObjectMethodExecutor MethodExecutor { get; } public IReadOnlyList ParameterTypes { get; } @@ -76,9 +99,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal public Type NonAsyncReturnType { get; } - public bool IsChannel { get; } - - public bool IsStreamable => IsChannel; + public bool IsStreamResponse => StreamReturnType != null; public Type StreamReturnType { get; } @@ -86,57 +107,39 @@ namespace Microsoft.AspNetCore.SignalR.Internal public bool HasSyntheticArguments { get; private set; } - private static bool IsChannelType(Type type, out Type payloadType) - { - var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>)); - if (channelType == null) - { - payloadType = null; - return false; - } - - payloadType = channelType.GetGenericArguments()[0]; - return true; - } - - public IAsyncEnumerator FromChannel(object channel, CancellationToken cancellationToken) + public IAsyncEnumerable FromReturnedStream(object stream, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf - if (_convertToEnumerator == null) + if (_makeCancelableEnumerable == null) { - _convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType); + _makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo, StreamReturnType); } - return _convertToEnumerator.Invoke(channel, cancellationToken); + return _makeCancelableEnumerable.Invoke(stream, cancellationToken); } - private static Func> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType) + private static Func> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType) { - // This will call one of two adapter methods to wrap the passed in streamable value - // and cancellation token to an IAsyncEnumerator - // ChannelReader - // AsyncEnumeratorAdapters.GetAsyncEnumerator(channelReader, cancellationToken); + // This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable: + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable(asyncEnumerable, cancellationToken); + // - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel(channelReader, cancellationToken); - var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); - - var methodParameters = genericMethodInfo.GetParameters(); - - // arg1 and arg2 are the parameter names on Func - // we reference these values and then use them to call adaptor method - var targetParameter = Expression.Parameter(typeof(object), "arg1"); - var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2"); - - var parameters = new List + var parameters = new[] { - Expression.Convert(targetParameter, methodParameters[0].ParameterType), - parametersParameter + Expression.Parameter(typeof(object)), + Expression.Parameter(typeof(CancellationToken)), }; - var methodCall = Expression.Call(null, genericMethodInfo, parameters); + var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType); + var methodParameters = genericMethodInfo.GetParameters(); + var methodArguements = new Expression[] + { + Expression.Convert(parameters[0], methodParameters[0].ParameterType), + parameters[1], + }; - var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator)); - - var lambda = Expression.Lambda>>(castMethodCall, targetParameter, parametersParameter); + var methodCall = Expression.Call(null, genericMethodInfo, methodArguements); + var lambda = Expression.Lambda>>(methodCall, parameters); return lambda.Compile(); } } diff --git a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj index d26d4a30a9..8470423c2c 100644 --- a/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj +++ b/src/SignalR/server/Core/src/Microsoft.AspNetCore.SignalR.Core.csproj @@ -5,6 +5,7 @@ netcoreapp3.0 true Microsoft.AspNetCore.SignalR + 8.0 diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index 681c4759fd..0d5e64d1c6 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -591,6 +591,31 @@ namespace Microsoft.AspNetCore.SignalR.Tests return CounterChannel(count); } + public async IAsyncEnumerable CounterAsyncEnumerable(int count) + { + for (int i = 0; i < count; i++) + { + await Task.Yield(); + yield return i.ToString(); + } + } + + public async Task> CounterAsyncEnumerableAsync(int count) + { + await Task.Yield(); + return CounterAsyncEnumerable(count); + } + + public AsyncEnumerableImpl CounterAsyncEnumerableImpl(int count) + { + return new AsyncEnumerableImpl(CounterAsyncEnumerable(count)); + } + + public AsyncEnumerableImplChannelThrows AsyncEnumerableIsPreferedOverChannelReader(int count) + { + return new AsyncEnumerableImplChannelThrows(CounterChannel(count)); + } + public ChannelReader BlockingStream() { return Channel.CreateUnbounded().Reader; @@ -627,6 +652,99 @@ namespace Microsoft.AspNetCore.SignalR.Tests return output.Reader; } + + public class AsyncEnumerableImpl : IAsyncEnumerable + { + private readonly IAsyncEnumerable _inner; + + public AsyncEnumerableImpl(IAsyncEnumerable inner) + { + _inner = inner; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return _inner.GetAsyncEnumerator(cancellationToken); + } + } + + public class AsyncEnumerableImplChannelThrows : ChannelReader, IAsyncEnumerable + { + private ChannelReader _inner; + + public AsyncEnumerableImplChannelThrows(ChannelReader inner) + { + _inner = inner; + } + + public override bool TryRead(out T item) + { + // Not implemented to verify this is consumed as an IAsyncEnumerable instead of a ChannelReader. + throw new NotImplementedException(); + } + + public override ValueTask WaitToReadAsync(CancellationToken cancellationToken = default) + { + // Not implemented to verify this is consumed as an IAsyncEnumerable instead of a ChannelReader. + throw new NotImplementedException(); + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new ChannelAsyncEnumerator(_inner, cancellationToken); + } + + // Copied from AsyncEnumeratorAdapters + private class ChannelAsyncEnumerator : IAsyncEnumerator + { + /// The channel being enumerated. + private readonly ChannelReader _channel; + /// Cancellation token used to cancel the enumeration. + private readonly CancellationToken _cancellationToken; + /// The current element of the enumeration. + private T _current; + + public ChannelAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + { + _channel = channel; + _cancellationToken = cancellationToken; + } + + public T Current => _current; + + public ValueTask MoveNextAsync() + { + var result = _channel.ReadAsync(_cancellationToken); + + if (result.IsCompletedSuccessfully) + { + _current = result.Result; + return new ValueTask(true); + } + + return new ValueTask(MoveNextAsyncAwaited(result)); + } + + private async Task MoveNextAsyncAwaited(ValueTask channelReadTask) + { + try + { + _current = await channelReadTask; + } + catch (ChannelClosedException ex) when (ex.InnerException == null) + { + return false; + } + + return true; + } + + public ValueTask DisposeAsync() + { + return default; + } + } + } } public class SimpleHub : Hub @@ -681,7 +799,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Channel.CreateUnbounded().Reader; } - public ChannelReader CancelableStream(CancellationToken token) + public ChannelReader CancelableStreamSingleParameter(CancellationToken token) { var channel = Channel.CreateBounded(10); @@ -696,7 +814,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests return channel.Reader; } - public ChannelReader CancelableStream2(int ignore, int ignore2, CancellationToken token) + public ChannelReader CancelableStreamMultiParameter(int ignore, int ignore2, CancellationToken token) { var channel = Channel.CreateBounded(10); @@ -711,7 +829,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests return channel.Reader; } - public ChannelReader CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2) + public ChannelReader CancelableStreamMiddleParameter(int ignore, CancellationToken token, int ignore2) { var channel = Channel.CreateBounded(10); @@ -726,16 +844,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests return channel.Reader; } + public async IAsyncEnumerable CancelableStreamGeneratedAsyncEnumerable(CancellationToken token) + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + yield break; + } + + public IAsyncEnumerable CancelableStreamCustomAsyncEnumerable() + { + return new CustomAsyncEnumerable(_tcsService); + } + public int SimpleMethod() { return 21; } + + private class CustomAsyncEnumerable : IAsyncEnumerable + { + private readonly TcsService _tcsService; + + public CustomAsyncEnumerable(TcsService tcsService) + { + _tcsService = tcsService; + } + + public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) + { + return new CustomAsyncEnumerator(_tcsService, cancellationToken); + } + + private class CustomAsyncEnumerator : IAsyncEnumerator + { + private readonly TcsService _tcsService; + private readonly CancellationToken _cancellationToken; + + public CustomAsyncEnumerator(TcsService tcsService, CancellationToken cancellationToken) + { + _tcsService = tcsService; + _cancellationToken = cancellationToken; + } + + public int Current => throw new NotImplementedException(); + + public ValueTask DisposeAsync() + { + return default; + } + + public async ValueTask MoveNextAsync() + { + _tcsService.StartedMethod.SetResult(null); + await _cancellationToken.WaitForCancellationAsync(); + _tcsService.EndMethod.SetResult(null); + return false; + } + } + } } public class TcsService { - public TaskCompletionSource StartedMethod = new TaskCompletionSource(); - public TaskCompletionSource EndMethod = new TaskCompletionSource(); + public TaskCompletionSource StartedMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + public TaskCompletionSource EndMethod = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); } public interface ITypedHubClient diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index b51c47c229..b7a576954f 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -1763,10 +1763,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var protocol = HubProtocolHelpers.GetHubProtocol(protocolName); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); - var connectionHandler = serviceProvider.GetService>(); - var invocationBinder = new Mock(); - invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + var invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); using (var client = new TestClient(protocol: protocol, invocationBinder: invocationBinder.Object)) { @@ -1909,10 +1909,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests { get { - foreach (var method in new[] + var methods = new[] { - nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync) - }) + nameof(StreamingHub.CounterChannel), + nameof(StreamingHub.CounterChannelAsync), + nameof(StreamingHub.CounterChannelValueTaskAsync), + nameof(StreamingHub.CounterAsyncEnumerable), + nameof(StreamingHub.CounterAsyncEnumerableAsync), + nameof(StreamingHub.CounterAsyncEnumerableImpl), + nameof(StreamingHub.AsyncEnumerableIsPreferedOverChannelReader), + }; + + foreach (var method in methods) { foreach (var protocolName in HubProtocolHelpers.AllProtocolNames) { @@ -3150,10 +3158,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Theory] - [InlineData(nameof(LongRunningHub.CancelableStream))] - [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)] - [InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)] - public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args) + [InlineData(nameof(LongRunningHub.CancelableStreamSingleParameter))] + [InlineData(nameof(LongRunningHub.CancelableStreamMultiParameter), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableStreamMiddleParameter), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableStreamGeneratedAsyncEnumerable))] + [InlineData(nameof(LongRunningHub.CancelableStreamCustomAsyncEnumerable))] + public async Task StreamHubMethodCanBeTriggeredOnCancellation(string methodName, params object[] args) { using (StartVerifiableLog()) { @@ -3207,7 +3217,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout(); + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStreamSingleParameter)).OrTimeout(); // Wait for the stream method to start await tcsService.StartedMethod.Task.OrTimeout(); diff --git a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj index 9b35e64990..ec113f4e57 100644 --- a/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj +++ b/src/SignalR/server/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.csproj @@ -1,7 +1,8 @@ - + netcoreapp3.0 + 8.0