From 3cabb6aeb11d0c80617c31c4eea909ce80d823b9 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Mon, 12 Jun 2017 14:26:33 -0700 Subject: [PATCH] Don't over-discover Hub methods (#511) --- .../HubEndPoint.cs | 27 +----- .../Internal/HubReflectionHelper.cs | 46 ++++++++++ .../HubEndpointTests.cs | 78 ++++++++++++++++ .../HubReflectionHelperTests.cs | 88 +++++++++++++++++++ 4 files changed, 216 insertions(+), 23 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR/Internal/HubReflectionHelper.cs create mode 100644 test/Microsoft.AspNetCore.SignalR.Tests/HubReflectionHelperTests.cs diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 15a3942b89..139920dd05 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -2,18 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Buffers; using System.Collections.Generic; using System.Linq; using System.Reflection; -using System.Text; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; @@ -386,7 +383,9 @@ namespace Microsoft.AspNetCore.SignalR private void DiscoverHubMethods() { var hubType = typeof(THub); - foreach (var methodInfo in hubType.GetMethods().Where(m => IsHubMethod(m))) + var hubTypeInfo = hubType.GetTypeInfo(); + + foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType)) { var methodName = methodInfo.Name; @@ -395,7 +394,7 @@ namespace Microsoft.AspNetCore.SignalR throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported."); } - var executor = ObjectMethodExecutor.Create(methodInfo, hubType.GetTypeInfo()); + var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo); _methods[methodName] = new HubMethodDescriptor(executor); if (_logger.IsEnabled(LogLevel.Debug)) @@ -405,24 +404,6 @@ namespace Microsoft.AspNetCore.SignalR } } - private static bool IsHubMethod(MethodInfo methodInfo) - { - // TODO: Add more checks - if (!methodInfo.IsPublic || methodInfo.IsSpecialName) - { - return false; - } - - var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType; - var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition; - if (typeof(Hub<>) == baseType) - { - return false; - } - - return true; - } - Type IInvocationBinder.GetReturnType(string invocationId) { return typeof(object); diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/HubReflectionHelper.cs b/src/Microsoft.AspNetCore.SignalR/Internal/HubReflectionHelper.cs new file mode 100644 index 0000000000..5ed3f98829 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Internal/HubReflectionHelper.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Reflection; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + public static class HubReflectionHelper + { + private static readonly Type[] _excludeInterfaces = new[] { typeof(Hub<>), typeof(IDisposable) }; + + public static IEnumerable GetHubMethods(Type hubType) + { + var methods = hubType.GetMethods(BindingFlags.Public | BindingFlags.Instance); + var allInterfaceMethods = _excludeInterfaces.SelectMany(i => GetInterfaceMethods(hubType, i)); + + return methods.Except(allInterfaceMethods).Where(m => IsHubMethod(m)); + } + + private static IEnumerable GetInterfaceMethods(Type type, Type iface) + { + if (!iface.IsAssignableFrom(type)) + { + return Enumerable.Empty(); + } + + return type.GetInterfaceMap(iface).TargetMethods; + } + + private static bool IsHubMethod(MethodInfo methodInfo) + { + var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType; + if (typeof(object) == baseDefinition || methodInfo.IsSpecialName) + { + return false; + } + + // removes methods such as Hub.OnConnectedAsync + var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition; + return typeof(Hub<>) != baseType; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index b51bd2ba47..c7a43e1030 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -336,6 +336,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task CannotCallStaticHubMethods() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + var result = await client.InvokeAsync(nameof(MethodHub.StaticMethod)).OrTimeout(); + + Assert.Equal("Unknown hub method 'StaticMethod'", result.Error); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + + [Fact] + public async Task CannotCallObjectMethodsOnHub() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + var result = await client.InvokeAsync(nameof(MethodHub.ToString)).OrTimeout(); + Assert.Equal("Unknown hub method 'ToString'", result.Error); + + result = await client.InvokeAsync(nameof(MethodHub.GetHashCode)).OrTimeout(); + Assert.Equal("Unknown hub method 'GetHashCode'", result.Error); + + result = await client.InvokeAsync(nameof(MethodHub.Equals)).OrTimeout(); + Assert.Equal("Unknown hub method 'Equals'", result.Error); + + result = await client.InvokeAsync(nameof(MethodHub.ReferenceEquals)).OrTimeout(); + Assert.Equal("Unknown hub method 'ReferenceEquals'", result.Error); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + + [Fact] + public async Task CannotCallDisposeMethodOnHub() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + var result = await client.InvokeAsync(nameof(MethodHub.Dispose)).OrTimeout(); + + Assert.Equal("Unknown hub method 'Dispose'", result.Error); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + [Fact] public async Task BroadcastHubMethod_SendsToAllClients() { @@ -715,6 +789,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests { return Task.FromException(new InvalidOperationException("BOOM!")); } + + public static void StaticMethod() + { + } } private class InheritedHub : BaseHub diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubReflectionHelperTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubReflectionHelperTests.cs new file mode 100644 index 0000000000..9ac1c4aacd --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubReflectionHelperTests.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Linq; +using Microsoft.AspNetCore.SignalR.Internal; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class HubReflectionHelperTests + { + [Fact] + public void EmptyHubHasNoHubMethods() + { + var hubMethods = HubReflectionHelper.GetHubMethods(typeof(EmptyHub)); + + Assert.Empty(hubMethods); + } + + [Fact] + public void HubWithMethodsHasHubMethods() + { + var hubType = typeof(BaseMethodHub); + var hubMethods = HubReflectionHelper.GetHubMethods(hubType); + + Assert.Equal(3, hubMethods.Count()); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod")); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod")); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod")); + } + + [Fact] + public void InheritedHubHasBaseHubMethodsAndOwnMethods() + { + var hubType = typeof(InheritedMethodHub); + var hubMethods = HubReflectionHelper.GetHubMethods(hubType); + + Assert.Equal(4, hubMethods.Count()); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("ExtraMethod")); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod")); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod")); + Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod")); + } + + private class EmptyHub : Hub + { + } + + private class BaseMethodHub : Hub + { + public void VoidMethod() + { + } + + public int IntMethod() + { + return 0; + } + + public void ArgMethod(string str) + { + } + + // static is not supported as a Hub method + public static void StaticMethod() + { + } + + // internal is not a Hub method + internal void InternalMethod() + { + } + + // private is not a Hub method + private void PrivateMethod() + { + } + } + + private class InheritedMethodHub : BaseMethodHub + { + public int ExtraMethod(bool b) + { + return 2; + } + } + } +}