Hub filters! (#21278)

This commit is contained in:
Brennan 2020-05-19 22:05:34 -07:00 committed by GitHub
parent bad6e32e7e
commit 2ad8121efb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 1544 additions and 49 deletions

View File

@ -38,7 +38,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
serviceScopeFactory,
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>(NullLogger<DefaultHubLifetimeManager<TestHub>>.Instance)),
enableDetailedErrors: false,
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance));
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance),
hubFilters: null);
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport);

View File

@ -1,16 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using System.Text.Json;
using System.Text.Json.Serialization;
using SignalRSamples.ConnectionHandlers;
using SignalRSamples.Hubs;
@ -18,7 +14,6 @@ namespace SignalRSamples
{
public class Startup
{
private readonly JsonWriterOptions _jsonWriterOptions = new JsonWriterOptions { Indented = true };
// This method gets called by the runtime. Use this method to add services to the container.
@ -27,11 +22,7 @@ namespace SignalRSamples
{
services.AddConnections();
services.AddSignalR(options =>
{
// Faster pings for testing
options.KeepAliveInterval = TimeSpan.FromSeconds(5);
})
services.AddSignalR()
.AddMessagePackProtocol();
//.AddStackExchangeRedis();
}

View File

@ -181,12 +181,23 @@ namespace Microsoft.AspNetCore.SignalR
}
public partial class HubInvocationContext
{
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub, System.Reflection.MethodInfo hubMethod, System.Collections.Generic.IReadOnlyList<object> hubMethodArguments) { }
[System.ObsoleteAttribute("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")]
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.Type hubType, string hubMethodName, object[] hubMethodArguments) { }
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.Reflection.MethodInfo HubMethod { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
[System.ObsoleteAttribute("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")]
public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.Type HubType { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
}
public sealed partial class HubLifetimeContext
{
public HubLifetimeContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub) { }
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
}
public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
{
@ -227,6 +238,12 @@ namespace Microsoft.AspNetCore.SignalR
public int? StreamBufferCapacity { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
}
public static partial class HubOptionsExtensions
{
public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, Microsoft.AspNetCore.SignalR.IHubFilter hubFilter) { }
public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, System.Type filterType) { }
public static void AddFilter<TFilter>(this Microsoft.AspNetCore.SignalR.HubOptions options) where TFilter : Microsoft.AspNetCore.SignalR.IHubFilter { }
}
public partial class HubOptionsSetup : Microsoft.Extensions.Options.IConfigureOptions<Microsoft.AspNetCore.SignalR.HubOptions>
{
public HubOptionsSetup(System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> protocols) { }
@ -294,6 +311,12 @@ namespace Microsoft.AspNetCore.SignalR
Microsoft.AspNetCore.SignalR.IHubClients<T> Clients { get; }
Microsoft.AspNetCore.SignalR.IGroupManager Groups { get; }
}
public partial interface IHubFilter
{
System.Threading.Tasks.ValueTask<object> InvokeMethodAsync(Microsoft.AspNetCore.SignalR.HubInvocationContext invocationContext, System.Func<Microsoft.AspNetCore.SignalR.HubInvocationContext, System.Threading.Tasks.ValueTask<object>> next);
System.Threading.Tasks.Task OnConnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Func<Microsoft.AspNetCore.SignalR.HubLifetimeContext, System.Threading.Tasks.Task> next) { throw null; }
System.Threading.Tasks.Task OnDisconnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Exception exception, System.Func<Microsoft.AspNetCore.SignalR.HubLifetimeContext, System.Exception, System.Threading.Tasks.Task> next) { throw null; }
}
public partial interface IHubProtocolResolver
{
System.Collections.Generic.IReadOnlyList<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> AllProtocols { get; }

View File

@ -64,22 +64,37 @@ namespace Microsoft.AspNetCore.SignalR
_userIdProvider = userIdProvider;
_enableDetailedErrors = false;
List<IHubFilter> hubFilters = null;
if (_hubOptions.UserHasSetValues)
{
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
if (_hubOptions.HubFilters != null)
{
hubFilters = new List<IHubFilter>();
hubFilters.AddRange(_hubOptions.HubFilters);
}
}
else
{
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
if (_globalHubOptions.HubFilters != null)
{
hubFilters = new List<IHubFilter>();
hubFilters.AddRange(_globalHubOptions.HubFilters);
}
}
_dispatcher = new DefaultHubDispatcher<THub>(
serviceScopeFactory,
new HubContext<THub>(lifetimeManager),
_enableDetailedErrors,
new Logger<DefaultHubDispatcher<THub>>(loggerFactory));
new Logger<DefaultHubDispatcher<THub>>(loggerFactory),
hubFilters);
}
/// <inheritdoc />

View File

@ -3,7 +3,9 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Authorization;
using System.Linq;
using System.Reflection;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR
{
@ -12,16 +14,27 @@ namespace Microsoft.AspNetCore.SignalR
/// </summary>
public class HubInvocationContext
{
internal ObjectMethodExecutor ObjectMethodExecutor { get; }
/// <summary>
/// Instantiates a new instance of the <see cref="HubInvocationContext"/> class.
/// </summary>
/// <param name="context">Context for the active Hub connection and caller.</param>
/// <param name="hubType">The type of the Hub.</param>
/// <param name="hubMethodName">The name of the Hub method being invoked.</param>
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.</param>
/// <param name="hub">The instance of the Hub.</param>
/// <param name="hubMethod">The <see cref="MethodInfo"/> for the Hub method being invoked.</param>
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
public HubInvocationContext(HubCallerContext context, Type hubType, string hubMethodName, object[] hubMethodArguments): this(context, hubMethodName, hubMethodArguments)
public HubInvocationContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub, MethodInfo hubMethod, IReadOnlyList<object> hubMethodArguments)
{
HubType = hubType;
Hub = hub;
ServiceProvider = serviceProvider;
HubMethod = hubMethod;
HubMethodArguments = hubMethodArguments;
Context = context;
#pragma warning disable CS0618 // Type or member is obsolete
HubMethodName = HubMethod.Name;
#pragma warning restore CS0618 // Type or member is obsolete
}
/// <summary>
@ -30,11 +43,16 @@ namespace Microsoft.AspNetCore.SignalR
/// <param name="context">Context for the active Hub connection and caller.</param>
/// <param name="hubMethodName">The name of the Hub method being invoked.</param>
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
[Obsolete("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")]
public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
{
HubMethodName = hubMethodName;
HubMethodArguments = hubMethodArguments;
Context = context;
throw new NotSupportedException("This constructor no longer works. Use the other constructor.");
}
internal HubInvocationContext(ObjectMethodExecutor objectMethodExecutor, HubCallerContext context, IServiceProvider serviceProvider, Hub hub, object[] hubMethodArguments)
: this(context, serviceProvider, hub, objectMethodExecutor.MethodInfo, hubMethodArguments)
{
ObjectMethodExecutor = objectMethodExecutor;
}
/// <summary>
@ -43,18 +61,29 @@ namespace Microsoft.AspNetCore.SignalR
public HubCallerContext Context { get; }
/// <summary>
/// Gets the Hub type.
/// Gets the Hub instance.
/// </summary>
public Type HubType { get; }
public Hub Hub { get; }
/// <summary>
/// Gets the name of the Hub method being invoked.
/// </summary>
[Obsolete("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")]
public string HubMethodName { get; }
/// <summary>
/// Gets the arguments provided by the client.
/// </summary>
public IReadOnlyList<object> HubMethodArguments { get; }
/// <summary>
/// The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.
/// </summary>
public IServiceProvider ServiceProvider { get; }
/// <summary>
/// The <see cref="MethodInfo"/> for the Hub method being invoked.
/// </summary>
public MethodInfo HubMethod { get; }
}
}

View File

@ -0,0 +1,43 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
#nullable enable
using System;
namespace Microsoft.AspNetCore.SignalR
{
/// <summary>
/// Context for the hub lifetime events <see cref="Hub.OnConnectedAsync"/> and <see cref="Hub.OnDisconnectedAsync(Exception)"/>.
/// </summary>
public sealed class HubLifetimeContext
{
/// <summary>
/// Instantiates a new instance of the <see cref="HubLifetimeContext"/> class.
/// </summary>
/// <param name="context">Context for the active Hub connection and caller.</param>
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.</param>
/// <param name="hub">The instance of the Hub.</param>
public HubLifetimeContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub)
{
Hub = hub;
ServiceProvider = serviceProvider;
Context = context;
}
/// <summary>
/// Gets the context for the active Hub connection and caller.
/// </summary>
public HubCallerContext Context { get; }
/// <summary>
/// Gets the Hub instance.
/// </summary>
public Hub Hub { get; }
/// <summary>
/// The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.
/// </summary>
public IServiceProvider ServiceProvider { get; }
}
}

View File

@ -51,5 +51,7 @@ namespace Microsoft.AspNetCore.SignalR
/// Gets or sets the max buffer size for client upload streams. The default size is 10.
/// </summary>
public int? StreamBufferCapacity { get; set; } = null;
internal List<IHubFilter> HubFilters { get; set; } = null;
}
}

View File

@ -0,0 +1,60 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
#nullable enable
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.AspNetCore.SignalR
{
/// <summary>
/// Methods to add <see cref="IHubFilter"/>'s to Hubs.
/// </summary>
public static class HubOptionsExtensions
{
/// <summary>
/// Adds an instance of an <see cref="IHubFilter"/> to the <see cref="HubOptions"/>.
/// </summary>
/// <param name="options">The options to add a filter to.</param>
/// <param name="hubFilter">The filter instance to add to the options.</param>
public static void AddFilter(this HubOptions options, IHubFilter hubFilter)
{
_ = options ?? throw new ArgumentNullException(nameof(options));
_ = hubFilter ?? throw new ArgumentNullException(nameof(hubFilter));
if (options.HubFilters == null)
{
options.HubFilters = new List<IHubFilter>();
}
options.HubFilters.Add(hubFilter);
}
/// <summary>
/// Adds an <see cref="IHubFilter"/> type to the <see cref="HubOptions"/> that will be resolved via DI or type activated.
/// </summary>
/// <typeparam name="TFilter">The <see cref="IHubFilter"/> type that will be added to the options.</typeparam>
/// <param name="options">The options to add a filter to.</param>
public static void AddFilter<TFilter>(this HubOptions options) where TFilter : IHubFilter
{
_ = options ?? throw new ArgumentNullException(nameof(options));
options.AddFilter(typeof(TFilter));
}
/// <summary>
/// Adds an <see cref="IHubFilter"/> type to the <see cref="HubOptions"/> that will be resolved via DI or type activated.
/// </summary>
/// <param name="options">The options to add a filter to.</param>
/// <param name="filterType">The <see cref="IHubFilter"/> type that will be added to the options.</param>
public static void AddFilter(this HubOptions options, Type filterType)
{
_ = options ?? throw new ArgumentNullException(nameof(options));
_ = filterType ?? throw new ArgumentNullException(nameof(filterType));
options.AddFilter(new HubFilterFactory(filterType));
}
}
}

View File

@ -17,11 +17,7 @@ namespace Microsoft.AspNetCore.SignalR
public void Configure(HubOptions<THub> options)
{
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols.Count);
foreach (var protocol in _hubOptions.SupportedProtocols)
{
options.SupportedProtocols.Add(protocol);
}
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols);
options.KeepAliveInterval = _hubOptions.KeepAliveInterval;
options.HandshakeTimeout = _hubOptions.HandshakeTimeout;
options.ClientTimeoutInterval = _hubOptions.ClientTimeoutInterval;
@ -30,6 +26,11 @@ namespace Microsoft.AspNetCore.SignalR
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
options.UserHasSetValues = true;
if (_hubOptions.HubFilters != null)
{
options.HubFilters = new List<IHubFilter>(_hubOptions.HubFilters);
}
}
}
}

View File

@ -0,0 +1,41 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
#nullable enable
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
{
/// <summary>
/// The filter abstraction for hub method invocations.
/// </summary>
public interface IHubFilter
{
/// <summary>
/// Allows handling of all Hub method invocations.
/// </summary>
/// <param name="invocationContext">The context for the method invocation that holds all the important information about the invoke.</param>
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
/// <returns>Returns the result of the Hub method invoke.</returns>
ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next);
/// <summary>
/// Allows handling of the <see cref="Hub.OnConnectedAsync"/> method.
/// </summary>
/// <param name="context">The context for OnConnectedAsync.</param>
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
/// <returns></returns>
Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next) => next(context);
/// <summary>
/// Allows handling of the <see cref="Hub.OnDisconnectedAsync(Exception)"/> method.
/// </summary>
/// <param name="context">The context for OnDisconnectedAsync.</param>
/// <param name="exception">The exception, if any, for the connection closing.</param>
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
/// <returns></returns>
Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next) => next(context, exception);
}
}

View File

@ -16,7 +16,6 @@ using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR.Internal
{
@ -27,14 +26,48 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private readonly IHubContext<THub> _hubContext;
private readonly ILogger<HubDispatcher<THub>> _logger;
private readonly bool _enableDetailedErrors;
private readonly Func<HubInvocationContext, ValueTask<object>> _invokeMiddleware;
private readonly Func<HubLifetimeContext, Task> _onConnectedMiddleware;
private readonly Func<HubLifetimeContext, Exception, Task> _onDisconnectedMiddleware;
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors, ILogger<DefaultHubDispatcher<THub>> logger)
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter> hubFilters)
{
_serviceScopeFactory = serviceScopeFactory;
_hubContext = hubContext;
_enableDetailedErrors = enableDetailedErrors;
_logger = logger;
DiscoverHubMethods();
var count = hubFilters?.Count ?? 0;
if (count != 0)
{
_invokeMiddleware = (invocationContext) =>
{
var arguments = invocationContext.HubMethodArguments as object[] ?? invocationContext.HubMethodArguments.ToArray();
if (invocationContext.ObjectMethodExecutor != null)
{
return ExecuteMethod(invocationContext.ObjectMethodExecutor, invocationContext.Hub, arguments);
}
return ExecuteMethod(invocationContext.HubMethod.Name, invocationContext.Hub, arguments);
};
_onConnectedMiddleware = (context) => context.Hub.OnConnectedAsync();
_onDisconnectedMiddleware = (context, exception) => context.Hub.OnDisconnectedAsync(exception);
for (var i = count - 1; i > -1; i--)
{
var resolvedFilter = hubFilters[i];
var nextFilter = _invokeMiddleware;
_invokeMiddleware = (context) => resolvedFilter.InvokeMethodAsync(context, nextFilter);
var connectedFilter = _onConnectedMiddleware;
_onConnectedMiddleware = (context) => resolvedFilter.OnConnectedAsync(context, connectedFilter);
var disconnectedFilter = _onDisconnectedMiddleware;
_onDisconnectedMiddleware = (context, exception) => resolvedFilter.OnDisconnectedAsync(context, exception, disconnectedFilter);
}
}
}
public override async Task OnConnectedAsync(HubConnectionContext connection)
@ -50,7 +83,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
try
{
InitializeHub(hub, connection);
await hub.OnConnectedAsync();
if (_onConnectedMiddleware != null)
{
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
await _onConnectedMiddleware(context);
}
else
{
await hub.OnConnectedAsync();
}
}
finally
{
@ -76,7 +118,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
try
{
InitializeHub(hub, connection);
await hub.OnDisconnectedAsync(exception);
if (_onDisconnectedMiddleware != null)
{
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
await _onDisconnectedMiddleware(context, exception);
}
else
{
await hub.OnDisconnectedAsync(exception);
}
}
finally
{
@ -220,7 +271,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
THub hub = null;
try
{
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments))
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
hub = hubActivator.Create();
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub))
{
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
@ -233,9 +287,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return;
}
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
hub = hubActivator.Create();
try
{
var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
@ -298,7 +349,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (isStreamResponse)
{
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
var result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
if (result == null)
{
@ -315,7 +366,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
}
else
{
// Invoke or Send
@ -324,7 +374,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
object result;
try
{
result = await ExecuteHubMethod(methodExecutor, hub, arguments);
result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
}
catch (Exception ex)
@ -447,13 +497,36 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
private ValueTask<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments, HubConnectionContext connection, IServiceProvider serviceProvider)
{
if (_invokeMiddleware != null)
{
var invocationContext = new HubInvocationContext(methodExecutor, connection.HubCallerContext, serviceProvider, hub, arguments);
return _invokeMiddleware(invocationContext);
}
// If no Hub filters are registered
return ExecuteMethod(methodExecutor, hub, arguments);
}
private ValueTask<object> ExecuteMethod(string hubMethodName, Hub hub, object[] arguments)
{
if (!_methods.TryGetValue(hubMethodName, out var methodDescriptor))
{
throw new HubException($"Unknown hub method '{hubMethodName}'");
}
var methodExecutor = methodDescriptor.MethodExecutor;
return ExecuteMethod(methodExecutor, hub, arguments);
}
private async ValueTask<object> ExecuteMethod(ObjectMethodExecutor methodExecutor, Hub hub, object[] arguments)
{
if (methodExecutor.IsMethodAsync)
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
await (Task)methodExecutor.Execute(hub, arguments);
return null;
}
else
{
@ -464,8 +537,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
return methodExecutor.Execute(hub, arguments);
}
return null;
}
private async Task SendInvocationError(string invocationId,
@ -486,15 +557,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
hub.Groups = _hubContext.Groups;
}
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList<IAuthorizeData> policies, string hubMethodName, object[] hubMethodArguments)
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, HubMethodDescriptor descriptor, object[] hubMethodArguments, Hub hub)
{
// If there are no policies we don't need to run auth
if (policies.Count == 0)
if (descriptor.Policies.Count == 0)
{
return TaskCache.True;
}
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, typeof(THub), hubMethodName, hubMethodArguments));
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, descriptor.Policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, provider, hub, descriptor.MethodExecutor.MethodInfo, hubMethodArguments));
}
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)

View File

@ -0,0 +1,100 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
#nullable enable
using System;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class HubFilterFactory : IHubFilter
{
private readonly ObjectFactory _objectFactory;
private readonly Type _filterType;
public HubFilterFactory(Type filterType)
{
_objectFactory = ActivatorUtilities.CreateFactory(filterType, Array.Empty<Type>());
_filterType = filterType;
}
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
var (filter, owned) = GetFilter(invocationContext.ServiceProvider);
try
{
return await filter.InvokeMethodAsync(invocationContext, next);
}
finally
{
if (owned)
{
await DisposeFilter(filter);
}
}
}
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
var (filter, owned) = GetFilter(context.ServiceProvider);
try
{
await filter.OnConnectedAsync(context, next);
}
finally
{
if (owned)
{
await DisposeFilter(filter);
}
}
}
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
var (filter, owned) = GetFilter(context.ServiceProvider);
try
{
await filter.OnDisconnectedAsync(context, exception, next);
}
finally
{
if (owned)
{
await DisposeFilter(filter);
}
}
}
private ValueTask DisposeFilter(IHubFilter filter)
{
if (filter is IAsyncDisposable asyncDispsable)
{
return asyncDispsable.DisposeAsync();
}
if (filter is IDisposable disposable)
{
disposable.Dispose();
}
return default;
}
private (IHubFilter, bool) GetFilter(IServiceProvider serviceProvider)
{
var owned = false;
var filter = (IHubFilter?)serviceProvider.GetService(_filterType);
if (filter == null)
{
filter = (IHubFilter)_objectFactory.Invoke(serviceProvider, null);
owned = true;
}
return (filter, owned);
}
}
}

View File

@ -1048,8 +1048,19 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public class TcsService
{
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
public TaskCompletionSource<object> StartedMethod;
public TaskCompletionSource<object> EndMethod;
public TcsService()
{
Reset();
}
public void Reset()
{
StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
}
}
public interface ITypedHubClient

View File

@ -2232,14 +2232,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
Assert.NotNull(context.Resource);
var resource = Assert.IsType<HubInvocationContext>(context.Resource);
Assert.Equal(typeof(MethodHub), resource.HubType);
Assert.Equal(typeof(MethodHub), resource.Hub.GetType());
#pragma warning disable CS0618 // Type or member is obsolete
Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
#pragma warning restore CS0618 // Type or member is obsolete
Assert.Equal(2, resource.HubMethodArguments?.Count);
Assert.Equal("Hello", resource.HubMethodArguments[0]);
Assert.Equal("World!", resource.HubMethodArguments[1]);
Assert.NotNull(resource.Context);
Assert.Equal(context.User, resource.Context.User);
Assert.NotNull(resource.Context.GetHttpContext());
Assert.NotNull(resource.ServiceProvider);
Assert.Equal(typeof(MethodHub).GetMethod(nameof(MethodHub.MultiParamAuthMethod)), resource.HubMethod);
return Task.CompletedTask;
}

View File

@ -0,0 +1,867 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class HubFilterTests : VerifiableLoggedTest
{
[Fact]
public async Task GlobalHubFilterByType_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter<VerifyMethodFilter>();
});
services.AddSingleton(tcsService);
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
[Fact]
public async Task GlobalHubFilterByInstance_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new VerifyMethodFilter(tcsService));
});
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
[Fact]
public async Task PerHubFilterByInstance_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR().AddHubOptions<MethodHub>(options =>
{
options.AddFilter(new VerifyMethodFilter(tcsService));
});
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
[Fact]
public async Task PerHubFilterByCompileTimeType_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR().AddHubOptions<MethodHub>(options =>
{
options.AddFilter<VerifyMethodFilter>();
});
services.AddSingleton(tcsService);
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
[Fact]
public async Task PerHubFilterByRuntimeType_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR().AddHubOptions<MethodHub>(options =>
{
options.AddFilter(typeof(VerifyMethodFilter));
});
services.AddSingleton(tcsService);
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
private async Task AssertMethodsCalled(IServiceProvider serviceProvider, TcsService tcsService)
{
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await tcsService.StartedMethod.Task.OrTimeout();
await client.Connected.OrTimeout();
await tcsService.EndMethod.Task.OrTimeout();
tcsService.Reset();
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
await tcsService.EndMethod.Task.OrTimeout();
tcsService.Reset();
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
await tcsService.EndMethod.Task.OrTimeout();
}
}
[Fact]
public async Task MutlipleFilters_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService1 = new TcsService();
var tcsService2 = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new VerifyMethodFilter(tcsService1));
options.AddFilter(new VerifyMethodFilter(tcsService2));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await tcsService1.StartedMethod.Task.OrTimeout();
await tcsService2.StartedMethod.Task.OrTimeout();
await client.Connected.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
}
}
}
[Fact]
public async Task MixingTypeAndInstanceGlobalFilters_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService1 = new TcsService();
var tcsService2 = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new VerifyMethodFilter(tcsService1));
options.AddFilter<VerifyMethodFilter>();
});
services.AddSingleton(tcsService2);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await tcsService1.StartedMethod.Task.OrTimeout();
await tcsService2.StartedMethod.Task.OrTimeout();
await client.Connected.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
}
}
}
[Fact]
public async Task MixingTypeAndInstanceHubSpecificFilters_MethodsAreCalled()
{
using (StartVerifiableLog())
{
var tcsService1 = new TcsService();
var tcsService2 = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR()
.AddHubOptions<MethodHub>(options =>
{
options.AddFilter(new VerifyMethodFilter(tcsService1));
options.AddFilter<VerifyMethodFilter>();
});
services.AddSingleton(tcsService2);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await tcsService1.StartedMethod.Task.OrTimeout();
await tcsService2.StartedMethod.Task.OrTimeout();
await client.Connected.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
tcsService1.Reset();
tcsService2.Reset();
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
await tcsService1.EndMethod.Task.OrTimeout();
await tcsService2.EndMethod.Task.OrTimeout();
}
}
}
[Fact]
public async Task GlobalFiltersRunInOrder()
{
using (StartVerifiableLog())
{
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new SyncPointFilter(syncPoints1));
options.AddFilter(new SyncPointFilter(syncPoints2));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
syncPoints1[0].Continue();
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
syncPoints2[0].Continue();
await client.Connected.OrTimeout();
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
syncPoints1[1].Continue();
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
syncPoints2[1].Continue();
var message = await invokeTask.OrTimeout();
Assert.Null(message.Error);
client.Dispose();
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
syncPoints1[2].Continue();
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
syncPoints2[2].Continue();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task HubSpecificFiltersRunInOrder()
{
using (StartVerifiableLog())
{
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR()
.AddHubOptions<MethodHub>(options =>
{
options.AddFilter(new SyncPointFilter(syncPoints1));
options.AddFilter(new SyncPointFilter(syncPoints2));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
syncPoints1[0].Continue();
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
syncPoints2[0].Continue();
await client.Connected.OrTimeout();
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
syncPoints1[1].Continue();
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
syncPoints2[1].Continue();
var message = await invokeTask.OrTimeout();
Assert.Null(message.Error);
client.Dispose();
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
syncPoints1[2].Continue();
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
syncPoints2[2].Continue();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task GlobalFiltersRunBeforeHubSpecificFilters()
{
using (StartVerifiableLog())
{
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new SyncPointFilter(syncPoints1));
})
.AddHubOptions<MethodHub>(options =>
{
options.AddFilter(new SyncPointFilter(syncPoints2));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
syncPoints1[0].Continue();
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
syncPoints2[0].Continue();
await client.Connected.OrTimeout();
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
syncPoints1[1].Continue();
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
syncPoints2[1].Continue();
var message = await invokeTask.OrTimeout();
Assert.Null(message.Error);
client.Dispose();
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
// Second filter wont run yet because first filter is waiting on SyncPoint
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
syncPoints1[2].Continue();
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
syncPoints2[2].Continue();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task FilterCanBeResolvedFromDI()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter<VerifyMethodFilter>();
});
// If this instance wasn't resolved, then the tcsService.StartedMethod waits would never trigger and fail the test
services.AddSingleton(new VerifyMethodFilter(tcsService));
}, LoggerFactory);
await AssertMethodsCalled(serviceProvider, tcsService);
}
}
[Fact]
public async Task FiltersHaveTransientScopeByDefault()
{
using (StartVerifiableLog())
{
var counter = new FilterCounter();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter<CounterFilter>();
});
services.AddSingleton(counter);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
// Filter is transient, so these counts are reset every time the filter is created
Assert.Equal(1, counter.OnConnectedAsyncCount);
Assert.Equal(0, counter.InvokeMethodAsyncCount);
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
// Filter is transient, so these counts are reset every time the filter is created
Assert.Equal(0, counter.OnConnectedAsyncCount);
Assert.Equal(1, counter.InvokeMethodAsyncCount);
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
// Filter is transient, so these counts are reset every time the filter is created
Assert.Equal(0, counter.OnConnectedAsyncCount);
Assert.Equal(0, counter.InvokeMethodAsyncCount);
Assert.Equal(1, counter.OnDisconnectedAsyncCount);
}
}
}
[Fact]
public async Task FiltersCanBeSingletonIfAddedToDI()
{
using (StartVerifiableLog())
{
var counter = new FilterCounter();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter<CounterFilter>();
});
services.AddSingleton<CounterFilter>();
services.AddSingleton(counter);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
Assert.Equal(1, counter.OnConnectedAsyncCount);
Assert.Equal(0, counter.InvokeMethodAsyncCount);
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
Assert.Equal(1, counter.OnConnectedAsyncCount);
Assert.Equal(1, counter.InvokeMethodAsyncCount);
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
Assert.Equal(1, counter.OnConnectedAsyncCount);
Assert.Equal(1, counter.InvokeMethodAsyncCount);
Assert.Equal(1, counter.OnDisconnectedAsyncCount);
}
}
}
[Fact]
public async Task ConnectionContinuesIfOnConnectedAsyncThrowsAndFilterDoesNot()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter<NoExceptionFilter>();
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<OnConnectedThrowsHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
// Verify connection still connected, can't invoke a method if the connection is disconnected
var message = await client.InvokeAsync("Method");
Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task ConnectionContinuesIfOnConnectedAsyncNotCalledByFilter()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter(new SkipNextFilter(skipOnConnected: true));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
// Verify connection still connected, can't invoke a method if the connection is disconnected
var message = await client.InvokeAsync("Method");
Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task FilterCanSkipCallingHubMethod()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.AddFilter(new SkipNextFilter(skipInvoke: true));
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
Assert.Null(message.Error);
Assert.Null(message.Result);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task FiltersWithIDisposableAreDisposed()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter<DisposableFilter>();
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
// OnConnectedAsync creates and destroys the filter
await tcsService.StartedMethod.Task.OrTimeout();
tcsService.Reset();
var message = await client.InvokeAsync("Echo", "Hello");
Assert.Equal("Hello", message.Result);
await tcsService.StartedMethod.Task.OrTimeout();
tcsService.Reset();
client.Dispose();
// OnDisconnectedAsync creates and destroys the filter
await tcsService.StartedMethod.Task.OrTimeout();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task InstanceFiltersWithIDisposableAreNotDisposed()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter(new DisposableFilter(tcsService));
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
var message = await client.InvokeAsync("Echo", "Hello");
Assert.Equal("Hello", message.Result);
client.Dispose();
await connectionHandlerTask.OrTimeout();
Assert.False(tcsService.StartedMethod.Task.IsCompleted);
}
}
}
[Fact]
public async Task FiltersWithIAsyncDisposableAreDisposed()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter<AsyncDisposableFilter>();
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
// OnConnectedAsync creates and destroys the filter
await tcsService.StartedMethod.Task.OrTimeout();
tcsService.Reset();
var message = await client.InvokeAsync("Echo", "Hello");
Assert.Equal("Hello", message.Result);
await tcsService.StartedMethod.Task.OrTimeout();
tcsService.Reset();
client.Dispose();
// OnDisconnectedAsync creates and destroys the filter
await tcsService.StartedMethod.Task.OrTimeout();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task InstanceFiltersWithIAsyncDisposableAreNotDisposed()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter(new AsyncDisposableFilter(tcsService));
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
var message = await client.InvokeAsync("Echo", "Hello");
Assert.Equal("Hello", message.Result);
client.Dispose();
await connectionHandlerTask.OrTimeout();
Assert.False(tcsService.StartedMethod.Task.IsCompleted);
}
}
}
[Fact]
public async Task InvokeFailsWhenFilterCallsNonExistantMethod()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
writeContext.EventId.Name == "FailedInvokingHubMethod";
}
using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors))
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.EnableDetailedErrors = true;
options.AddFilter<ChangeMethodFilter>();
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
var message = await client.InvokeAsync("Echo", "Hello");
Assert.Equal("An unexpected error occurred invoking 'Echo' on the server. HubException: Unknown hub method 'BaseMethod'", message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
}
}

View File

@ -0,0 +1,236 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Internal;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class VerifyMethodFilter : IHubFilter
{
private readonly TcsService _service;
public VerifyMethodFilter(TcsService tcsService)
{
_service = tcsService;
}
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
_service.StartedMethod.TrySetResult(null);
await next(context);
_service.EndMethod.TrySetResult(null);
}
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
_service.StartedMethod.TrySetResult(null);
var result = await next(invocationContext);
_service.EndMethod.TrySetResult(null);
return result;
}
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
_service.StartedMethod.TrySetResult(null);
await next(context, exception);
_service.EndMethod.TrySetResult(null);
}
}
public class SyncPointFilter : IHubFilter
{
private readonly SyncPoint[] _syncPoint;
public SyncPointFilter(SyncPoint[] syncPoints)
{
Debug.Assert(syncPoints.Length == 3);
_syncPoint = syncPoints;
}
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
await _syncPoint[0].WaitToContinue();
await next(context);
}
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
await _syncPoint[1].WaitToContinue();
var result = await next(invocationContext);
return result;
}
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
await _syncPoint[2].WaitToContinue();
await next(context, exception);
}
}
public class FilterCounter
{
public int OnConnectedAsyncCount;
public int InvokeMethodAsyncCount;
public int OnDisconnectedAsyncCount;
}
public class CounterFilter : IHubFilter
{
private readonly FilterCounter _counter;
public CounterFilter(FilterCounter counter)
{
_counter = counter;
_counter.OnConnectedAsyncCount = 0;
_counter.InvokeMethodAsyncCount = 0;
_counter.OnDisconnectedAsyncCount = 0;
}
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
_counter.OnConnectedAsyncCount++;
return next(context);
}
public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
_counter.OnDisconnectedAsyncCount++;
return next(context, exception);
}
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
_counter.InvokeMethodAsyncCount++;
return next(invocationContext);
}
}
public class NoExceptionFilter : IHubFilter
{
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
try
{
await next(context);
}
catch { }
}
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
try
{
await next(context, exception);
}
catch { }
}
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
try
{
return await next(invocationContext);
}
catch { }
return null;
}
}
public class SkipNextFilter : IHubFilter
{
private readonly bool _skipOnConnected;
private readonly bool _skipInvoke;
private readonly bool _skipOnDisconnected;
public SkipNextFilter(bool skipOnConnected = false, bool skipInvoke = false, bool skipOnDisconnected = false)
{
_skipOnConnected = skipOnConnected;
_skipInvoke = skipInvoke;
_skipOnDisconnected = skipOnDisconnected;
}
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
{
if (_skipOnConnected)
{
return Task.CompletedTask;
}
return next(context);
}
public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
{
if (_skipOnDisconnected)
{
return Task.CompletedTask;
}
return next(context, exception);
}
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
if (_skipInvoke)
{
return new ValueTask<object>();
}
return next(invocationContext);
}
}
public class DisposableFilter : IHubFilter, IDisposable
{
private readonly TcsService _tcsService;
public DisposableFilter(TcsService tcsService)
{
_tcsService = tcsService;
}
public void Dispose()
{
_tcsService.StartedMethod.SetResult(null);
}
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
return next(invocationContext);
}
}
public class AsyncDisposableFilter : IHubFilter, IAsyncDisposable
{
private readonly TcsService _tcsService;
public AsyncDisposableFilter(TcsService tcsService)
{
_tcsService = tcsService;
}
public ValueTask DisposeAsync()
{
_tcsService.StartedMethod.SetResult(null);
return default;
}
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
return next(invocationContext);
}
}
public class ChangeMethodFilter : IHubFilter
{
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
var methodInfo = typeof(BaseHub).GetMethod(nameof(BaseHub.BaseMethod));
var context = new HubInvocationContext(invocationContext.Context, invocationContext.ServiceProvider, invocationContext.Hub, methodInfo, invocationContext.HubMethodArguments);
return next(context);
}
}
}