Support IAsyncEnumerable returns in SignalR hubs (#6791)
This commit is contained in:
parent
03460d81ce
commit
46fe595606
|
|
@ -822,7 +822,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
await connection.StartAsync().OrTimeout();
|
||||
var channel = await connection.StreamAsChannelAsync<int>("StreamBroken").OrTimeout();
|
||||
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAndCollectAllAsync()).OrTimeout();
|
||||
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message);
|
||||
Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<> or IAsyncEnumerable<>.", ex.Message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Reactive.Linq;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
|
|
@ -11,6 +12,15 @@ namespace SignalRSamples.Hubs
|
|||
{
|
||||
public class Streaming : Hub
|
||||
{
|
||||
public async IAsyncEnumerable<int> AsyncEnumerableCounter(int count, int delay)
|
||||
{
|
||||
for (var i = 0; i < count; i++)
|
||||
{
|
||||
yield return i;
|
||||
await Task.Delay(delay);
|
||||
}
|
||||
}
|
||||
|
||||
public ChannelReader<int> ObservableCounter(int count, int delay)
|
||||
{
|
||||
var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay))
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk.Web">
|
||||
<Project Sdk="Microsoft.NET.Sdk.Web">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netcoreapp3.0</TargetFramework>
|
||||
<LangVersion>8.0</LangVersion>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
</div>
|
||||
|
||||
<div>
|
||||
<button id="asyncEnumerableButton" name="asyncEnumerable" type="button" disabled>From IAsyncEnumerable</button>
|
||||
<button id="observableButton" name="observable" type="button" disabled>From Observable</button>
|
||||
<button id="channelButton" name="channel" type="button" disabled>From Channel</button>
|
||||
</div>
|
||||
|
|
@ -32,7 +33,7 @@
|
|||
let resultsList = document.getElementById('resultsList');
|
||||
let channelButton = document.getElementById('channelButton');
|
||||
let observableButton = document.getElementById('observableButton');
|
||||
let clearButton = document.getElementById('clearButton');
|
||||
let asyncEnumerableButton = document.getElementById('asyncEnumerableButton');
|
||||
|
||||
let connectButton = document.getElementById('connectButton');
|
||||
let disconnectButton = document.getElementById('disconnectButton');
|
||||
|
|
@ -61,6 +62,7 @@
|
|||
connection.onclose(function () {
|
||||
channelButton.disabled = true;
|
||||
observableButton.disabled = true;
|
||||
asyncEnumerableButton.disabled = true;
|
||||
connectButton.disabled = false;
|
||||
disconnectButton.disabled = true;
|
||||
|
||||
|
|
@ -71,12 +73,17 @@
|
|||
.then(function () {
|
||||
channelButton.disabled = false;
|
||||
observableButton.disabled = false;
|
||||
asyncEnumerableButton.disabled = false;
|
||||
connectButton.disabled = true;
|
||||
disconnectButton.disabled = false;
|
||||
addLine('resultsList', 'connected', 'green');
|
||||
});
|
||||
});
|
||||
|
||||
click('asyncEnumerableButton', function () {
|
||||
run('AsyncEnumerableCounter');
|
||||
})
|
||||
|
||||
click('observableButton', function () {
|
||||
run('ObservableCounter');
|
||||
});
|
||||
|
|
|
|||
|
|
@ -0,0 +1,70 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Threading;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
// True-internal because this is a weird and tricky class to use :)
|
||||
internal static class AsyncEnumerableAdapters
|
||||
{
|
||||
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerable<T>(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new CancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
|
||||
}
|
||||
|
||||
public static IAsyncEnumerable<object> MakeCancelableAsyncEnumerableFromChannel<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default)
|
||||
{
|
||||
return MakeCancelableAsyncEnumerable(channel.ReadAllAsync(), cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
|
||||
private class CancelableAsyncEnumerable<T> : IAsyncEnumerable<object>
|
||||
{
|
||||
private readonly IAsyncEnumerable<T> _asyncEnumerable;
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
|
||||
public CancelableAsyncEnumerable(IAsyncEnumerable<T> asyncEnumerable, CancellationToken cancellationToken)
|
||||
{
|
||||
_asyncEnumerable = asyncEnumerable;
|
||||
_cancellationToken = cancellationToken;
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Assume that this will be iterated through with await foreach which always passes a default token.
|
||||
// Instead use the token from the ctor.
|
||||
Debug.Assert(cancellationToken == default);
|
||||
|
||||
var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
|
||||
return enumeratorOfT as IAsyncEnumerator<object> ?? new BoxedAsyncEnumerator(enumeratorOfT);
|
||||
}
|
||||
|
||||
private class BoxedAsyncEnumerator : IAsyncEnumerator<object>
|
||||
{
|
||||
private IAsyncEnumerator<T> _asyncEnumerator;
|
||||
|
||||
public BoxedAsyncEnumerator(IAsyncEnumerator<T> asyncEnumerator)
|
||||
{
|
||||
_asyncEnumerator = asyncEnumerator;
|
||||
}
|
||||
|
||||
public object Current => _asyncEnumerator.Current;
|
||||
|
||||
public ValueTask<bool> MoveNextAsync()
|
||||
{
|
||||
return _asyncEnumerator.MoveNextAsync();
|
||||
}
|
||||
|
||||
public ValueTask DisposeAsync()
|
||||
{
|
||||
return _asyncEnumerator.DisposeAsync();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,84 +0,0 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
// True-internal because this is a weird and tricky class to use :)
|
||||
internal static class AsyncEnumeratorAdapters
|
||||
{
|
||||
public static IAsyncEnumerator<object> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
// Nothing to dispose when we finish enumerating in this case.
|
||||
return new AsyncEnumerator<T>(channel, cancellationToken, disposable: null);
|
||||
}
|
||||
|
||||
/// <summary>Provides an async enumerator for the data in a channel.</summary>
|
||||
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>, IDisposable
|
||||
{
|
||||
/// <summary>The channel being enumerated.</summary>
|
||||
private readonly ChannelReader<T> _channel;
|
||||
/// <summary>Cancellation token used to cancel the enumeration.</summary>
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
/// <summary>The current element of the enumeration.</summary>
|
||||
private object _current;
|
||||
|
||||
private readonly IDisposable _disposable;
|
||||
|
||||
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken, IDisposable disposable)
|
||||
{
|
||||
_channel = channel;
|
||||
_cancellationToken = cancellationToken;
|
||||
_disposable = disposable;
|
||||
}
|
||||
|
||||
public object Current => _current;
|
||||
|
||||
public Task<bool> MoveNextAsync()
|
||||
{
|
||||
var result = _channel.ReadAsync(_cancellationToken);
|
||||
|
||||
if (result.IsCompletedSuccessfully)
|
||||
{
|
||||
_current = result.Result;
|
||||
return Task.FromResult(true);
|
||||
}
|
||||
|
||||
return result.AsTask().ContinueWith((t, s) =>
|
||||
{
|
||||
var thisRef = (AsyncEnumerator<T>)s;
|
||||
if (t.IsFaulted && t.Exception.InnerException is ChannelClosedException cce && cce.InnerException == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
thisRef._current = t.GetAwaiter().GetResult();
|
||||
return true;
|
||||
}, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default);
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_disposable?.Dispose();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>Represents an enumerator accessed asynchronously.</summary>
|
||||
/// <typeparam name="T">Specifies the type of the data enumerated.</typeparam>
|
||||
internal interface IAsyncEnumerator<out T>
|
||||
{
|
||||
/// <summary>Asynchronously move the enumerator to the next element.</summary>
|
||||
/// <returns>
|
||||
/// A task that returns true if the enumerator was successfully advanced to the next item,
|
||||
/// or false if no more data was available in the collection.
|
||||
/// </returns>
|
||||
Task<bool> MoveNextAsync();
|
||||
|
||||
/// <summary>Gets the current element being enumerated.</summary>
|
||||
T Current { get; }
|
||||
}
|
||||
}
|
||||
|
|
@ -293,16 +293,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
|
||||
|
||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
|
||||
if (result == null)
|
||||
{
|
||||
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
|
||||
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
|
||||
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
|
||||
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
|
||||
return;
|
||||
}
|
||||
|
||||
cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
|
||||
var enumerable = descriptor.FromReturnedStream(result, cts.Token);
|
||||
|
||||
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
|
||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
|
||||
}
|
||||
|
||||
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
||||
|
|
@ -393,17 +397,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return scope.DisposeAsync();
|
||||
}
|
||||
|
||||
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
|
||||
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
|
||||
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
|
||||
{
|
||||
string error = null;
|
||||
|
||||
try
|
||||
{
|
||||
while (await enumerator.MoveNextAsync())
|
||||
await foreach (var streamItem in enumerable)
|
||||
{
|
||||
// Send the stream item
|
||||
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
|
||||
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
|
||||
}
|
||||
}
|
||||
catch (ChannelClosedException ex)
|
||||
|
|
@ -422,8 +426,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
}
|
||||
finally
|
||||
{
|
||||
(enumerator as IDisposable)?.Dispose();
|
||||
|
||||
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
|
||||
|
||||
// Dispose the linked CTS for the stream.
|
||||
|
|
@ -502,10 +504,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return authorizationResult.Succeeded;
|
||||
}
|
||||
|
||||
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
|
||||
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamResponse,
|
||||
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
|
||||
{
|
||||
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
|
||||
if (hubMethodDescriptor.IsStreamResponse && !isStreamResponse)
|
||||
{
|
||||
// Non-null/empty InvocationId? Blocking
|
||||
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
||||
|
|
@ -518,7 +520,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
|
||||
if (!hubMethodDescriptor.IsStreamResponse && isStreamResponse)
|
||||
{
|
||||
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
|
||||
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
|
||||
|
|
@ -530,26 +532,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return true;
|
||||
}
|
||||
|
||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
|
||||
{
|
||||
if (result != null)
|
||||
{
|
||||
if (hubMethodDescriptor.IsChannel)
|
||||
{
|
||||
if (streamCts == null)
|
||||
{
|
||||
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
|
||||
}
|
||||
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
|
||||
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
enumerator = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
private void DiscoverHubMethods()
|
||||
{
|
||||
var hubType = typeof(THub);
|
||||
|
|
|
|||
|
|
@ -15,9 +15,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
{
|
||||
internal class HubMethodDescriptor
|
||||
{
|
||||
private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
|
||||
private static readonly MethodInfo MakeCancelableAsyncEnumerableMethod = typeof(AsyncEnumerableAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod);
|
||||
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable)) && m.IsGenericMethod);
|
||||
|
||||
private static readonly MethodInfo MakeCancelableAsyncEnumerableFromChannelMethod = typeof(AsyncEnumerableAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel)) && m.IsGenericMethod);
|
||||
|
||||
private readonly MethodInfo _makeCancelableEnumerableMethodInfo;
|
||||
private Func<object, CancellationToken, IAsyncEnumerable<object>> _makeCancelableEnumerable;
|
||||
|
||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||
{
|
||||
|
|
@ -27,17 +34,35 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
? MethodExecutor.AsyncResultType
|
||||
: MethodExecutor.MethodReturnType;
|
||||
|
||||
if (IsChannelType(NonAsyncReturnType, out var channelItemType))
|
||||
foreach (var returnType in NonAsyncReturnType.GetInterfaces().Concat(NonAsyncReturnType.AllBaseTypes()))
|
||||
{
|
||||
IsChannel = true;
|
||||
StreamReturnType = channelItemType;
|
||||
if (!returnType.IsGenericType)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var openReturnType = returnType.GetGenericTypeDefinition();
|
||||
|
||||
if (openReturnType == typeof(IAsyncEnumerable<>))
|
||||
{
|
||||
StreamReturnType = returnType.GetGenericArguments()[0];
|
||||
_makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableMethod;
|
||||
break;
|
||||
}
|
||||
|
||||
if (openReturnType == typeof(ChannelReader<>))
|
||||
{
|
||||
StreamReturnType = returnType.GetGenericArguments()[0];
|
||||
_makeCancelableEnumerableMethodInfo = MakeCancelableAsyncEnumerableFromChannelMethod;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
|
||||
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
|
||||
{
|
||||
// Only streams can take CancellationTokens currently
|
||||
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
|
||||
if (IsStreamResponse && p.ParameterType == typeof(CancellationToken))
|
||||
{
|
||||
HasSyntheticArguments = true;
|
||||
return false;
|
||||
|
|
@ -66,8 +91,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public List<Type> StreamingParameters { get; private set; }
|
||||
|
||||
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
|
||||
|
||||
public ObjectMethodExecutor MethodExecutor { get; }
|
||||
|
||||
public IReadOnlyList<Type> ParameterTypes { get; }
|
||||
|
|
@ -76,9 +99,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public Type NonAsyncReturnType { get; }
|
||||
|
||||
public bool IsChannel { get; }
|
||||
|
||||
public bool IsStreamable => IsChannel;
|
||||
public bool IsStreamResponse => StreamReturnType != null;
|
||||
|
||||
public Type StreamReturnType { get; }
|
||||
|
||||
|
|
@ -86,57 +107,39 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
public bool HasSyntheticArguments { get; private set; }
|
||||
|
||||
private static bool IsChannelType(Type type, out Type payloadType)
|
||||
{
|
||||
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
||||
if (channelType == null)
|
||||
{
|
||||
payloadType = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
payloadType = channelType.GetGenericArguments()[0];
|
||||
return true;
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<object> FromChannel(object channel, CancellationToken cancellationToken)
|
||||
public IAsyncEnumerable<object> FromReturnedStream(object stream, CancellationToken cancellationToken)
|
||||
{
|
||||
// there is the potential for compile to be called times but this has no harmful effect other than perf
|
||||
if (_convertToEnumerator == null)
|
||||
if (_makeCancelableEnumerable == null)
|
||||
{
|
||||
_convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType);
|
||||
_makeCancelableEnumerable = CompileConvertToEnumerable(_makeCancelableEnumerableMethodInfo, StreamReturnType);
|
||||
}
|
||||
|
||||
return _convertToEnumerator.Invoke(channel, cancellationToken);
|
||||
return _makeCancelableEnumerable.Invoke(stream, cancellationToken);
|
||||
}
|
||||
|
||||
private static Func<object, CancellationToken, IAsyncEnumerator<object>> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType)
|
||||
private static Func<object, CancellationToken, IAsyncEnumerable<object>> CompileConvertToEnumerable(MethodInfo adapterMethodInfo, Type streamReturnType)
|
||||
{
|
||||
// This will call one of two adapter methods to wrap the passed in streamable value
|
||||
// and cancellation token to an IAsyncEnumerator<object>
|
||||
// ChannelReader<T>
|
||||
// AsyncEnumeratorAdapters.GetAsyncEnumerator<T>(channelReader, cancellationToken);
|
||||
// This will call one of two adapter methods to wrap the passed in streamable value into an IAsyncEnumerable<object>:
|
||||
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerable<T>(asyncEnumerable, cancellationToken);
|
||||
// - AsyncEnumerableAdapters.MakeCancelableAsyncEnumerableFromChannel<T>(channelReader, cancellationToken);
|
||||
|
||||
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
|
||||
|
||||
var methodParameters = genericMethodInfo.GetParameters();
|
||||
|
||||
// arg1 and arg2 are the parameter names on Func<T1, T2, TReturn>
|
||||
// we reference these values and then use them to call adaptor method
|
||||
var targetParameter = Expression.Parameter(typeof(object), "arg1");
|
||||
var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2");
|
||||
|
||||
var parameters = new List<Expression>
|
||||
var parameters = new[]
|
||||
{
|
||||
Expression.Convert(targetParameter, methodParameters[0].ParameterType),
|
||||
parametersParameter
|
||||
Expression.Parameter(typeof(object)),
|
||||
Expression.Parameter(typeof(CancellationToken)),
|
||||
};
|
||||
|
||||
var methodCall = Expression.Call(null, genericMethodInfo, parameters);
|
||||
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
|
||||
var methodParameters = genericMethodInfo.GetParameters();
|
||||
var methodArguements = new Expression[]
|
||||
{
|
||||
Expression.Convert(parameters[0], methodParameters[0].ParameterType),
|
||||
parameters[1],
|
||||
};
|
||||
|
||||
var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator<object>));
|
||||
|
||||
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerator<object>>>(castMethodCall, targetParameter, parametersParameter);
|
||||
var methodCall = Expression.Call(null, genericMethodInfo, methodArguements);
|
||||
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerable<object>>>(methodCall, parameters);
|
||||
return lambda.Compile();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
<TargetFramework>netcoreapp3.0</TargetFramework>
|
||||
<IsAspNetCoreApp>true</IsAspNetCoreApp>
|
||||
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
|
||||
<LangVersion>8.0</LangVersion>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
|
|
|
|||
|
|
@ -591,6 +591,31 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return CounterChannel(count);
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<string> CounterAsyncEnumerable(int count)
|
||||
{
|
||||
for (int i = 0; i < count; i++)
|
||||
{
|
||||
await Task.Yield();
|
||||
yield return i.ToString();
|
||||
}
|
||||
}
|
||||
|
||||
public async Task<IAsyncEnumerable<string>> CounterAsyncEnumerableAsync(int count)
|
||||
{
|
||||
await Task.Yield();
|
||||
return CounterAsyncEnumerable(count);
|
||||
}
|
||||
|
||||
public AsyncEnumerableImpl<string> CounterAsyncEnumerableImpl(int count)
|
||||
{
|
||||
return new AsyncEnumerableImpl<string>(CounterAsyncEnumerable(count));
|
||||
}
|
||||
|
||||
public AsyncEnumerableImplChannelThrows<string> AsyncEnumerableIsPreferedOverChannelReader(int count)
|
||||
{
|
||||
return new AsyncEnumerableImplChannelThrows<string>(CounterChannel(count));
|
||||
}
|
||||
|
||||
public ChannelReader<string> BlockingStream()
|
||||
{
|
||||
return Channel.CreateUnbounded<string>().Reader;
|
||||
|
|
@ -627,6 +652,99 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
return output.Reader;
|
||||
}
|
||||
|
||||
public class AsyncEnumerableImpl<T> : IAsyncEnumerable<T>
|
||||
{
|
||||
private readonly IAsyncEnumerable<T> _inner;
|
||||
|
||||
public AsyncEnumerableImpl(IAsyncEnumerable<T> inner)
|
||||
{
|
||||
_inner = inner;
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return _inner.GetAsyncEnumerator(cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
public class AsyncEnumerableImplChannelThrows<T> : ChannelReader<T>, IAsyncEnumerable<T>
|
||||
{
|
||||
private ChannelReader<T> _inner;
|
||||
|
||||
public AsyncEnumerableImplChannelThrows(ChannelReader<T> inner)
|
||||
{
|
||||
_inner = inner;
|
||||
}
|
||||
|
||||
public override bool TryRead(out T item)
|
||||
{
|
||||
// Not implemented to verify this is consumed as an IAsyncEnumerable<T> instead of a ChannelReader<T>.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override ValueTask<bool> WaitToReadAsync(CancellationToken cancellationToken = default)
|
||||
{
|
||||
// Not implemented to verify this is consumed as an IAsyncEnumerable<T> instead of a ChannelReader<T>.
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new ChannelAsyncEnumerator(_inner, cancellationToken);
|
||||
}
|
||||
|
||||
// Copied from AsyncEnumeratorAdapters
|
||||
private class ChannelAsyncEnumerator : IAsyncEnumerator<T>
|
||||
{
|
||||
/// <summary>The channel being enumerated.</summary>
|
||||
private readonly ChannelReader<T> _channel;
|
||||
/// <summary>Cancellation token used to cancel the enumeration.</summary>
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
/// <summary>The current element of the enumeration.</summary>
|
||||
private T _current;
|
||||
|
||||
public ChannelAsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
|
||||
{
|
||||
_channel = channel;
|
||||
_cancellationToken = cancellationToken;
|
||||
}
|
||||
|
||||
public T Current => _current;
|
||||
|
||||
public ValueTask<bool> MoveNextAsync()
|
||||
{
|
||||
var result = _channel.ReadAsync(_cancellationToken);
|
||||
|
||||
if (result.IsCompletedSuccessfully)
|
||||
{
|
||||
_current = result.Result;
|
||||
return new ValueTask<bool>(true);
|
||||
}
|
||||
|
||||
return new ValueTask<bool>(MoveNextAsyncAwaited(result));
|
||||
}
|
||||
|
||||
private async Task<bool> MoveNextAsyncAwaited(ValueTask<T> channelReadTask)
|
||||
{
|
||||
try
|
||||
{
|
||||
_current = await channelReadTask;
|
||||
}
|
||||
catch (ChannelClosedException ex) when (ex.InnerException == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public ValueTask DisposeAsync()
|
||||
{
|
||||
return default;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class SimpleHub : Hub
|
||||
|
|
@ -681,7 +799,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return Channel.CreateUnbounded<string>().Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStream(CancellationToken token)
|
||||
public ChannelReader<int> CancelableStreamSingleParameter(CancellationToken token)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
|
|
@ -696,7 +814,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return channel.Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
|
||||
public ChannelReader<int> CancelableStreamMultiParameter(int ignore, int ignore2, CancellationToken token)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
|
|
@ -711,7 +829,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return channel.Reader;
|
||||
}
|
||||
|
||||
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
|
||||
public ChannelReader<int> CancelableStreamMiddleParameter(int ignore, CancellationToken token, int ignore2)
|
||||
{
|
||||
var channel = Channel.CreateBounded<int>(10);
|
||||
|
||||
|
|
@ -726,16 +844,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return channel.Reader;
|
||||
}
|
||||
|
||||
public async IAsyncEnumerable<int> CancelableStreamGeneratedAsyncEnumerable(CancellationToken token)
|
||||
{
|
||||
_tcsService.StartedMethod.SetResult(null);
|
||||
await token.WaitForCancellationAsync();
|
||||
_tcsService.EndMethod.SetResult(null);
|
||||
yield break;
|
||||
}
|
||||
|
||||
public IAsyncEnumerable<int> CancelableStreamCustomAsyncEnumerable()
|
||||
{
|
||||
return new CustomAsyncEnumerable(_tcsService);
|
||||
}
|
||||
|
||||
public int SimpleMethod()
|
||||
{
|
||||
return 21;
|
||||
}
|
||||
|
||||
private class CustomAsyncEnumerable : IAsyncEnumerable<int>
|
||||
{
|
||||
private readonly TcsService _tcsService;
|
||||
|
||||
public CustomAsyncEnumerable(TcsService tcsService)
|
||||
{
|
||||
_tcsService = tcsService;
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<int> GetAsyncEnumerator(CancellationToken cancellationToken = default)
|
||||
{
|
||||
return new CustomAsyncEnumerator(_tcsService, cancellationToken);
|
||||
}
|
||||
|
||||
private class CustomAsyncEnumerator : IAsyncEnumerator<int>
|
||||
{
|
||||
private readonly TcsService _tcsService;
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
|
||||
public CustomAsyncEnumerator(TcsService tcsService, CancellationToken cancellationToken)
|
||||
{
|
||||
_tcsService = tcsService;
|
||||
_cancellationToken = cancellationToken;
|
||||
}
|
||||
|
||||
public int Current => throw new NotImplementedException();
|
||||
|
||||
public ValueTask DisposeAsync()
|
||||
{
|
||||
return default;
|
||||
}
|
||||
|
||||
public async ValueTask<bool> MoveNextAsync()
|
||||
{
|
||||
_tcsService.StartedMethod.SetResult(null);
|
||||
await _cancellationToken.WaitForCancellationAsync();
|
||||
_tcsService.EndMethod.SetResult(null);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class TcsService
|
||||
{
|
||||
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>();
|
||||
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>();
|
||||
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
}
|
||||
|
||||
public interface ITypedHubClient
|
||||
|
|
|
|||
|
|
@ -1763,10 +1763,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
|
||||
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
|
||||
var invocationBinder = new Mock<IInvocationBinder>();
|
||||
invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny<string>())).Returns(typeof(string));
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
|
||||
var invocationBinder = new Mock<IInvocationBinder>();
|
||||
invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny<string>())).Returns(typeof(string));
|
||||
|
||||
using (var client = new TestClient(protocol: protocol, invocationBinder: invocationBinder.Object))
|
||||
{
|
||||
|
|
@ -1909,10 +1909,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
get
|
||||
{
|
||||
foreach (var method in new[]
|
||||
var methods = new[]
|
||||
{
|
||||
nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync)
|
||||
})
|
||||
nameof(StreamingHub.CounterChannel),
|
||||
nameof(StreamingHub.CounterChannelAsync),
|
||||
nameof(StreamingHub.CounterChannelValueTaskAsync),
|
||||
nameof(StreamingHub.CounterAsyncEnumerable),
|
||||
nameof(StreamingHub.CounterAsyncEnumerableAsync),
|
||||
nameof(StreamingHub.CounterAsyncEnumerableImpl),
|
||||
nameof(StreamingHub.AsyncEnumerableIsPreferedOverChannelReader),
|
||||
};
|
||||
|
||||
foreach (var method in methods)
|
||||
{
|
||||
foreach (var protocolName in HubProtocolHelpers.AllProtocolNames)
|
||||
{
|
||||
|
|
@ -3150,10 +3158,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStream))]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
|
||||
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamSingleParameter))]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamMultiParameter), 1, 2)]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamMiddleParameter), 1, 2)]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamGeneratedAsyncEnumerable))]
|
||||
[InlineData(nameof(LongRunningHub.CancelableStreamCustomAsyncEnumerable))]
|
||||
public async Task StreamHubMethodCanBeTriggeredOnCancellation(string methodName, params object[] args)
|
||||
{
|
||||
using (StartVerifiableLog())
|
||||
{
|
||||
|
|
@ -3207,7 +3217,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
|
||||
|
||||
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
|
||||
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStreamSingleParameter)).OrTimeout();
|
||||
// Wait for the stream method to start
|
||||
await tcsService.StartedMethod.Task.OrTimeout();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
<Project Sdk="Microsoft.NET.Sdk">
|
||||
|
||||
<PropertyGroup>
|
||||
<TargetFramework>netcoreapp3.0</TargetFramework>
|
||||
<LangVersion>8.0</LangVersion>
|
||||
</PropertyGroup>
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue