Don't over-discover Hub methods (#511)

This commit is contained in:
BrennanConroy 2017-06-12 14:26:33 -07:00 committed by GitHub
parent 8277b2cc27
commit 3cabb6aeb1
4 changed files with 216 additions and 23 deletions

View File

@ -2,18 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
@ -386,7 +383,9 @@ namespace Microsoft.AspNetCore.SignalR
private void DiscoverHubMethods()
{
var hubType = typeof(THub);
foreach (var methodInfo in hubType.GetMethods().Where(m => IsHubMethod(m)))
var hubTypeInfo = hubType.GetTypeInfo();
foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
{
var methodName = methodInfo.Name;
@ -395,7 +394,7 @@ namespace Microsoft.AspNetCore.SignalR
throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported.");
}
var executor = ObjectMethodExecutor.Create(methodInfo, hubType.GetTypeInfo());
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
_methods[methodName] = new HubMethodDescriptor(executor);
if (_logger.IsEnabled(LogLevel.Debug))
@ -405,24 +404,6 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private static bool IsHubMethod(MethodInfo methodInfo)
{
// TODO: Add more checks
if (!methodInfo.IsPublic || methodInfo.IsSpecialName)
{
return false;
}
var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType;
var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition;
if (typeof(Hub<>) == baseType)
{
return false;
}
return true;
}
Type IInvocationBinder.GetReturnType(string invocationId)
{
return typeof(object);

View File

@ -0,0 +1,46 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public static class HubReflectionHelper
{
private static readonly Type[] _excludeInterfaces = new[] { typeof(Hub<>), typeof(IDisposable) };
public static IEnumerable<MethodInfo> GetHubMethods(Type hubType)
{
var methods = hubType.GetMethods(BindingFlags.Public | BindingFlags.Instance);
var allInterfaceMethods = _excludeInterfaces.SelectMany(i => GetInterfaceMethods(hubType, i));
return methods.Except(allInterfaceMethods).Where(m => IsHubMethod(m));
}
private static IEnumerable<MethodInfo> GetInterfaceMethods(Type type, Type iface)
{
if (!iface.IsAssignableFrom(type))
{
return Enumerable.Empty<MethodInfo>();
}
return type.GetInterfaceMap(iface).TargetMethods;
}
private static bool IsHubMethod(MethodInfo methodInfo)
{
var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType;
if (typeof(object) == baseDefinition || methodInfo.IsSpecialName)
{
return false;
}
// removes methods such as Hub<TClient>.OnConnectedAsync
var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition;
return typeof(Hub<>) != baseType;
}
}
}

View File

@ -336,6 +336,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task CannotCallStaticHubMethods()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.InvokeAsync(nameof(MethodHub.StaticMethod)).OrTimeout();
Assert.Equal("Unknown hub method 'StaticMethod'", result.Error);
// kill the connection
client.Dispose();
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task CannotCallObjectMethodsOnHub()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.InvokeAsync(nameof(MethodHub.ToString)).OrTimeout();
Assert.Equal("Unknown hub method 'ToString'", result.Error);
result = await client.InvokeAsync(nameof(MethodHub.GetHashCode)).OrTimeout();
Assert.Equal("Unknown hub method 'GetHashCode'", result.Error);
result = await client.InvokeAsync(nameof(MethodHub.Equals)).OrTimeout();
Assert.Equal("Unknown hub method 'Equals'", result.Error);
result = await client.InvokeAsync(nameof(MethodHub.ReferenceEquals)).OrTimeout();
Assert.Equal("Unknown hub method 'ReferenceEquals'", result.Error);
// kill the connection
client.Dispose();
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task CannotCallDisposeMethodOnHub()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var result = await client.InvokeAsync(nameof(MethodHub.Dispose)).OrTimeout();
Assert.Equal("Unknown hub method 'Dispose'", result.Error);
// kill the connection
client.Dispose();
await endPointTask.OrTimeout();
}
}
[Fact]
public async Task BroadcastHubMethod_SendsToAllClients()
{
@ -715,6 +789,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
return Task.FromException(new InvalidOperationException("BOOM!"));
}
public static void StaticMethod()
{
}
}
private class InheritedHub : BaseHub

View File

@ -0,0 +1,88 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Linq;
using Microsoft.AspNetCore.SignalR.Internal;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class HubReflectionHelperTests
{
[Fact]
public void EmptyHubHasNoHubMethods()
{
var hubMethods = HubReflectionHelper.GetHubMethods(typeof(EmptyHub));
Assert.Empty(hubMethods);
}
[Fact]
public void HubWithMethodsHasHubMethods()
{
var hubType = typeof(BaseMethodHub);
var hubMethods = HubReflectionHelper.GetHubMethods(hubType);
Assert.Equal(3, hubMethods.Count());
Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod"));
Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod"));
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod"));
}
[Fact]
public void InheritedHubHasBaseHubMethodsAndOwnMethods()
{
var hubType = typeof(InheritedMethodHub);
var hubMethods = HubReflectionHelper.GetHubMethods(hubType);
Assert.Equal(4, hubMethods.Count());
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ExtraMethod"));
Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod"));
Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod"));
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod"));
}
private class EmptyHub : Hub
{
}
private class BaseMethodHub : Hub
{
public void VoidMethod()
{
}
public int IntMethod()
{
return 0;
}
public void ArgMethod(string str)
{
}
// static is not supported as a Hub method
public static void StaticMethod()
{
}
// internal is not a Hub method
internal void InternalMethod()
{
}
// private is not a Hub method
private void PrivateMethod()
{
}
}
private class InheritedMethodHub : BaseMethodHub
{
public int ExtraMethod(bool b)
{
return 2;
}
}
}
}