Support IAsyncEnumerable returns in SignalR hubs (#6791)

This commit is contained in:
Stephen Halter 2019-02-25 15:08:11 -08:00 committed by GitHub
parent 03460d81ce
commit 46fe595606
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 359 additions and 185 deletions

View File

@ -822,7 +822,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await connection.StartAsync().OrTimeout();
var channel = await connection.StreamAsChannelAsync<int>("StreamBroken").OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAndCollectAllAsync()).OrTimeout();
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message);
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message);
}
catch (Exception ex)
{

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Reactive.Linq;
using System.Threading.Channels;
using System.Threading.Tasks;
@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs
{
public class Streaming : Hub
{
public async IAsyncEnumerable<int> AsyncEnumerableCounter(int count, int delay)
{
for (var i = 0; i < count; i++)
{
yield return i;
await Task.Delay(delay);
}
}
public ChannelReader<int> ObservableCounter(int count, int delay)
{
var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay))

View File

@ -1,7 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk.Web">
<Project Sdk="Microsoft.NET.Sdk.Web">
<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>
<LangVersion>8.0</LangVersion>
</PropertyGroup>
<ItemGroup>

View File

@ -17,6 +17,7 @@
</div>
<div>
<button id="asyncEnumerableButton" name="asyncEnumerable" type="button" disabled>From IAsyncEnumerable</button>
<button id="observableButton" name="observable" type="button" disabled>From Observable</button>
<button id="channelButton" name="channel" type="button" disabled>From Channel</button>
</div>
@ -32,7 +33,7 @@
let resultsList = document.getElementById('resultsList');
let channelButton = document.getElementById('channelButton');
let observableButton = document.getElementById('observableButton');
let clearButton = document.getElementById('clearButton');
let asyncEnumerableButton = document.getElementById('asyncEnumerableButton');
let connectButton = document.getElementById('connectButton');
let disconnectButton = document.getElementById('disconnectButton');
@ -61,6 +62,7 @@
connection.onclose(function () {
channelButton.disabled = true;
observableButton.disabled = true;
asyncEnumerableButton.disabled = true;
connectButton.disabled = false;
disconnectButton.disabled = true;
@ -71,12 +73,17 @@
.then(function () {
channelButton.disabled = false;
observableButton.disabled = false;
asyncEnumerableButton.disabled = false;
connectButton.disabled = true;
disconnectButton.disabled = false;
addLine('resultsList', 'connected', 'green');
});
});
click('asyncEnumerableButton', function () {
run('AsyncEnumerableCounter');
})
click('observableButton', function () {
run('ObservableCounter');
});

View File

@ -0,0 +1,70 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
{
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumerableAdapters
{
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
{
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
}
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
{
return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken);
}
/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object>
{
private readonly IAsyncEnumerable<T> _asyncEnumerable;
private readonly CancellationToken _cancellationToken;
public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
{
_asyncEnumerable = asyncEnumerable;
_cancellationToken = cancellationToken;
}
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
// Assume that this will be iterated through with await foreach which always passes a default token.
// Instead use the token from the ctor.
Debug.Assert(cancellationToken == default);
var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
return enumeratorOfT as IAsyncEnumerator<object> ?? new BoxedAsyncEnumerator(enumeratorOfT);
}
private class BoxedAsyncEnumerator : IAsyncEnumerator<object>
{
private IAsyncEnumerator<T> _asyncEnumerator;
public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
{
_asyncEnumerator = asyncEnumerator;
}
public object Current => _asyncEnumerator.Current;
public ValueTask<bool> MoveNextAsync()
{
return _asyncEnumerator.MoveNextAsync();
}
public ValueTask DisposeAsync()
{
return _asyncEnumerator.DisposeAsync();
}
}
}
}
}

View File

@ -1,84 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
{
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumeratorAdapters
{
public static IAsyncEnumerator<object> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
{
// Nothing to dispose when we finish enumerating in this case.
return new AsyncEnumerator<T>(channel, cancellationToken, disposable: null);
}
/// <summary>Provides an async enumerator for the data in a channel.</summary>
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>, IDisposable
{
/// <summary>The channel being enumerated.</summary>
private readonly ChannelReader<T> _channel;
/// <summary>Cancellation token used to cancel the enumeration.</summary>
private readonly CancellationToken _cancellationToken;
/// <summary>The current element of the enumeration.</summary>
private object _current;
private readonly IDisposable _disposable;
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken, IDisposable disposable)
{
_channel = channel;
_cancellationToken = cancellationToken;
_disposable = disposable;
}
public object Current => _current;
public Task<bool> MoveNextAsync()
{
var result = _channel.ReadAsync(_cancellationToken);
if (result.IsCompletedSuccessfully)
{
_current = result.Result;
return Task.FromResult(true);
}
return result.AsTask().ContinueWith((t, s) =>
{
var thisRef = (AsyncEnumerator<T>)s;
if (t.IsFaulted && t.Exception.InnerException is ChannelClosedException cce && cce.InnerException == null)
{
return false;
}
thisRef._current = t.GetAwaiter().GetResult();
return true;
}, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default);
}
public void Dispose()
{
_disposable?.Dispose();
}
}
}
/// <summary>Represents an enumerator accessed asynchronously.</summary>
/// <typeparam name="T">Specifies the type of the data enumerated.</typeparam>
internal interface IAsyncEnumerator<out T>
{
/// <summary>Asynchronously move the enumerator to the next element.</summary>
/// <returns>
/// A task that returns true if the enumerator was successfully advanced to the next item,
/// or false if no more data was available in the collection.
/// </returns>
Task<bool> MoveNextAsync();
/// <summary>Gets the current element being enumerated.</summary>
T Current { get; }
}
}

View File

@ -293,16 +293,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
if (result == null)
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
return;
}
cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
var enumerable = descriptor.FromReturnedStream(result, cts.Token);
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
}
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@ -393,17 +397,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return scope.DisposeAsync();
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
{
string error = null;
try
{
while (await enumerator.MoveNextAsync())
await foreach (var streamItem in enumerable)
{
// Send the stream item
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
}
}
catch (ChannelClosedException ex)
@ -422,8 +426,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
finally
{
(enumerator as IDisposable)?.Dispose();
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
// Dispose the linked CTS for the stream.
@ -502,10 +504,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return authorizationResult.Succeeded;
}
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@ -518,7 +520,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return false;
}
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
@ -530,26 +532,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsChannel)
{
if (streamCts == null)
{
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
}
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}
enumerator = null;
return false;
}
private void DiscoverHubMethods()
{
var hubType = typeof(THub);

View File

@ -15,9 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class HubMethodDescriptor
{
private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod);
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod);
private static readonly MethodInfo MakeCancelableAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel)) && m.IsGenericMethod);
private readonly MethodInfo _makeCancelableEnumerableMethodInfo;
private Func<object, CancellationToken, IAsyncEnumerable<object>> _makeCancelableEnumerable;
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
@ -27,17 +34,35 @@ namespace Microsoft.AspNetCore.SignalR.Internal
? MethodExecutor.AsyncResultType
: MethodExecutor.MethodReturnType;
if (IsChannelType(NonAsyncReturnType, out var channelItemType))
foreach (var returnType in NonAsyncReturnType.GetInterfaces().Concat(NonAsyncReturnType.AllBaseTypes()))
{
IsChannel = true;
StreamReturnType = channelItemType;
if (!returnType.IsGenericType)
{
continue;
}
var openReturnType = returnType.GetGenericTypeDefinition();
if (openReturnType == typeof(IAsyncEnumerable<>))
{
StreamReturnType = returnType.GetGenericArguments()[0];
_makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod;
break;
}
if (openReturnType == typeof(ChannelReader<>))
{
StreamReturnType = returnType.GetGenericArguments()[0];
_makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableFromChannelMethod;
break;
}
}
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
{
// Only streams can take CancellationTokens currently
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
{
HasSyntheticArguments = true;
return false;
@ -66,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public List<Type> StreamingParameters { get; private set; }
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
public ObjectMethodExecutor MethodExecutor { get; }
public IReadOnlyList<Type> ParameterTypes { get; }
@ -76,9 +99,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public Type NonAsyncReturnType { get; }
public bool IsChannel { get; }
public bool IsStreamable => IsChannel;
public bool IsStreamResponse => StreamReturnType != null;
public Type StreamReturnType { get; }
@ -86,57 +107,39 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public bool HasSyntheticArguments { get; private set; }
private static bool IsChannelType(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
if (channelType == null)
{
payloadType = null;
return false;
}
payloadType = channelType.GetGenericArguments()[0];
return true;
}
public IAsyncEnumerator<object> FromChannel(object channel, CancellationToken cancellationToken)
public IAsyncEnumerable<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
if (_convertToEnumerator == null)
if (_makeCancelableEnumerable == null)
{
_convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType);
_makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo, StreamReturnType);
}
return _convertToEnumerator.Invoke(channel, cancellationToken);
return _makeCancelableEnumerable.Invoke(stream, cancellationToken);
}
private static Func<object, CancellationToken, IAsyncEnumerator<object>> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType)
private static Func<object, CancellationToken, IAsyncEnumerable<object>> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType)
{
// This will call one of two adapter methods to wrap the passed in streamable value
// and cancellation token to an IAsyncEnumerator<object>
// ChannelReader<T>
// AsyncEnumeratorAdapters.GetAsyncEnumerator<T>(channelReader, cancellationToken);
// This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable<object>:
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel<T>(channelReader, cancellationToken);
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
var methodParameters = genericMethodInfo.GetParameters();
// arg1 and arg2 are the parameter names on Func<T1, T2, TReturn>
// we reference these values and then use them to call adaptor method
var targetParameter = Expression.Parameter(typeof(object), "arg1");
var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2");
var parameters = new List<Expression>
var parameters = new[]
{
Expression.Convert(targetParameter, methodParameters[0].ParameterType),
parametersParameter
Expression.Parameter(typeof(object)),
Expression.Parameter(typeof(CancellationToken)),
};
var methodCall = Expression.Call(null, genericMethodInfo, parameters);
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
var methodParameters = genericMethodInfo.GetParameters();
var methodArguements = new Expression[]
{
Expression.Convert(parameters[0], methodParameters[0].ParameterType),
parameters[1],
};
var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator<object>));
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerator<object>>>(castMethodCall, targetParameter, parametersParameter);
var methodCall = Expression.Call(null, genericMethodInfo, methodArguements);
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerable<object>>>(methodCall, parameters);
return lambda.Compile();
}
}

View File

@ -5,6 +5,7 @@
<TargetFramework>netcoreapp3.0</TargetFramework>
<IsAspNetCoreApp>true</IsAspNetCoreApp>
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
<LangVersion>8.0</LangVersion>
</PropertyGroup>
<ItemGroup>

View File

@ -591,6 +591,31 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return CounterChannel(count);
}
public async IAsyncEnumerable<string> CounterAsyncEnumerable(int count)
{
for (int i = 0; i < count; i++)
{
await Task.Yield();
yield return i.ToString();
}
}
public async Task<IAsyncEnumerable<string>> CounterAsyncEnumerableAsync(int count)
{
await Task.Yield();
return CounterAsyncEnumerable(count);
}
public AsyncEnumerableImpl<string> CounterAsyncEnumerableImpl(int count)
{
return new AsyncEnumerableImpl<string>(CounterAsyncEnumerable(count));
}
public AsyncEnumerableImplChannelThrows<string> AsyncEnumerableIsPreferedOverChannelReader(int count)
{
return new AsyncEnumerableImplChannelThrows<string>(CounterChannel(count));
}
public ChannelReader<string> BlockingStream()
{
return Channel.CreateUnbounded<string>().Reader;
@ -627,6 +652,99 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return output.Reader;
}
public class AsyncEnumerableImpl<T> : IAsyncEnumerable<T>
{
private readonly IAsyncEnumerable<T> _inner;
public AsyncEnumerableImpl(IAsyncEnumerable<T> inner)
{
_inner = inner;
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return _inner.GetAsyncEnumerator(cancellationToken);
}
}
public class AsyncEnumerableImplChannelThrows<T> : ChannelReader<T>, IAsyncEnumerable<T>
{
private ChannelReader<T> _inner;
public AsyncEnumerableImplChannelThrows(ChannelReader<T> inner)
{
_inner = inner;
}
public override bool TryRead(out T item)
{
// Not implemented to verify this is consumed as an IAsyncEnumerable<T> instead of a ChannelReader<T>.
throw new NotImplementedException();
}
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
{
// Not implemented to verify this is consumed as an IAsyncEnumerable<T> instead of a ChannelReader<T>.
throw new NotImplementedException();
}
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new ChannelAsyncEnumerator(_inner, cancellationToken);
}
// Copied from AsyncEnumeratorAdapters
private class ChannelAsyncEnumerator : IAsyncEnumerator<T>
{
/// <summary>The channel being enumerated.</summary>
private readonly ChannelReader<T> _channel;
/// <summary>Cancellation token used to cancel the enumeration.</summary>
private readonly CancellationToken _cancellationToken;
/// <summary>The current element of the enumeration.</summary>
private T _current;
public ChannelAsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
{
_channel = channel;
_cancellationToken = cancellationToken;
}
public T Current => _current;
public ValueTask<bool> MoveNextAsync()
{
var result = _channel.ReadAsync(_cancellationToken);
if (result.IsCompletedSuccessfully)
{
_current = result.Result;
return new ValueTask<bool>(true);
}
return new ValueTask<bool>(MoveNextAsyncAwaited(result));
}
private async Task<bool> MoveNextAsyncAwaited(ValueTask<T> channelReadTask)
{
try
{
_current = await channelReadTask;
}
catch (ChannelClosedException ex) when (ex.InnerException == null)
{
return false;
}
return true;
}
public ValueTask DisposeAsync()
{
return default;
}
}
}
}
public class SimpleHub : Hub
@ -681,7 +799,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Channel.CreateUnbounded<string>().Reader;
}
public ChannelReader<int> CancelableStream(CancellationToken token)
public ChannelReader<int> CancelableStreamSingleParameter(CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
@ -696,7 +814,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return channel.Reader;
}
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
public ChannelReader<int> CancelableStreamMultiParameter(int ignore, int ignore2, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
@ -711,7 +829,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return channel.Reader;
}
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
public ChannelReader<int> CancelableStreamMiddleParameter(int ignore, CancellationToken token, int ignore2)
{
var channel = Channel.CreateBounded<int>(10);
@ -726,16 +844,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return channel.Reader;
}
public async IAsyncEnumerable<int> CancelableStreamGeneratedAsyncEnumerable(CancellationToken token)
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
_tcsService.EndMethod.SetResult(null);
yield break;
}
public IAsyncEnumerable<int> CancelableStreamCustomAsyncEnumerable()
{
return new CustomAsyncEnumerable(_tcsService);
}
public int SimpleMethod()
{
return 21;
}
private class CustomAsyncEnumerable : IAsyncEnumerable<int>
{
private readonly TcsService _tcsService;
public CustomAsyncEnumerable(TcsService tcsService)
{
_tcsService = tcsService;
}
public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
return new CustomAsyncEnumerator(_tcsService, cancellationToken);
}
private class CustomAsyncEnumerator : IAsyncEnumerator<int>
{
private readonly TcsService _tcsService;
private readonly CancellationToken _cancellationToken;
public CustomAsyncEnumerator(TcsService tcsService, CancellationToken cancellationToken)
{
_tcsService = tcsService;
_cancellationToken = cancellationToken;
}
public int Current => throw new NotImplementedException();
public ValueTask DisposeAsync()
{
return default;
}
public async ValueTask<bool> MoveNextAsync()
{
_tcsService.StartedMethod.SetResult(null);
await _cancellationToken.WaitForCancellationAsync();
_tcsService.EndMethod.SetResult(null);
return false;
}
}
}
}
public class TcsService
{
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>();
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>();
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
}
public interface ITypedHubClient

View File

@ -1763,10 +1763,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
var invocationBinder = new Mock<IInvocationBinder>();
invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny<string>())).Returns(typeof(string));
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
var invocationBinder = new Mock<IInvocationBinder>();
invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny<string>())).Returns(typeof(string));
using (var client = new TestClient(protocol: protocol, invocationBinder: invocationBinder.Object))
{
@ -1909,10 +1909,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
get
{
foreach (var method in new[]
var methods = new[]
{
nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync)
})
nameof(StreamingHub.CounterChannel),
nameof(StreamingHub.CounterChannelAsync),
nameof(StreamingHub.CounterChannelValueTaskAsync),
nameof(StreamingHub.CounterAsyncEnumerable),
nameof(StreamingHub.CounterAsyncEnumerableAsync),
nameof(StreamingHub.CounterAsyncEnumerableImpl),
nameof(StreamingHub.AsyncEnumerableIsPreferedOverChannelReader),
};
foreach (var method in methods)
{
foreach (var protocolName in HubProtocolHelpers.AllProtocolNames)
{
@ -3150,10 +3158,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
[Theory]
[InlineData(nameof(LongRunningHub.CancelableStream))]
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
[InlineData(nameof(LongRunningHub.CancelableStreamSingleParameter))]
[InlineData(nameof(LongRunningHub.CancelableStreamMultiParameter), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamMiddleParameter), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamGeneratedAsyncEnumerable))]
[InlineData(nameof(LongRunningHub.CancelableStreamCustomAsyncEnumerable))]
public async Task StreamHubMethodCanBeTriggeredOnCancellation(string methodName, params object[] args)
{
using (StartVerifiableLog())
{
@ -3207,7 +3217,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStreamSingleParameter)).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();

View File

@ -1,7 +1,8 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>
<LangVersion>8.0</LangVersion>
</PropertyGroup>