Don't over-discover Hub methods (#511)
This commit is contained in:
parent
8277b2cc27
commit
3cabb6aeb1
|
|
@ -2,18 +2,15 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Microsoft.Extensions.Logging;
|
||||
|
|
@ -386,7 +383,9 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private void DiscoverHubMethods()
|
||||
{
|
||||
var hubType = typeof(THub);
|
||||
foreach (var methodInfo in hubType.GetMethods().Where(m => IsHubMethod(m)))
|
||||
var hubTypeInfo = hubType.GetTypeInfo();
|
||||
|
||||
foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
|
||||
{
|
||||
var methodName = methodInfo.Name;
|
||||
|
||||
|
|
@ -395,7 +394,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported.");
|
||||
}
|
||||
|
||||
var executor = ObjectMethodExecutor.Create(methodInfo, hubType.GetTypeInfo());
|
||||
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
|
||||
_methods[methodName] = new HubMethodDescriptor(executor);
|
||||
|
||||
if (_logger.IsEnabled(LogLevel.Debug))
|
||||
|
|
@ -405,24 +404,6 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private static bool IsHubMethod(MethodInfo methodInfo)
|
||||
{
|
||||
// TODO: Add more checks
|
||||
if (!methodInfo.IsPublic || methodInfo.IsSpecialName)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType;
|
||||
var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition;
|
||||
if (typeof(Hub<>) == baseType)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Type IInvocationBinder.GetReturnType(string invocationId)
|
||||
{
|
||||
return typeof(object);
|
||||
|
|
|
|||
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||
{
|
||||
public static class HubReflectionHelper
|
||||
{
|
||||
private static readonly Type[] _excludeInterfaces = new[] { typeof(Hub<>), typeof(IDisposable) };
|
||||
|
||||
public static IEnumerable<MethodInfo> GetHubMethods(Type hubType)
|
||||
{
|
||||
var methods = hubType.GetMethods(BindingFlags.Public | BindingFlags.Instance);
|
||||
var allInterfaceMethods = _excludeInterfaces.SelectMany(i => GetInterfaceMethods(hubType, i));
|
||||
|
||||
return methods.Except(allInterfaceMethods).Where(m => IsHubMethod(m));
|
||||
}
|
||||
|
||||
private static IEnumerable<MethodInfo> GetInterfaceMethods(Type type, Type iface)
|
||||
{
|
||||
if (!iface.IsAssignableFrom(type))
|
||||
{
|
||||
return Enumerable.Empty<MethodInfo>();
|
||||
}
|
||||
|
||||
return type.GetInterfaceMap(iface).TargetMethods;
|
||||
}
|
||||
|
||||
private static bool IsHubMethod(MethodInfo methodInfo)
|
||||
{
|
||||
var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType;
|
||||
if (typeof(object) == baseDefinition || methodInfo.IsSpecialName)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// removes methods such as Hub<TClient>.OnConnectedAsync
|
||||
var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition;
|
||||
return typeof(Hub<>) != baseType;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -336,6 +336,80 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CannotCallStaticHubMethods()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.InvokeAsync(nameof(MethodHub.StaticMethod)).OrTimeout();
|
||||
|
||||
Assert.Equal("Unknown hub method 'StaticMethod'", result.Error);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CannotCallObjectMethodsOnHub()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.InvokeAsync(nameof(MethodHub.ToString)).OrTimeout();
|
||||
Assert.Equal("Unknown hub method 'ToString'", result.Error);
|
||||
|
||||
result = await client.InvokeAsync(nameof(MethodHub.GetHashCode)).OrTimeout();
|
||||
Assert.Equal("Unknown hub method 'GetHashCode'", result.Error);
|
||||
|
||||
result = await client.InvokeAsync(nameof(MethodHub.Equals)).OrTimeout();
|
||||
Assert.Equal("Unknown hub method 'Equals'", result.Error);
|
||||
|
||||
result = await client.InvokeAsync(nameof(MethodHub.ReferenceEquals)).OrTimeout();
|
||||
Assert.Equal("Unknown hub method 'ReferenceEquals'", result.Error);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CannotCallDisposeMethodOnHub()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
var result = await client.InvokeAsync(nameof(MethodHub.Dispose)).OrTimeout();
|
||||
|
||||
Assert.Equal("Unknown hub method 'Dispose'", result.Error);
|
||||
|
||||
// kill the connection
|
||||
client.Dispose();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task BroadcastHubMethod_SendsToAllClients()
|
||||
{
|
||||
|
|
@ -715,6 +789,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
return Task.FromException(new InvalidOperationException("BOOM!"));
|
||||
}
|
||||
|
||||
public static void StaticMethod()
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
private class InheritedHub : BaseHub
|
||||
|
|
|
|||
|
|
@ -0,0 +1,88 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Linq;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public class HubReflectionHelperTests
|
||||
{
|
||||
[Fact]
|
||||
public void EmptyHubHasNoHubMethods()
|
||||
{
|
||||
var hubMethods = HubReflectionHelper.GetHubMethods(typeof(EmptyHub));
|
||||
|
||||
Assert.Empty(hubMethods);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void HubWithMethodsHasHubMethods()
|
||||
{
|
||||
var hubType = typeof(BaseMethodHub);
|
||||
var hubMethods = HubReflectionHelper.GetHubMethods(hubType);
|
||||
|
||||
Assert.Equal(3, hubMethods.Count());
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod"));
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod"));
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod"));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void InheritedHubHasBaseHubMethodsAndOwnMethods()
|
||||
{
|
||||
var hubType = typeof(InheritedMethodHub);
|
||||
var hubMethods = HubReflectionHelper.GetHubMethods(hubType);
|
||||
|
||||
Assert.Equal(4, hubMethods.Count());
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ExtraMethod"));
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("VoidMethod"));
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("IntMethod"));
|
||||
Assert.Contains(hubMethods, m => m == hubType.GetMethod("ArgMethod"));
|
||||
}
|
||||
|
||||
private class EmptyHub : Hub
|
||||
{
|
||||
}
|
||||
|
||||
private class BaseMethodHub : Hub
|
||||
{
|
||||
public void VoidMethod()
|
||||
{
|
||||
}
|
||||
|
||||
public int IntMethod()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
public void ArgMethod(string str)
|
||||
{
|
||||
}
|
||||
|
||||
// static is not supported as a Hub method
|
||||
public static void StaticMethod()
|
||||
{
|
||||
}
|
||||
|
||||
// internal is not a Hub method
|
||||
internal void InternalMethod()
|
||||
{
|
||||
}
|
||||
|
||||
// private is not a Hub method
|
||||
private void PrivateMethod()
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
private class InheritedMethodHub : BaseMethodHub
|
||||
{
|
||||
public int ExtraMethod(bool b)
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue