Fixed parameter count mismatch when invoking methods with wrong case (#162)

* Fixed parameter count mismatch when invoking methods with wrong case
- Hub methods were being tracked with 2 dictionaries, one for parameter names
the other for callbacks. This change introduces a single dictionary that stores
the hub name to a HubMethodDescriptor. That descriptor stores the parameter types
and method info for the bound hub method.
- The callback is now just an invoke method on the HubEndPoint itself.
- Added tests for case sensitivity in hub method names
This commit is contained in:
David Fowler 2017-01-26 18:25:49 +00:00 committed by GitHub
parent c997ea8165
commit eafbe74160
3 changed files with 130 additions and 64 deletions

View File

@ -29,9 +29,7 @@ namespace Microsoft.AspNetCore.SignalR
public class HubEndPoint<THub, TClient> : EndPoint, IInvocationBinder where THub : Hub<TClient>
{
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly Dictionary<string, HubMethodDescriptor> _methods = new Dictionary<string, HubMethodDescriptor>(StringComparer.OrdinalIgnoreCase);
private readonly HubLifetimeManager<THub> _lifetimeManager;
private readonly IHubContext<THub, TClient> _hubContext;
@ -213,10 +211,10 @@ namespace Microsoft.AspNetCore.SignalR
private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor)
{
InvocationResultDescriptor result;
Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback))
HubMethodDescriptor descriptor;
if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor))
{
result = await callback(connection, invocationDescriptor);
result = await Invoke(descriptor, connection, invocationDescriptor);
}
else
{
@ -246,6 +244,59 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task<InvocationResultDescriptor> Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor)
{
var invocationResult = new InvocationResultDescriptor
{
Id = invocationDescriptor.Id
};
var methodInfo = descriptor.MethodInfo;
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub, TClient>>();
var hub = hubActivator.Create();
try
{
InitializeHub(hub, connection);
var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments);
var resultTask = result as Task;
if (resultTask != null)
{
await resultTask;
if (methodInfo.ReturnType.GetTypeInfo().IsGenericType)
{
var property = resultTask.GetType().GetProperty("Result");
invocationResult.Result = property?.GetValue(resultTask);
}
}
else
{
invocationResult.Result = result;
}
}
catch (TargetInvocationException ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.InnerException.Message;
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.Message;
}
finally
{
hubActivator.Release(hub);
}
}
return invocationResult;
}
private void InitializeHub(THub hub, Connection connection)
{
hub.Clients = _hubContext.Clients;
@ -261,69 +312,17 @@ namespace Microsoft.AspNetCore.SignalR
{
var methodName = methodInfo.Name;
if (_callbacks.ContainsKey(methodName))
if (_methods.ContainsKey(methodName))
{
throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported.");
}
var parameters = methodInfo.GetParameters();
_paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray();
_methods[methodName] = new HubMethodDescriptor(methodInfo);
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Hub method '{methodName}' is bound", methodName);
}
_callbacks[methodName] = async (connection, invocationDescriptor) =>
{
var invocationResult = new InvocationResultDescriptor()
{
Id = invocationDescriptor.Id
};
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub, TClient>>();
var hub = hubActivator.Create();
try
{
InitializeHub(hub, connection);
var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments);
var resultTask = result as Task;
if (resultTask != null)
{
await resultTask;
if (methodInfo.ReturnType.GetTypeInfo().IsGenericType)
{
var property = resultTask.GetType().GetProperty("Result");
invocationResult.Result = property?.GetValue(resultTask);
}
}
else
{
invocationResult.Result = result;
}
}
catch (TargetInvocationException ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.InnerException.Message;
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to invoke hub method");
invocationResult.Error = ex.Message;
}
finally
{
hubActivator.Release(hub);
}
}
return invocationResult;
};
};
}
@ -352,12 +351,26 @@ namespace Microsoft.AspNetCore.SignalR
Type[] IInvocationBinder.GetParameterTypes(string methodName)
{
Type[] types;
if (!_paramTypes.TryGetValue(methodName, out types))
HubMethodDescriptor descriptor;
if (!_methods.TryGetValue(methodName, out descriptor))
{
return Type.EmptyTypes;
}
return types;
return descriptor.ParameterTypes;
}
// REVIEW: We can decide to move this out of here if we want pluggable hub discovery
private class HubMethodDescriptor
{
public HubMethodDescriptor(MethodInfo methodInfo)
{
MethodInfo = methodInfo;
ParameterTypes = methodInfo.GetParameters().Select(p => p.ParameterType).ToArray();
}
public MethodInfo MethodInfo { get; }
public Type[] ParameterTypes { get; }
}
}
}

View File

@ -83,6 +83,27 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Fact]
public async Task MethodsAreCaseInsensitive()
{
var loggerFactory = CreateLogger();
const string originalMessage = "SignalR";
using (var httpClient = _testServer.CreateClient())
{
var transport = new LongPollingTransport(httpClient, loggerFactory);
using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"),
new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory))
{
EnsureConnectionEstablished(connection);
var result = await connection.Invoke<string>("echo", originalMessage);
Assert.Equal(originalMessage, result);
}
}
}
[Fact]
public async Task CanInvokeClientMethodFromServer()
{

View File

@ -149,6 +149,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task HubMethodsAreCaseInsensitive()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var connectionWrapper = new ConnectionWrapper())
{
var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection);
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper, adapter, "echo", "hello");
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper).OrTimeout();
Assert.Null(result.Error);
Assert.Equal("hello", result.Result);
// kill the connection
connectionWrapper.Connection.Dispose();
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task HubMethodCanReturnValue()
{
@ -163,7 +190,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>();
var adapter = invocationAdapter.GetInvocationAdapter("json");
await SendRequest(connectionWrapper, adapter, "ValueMethod");
await SendRequest(connectionWrapper, adapter, nameof(MethodHub.ValueMethod));
var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper).OrTimeout();
// json serializer makes this a long
@ -572,6 +599,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return 43;
}
public string Echo(string data)
{
return data;
}
static public string StaticMethod()
{
return "fromStatic";