Hub method reflection caching and invocation benchmarks (#1574)
This commit is contained in:
parent
d816c6ef60
commit
974eb28b8b
|
|
@ -0,0 +1,235 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Reactive.Linq;
|
||||
using System.Threading.Channels;
|
||||
using System.Threading.Tasks;
|
||||
using BenchmarkDotNet.Attributes;
|
||||
using Microsoft.AspNetCore.Protocols;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Moq;
|
||||
using DefaultConnectionContext = Microsoft.AspNetCore.Sockets.DefaultConnectionContext;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
||||
{
|
||||
public class DefaultHubDispatcherBenchmark
|
||||
{
|
||||
private DefaultHubDispatcher<TestHub> _dispatcher;
|
||||
private HubConnectionContext _connectionContext;
|
||||
|
||||
[GlobalSetup]
|
||||
public void GlobalSetup()
|
||||
{
|
||||
var serviceCollection = new ServiceCollection();
|
||||
serviceCollection.AddSignalRCore();
|
||||
|
||||
var provider = serviceCollection.BuildServiceProvider();
|
||||
|
||||
var serviceScopeFactory = provider.GetService<IServiceScopeFactory>();
|
||||
|
||||
_dispatcher = new DefaultHubDispatcher<TestHub>(
|
||||
serviceScopeFactory,
|
||||
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>()),
|
||||
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance));
|
||||
|
||||
var options = new PipeOptions();
|
||||
var pair = DuplexPipe.CreateConnectionPair(options, options);
|
||||
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application);
|
||||
|
||||
_connectionContext = new NoErrorHubConnectionContext(connection, TimeSpan.Zero, NullLoggerFactory.Instance);
|
||||
|
||||
_connectionContext.ProtocolReaderWriter = new HubProtocolReaderWriter(new FakeHubProtocol(), new FakeDataEncoder());
|
||||
}
|
||||
|
||||
public class FakeHubProtocol : IHubProtocol
|
||||
{
|
||||
public string Name { get; }
|
||||
public ProtocolType Type { get; }
|
||||
|
||||
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
public void WriteMessage(HubMessage message, Stream output)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
public class FakeDataEncoder : IDataEncoder
|
||||
{
|
||||
public byte[] Encode(byte[] payload)
|
||||
{
|
||||
return null;
|
||||
}
|
||||
|
||||
public bool TryDecode(ref ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> data)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
public class NoErrorHubConnectionContext : HubConnectionContext
|
||||
{
|
||||
public NoErrorHubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) : base(connectionContext, keepAliveInterval, loggerFactory)
|
||||
{
|
||||
}
|
||||
|
||||
public override Task WriteAsync(HubMessage message)
|
||||
{
|
||||
if (message is CompletionMessage completionMessage)
|
||||
{
|
||||
if (!string.IsNullOrEmpty(completionMessage.Error))
|
||||
{
|
||||
throw new Exception("Error invoking hub method: " + completionMessage.Error);
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
public class TestHub : Hub
|
||||
{
|
||||
private static readonly IObservable<int> ObservableInstance = Observable.Empty<int>();
|
||||
|
||||
public void Invocation()
|
||||
{
|
||||
}
|
||||
|
||||
public Task InvocationAsync()
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public int InvocationReturnValue()
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
public Task<int> InvocationReturnAsync()
|
||||
{
|
||||
return Task.FromResult(1);
|
||||
}
|
||||
|
||||
public ValueTask<int> InvocationValueTaskAsync()
|
||||
{
|
||||
return new ValueTask<int>(1);
|
||||
}
|
||||
|
||||
public IObservable<int> StreamObservable()
|
||||
{
|
||||
return ObservableInstance;
|
||||
}
|
||||
|
||||
public Task<IObservable<int>> StreamObservableAsync()
|
||||
{
|
||||
return Task.FromResult(ObservableInstance);
|
||||
}
|
||||
|
||||
public ValueTask<IObservable<int>> StreamObservableValueTaskAsync()
|
||||
{
|
||||
return new ValueTask<IObservable<int>>(ObservableInstance);
|
||||
}
|
||||
|
||||
public ChannelReader<int> StreamChannelReader()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<int>();
|
||||
channel.Writer.Complete();
|
||||
|
||||
return channel;
|
||||
}
|
||||
|
||||
public Task<ChannelReader<int>> StreamChannelReaderAsync()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<int>();
|
||||
channel.Writer.Complete();
|
||||
|
||||
return Task.FromResult<ChannelReader<int>>(channel);
|
||||
}
|
||||
|
||||
public ValueTask<ChannelReader<int>> StreamChannelReaderValueTaskAsync()
|
||||
{
|
||||
var channel = Channel.CreateUnbounded<int>();
|
||||
channel.Writer.Complete();
|
||||
|
||||
return new ValueTask<ChannelReader<int>>(channel);
|
||||
}
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task Invocation()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "Invocation", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task InvocationAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task InvocationReturnValue()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnValue", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task InvocationReturnAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task InvocationValueTaskAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationValueTaskAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamObservable()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservable", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamObservableAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamObservableValueTaskAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableValueTaskAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamChannelReader()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReader", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamChannelReaderAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderAsync", null));
|
||||
}
|
||||
|
||||
[Benchmark]
|
||||
public Task StreamChannelReaderValueTaskAsync()
|
||||
{
|
||||
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderValueTaskAsync", null));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -12,8 +12,11 @@
|
|||
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
|
||||
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.BenchmarkRunner.Sources" Version="$(MicrosoftAspNetCoreBenchmarkRunnerSourcesPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(MicrosoftExtensionsDependencyInjectionPackageVersion)" />
|
||||
<PackageReference Include="Moq" Version="$(MoqPackageVersion)" />
|
||||
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
|
||||
<PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsPackageVersion)" />
|
||||
<PackageReference Include="System.Reactive.Linq" Version="$(SystemReactiveLinqPackageVersion)" />
|
||||
</ItemGroup>
|
||||
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -13,26 +13,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
// True-internal because this is a weird and tricky class to use :)
|
||||
internal static class AsyncEnumeratorAdapters
|
||||
{
|
||||
private static readonly MethodInfo _boxEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(BoxEnumerator)) && m.IsGenericMethod);
|
||||
|
||||
private static readonly MethodInfo _fromObservableMethod = typeof(AsyncEnumeratorAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod);
|
||||
|
||||
private static readonly MethodInfo _getAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(GetAsyncEnumerator)) && m.IsGenericMethod);
|
||||
|
||||
public static IAsyncEnumerator<object> FromObservable(object observable, Type observableInterface, CancellationToken cancellationToken)
|
||||
{
|
||||
// TODO: Cache expressions by observable.GetType()?
|
||||
return (IAsyncEnumerator<object>)_fromObservableMethod
|
||||
.MakeGenericMethod(observableInterface.GetGenericArguments())
|
||||
.Invoke(null, new[] { observable, cancellationToken });
|
||||
}
|
||||
|
||||
public static IAsyncEnumerator<object> FromObservable<T>(IObservable<T> observable, CancellationToken cancellationToken)
|
||||
{
|
||||
// TODO: Allow bounding and optimizations?
|
||||
|
|
@ -46,33 +26,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return GetAsyncEnumerator(channel.Reader, cancellationToken);
|
||||
}
|
||||
|
||||
public static IAsyncEnumerator<object> FromChannel(object readableChannelOfT, Type payloadType, CancellationToken cancellationToken)
|
||||
{
|
||||
var enumerator = _getAsyncEnumeratorMethod
|
||||
.MakeGenericMethod(payloadType)
|
||||
.Invoke(null, new object[] { readableChannelOfT, cancellationToken });
|
||||
|
||||
if (payloadType.IsValueType)
|
||||
{
|
||||
return (IAsyncEnumerator<object>)_boxEnumeratorMethod
|
||||
.MakeGenericMethod(payloadType)
|
||||
.Invoke(null, new[] { enumerator });
|
||||
}
|
||||
else
|
||||
{
|
||||
return (IAsyncEnumerator<object>)enumerator;
|
||||
}
|
||||
}
|
||||
|
||||
private static IAsyncEnumerator<object> BoxEnumerator<T>(IAsyncEnumerator<T> input) where T : struct
|
||||
{
|
||||
return new BoxingEnumerator<T>(input);
|
||||
}
|
||||
|
||||
private class ChannelObserver<T> : IObserver<T>
|
||||
{
|
||||
private ChannelWriter<object> _output;
|
||||
private CancellationToken _cancellationToken;
|
||||
private readonly ChannelWriter<object> _output;
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
|
||||
public ChannelObserver(ChannelWriter<object> output, CancellationToken cancellationToken)
|
||||
{
|
||||
|
|
@ -116,33 +73,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
}
|
||||
}
|
||||
|
||||
private class BoxingEnumerator<T> : IAsyncEnumerator<object> where T : struct
|
||||
{
|
||||
private IAsyncEnumerator<T> _input;
|
||||
|
||||
public BoxingEnumerator(IAsyncEnumerator<T> input)
|
||||
{
|
||||
_input = input;
|
||||
}
|
||||
|
||||
public object Current => _input.Current;
|
||||
public Task<bool> MoveNextAsync() => _input.MoveNextAsync();
|
||||
}
|
||||
|
||||
public static IAsyncEnumerator<T> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
|
||||
public static IAsyncEnumerator<object> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
return new AsyncEnumerator<T>(channel, cancellationToken);
|
||||
}
|
||||
|
||||
/// <summary>Provides an async enumerator for the data in a channel.</summary>
|
||||
internal class AsyncEnumerator<T> : IAsyncEnumerator<T>
|
||||
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>
|
||||
{
|
||||
/// <summary>The channel being enumerated.</summary>
|
||||
private readonly ChannelReader<T> _channel;
|
||||
/// <summary>Cancellation token used to cancel the enumeration.</summary>
|
||||
private readonly CancellationToken _cancellationToken;
|
||||
/// <summary>The current element of the enumeration.</summary>
|
||||
private T _current;
|
||||
private object _current;
|
||||
|
||||
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
|
||||
{
|
||||
|
|
@ -150,7 +94,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
_cancellationToken = cancellationToken;
|
||||
}
|
||||
|
||||
public T Current => _current;
|
||||
public object Current => _current;
|
||||
|
||||
public Task<bool> MoveNextAsync()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -164,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return;
|
||||
}
|
||||
|
||||
if (!await ValidateInvocationMode(methodExecutor, isStreamedInvocation, hubMethodInvocationMessage, connection))
|
||||
if (!await ValidateInvocationMode(descriptor, isStreamedInvocation, hubMethodInvocationMessage, connection))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
|
@ -187,7 +187,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
if (isStreamedInvocation)
|
||||
{
|
||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, out var enumerator))
|
||||
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator))
|
||||
{
|
||||
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
|
||||
|
||||
|
|
@ -267,8 +267,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
|
||||
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
|
||||
{
|
||||
// ReadableChannel is awaitable but we don't want to await it.
|
||||
if (methodExecutor.IsMethodAsync && !IsChannel(methodExecutor.MethodReturnType, out _))
|
||||
if (methodExecutor.IsMethodAsync)
|
||||
{
|
||||
if (methodExecutor.MethodReturnType == typeof(Task))
|
||||
{
|
||||
|
|
@ -305,21 +304,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
hub.Groups = _hubContext.Groups;
|
||||
}
|
||||
|
||||
private static bool IsChannel(Type type, out Type payloadType)
|
||||
{
|
||||
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
||||
if (channelType == null)
|
||||
{
|
||||
payloadType = null;
|
||||
return false;
|
||||
}
|
||||
else
|
||||
{
|
||||
payloadType = channelType.GetGenericArguments()[0];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
|
||||
{
|
||||
// If there are no policies we don't need to run auth
|
||||
|
|
@ -340,11 +324,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return authorizationResult.Succeeded;
|
||||
}
|
||||
|
||||
private async Task<bool> ValidateInvocationMode(ObjectMethodExecutor methodExecutor, bool isStreamedInvocation,
|
||||
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
|
||||
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
|
||||
{
|
||||
var isStreamedResult = IsStreamed(methodExecutor);
|
||||
if (isStreamedResult && !isStreamedInvocation)
|
||||
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
|
||||
{
|
||||
// Non-null/empty InvocationId? Blocking
|
||||
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
|
||||
|
|
@ -357,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return false;
|
||||
}
|
||||
|
||||
if (!isStreamedResult && isStreamedInvocation)
|
||||
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
|
||||
{
|
||||
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
|
||||
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
|
||||
|
|
@ -369,51 +352,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return true;
|
||||
}
|
||||
|
||||
private static bool IsStreamed(ObjectMethodExecutor methodExecutor)
|
||||
{
|
||||
var resultType = (methodExecutor.IsMethodAsync)
|
||||
? methodExecutor.AsyncResultType
|
||||
: methodExecutor.MethodReturnType;
|
||||
|
||||
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
|
||||
var observableInterface = IsIObservable(resultType) ?
|
||||
resultType :
|
||||
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
|
||||
|
||||
if (observableInterface != null)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
if (IsChannel(resultType, out _))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, out IAsyncEnumerator<object> enumerator)
|
||||
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator)
|
||||
{
|
||||
if (result != null)
|
||||
{
|
||||
var resultType = (methodExecutor.IsMethodAsync)
|
||||
? methodExecutor.AsyncResultType
|
||||
: methodExecutor.MethodReturnType;
|
||||
|
||||
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
|
||||
var observableInterface = IsIObservable(resultType) ?
|
||||
resultType :
|
||||
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
|
||||
if (observableInterface != null)
|
||||
if (hubMethodDescriptor.IsObservable)
|
||||
{
|
||||
enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, CreateCancellation());
|
||||
enumerator = hubMethodDescriptor.FromObservable(result, CreateCancellation());
|
||||
return true;
|
||||
}
|
||||
|
||||
if (IsChannel(resultType, out var payloadType))
|
||||
if (hubMethodDescriptor.IsChannel)
|
||||
{
|
||||
enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, CreateCancellation());
|
||||
enumerator = hubMethodDescriptor.FromChannel(result, CreateCancellation());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
|
@ -429,11 +380,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
}
|
||||
}
|
||||
|
||||
private static bool IsIObservable(Type iface)
|
||||
{
|
||||
return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>);
|
||||
}
|
||||
|
||||
private void DiscoverHubMethods()
|
||||
{
|
||||
var hubType = typeof(THub);
|
||||
|
|
@ -458,22 +404,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
Log.HubMethodBound(_logger, hubName, methodName);
|
||||
}
|
||||
}
|
||||
|
||||
// REVIEW: We can decide to move this out of here if we want pluggable hub discovery
|
||||
private class HubMethodDescriptor
|
||||
{
|
||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||
{
|
||||
MethodExecutor = methodExecutor;
|
||||
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||
Policies = policies.ToArray();
|
||||
}
|
||||
|
||||
public ObjectMethodExecutor MethodExecutor { get; }
|
||||
|
||||
public IReadOnlyList<Type> ParameterTypes { get; }
|
||||
|
||||
public IList<IAuthorizeData> Policies { get; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,153 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Linq.Expressions;
|
||||
using System.Reflection;
|
||||
using System.Threading;
|
||||
using System.Threading.Channels;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
internal class HubMethodDescriptor
|
||||
{
|
||||
private static readonly MethodInfo FromObservableMethod = typeof(AsyncEnumeratorAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.FromObservable)) && m.IsGenericMethod);
|
||||
|
||||
private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
|
||||
.GetRuntimeMethods()
|
||||
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod);
|
||||
|
||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||
{
|
||||
MethodExecutor = methodExecutor;
|
||||
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||
Policies = policies.ToArray();
|
||||
|
||||
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
|
||||
? MethodExecutor.AsyncResultType
|
||||
: MethodExecutor.MethodReturnType;
|
||||
|
||||
if (IsObservableType(NonAsyncReturnType, out var observableItemType))
|
||||
{
|
||||
IsObservable = true;
|
||||
StreamReturnType = observableItemType;
|
||||
}
|
||||
else if (IsChannelType(NonAsyncReturnType, out var channelItemType))
|
||||
{
|
||||
IsChannel = true;
|
||||
StreamReturnType = channelItemType;
|
||||
}
|
||||
}
|
||||
|
||||
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
|
||||
|
||||
public ObjectMethodExecutor MethodExecutor { get; }
|
||||
|
||||
public IReadOnlyList<Type> ParameterTypes { get; }
|
||||
|
||||
public Type NonAsyncReturnType { get; }
|
||||
|
||||
public bool IsObservable { get; }
|
||||
|
||||
public bool IsChannel { get; }
|
||||
|
||||
public bool IsStreamable => IsObservable || IsChannel;
|
||||
|
||||
public Type StreamReturnType { get; }
|
||||
|
||||
public IList<IAuthorizeData> Policies { get; }
|
||||
|
||||
private static bool IsChannelType(Type type, out Type payloadType)
|
||||
{
|
||||
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
|
||||
if (channelType == null)
|
||||
{
|
||||
payloadType = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
payloadType = channelType.GetGenericArguments()[0];
|
||||
return true;
|
||||
}
|
||||
|
||||
private static bool IsObservableType(Type type, out Type payloadType)
|
||||
{
|
||||
var observableInterface = IsIObservable(type) ? type : type.GetInterfaces().FirstOrDefault(IsIObservable);
|
||||
if (observableInterface == null)
|
||||
{
|
||||
payloadType = null;
|
||||
return false;
|
||||
}
|
||||
|
||||
payloadType = observableInterface.GetGenericArguments()[0];
|
||||
return true;
|
||||
|
||||
bool IsIObservable(Type iface)
|
||||
{
|
||||
return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>);
|
||||
}
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<object> FromObservable(object observable, CancellationToken cancellationToken)
|
||||
{
|
||||
// there is the potential for compile to be called times but this has no harmful effect other than perf
|
||||
if (_convertToEnumerator == null)
|
||||
{
|
||||
_convertToEnumerator = CompileConvertToEnumerator(FromObservableMethod, StreamReturnType);
|
||||
}
|
||||
|
||||
return _convertToEnumerator.Invoke(observable, cancellationToken);
|
||||
}
|
||||
|
||||
public IAsyncEnumerator<object> FromChannel(object channel, CancellationToken cancellationToken)
|
||||
{
|
||||
// there is the potential for compile to be called times but this has no harmful effect other than perf
|
||||
if (_convertToEnumerator == null)
|
||||
{
|
||||
_convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType);
|
||||
}
|
||||
|
||||
return _convertToEnumerator.Invoke(channel, cancellationToken);
|
||||
}
|
||||
|
||||
private static Func<object, CancellationToken, IAsyncEnumerator<object>> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType)
|
||||
{
|
||||
// This will call one of two adapter methods to wrap the passed in streamable value
|
||||
// and cancellation token to an IAsyncEnumerator<object>
|
||||
//
|
||||
// IObservable<T>:
|
||||
// AsyncEnumeratorAdapters.FromObservable<T>(observable, cancellationToken);
|
||||
//
|
||||
// ChannelReader<T>
|
||||
// AsyncEnumeratorAdapters.GetAsyncEnumerator<T>(channelReader, cancellationToken);
|
||||
|
||||
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
|
||||
|
||||
var methodParameters = genericMethodInfo.GetParameters();
|
||||
|
||||
// arg1 and arg2 are the parameter names on Func<T1, T2, TReturn>
|
||||
// we reference these values and then use them to call adaptor method
|
||||
var targetParameter = Expression.Parameter(typeof(object), "arg1");
|
||||
var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2");
|
||||
|
||||
var parameters = new List<Expression>
|
||||
{
|
||||
Expression.Convert(targetParameter, methodParameters[0].ParameterType),
|
||||
parametersParameter
|
||||
};
|
||||
|
||||
var methodCall = Expression.Call(null, genericMethodInfo, parameters);
|
||||
|
||||
var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator<object>));
|
||||
|
||||
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerator<object>>>(castMethodCall, targetParameter, parametersParameter);
|
||||
return lambda.Compile();
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue