diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index ae4a357290..c447d882ab 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -29,9 +29,7 @@ namespace Microsoft.AspNetCore.SignalR public class HubEndPoint : EndPoint, IInvocationBinder where THub : Hub { - private readonly Dictionary>> _callbacks - = new Dictionary>>(StringComparer.OrdinalIgnoreCase); - private readonly Dictionary _paramTypes = new Dictionary(); + private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; @@ -213,10 +211,10 @@ namespace Microsoft.AspNetCore.SignalR private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor) { InvocationResultDescriptor result; - Func> callback; - if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback)) + HubMethodDescriptor descriptor; + if (_methods.TryGetValue(invocationDescriptor.Method, out descriptor)) { - result = await callback(connection, invocationDescriptor); + result = await Invoke(descriptor, connection, invocationDescriptor); } else { @@ -246,6 +244,59 @@ namespace Microsoft.AspNetCore.SignalR } } + private async Task Invoke(HubMethodDescriptor descriptor, Connection connection, InvocationDescriptor invocationDescriptor) + { + var invocationResult = new InvocationResultDescriptor + { + Id = invocationDescriptor.Id + }; + + var methodInfo = descriptor.MethodInfo; + + using (var scope = _serviceScopeFactory.CreateScope()) + { + var hubActivator = scope.ServiceProvider.GetRequiredService>(); + var hub = hubActivator.Create(); + + try + { + InitializeHub(hub, connection); + + var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments); + var resultTask = result as Task; + if (resultTask != null) + { + await resultTask; + if (methodInfo.ReturnType.GetTypeInfo().IsGenericType) + { + var property = resultTask.GetType().GetProperty("Result"); + invocationResult.Result = property?.GetValue(resultTask); + } + } + else + { + invocationResult.Result = result; + } + } + catch (TargetInvocationException ex) + { + _logger.LogError(0, ex, "Failed to invoke hub method"); + invocationResult.Error = ex.InnerException.Message; + } + catch (Exception ex) + { + _logger.LogError(0, ex, "Failed to invoke hub method"); + invocationResult.Error = ex.Message; + } + finally + { + hubActivator.Release(hub); + } + } + + return invocationResult; + } + private void InitializeHub(THub hub, Connection connection) { hub.Clients = _hubContext.Clients; @@ -261,69 +312,17 @@ namespace Microsoft.AspNetCore.SignalR { var methodName = methodInfo.Name; - if (_callbacks.ContainsKey(methodName)) + if (_methods.ContainsKey(methodName)) { throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported."); } - var parameters = methodInfo.GetParameters(); - _paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray(); + _methods[methodName] = new HubMethodDescriptor(methodInfo); if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Hub method '{methodName}' is bound", methodName); } - - _callbacks[methodName] = async (connection, invocationDescriptor) => - { - var invocationResult = new InvocationResultDescriptor() - { - Id = invocationDescriptor.Id - }; - - using (var scope = _serviceScopeFactory.CreateScope()) - { - var hubActivator = scope.ServiceProvider.GetRequiredService>(); - var hub = hubActivator.Create(); - - try - { - InitializeHub(hub, connection); - - var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments); - var resultTask = result as Task; - if (resultTask != null) - { - await resultTask; - if (methodInfo.ReturnType.GetTypeInfo().IsGenericType) - { - var property = resultTask.GetType().GetProperty("Result"); - invocationResult.Result = property?.GetValue(resultTask); - } - } - else - { - invocationResult.Result = result; - } - } - catch (TargetInvocationException ex) - { - _logger.LogError(0, ex, "Failed to invoke hub method"); - invocationResult.Error = ex.InnerException.Message; - } - catch (Exception ex) - { - _logger.LogError(0, ex, "Failed to invoke hub method"); - invocationResult.Error = ex.Message; - } - finally - { - hubActivator.Release(hub); - } - } - - return invocationResult; - }; }; } @@ -352,12 +351,26 @@ namespace Microsoft.AspNetCore.SignalR Type[] IInvocationBinder.GetParameterTypes(string methodName) { - Type[] types; - if (!_paramTypes.TryGetValue(methodName, out types)) + HubMethodDescriptor descriptor; + if (!_methods.TryGetValue(methodName, out descriptor)) { return Type.EmptyTypes; } - return types; + return descriptor.ParameterTypes; + } + + // REVIEW: We can decide to move this out of here if we want pluggable hub discovery + private class HubMethodDescriptor + { + public HubMethodDescriptor(MethodInfo methodInfo) + { + MethodInfo = methodInfo; + ParameterTypes = methodInfo.GetParameters().Select(p => p.ParameterType).ToArray(); + } + + public MethodInfo MethodInfo { get; } + + public Type[] ParameterTypes { get; } } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 71c798b038..cde8d074c0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -83,6 +83,27 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task MethodsAreCaseInsensitive() + { + var loggerFactory = CreateLogger(); + const string originalMessage = "SignalR"; + + using (var httpClient = _testServer.CreateClient()) + { + var transport = new LongPollingTransport(httpClient, loggerFactory); + using (var connection = await HubConnection.ConnectAsync(new Uri("http://test/hubs"), + new JsonNetInvocationAdapter(), transport, httpClient, loggerFactory)) + { + EnsureConnectionEstablished(connection); + + var result = await connection.Invoke("echo", originalMessage); + + Assert.Equal(originalMessage, result); + } + } + } + [Fact] public async Task CanInvokeClientMethodFromServer() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 78bc3089dd..3722e6f894 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -149,6 +149,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task HubMethodsAreCaseInsensitive() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var connectionWrapper = new ConnectionWrapper()) + { + var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); + + var invocationAdapter = serviceProvider.GetService(); + var adapter = invocationAdapter.GetInvocationAdapter("json"); + + await SendRequest(connectionWrapper, adapter, "echo", "hello"); + var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); + + Assert.Null(result.Error); + Assert.Equal("hello", result.Result); + + // kill the connection + connectionWrapper.Connection.Dispose(); + + await endPointTask.OrTimeout(); + } + } + [Fact] public async Task HubMethodCanReturnValue() { @@ -163,7 +190,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper, adapter, "ValueMethod"); + await SendRequest(connectionWrapper, adapter, nameof(MethodHub.ValueMethod)); var result = await ReadConnectionOutputAsync(connectionWrapper).OrTimeout(); // json serializer makes this a long @@ -572,6 +599,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests return 43; } + public string Echo(string data) + { + return data; + } + static public string StaticMethod() { return "fromStatic";