Hub method reflection caching and invocation benchmarks (#1574)

This commit is contained in:
James Newton-King 2018-03-13 10:30:45 +13:00 committed by GitHub
parent d816c6ef60
commit 974eb28b8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 408 additions and 144 deletions

View File

@ -0,0 +1,235 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Reactive.Linq;
using System.Threading.Channels;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Protocols;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Moq;
using DefaultConnectionContext = Microsoft.AspNetCore.Sockets.DefaultConnectionContext;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
public class DefaultHubDispatcherBenchmark
{
private DefaultHubDispatcher<TestHub> _dispatcher;
private HubConnectionContext _connectionContext;
[GlobalSetup]
public void GlobalSetup()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSignalRCore();
var provider = serviceCollection.BuildServiceProvider();
var serviceScopeFactory = provider.GetService<IServiceScopeFactory>();
_dispatcher = new DefaultHubDispatcher<TestHub>(
serviceScopeFactory,
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>()),
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance));
var options = new PipeOptions();
var pair = DuplexPipe.CreateConnectionPair(options, options);
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application);
_connectionContext = new NoErrorHubConnectionContext(connection, TimeSpan.Zero, NullLoggerFactory.Instance);
_connectionContext.ProtocolReaderWriter = new HubProtocolReaderWriter(new FakeHubProtocol(), new FakeDataEncoder());
}
public class FakeHubProtocol : IHubProtocol
{
public string Name { get; }
public ProtocolType Type { get; }
public bool TryParseMessages(ReadOnlySpan<byte> input, IInvocationBinder binder, IList<HubMessage> messages)
{
return false;
}
public void WriteMessage(HubMessage message, Stream output)
{
}
}
public class FakeDataEncoder : IDataEncoder
{
public byte[] Encode(byte[] payload)
{
return null;
}
public bool TryDecode(ref ReadOnlySpan<byte> buffer, out ReadOnlySpan<byte> data)
{
return false;
}
}
public class NoErrorHubConnectionContext : HubConnectionContext
{
public NoErrorHubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) : base(connectionContext, keepAliveInterval, loggerFactory)
{
}
public override Task WriteAsync(HubMessage message)
{
if (message is CompletionMessage completionMessage)
{
if (!string.IsNullOrEmpty(completionMessage.Error))
{
throw new Exception("Error invoking hub method: " + completionMessage.Error);
}
}
return Task.CompletedTask;
}
}
public class TestHub : Hub
{
private static readonly IObservable<int> ObservableInstance = Observable.Empty<int>();
public void Invocation()
{
}
public Task InvocationAsync()
{
return Task.CompletedTask;
}
public int InvocationReturnValue()
{
return 1;
}
public Task<int> InvocationReturnAsync()
{
return Task.FromResult(1);
}
public ValueTask<int> InvocationValueTaskAsync()
{
return new ValueTask<int>(1);
}
public IObservable<int> StreamObservable()
{
return ObservableInstance;
}
public Task<IObservable<int>> StreamObservableAsync()
{
return Task.FromResult(ObservableInstance);
}
public ValueTask<IObservable<int>> StreamObservableValueTaskAsync()
{
return new ValueTask<IObservable<int>>(ObservableInstance);
}
public ChannelReader<int> StreamChannelReader()
{
var channel = Channel.CreateUnbounded<int>();
channel.Writer.Complete();
return channel;
}
public Task<ChannelReader<int>> StreamChannelReaderAsync()
{
var channel = Channel.CreateUnbounded<int>();
channel.Writer.Complete();
return Task.FromResult<ChannelReader<int>>(channel);
}
public ValueTask<ChannelReader<int>> StreamChannelReaderValueTaskAsync()
{
var channel = Channel.CreateUnbounded<int>();
channel.Writer.Complete();
return new ValueTask<ChannelReader<int>>(channel);
}
}
[Benchmark]
public Task Invocation()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "Invocation", null));
}
[Benchmark]
public Task InvocationAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationAsync", null));
}
[Benchmark]
public Task InvocationReturnValue()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnValue", null));
}
[Benchmark]
public Task InvocationReturnAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnAsync", null));
}
[Benchmark]
public Task InvocationValueTaskAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationValueTaskAsync", null));
}
[Benchmark]
public Task StreamObservable()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservable", null));
}
[Benchmark]
public Task StreamObservableAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableAsync", null));
}
[Benchmark]
public Task StreamObservableValueTaskAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableValueTaskAsync", null));
}
[Benchmark]
public Task StreamChannelReader()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReader", null));
}
[Benchmark]
public Task StreamChannelReaderAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderAsync", null));
}
[Benchmark]
public Task StreamChannelReaderValueTaskAsync()
{
return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderValueTaskAsync", null));
}
}
}

View File

@ -12,8 +12,11 @@
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
<PackageReference Include="BenchmarkDotNet" Version="$(BenchmarkDotNetPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.BenchmarkRunner.Sources" Version="$(MicrosoftAspNetCoreBenchmarkRunnerSourcesPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="$(MicrosoftExtensionsDependencyInjectionPackageVersion)" />
<PackageReference Include="Moq" Version="$(MoqPackageVersion)" />
<PackageReference Include="System.Threading.Channels" Version="$(SystemThreadingChannelsPackageVersion)" />
<PackageReference Include="System.Threading.Tasks.Extensions" Version="$(SystemThreadingTasksExtensionsPackageVersion)" />
<PackageReference Include="System.Reactive.Linq" Version="$(SystemReactiveLinqPackageVersion)" />
</ItemGroup>
</Project>

View File

@ -13,26 +13,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// True-internal because this is a weird and tricky class to use :)
internal static class AsyncEnumeratorAdapters
{
private static readonly MethodInfo _boxEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(BoxEnumerator)) && m.IsGenericMethod);
private static readonly MethodInfo _fromObservableMethod = typeof(AsyncEnumeratorAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(FromObservable)) && m.IsGenericMethod);
private static readonly MethodInfo _getAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(GetAsyncEnumerator)) && m.IsGenericMethod);
public static IAsyncEnumerator<object> FromObservable(object observable, Type observableInterface, CancellationToken cancellationToken)
{
// TODO: Cache expressions by observable.GetType()?
return (IAsyncEnumerator<object>)_fromObservableMethod
.MakeGenericMethod(observableInterface.GetGenericArguments())
.Invoke(null, new[] { observable, cancellationToken });
}
public static IAsyncEnumerator<object> FromObservable<T>(IObservable<T> observable, CancellationToken cancellationToken)
{
// TODO: Allow bounding and optimizations?
@ -46,33 +26,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return GetAsyncEnumerator(channel.Reader, cancellationToken);
}
public static IAsyncEnumerator<object> FromChannel(object readableChannelOfT, Type payloadType, CancellationToken cancellationToken)
{
var enumerator = _getAsyncEnumeratorMethod
.MakeGenericMethod(payloadType)
.Invoke(null, new object[] { readableChannelOfT, cancellationToken });
if (payloadType.IsValueType)
{
return (IAsyncEnumerator<object>)_boxEnumeratorMethod
.MakeGenericMethod(payloadType)
.Invoke(null, new[] { enumerator });
}
else
{
return (IAsyncEnumerator<object>)enumerator;
}
}
private static IAsyncEnumerator<object> BoxEnumerator<T>(IAsyncEnumerator<T> input) where T : struct
{
return new BoxingEnumerator<T>(input);
}
private class ChannelObserver<T> : IObserver<T>
{
private ChannelWriter<object> _output;
private CancellationToken _cancellationToken;
private readonly ChannelWriter<object> _output;
private readonly CancellationToken _cancellationToken;
public ChannelObserver(ChannelWriter<object> output, CancellationToken cancellationToken)
{
@ -116,33 +73,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private class BoxingEnumerator<T> : IAsyncEnumerator<object> where T : struct
{
private IAsyncEnumerator<T> _input;
public BoxingEnumerator(IAsyncEnumerator<T> input)
{
_input = input;
}
public object Current => _input.Current;
public Task<bool> MoveNextAsync() => _input.MoveNextAsync();
}
public static IAsyncEnumerator<T> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
public static IAsyncEnumerator<object> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
{
return new AsyncEnumerator<T>(channel, cancellationToken);
}
/// <summary>Provides an async enumerator for the data in a channel.</summary>
internal class AsyncEnumerator<T> : IAsyncEnumerator<T>
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>
{
/// <summary>The channel being enumerated.</summary>
private readonly ChannelReader<T> _channel;
/// <summary>Cancellation token used to cancel the enumeration.</summary>
private readonly CancellationToken _cancellationToken;
/// <summary>The current element of the enumeration.</summary>
private T _current;
private object _current;
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
{
@ -150,7 +94,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
_cancellationToken = cancellationToken;
}
public T Current => _current;
public object Current => _current;
public Task<bool> MoveNextAsync()
{

View File

@ -164,7 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return;
}
if (!await ValidateInvocationMode(methodExecutor, isStreamedInvocation, hubMethodInvocationMessage, connection))
if (!await ValidateInvocationMode(descriptor, isStreamedInvocation, hubMethodInvocationMessage, connection))
{
return;
}
@ -187,7 +187,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (isStreamedInvocation)
{
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, out var enumerator))
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
@ -267,8 +267,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
{
// ReadableChannel is awaitable but we don't want to await it.
if (methodExecutor.IsMethodAsync && !IsChannel(methodExecutor.MethodReturnType, out _))
if (methodExecutor.IsMethodAsync)
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
@ -305,21 +304,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
hub.Groups = _hubContext.Groups;
}
private static bool IsChannel(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
if (channelType == null)
{
payloadType = null;
return false;
}
else
{
payloadType = channelType.GetGenericArguments()[0];
return true;
}
}
private async Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
// If there are no policies we don't need to run auth
@ -340,11 +324,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return authorizationResult.Succeeded;
}
private async Task<bool> ValidateInvocationMode(ObjectMethodExecutor methodExecutor, bool isStreamedInvocation,
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
var isStreamedResult = IsStreamed(methodExecutor);
if (isStreamedResult && !isStreamedInvocation)
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@ -357,7 +340,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return false;
}
if (!isStreamedResult && isStreamedInvocation)
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
@ -369,51 +352,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private static bool IsStreamed(ObjectMethodExecutor methodExecutor)
{
var resultType = (methodExecutor.IsMethodAsync)
? methodExecutor.AsyncResultType
: methodExecutor.MethodReturnType;
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface != null)
{
return true;
}
if (IsChannel(resultType, out _))
{
return true;
}
return false;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, out IAsyncEnumerator<object> enumerator)
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator)
{
if (result != null)
{
var resultType = (methodExecutor.IsMethodAsync)
? methodExecutor.AsyncResultType
: methodExecutor.MethodReturnType;
// TODO: cache reflection for performance, on HubMethodDescriptor maybe?
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface != null)
if (hubMethodDescriptor.IsObservable)
{
enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, CreateCancellation());
enumerator = hubMethodDescriptor.FromObservable(result, CreateCancellation());
return true;
}
if (IsChannel(resultType, out var payloadType))
if (hubMethodDescriptor.IsChannel)
{
enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, CreateCancellation());
enumerator = hubMethodDescriptor.FromChannel(result, CreateCancellation());
return true;
}
}
@ -429,11 +380,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private static bool IsIObservable(Type iface)
{
return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>);
}
private void DiscoverHubMethods()
{
var hubType = typeof(THub);
@ -458,22 +404,5 @@ namespace Microsoft.AspNetCore.SignalR.Internal
Log.HubMethodBound(_logger, hubName, methodName);
}
}
// REVIEW: We can decide to move this out of here if we want pluggable hub discovery
private class HubMethodDescriptor
{
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
Policies = policies.ToArray();
}
public ObjectMethodExecutor MethodExecutor { get; }
public IReadOnlyList<Type> ParameterTypes { get; }
public IList<IAuthorizeData> Policies { get; }
}
}
}

View File

@ -0,0 +1,153 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading;
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class HubMethodDescriptor
{
private static readonly MethodInfo FromObservableMethod = typeof(AsyncEnumeratorAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.FromObservable)) && m.IsGenericMethod);
private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters)
.GetRuntimeMethods()
.Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod);
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
Policies = policies.ToArray();
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
? MethodExecutor.AsyncResultType
: MethodExecutor.MethodReturnType;
if (IsObservableType(NonAsyncReturnType, out var observableItemType))
{
IsObservable = true;
StreamReturnType = observableItemType;
}
else if (IsChannelType(NonAsyncReturnType, out var channelItemType))
{
IsChannel = true;
StreamReturnType = channelItemType;
}
}
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
public ObjectMethodExecutor MethodExecutor { get; }
public IReadOnlyList<Type> ParameterTypes { get; }
public Type NonAsyncReturnType { get; }
public bool IsObservable { get; }
public bool IsChannel { get; }
public bool IsStreamable => IsObservable || IsChannel;
public Type StreamReturnType { get; }
public IList<IAuthorizeData> Policies { get; }
private static bool IsChannelType(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));
if (channelType == null)
{
payloadType = null;
return false;
}
payloadType = channelType.GetGenericArguments()[0];
return true;
}
private static bool IsObservableType(Type type, out Type payloadType)
{
var observableInterface = IsIObservable(type) ? type : type.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface == null)
{
payloadType = null;
return false;
}
payloadType = observableInterface.GetGenericArguments()[0];
return true;
bool IsIObservable(Type iface)
{
return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>);
}
}
public IAsyncEnumerator<object> FromObservable(object observable, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
if (_convertToEnumerator == null)
{
_convertToEnumerator = CompileConvertToEnumerator(FromObservableMethod, StreamReturnType);
}
return _convertToEnumerator.Invoke(observable, cancellationToken);
}
public IAsyncEnumerator<object> FromChannel(object channel, CancellationToken cancellationToken)
{
// there is the potential for compile to be called times but this has no harmful effect other than perf
if (_convertToEnumerator == null)
{
_convertToEnumerator = CompileConvertToEnumerator(GetAsyncEnumeratorMethod, StreamReturnType);
}
return _convertToEnumerator.Invoke(channel, cancellationToken);
}
private static Func<object, CancellationToken, IAsyncEnumerator<object>> CompileConvertToEnumerator(MethodInfo adapterMethodInfo, Type streamReturnType)
{
// This will call one of two adapter methods to wrap the passed in streamable value
// and cancellation token to an IAsyncEnumerator<object>
//
// IObservable<T>:
// AsyncEnumeratorAdapters.FromObservable<T>(observable, cancellationToken);
//
// ChannelReader<T>
// AsyncEnumeratorAdapters.GetAsyncEnumerator<T>(channelReader, cancellationToken);
var genericMethodInfo = adapterMethodInfo.MakeGenericMethod(streamReturnType);
var methodParameters = genericMethodInfo.GetParameters();
// arg1 and arg2 are the parameter names on Func<T1, T2, TReturn>
// we reference these values and then use them to call adaptor method
var targetParameter = Expression.Parameter(typeof(object), "arg1");
var parametersParameter = Expression.Parameter(typeof(CancellationToken), "arg2");
var parameters = new List<Expression>
{
Expression.Convert(targetParameter, methodParameters[0].ParameterType),
parametersParameter
};
var methodCall = Expression.Call(null, genericMethodInfo, parameters);
var castMethodCall = Expression.Convert(methodCall, typeof(IAsyncEnumerator<object>));
var lambda = Expression.Lambda<Func<object, CancellationToken, IAsyncEnumerator<object>>>(castMethodCall, targetParameter, parametersParameter);
return lambda.Compile();
}
}
}