Hub filters! (#21278)
This commit is contained in:
parent
bad6e32e7e
commit
2ad8121efb
|
|
@ -38,7 +38,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
|
||||||
serviceScopeFactory,
|
serviceScopeFactory,
|
||||||
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>(NullLogger<DefaultHubLifetimeManager<TestHub>>.Instance)),
|
new HubContext<TestHub>(new DefaultHubLifetimeManager<TestHub>(NullLogger<DefaultHubLifetimeManager<TestHub>>.Instance)),
|
||||||
enableDetailedErrors: false,
|
enableDetailedErrors: false,
|
||||||
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance));
|
new Logger<DefaultHubDispatcher<TestHub>>(NullLoggerFactory.Instance),
|
||||||
|
hubFilters: null);
|
||||||
|
|
||||||
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
|
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
|
||||||
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport);
|
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport);
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,12 @@
|
||||||
// Copyright (c) .NET Foundation. All rights reserved.
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
using System;
|
|
||||||
using System.IO;
|
|
||||||
using System.Reflection;
|
using System.Reflection;
|
||||||
using System.Threading.Tasks;
|
|
||||||
using Microsoft.AspNetCore.Builder;
|
using Microsoft.AspNetCore.Builder;
|
||||||
using Microsoft.AspNetCore.Hosting;
|
using Microsoft.AspNetCore.Hosting;
|
||||||
using Microsoft.Extensions.DependencyInjection;
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
using Microsoft.Extensions.Hosting;
|
using Microsoft.Extensions.Hosting;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
using System.Text.Json.Serialization;
|
|
||||||
using SignalRSamples.ConnectionHandlers;
|
using SignalRSamples.ConnectionHandlers;
|
||||||
using SignalRSamples.Hubs;
|
using SignalRSamples.Hubs;
|
||||||
|
|
||||||
|
|
@ -18,7 +14,6 @@ namespace SignalRSamples
|
||||||
{
|
{
|
||||||
public class Startup
|
public class Startup
|
||||||
{
|
{
|
||||||
|
|
||||||
private readonly JsonWriterOptions _jsonWriterOptions = new JsonWriterOptions { Indented = true };
|
private readonly JsonWriterOptions _jsonWriterOptions = new JsonWriterOptions { Indented = true };
|
||||||
|
|
||||||
// This method gets called by the runtime. Use this method to add services to the container.
|
// This method gets called by the runtime. Use this method to add services to the container.
|
||||||
|
|
@ -27,11 +22,7 @@ namespace SignalRSamples
|
||||||
{
|
{
|
||||||
services.AddConnections();
|
services.AddConnections();
|
||||||
|
|
||||||
services.AddSignalR(options =>
|
services.AddSignalR()
|
||||||
{
|
|
||||||
// Faster pings for testing
|
|
||||||
options.KeepAliveInterval = TimeSpan.FromSeconds(5);
|
|
||||||
})
|
|
||||||
.AddMessagePackProtocol();
|
.AddMessagePackProtocol();
|
||||||
//.AddStackExchangeRedis();
|
//.AddStackExchangeRedis();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -181,12 +181,23 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
}
|
}
|
||||||
public partial class HubInvocationContext
|
public partial class HubInvocationContext
|
||||||
{
|
{
|
||||||
|
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub, System.Reflection.MethodInfo hubMethod, System.Collections.Generic.IReadOnlyList<object> hubMethodArguments) { }
|
||||||
|
[System.ObsoleteAttribute("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")]
|
||||||
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
|
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
|
||||||
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.Type hubType, string hubMethodName, object[] hubMethodArguments) { }
|
|
||||||
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
public System.Reflection.MethodInfo HubMethod { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
[System.ObsoleteAttribute("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")]
|
||||||
public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
public System.Type HubType { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
}
|
||||||
|
public sealed partial class HubLifetimeContext
|
||||||
|
{
|
||||||
|
public HubLifetimeContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, System.IServiceProvider serviceProvider, Microsoft.AspNetCore.SignalR.Hub hub) { }
|
||||||
|
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
public Microsoft.AspNetCore.SignalR.Hub Hub { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
|
public System.IServiceProvider ServiceProvider { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } }
|
||||||
}
|
}
|
||||||
public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
|
public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
|
||||||
{
|
{
|
||||||
|
|
@ -227,6 +238,12 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
public int? StreamBufferCapacity { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
|
public int? StreamBufferCapacity { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
|
||||||
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
|
public System.Collections.Generic.IList<string> SupportedProtocols { [System.Runtime.CompilerServices.CompilerGeneratedAttribute] get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute] set { } }
|
||||||
}
|
}
|
||||||
|
public static partial class HubOptionsExtensions
|
||||||
|
{
|
||||||
|
public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, Microsoft.AspNetCore.SignalR.IHubFilter hubFilter) { }
|
||||||
|
public static void AddFilter(this Microsoft.AspNetCore.SignalR.HubOptions options, System.Type filterType) { }
|
||||||
|
public static void AddFilter<TFilter>(this Microsoft.AspNetCore.SignalR.HubOptions options) where TFilter : Microsoft.AspNetCore.SignalR.IHubFilter { }
|
||||||
|
}
|
||||||
public partial class HubOptionsSetup : Microsoft.Extensions.Options.IConfigureOptions<Microsoft.AspNetCore.SignalR.HubOptions>
|
public partial class HubOptionsSetup : Microsoft.Extensions.Options.IConfigureOptions<Microsoft.AspNetCore.SignalR.HubOptions>
|
||||||
{
|
{
|
||||||
public HubOptionsSetup(System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> protocols) { }
|
public HubOptionsSetup(System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> protocols) { }
|
||||||
|
|
@ -294,6 +311,12 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
Microsoft.AspNetCore.SignalR.IHubClients<T> Clients { get; }
|
Microsoft.AspNetCore.SignalR.IHubClients<T> Clients { get; }
|
||||||
Microsoft.AspNetCore.SignalR.IGroupManager Groups { get; }
|
Microsoft.AspNetCore.SignalR.IGroupManager Groups { get; }
|
||||||
}
|
}
|
||||||
|
public partial interface IHubFilter
|
||||||
|
{
|
||||||
|
System.Threading.Tasks.ValueTask<object> InvokeMethodAsync(Microsoft.AspNetCore.SignalR.HubInvocationContext invocationContext, System.Func<Microsoft.AspNetCore.SignalR.HubInvocationContext, System.Threading.Tasks.ValueTask<object>> next);
|
||||||
|
System.Threading.Tasks.Task OnConnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Func<Microsoft.AspNetCore.SignalR.HubLifetimeContext, System.Threading.Tasks.Task> next) { throw null; }
|
||||||
|
System.Threading.Tasks.Task OnDisconnectedAsync(Microsoft.AspNetCore.SignalR.HubLifetimeContext context, System.Exception exception, System.Func<Microsoft.AspNetCore.SignalR.HubLifetimeContext, System.Exception, System.Threading.Tasks.Task> next) { throw null; }
|
||||||
|
}
|
||||||
public partial interface IHubProtocolResolver
|
public partial interface IHubProtocolResolver
|
||||||
{
|
{
|
||||||
System.Collections.Generic.IReadOnlyList<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> AllProtocols { get; }
|
System.Collections.Generic.IReadOnlyList<Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol> AllProtocols { get; }
|
||||||
|
|
|
||||||
|
|
@ -64,22 +64,37 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
_userIdProvider = userIdProvider;
|
_userIdProvider = userIdProvider;
|
||||||
|
|
||||||
_enableDetailedErrors = false;
|
_enableDetailedErrors = false;
|
||||||
|
|
||||||
|
List<IHubFilter> hubFilters = null;
|
||||||
if (_hubOptions.UserHasSetValues)
|
if (_hubOptions.UserHasSetValues)
|
||||||
{
|
{
|
||||||
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
|
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
|
||||||
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
|
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
|
||||||
|
|
||||||
|
if (_hubOptions.HubFilters != null)
|
||||||
|
{
|
||||||
|
hubFilters = new List<IHubFilter>();
|
||||||
|
hubFilters.AddRange(_hubOptions.HubFilters);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
|
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
|
||||||
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
|
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
|
||||||
|
|
||||||
|
if (_globalHubOptions.HubFilters != null)
|
||||||
|
{
|
||||||
|
hubFilters = new List<IHubFilter>();
|
||||||
|
hubFilters.AddRange(_globalHubOptions.HubFilters);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_dispatcher = new DefaultHubDispatcher<THub>(
|
_dispatcher = new DefaultHubDispatcher<THub>(
|
||||||
serviceScopeFactory,
|
serviceScopeFactory,
|
||||||
new HubContext<THub>(lifetimeManager),
|
new HubContext<THub>(lifetimeManager),
|
||||||
_enableDetailedErrors,
|
_enableDetailedErrors,
|
||||||
new Logger<DefaultHubDispatcher<THub>>(loggerFactory));
|
new Logger<DefaultHubDispatcher<THub>>(loggerFactory),
|
||||||
|
hubFilters);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,9 @@
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
using Microsoft.AspNetCore.Authorization;
|
using System.Linq;
|
||||||
|
using System.Reflection;
|
||||||
|
using Microsoft.Extensions.Internal;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
{
|
{
|
||||||
|
|
@ -12,16 +14,27 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public class HubInvocationContext
|
public class HubInvocationContext
|
||||||
{
|
{
|
||||||
|
internal ObjectMethodExecutor ObjectMethodExecutor { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Instantiates a new instance of the <see cref="HubInvocationContext"/> class.
|
/// Instantiates a new instance of the <see cref="HubInvocationContext"/> class.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
/// <param name="context">Context for the active Hub connection and caller.</param>
|
/// <param name="context">Context for the active Hub connection and caller.</param>
|
||||||
/// <param name="hubType">The type of the Hub.</param>
|
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.</param>
|
||||||
/// <param name="hubMethodName">The name of the Hub method being invoked.</param>
|
/// <param name="hub">The instance of the Hub.</param>
|
||||||
|
/// <param name="hubMethod">The <see cref="MethodInfo"/> for the Hub method being invoked.</param>
|
||||||
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
|
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
|
||||||
public HubInvocationContext(HubCallerContext context, Type hubType, string hubMethodName, object[] hubMethodArguments): this(context, hubMethodName, hubMethodArguments)
|
public HubInvocationContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub, MethodInfo hubMethod, IReadOnlyList<object> hubMethodArguments)
|
||||||
{
|
{
|
||||||
HubType = hubType;
|
Hub = hub;
|
||||||
|
ServiceProvider = serviceProvider;
|
||||||
|
HubMethod = hubMethod;
|
||||||
|
HubMethodArguments = hubMethodArguments;
|
||||||
|
Context = context;
|
||||||
|
|
||||||
|
#pragma warning disable CS0618 // Type or member is obsolete
|
||||||
|
HubMethodName = HubMethod.Name;
|
||||||
|
#pragma warning restore CS0618 // Type or member is obsolete
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -30,11 +43,16 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
/// <param name="context">Context for the active Hub connection and caller.</param>
|
/// <param name="context">Context for the active Hub connection and caller.</param>
|
||||||
/// <param name="hubMethodName">The name of the Hub method being invoked.</param>
|
/// <param name="hubMethodName">The name of the Hub method being invoked.</param>
|
||||||
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
|
/// <param name="hubMethodArguments">The arguments provided by the client.</param>
|
||||||
|
[Obsolete("This constructor is obsolete and will be removed in a future version. The recommended alternative is to use the other constructor.")]
|
||||||
public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
|
public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
|
||||||
{
|
{
|
||||||
HubMethodName = hubMethodName;
|
throw new NotSupportedException("This constructor no longer works. Use the other constructor.");
|
||||||
HubMethodArguments = hubMethodArguments;
|
}
|
||||||
Context = context;
|
|
||||||
|
internal HubInvocationContext(ObjectMethodExecutor objectMethodExecutor, HubCallerContext context, IServiceProvider serviceProvider, Hub hub, object[] hubMethodArguments)
|
||||||
|
: this(context, serviceProvider, hub, objectMethodExecutor.MethodInfo, hubMethodArguments)
|
||||||
|
{
|
||||||
|
ObjectMethodExecutor = objectMethodExecutor;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -43,18 +61,29 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
public HubCallerContext Context { get; }
|
public HubCallerContext Context { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the Hub type.
|
/// Gets the Hub instance.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public Type HubType { get; }
|
public Hub Hub { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the name of the Hub method being invoked.
|
/// Gets the name of the Hub method being invoked.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
|
[Obsolete("This property is obsolete and will be removed in a future version. Use HubMethod.Name instead.")]
|
||||||
public string HubMethodName { get; }
|
public string HubMethodName { get; }
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Gets the arguments provided by the client.
|
/// Gets the arguments provided by the client.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public IReadOnlyList<object> HubMethodArguments { get; }
|
public IReadOnlyList<object> HubMethodArguments { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.
|
||||||
|
/// </summary>
|
||||||
|
public IServiceProvider ServiceProvider { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The <see cref="MethodInfo"/> for the Hub method being invoked.
|
||||||
|
/// </summary>
|
||||||
|
public MethodInfo HubMethod { get; }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,43 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
#nullable enable
|
||||||
|
|
||||||
|
using System;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Context for the hub lifetime events <see cref="Hub.OnConnectedAsync"/> and <see cref="Hub.OnDisconnectedAsync(Exception)"/>.
|
||||||
|
/// </summary>
|
||||||
|
public sealed class HubLifetimeContext
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Instantiates a new instance of the <see cref="HubLifetimeContext"/> class.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="context">Context for the active Hub connection and caller.</param>
|
||||||
|
/// <param name="serviceProvider">The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.</param>
|
||||||
|
/// <param name="hub">The instance of the Hub.</param>
|
||||||
|
public HubLifetimeContext(HubCallerContext context, IServiceProvider serviceProvider, Hub hub)
|
||||||
|
{
|
||||||
|
Hub = hub;
|
||||||
|
ServiceProvider = serviceProvider;
|
||||||
|
Context = context;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the context for the active Hub connection and caller.
|
||||||
|
/// </summary>
|
||||||
|
public HubCallerContext Context { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Gets the Hub instance.
|
||||||
|
/// </summary>
|
||||||
|
public Hub Hub { get; }
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// The <see cref="IServiceProvider"/> specific to the scope of this Hub method invocation.
|
||||||
|
/// </summary>
|
||||||
|
public IServiceProvider ServiceProvider { get; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -51,5 +51,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
/// Gets or sets the max buffer size for client upload streams. The default size is 10.
|
/// Gets or sets the max buffer size for client upload streams. The default size is 10.
|
||||||
/// </summary>
|
/// </summary>
|
||||||
public int? StreamBufferCapacity { get; set; } = null;
|
public int? StreamBufferCapacity { get; set; } = null;
|
||||||
|
|
||||||
|
internal List<IHubFilter> HubFilters { get; set; } = null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
#nullable enable
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Internal;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Methods to add <see cref="IHubFilter"/>'s to Hubs.
|
||||||
|
/// </summary>
|
||||||
|
public static class HubOptionsExtensions
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Adds an instance of an <see cref="IHubFilter"/> to the <see cref="HubOptions"/>.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="options">The options to add a filter to.</param>
|
||||||
|
/// <param name="hubFilter">The filter instance to add to the options.</param>
|
||||||
|
public static void AddFilter(this HubOptions options, IHubFilter hubFilter)
|
||||||
|
{
|
||||||
|
_ = options ?? throw new ArgumentNullException(nameof(options));
|
||||||
|
_ = hubFilter ?? throw new ArgumentNullException(nameof(hubFilter));
|
||||||
|
|
||||||
|
if (options.HubFilters == null)
|
||||||
|
{
|
||||||
|
options.HubFilters = new List<IHubFilter>();
|
||||||
|
}
|
||||||
|
|
||||||
|
options.HubFilters.Add(hubFilter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Adds an <see cref="IHubFilter"/> type to the <see cref="HubOptions"/> that will be resolved via DI or type activated.
|
||||||
|
/// </summary>
|
||||||
|
/// <typeparam name="TFilter">The <see cref="IHubFilter"/> type that will be added to the options.</typeparam>
|
||||||
|
/// <param name="options">The options to add a filter to.</param>
|
||||||
|
public static void AddFilter<TFilter>(this HubOptions options) where TFilter : IHubFilter
|
||||||
|
{
|
||||||
|
_ = options ?? throw new ArgumentNullException(nameof(options));
|
||||||
|
|
||||||
|
options.AddFilter(typeof(TFilter));
|
||||||
|
}
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Adds an <see cref="IHubFilter"/> type to the <see cref="HubOptions"/> that will be resolved via DI or type activated.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="options">The options to add a filter to.</param>
|
||||||
|
/// <param name="filterType">The <see cref="IHubFilter"/> type that will be added to the options.</param>
|
||||||
|
public static void AddFilter(this HubOptions options, Type filterType)
|
||||||
|
{
|
||||||
|
_ = options ?? throw new ArgumentNullException(nameof(options));
|
||||||
|
_ = filterType ?? throw new ArgumentNullException(nameof(filterType));
|
||||||
|
|
||||||
|
options.AddFilter(new HubFilterFactory(filterType));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -17,11 +17,7 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
public void Configure(HubOptions<THub> options)
|
public void Configure(HubOptions<THub> options)
|
||||||
{
|
{
|
||||||
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
|
// Do a deep copy, otherwise users modifying the HubOptions<THub> list would be changing the global options list
|
||||||
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols.Count);
|
options.SupportedProtocols = new List<string>(_hubOptions.SupportedProtocols);
|
||||||
foreach (var protocol in _hubOptions.SupportedProtocols)
|
|
||||||
{
|
|
||||||
options.SupportedProtocols.Add(protocol);
|
|
||||||
}
|
|
||||||
options.KeepAliveInterval = _hubOptions.KeepAliveInterval;
|
options.KeepAliveInterval = _hubOptions.KeepAliveInterval;
|
||||||
options.HandshakeTimeout = _hubOptions.HandshakeTimeout;
|
options.HandshakeTimeout = _hubOptions.HandshakeTimeout;
|
||||||
options.ClientTimeoutInterval = _hubOptions.ClientTimeoutInterval;
|
options.ClientTimeoutInterval = _hubOptions.ClientTimeoutInterval;
|
||||||
|
|
@ -30,6 +26,11 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
|
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
|
||||||
|
|
||||||
options.UserHasSetValues = true;
|
options.UserHasSetValues = true;
|
||||||
|
|
||||||
|
if (_hubOptions.HubFilters != null)
|
||||||
|
{
|
||||||
|
options.HubFilters = new List<IHubFilter>(_hubOptions.HubFilters);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
#nullable enable
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// The filter abstraction for hub method invocations.
|
||||||
|
/// </summary>
|
||||||
|
public interface IHubFilter
|
||||||
|
{
|
||||||
|
/// <summary>
|
||||||
|
/// Allows handling of all Hub method invocations.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="invocationContext">The context for the method invocation that holds all the important information about the invoke.</param>
|
||||||
|
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
|
||||||
|
/// <returns>Returns the result of the Hub method invoke.</returns>
|
||||||
|
ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Allows handling of the <see cref="Hub.OnConnectedAsync"/> method.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="context">The context for OnConnectedAsync.</param>
|
||||||
|
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
|
||||||
|
/// <returns></returns>
|
||||||
|
Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next) => next(context);
|
||||||
|
|
||||||
|
/// <summary>
|
||||||
|
/// Allows handling of the <see cref="Hub.OnDisconnectedAsync(Exception)"/> method.
|
||||||
|
/// </summary>
|
||||||
|
/// <param name="context">The context for OnDisconnectedAsync.</param>
|
||||||
|
/// <param name="exception">The exception, if any, for the connection closing.</param>
|
||||||
|
/// <param name="next">The next filter to run, and for the final one, the Hub invocation.</param>
|
||||||
|
/// <returns></returns>
|
||||||
|
Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next) => next(context, exception);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -16,7 +16,6 @@ using Microsoft.AspNetCore.SignalR.Protocol;
|
||||||
using Microsoft.Extensions.DependencyInjection;
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
using Microsoft.Extensions.Internal;
|
using Microsoft.Extensions.Internal;
|
||||||
using Microsoft.Extensions.Logging;
|
using Microsoft.Extensions.Logging;
|
||||||
using Microsoft.Extensions.Options;
|
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR.Internal
|
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
{
|
{
|
||||||
|
|
@ -27,14 +26,48 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
private readonly IHubContext<THub> _hubContext;
|
private readonly IHubContext<THub> _hubContext;
|
||||||
private readonly ILogger<HubDispatcher<THub>> _logger;
|
private readonly ILogger<HubDispatcher<THub>> _logger;
|
||||||
private readonly bool _enableDetailedErrors;
|
private readonly bool _enableDetailedErrors;
|
||||||
|
private readonly Func<HubInvocationContext, ValueTask<object>> _invokeMiddleware;
|
||||||
|
private readonly Func<HubLifetimeContext, Task> _onConnectedMiddleware;
|
||||||
|
private readonly Func<HubLifetimeContext, Exception, Task> _onDisconnectedMiddleware;
|
||||||
|
|
||||||
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors, ILogger<DefaultHubDispatcher<THub>> logger)
|
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, bool enableDetailedErrors,
|
||||||
|
ILogger<DefaultHubDispatcher<THub>> logger, List<IHubFilter> hubFilters)
|
||||||
{
|
{
|
||||||
_serviceScopeFactory = serviceScopeFactory;
|
_serviceScopeFactory = serviceScopeFactory;
|
||||||
_hubContext = hubContext;
|
_hubContext = hubContext;
|
||||||
_enableDetailedErrors = enableDetailedErrors;
|
_enableDetailedErrors = enableDetailedErrors;
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
DiscoverHubMethods();
|
DiscoverHubMethods();
|
||||||
|
|
||||||
|
var count = hubFilters?.Count ?? 0;
|
||||||
|
if (count != 0)
|
||||||
|
{
|
||||||
|
_invokeMiddleware = (invocationContext) =>
|
||||||
|
{
|
||||||
|
var arguments = invocationContext.HubMethodArguments as object[] ?? invocationContext.HubMethodArguments.ToArray();
|
||||||
|
if (invocationContext.ObjectMethodExecutor != null)
|
||||||
|
{
|
||||||
|
return ExecuteMethod(invocationContext.ObjectMethodExecutor, invocationContext.Hub, arguments);
|
||||||
|
}
|
||||||
|
return ExecuteMethod(invocationContext.HubMethod.Name, invocationContext.Hub, arguments);
|
||||||
|
};
|
||||||
|
|
||||||
|
_onConnectedMiddleware = (context) => context.Hub.OnConnectedAsync();
|
||||||
|
_onDisconnectedMiddleware = (context, exception) => context.Hub.OnDisconnectedAsync(exception);
|
||||||
|
|
||||||
|
for (var i = count - 1; i > -1; i--)
|
||||||
|
{
|
||||||
|
var resolvedFilter = hubFilters[i];
|
||||||
|
var nextFilter = _invokeMiddleware;
|
||||||
|
_invokeMiddleware = (context) => resolvedFilter.InvokeMethodAsync(context, nextFilter);
|
||||||
|
|
||||||
|
var connectedFilter = _onConnectedMiddleware;
|
||||||
|
_onConnectedMiddleware = (context) => resolvedFilter.OnConnectedAsync(context, connectedFilter);
|
||||||
|
|
||||||
|
var disconnectedFilter = _onDisconnectedMiddleware;
|
||||||
|
_onDisconnectedMiddleware = (context, exception) => resolvedFilter.OnDisconnectedAsync(context, exception, disconnectedFilter);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public override async Task OnConnectedAsync(HubConnectionContext connection)
|
public override async Task OnConnectedAsync(HubConnectionContext connection)
|
||||||
|
|
@ -50,7 +83,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
InitializeHub(hub, connection);
|
InitializeHub(hub, connection);
|
||||||
await hub.OnConnectedAsync();
|
|
||||||
|
if (_onConnectedMiddleware != null)
|
||||||
|
{
|
||||||
|
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
|
||||||
|
await _onConnectedMiddleware(context);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
await hub.OnConnectedAsync();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
{
|
{
|
||||||
|
|
@ -76,7 +118,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
InitializeHub(hub, connection);
|
InitializeHub(hub, connection);
|
||||||
await hub.OnDisconnectedAsync(exception);
|
|
||||||
|
if (_onDisconnectedMiddleware != null)
|
||||||
|
{
|
||||||
|
var context = new HubLifetimeContext(connection.HubCallerContext, scope.ServiceProvider, hub);
|
||||||
|
await _onDisconnectedMiddleware(context, exception);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
await hub.OnDisconnectedAsync(exception);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
finally
|
finally
|
||||||
{
|
{
|
||||||
|
|
@ -220,7 +271,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
THub hub = null;
|
THub hub = null;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments))
|
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
|
||||||
|
hub = hubActivator.Create();
|
||||||
|
|
||||||
|
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor, hubMethodInvocationMessage.Arguments, hub))
|
||||||
{
|
{
|
||||||
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
|
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
|
||||||
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
|
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
|
||||||
|
|
@ -233,9 +287,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
|
|
||||||
hub = hubActivator.Create();
|
|
||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
|
var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0;
|
||||||
|
|
@ -298,7 +349,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
|
|
||||||
if (isStreamResponse)
|
if (isStreamResponse)
|
||||||
{
|
{
|
||||||
var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
|
var result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
|
||||||
|
|
||||||
if (result == null)
|
if (result == null)
|
||||||
{
|
{
|
||||||
|
|
@ -315,7 +366,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
||||||
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
|
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
|
||||||
}
|
}
|
||||||
|
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
// Invoke or Send
|
// Invoke or Send
|
||||||
|
|
@ -324,7 +374,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
object result;
|
object result;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
result = await ExecuteHubMethod(methodExecutor, hub, arguments);
|
result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
|
||||||
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
|
||||||
}
|
}
|
||||||
catch (Exception ex)
|
catch (Exception ex)
|
||||||
|
|
@ -447,13 +497,36 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
|
private ValueTask<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments, HubConnectionContext connection, IServiceProvider serviceProvider)
|
||||||
|
{
|
||||||
|
if (_invokeMiddleware != null)
|
||||||
|
{
|
||||||
|
var invocationContext = new HubInvocationContext(methodExecutor, connection.HubCallerContext, serviceProvider, hub, arguments);
|
||||||
|
return _invokeMiddleware(invocationContext);
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no Hub filters are registered
|
||||||
|
return ExecuteMethod(methodExecutor, hub, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
private ValueTask<object> ExecuteMethod(string hubMethodName, Hub hub, object[] arguments)
|
||||||
|
{
|
||||||
|
if (!_methods.TryGetValue(hubMethodName, out var methodDescriptor))
|
||||||
|
{
|
||||||
|
throw new HubException($"Unknown hub method '{hubMethodName}'");
|
||||||
|
}
|
||||||
|
var methodExecutor = methodDescriptor.MethodExecutor;
|
||||||
|
return ExecuteMethod(methodExecutor, hub, arguments);
|
||||||
|
}
|
||||||
|
|
||||||
|
private async ValueTask<object> ExecuteMethod(ObjectMethodExecutor methodExecutor, Hub hub, object[] arguments)
|
||||||
{
|
{
|
||||||
if (methodExecutor.IsMethodAsync)
|
if (methodExecutor.IsMethodAsync)
|
||||||
{
|
{
|
||||||
if (methodExecutor.MethodReturnType == typeof(Task))
|
if (methodExecutor.MethodReturnType == typeof(Task))
|
||||||
{
|
{
|
||||||
await (Task)methodExecutor.Execute(hub, arguments);
|
await (Task)methodExecutor.Execute(hub, arguments);
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
|
@ -464,8 +537,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
{
|
{
|
||||||
return methodExecutor.Execute(hub, arguments);
|
return methodExecutor.Execute(hub, arguments);
|
||||||
}
|
}
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private async Task SendInvocationError(string invocationId,
|
private async Task SendInvocationError(string invocationId,
|
||||||
|
|
@ -486,15 +557,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
hub.Groups = _hubContext.Groups;
|
hub.Groups = _hubContext.Groups;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList<IAuthorizeData> policies, string hubMethodName, object[] hubMethodArguments)
|
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, HubMethodDescriptor descriptor, object[] hubMethodArguments, Hub hub)
|
||||||
{
|
{
|
||||||
// If there are no policies we don't need to run auth
|
// If there are no policies we don't need to run auth
|
||||||
if (policies.Count == 0)
|
if (descriptor.Policies.Count == 0)
|
||||||
{
|
{
|
||||||
return TaskCache.True;
|
return TaskCache.True;
|
||||||
}
|
}
|
||||||
|
|
||||||
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, typeof(THub), hubMethodName, hubMethodArguments));
|
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, descriptor.Policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, provider, hub, descriptor.MethodExecutor.MethodInfo, hubMethodArguments));
|
||||||
}
|
}
|
||||||
|
|
||||||
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
|
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
#nullable enable
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Internal
|
||||||
|
{
|
||||||
|
internal class HubFilterFactory : IHubFilter
|
||||||
|
{
|
||||||
|
private readonly ObjectFactory _objectFactory;
|
||||||
|
private readonly Type _filterType;
|
||||||
|
|
||||||
|
public HubFilterFactory(Type filterType)
|
||||||
|
{
|
||||||
|
_objectFactory = ActivatorUtilities.CreateFactory(filterType, Array.Empty<Type>());
|
||||||
|
_filterType = filterType;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
var (filter, owned) = GetFilter(invocationContext.ServiceProvider);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
return await filter.InvokeMethodAsync(invocationContext, next);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
if (owned)
|
||||||
|
{
|
||||||
|
await DisposeFilter(filter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
var (filter, owned) = GetFilter(context.ServiceProvider);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await filter.OnConnectedAsync(context, next);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
if (owned)
|
||||||
|
{
|
||||||
|
await DisposeFilter(filter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
var (filter, owned) = GetFilter(context.ServiceProvider);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await filter.OnDisconnectedAsync(context, exception, next);
|
||||||
|
}
|
||||||
|
finally
|
||||||
|
{
|
||||||
|
if (owned)
|
||||||
|
{
|
||||||
|
await DisposeFilter(filter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private ValueTask DisposeFilter(IHubFilter filter)
|
||||||
|
{
|
||||||
|
if (filter is IAsyncDisposable asyncDispsable)
|
||||||
|
{
|
||||||
|
return asyncDispsable.DisposeAsync();
|
||||||
|
}
|
||||||
|
if (filter is IDisposable disposable)
|
||||||
|
{
|
||||||
|
disposable.Dispose();
|
||||||
|
}
|
||||||
|
return default;
|
||||||
|
}
|
||||||
|
|
||||||
|
private (IHubFilter, bool) GetFilter(IServiceProvider serviceProvider)
|
||||||
|
{
|
||||||
|
var owned = false;
|
||||||
|
var filter = (IHubFilter?)serviceProvider.GetService(_filterType);
|
||||||
|
if (filter == null)
|
||||||
|
{
|
||||||
|
filter = (IHubFilter)_objectFactory.Invoke(serviceProvider, null);
|
||||||
|
owned = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (filter, owned);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1048,8 +1048,19 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
|
||||||
public class TcsService
|
public class TcsService
|
||||||
{
|
{
|
||||||
public TaskCompletionSource<object> StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
public TaskCompletionSource<object> StartedMethod;
|
||||||
public TaskCompletionSource<object> EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
public TaskCompletionSource<object> EndMethod;
|
||||||
|
|
||||||
|
public TcsService()
|
||||||
|
{
|
||||||
|
Reset();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Reset()
|
||||||
|
{
|
||||||
|
StartedMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
|
EndMethod = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public interface ITypedHubClient
|
public interface ITypedHubClient
|
||||||
|
|
|
||||||
|
|
@ -2232,14 +2232,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
{
|
{
|
||||||
Assert.NotNull(context.Resource);
|
Assert.NotNull(context.Resource);
|
||||||
var resource = Assert.IsType<HubInvocationContext>(context.Resource);
|
var resource = Assert.IsType<HubInvocationContext>(context.Resource);
|
||||||
Assert.Equal(typeof(MethodHub), resource.HubType);
|
Assert.Equal(typeof(MethodHub), resource.Hub.GetType());
|
||||||
|
#pragma warning disable CS0618 // Type or member is obsolete
|
||||||
Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
|
Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
|
||||||
|
#pragma warning restore CS0618 // Type or member is obsolete
|
||||||
Assert.Equal(2, resource.HubMethodArguments?.Count);
|
Assert.Equal(2, resource.HubMethodArguments?.Count);
|
||||||
Assert.Equal("Hello", resource.HubMethodArguments[0]);
|
Assert.Equal("Hello", resource.HubMethodArguments[0]);
|
||||||
Assert.Equal("World!", resource.HubMethodArguments[1]);
|
Assert.Equal("World!", resource.HubMethodArguments[1]);
|
||||||
Assert.NotNull(resource.Context);
|
Assert.NotNull(resource.Context);
|
||||||
Assert.Equal(context.User, resource.Context.User);
|
Assert.Equal(context.User, resource.Context.User);
|
||||||
Assert.NotNull(resource.Context.GetHttpContext());
|
Assert.NotNull(resource.Context.GetHttpContext());
|
||||||
|
Assert.NotNull(resource.ServiceProvider);
|
||||||
|
Assert.Equal(typeof(MethodHub).GetMethod(nameof(MethodHub.MultiParamAuthMethod)), resource.HubMethod);
|
||||||
|
|
||||||
return Task.CompletedTask;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,867 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.AspNetCore.Internal;
|
||||||
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
|
using Microsoft.Extensions.Logging.Testing;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public class HubFilterTests : VerifiableLoggedTest
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task GlobalHubFilterByType_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter<VerifyMethodFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GlobalHubFilterByInstance_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task PerHubFilterByInstance_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR().AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task PerHubFilterByCompileTimeType_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR().AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter<VerifyMethodFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task PerHubFilterByRuntimeType_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR().AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(typeof(VerifyMethodFilter));
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async Task AssertMethodsCalled(IServiceProvider serviceProvider, TcsService tcsService)
|
||||||
|
{
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
await tcsService.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
tcsService.Reset();
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
await tcsService.EndMethod.Task.OrTimeout();
|
||||||
|
tcsService.Reset();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
await tcsService.EndMethod.Task.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task MutlipleFilters_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService1 = new TcsService();
|
||||||
|
var tcsService2 = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService1));
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService2));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await tcsService1.StartedMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.StartedMethod.Task.OrTimeout();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task MixingTypeAndInstanceGlobalFilters_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService1 = new TcsService();
|
||||||
|
var tcsService2 = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService1));
|
||||||
|
options.AddFilter<VerifyMethodFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService2);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await tcsService1.StartedMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.StartedMethod.Task.OrTimeout();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task MixingTypeAndInstanceHubSpecificFilters_MethodsAreCalled()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService1 = new TcsService();
|
||||||
|
var tcsService2 = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR()
|
||||||
|
.AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new VerifyMethodFilter(tcsService1));
|
||||||
|
options.AddFilter<VerifyMethodFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService2);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await tcsService1.StartedMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.StartedMethod.Task.OrTimeout();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
tcsService1.Reset();
|
||||||
|
tcsService2.Reset();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
await tcsService1.EndMethod.Task.OrTimeout();
|
||||||
|
await tcsService2.EndMethod.Task.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GlobalFiltersRunInOrder()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
|
||||||
|
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints1));
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints2));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[0].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[0].Continue();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
|
||||||
|
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
|
||||||
|
|
||||||
|
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[1].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[1].Continue();
|
||||||
|
var message = await invokeTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[2].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[2].Continue();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task HubSpecificFiltersRunInOrder()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
|
||||||
|
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR()
|
||||||
|
.AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints1));
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints2));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[0].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[0].Continue();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
|
||||||
|
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
|
||||||
|
|
||||||
|
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[1].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[1].Continue();
|
||||||
|
var message = await invokeTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[2].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[2].Continue();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task GlobalFiltersRunBeforeHubSpecificFilters()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var syncPoint1 = SyncPoint.Create(3, out var syncPoints1);
|
||||||
|
var syncPoint2 = SyncPoint.Create(3, out var syncPoints2);
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints1));
|
||||||
|
})
|
||||||
|
.AddHubOptions<MethodHub>(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new SyncPointFilter(syncPoints2));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await syncPoints1[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[0].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[0].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[0].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[0].Continue();
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
|
||||||
|
var invokeTask = client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!");
|
||||||
|
|
||||||
|
await syncPoints1[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[1].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[1].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[1].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[1].Continue();
|
||||||
|
var message = await invokeTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await syncPoints1[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
// Second filter wont run yet because first filter is waiting on SyncPoint
|
||||||
|
Assert.False(syncPoints2[2].WaitForSyncPoint().IsCompleted);
|
||||||
|
syncPoints1[2].Continue();
|
||||||
|
|
||||||
|
await syncPoints2[2].WaitForSyncPoint().OrTimeout();
|
||||||
|
syncPoints2[2].Continue();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FilterCanBeResolvedFromDI()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter<VerifyMethodFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
// If this instance wasn't resolved, then the tcsService.StartedMethod waits would never trigger and fail the test
|
||||||
|
services.AddSingleton(new VerifyMethodFilter(tcsService));
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
await AssertMethodsCalled(serviceProvider, tcsService);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FiltersHaveTransientScopeByDefault()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var counter = new FilterCounter();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter<CounterFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(counter);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
// Filter is transient, so these counts are reset every time the filter is created
|
||||||
|
Assert.Equal(1, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(0, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
// Filter is transient, so these counts are reset every time the filter is created
|
||||||
|
Assert.Equal(0, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(1, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
// Filter is transient, so these counts are reset every time the filter is created
|
||||||
|
Assert.Equal(0, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(0, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(1, counter.OnDisconnectedAsyncCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FiltersCanBeSingletonIfAddedToDI()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var counter = new FilterCounter();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter<CounterFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton<CounterFilter>();
|
||||||
|
services.AddSingleton(counter);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
Assert.Equal(1, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(0, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
Assert.Equal(1, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(1, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(0, counter.OnDisconnectedAsyncCount);
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.Equal(1, counter.OnConnectedAsyncCount);
|
||||||
|
Assert.Equal(1, counter.InvokeMethodAsyncCount);
|
||||||
|
Assert.Equal(1, counter.OnDisconnectedAsyncCount);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ConnectionContinuesIfOnConnectedAsyncThrowsAndFilterDoesNot()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter<NoExceptionFilter>();
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<OnConnectedThrowsHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
// Verify connection still connected, can't invoke a method if the connection is disconnected
|
||||||
|
var message = await client.InvokeAsync("Method");
|
||||||
|
Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task ConnectionContinuesIfOnConnectedAsyncNotCalledByFilter()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter(new SkipNextFilter(skipOnConnected: true));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
// Verify connection still connected, can't invoke a method if the connection is disconnected
|
||||||
|
var message = await client.InvokeAsync("Method");
|
||||||
|
Assert.Equal("Failed to invoke 'Method' due to an error on the server. HubException: Method does not exist.", message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FilterCanSkipCallingHubMethod()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.AddFilter(new SkipNextFilter(skipInvoke: true));
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
await client.Connected.OrTimeout();
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync(nameof(MethodHub.Echo), "Hello world!").OrTimeout();
|
||||||
|
|
||||||
|
Assert.Null(message.Error);
|
||||||
|
Assert.Null(message.Result);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FiltersWithIDisposableAreDisposed()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter<DisposableFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
// OnConnectedAsync creates and destroys the filter
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
tcsService.Reset();
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync("Echo", "Hello");
|
||||||
|
Assert.Equal("Hello", message.Result);
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
tcsService.Reset();
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
// OnDisconnectedAsync creates and destroys the filter
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InstanceFiltersWithIDisposableAreNotDisposed()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter(new DisposableFilter(tcsService));
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync("Echo", "Hello");
|
||||||
|
Assert.Equal("Hello", message.Result);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.False(tcsService.StartedMethod.Task.IsCompleted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task FiltersWithIAsyncDisposableAreDisposed()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter<AsyncDisposableFilter>();
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
// OnConnectedAsync creates and destroys the filter
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
tcsService.Reset();
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync("Echo", "Hello");
|
||||||
|
Assert.Equal("Hello", message.Result);
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
tcsService.Reset();
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
// OnDisconnectedAsync creates and destroys the filter
|
||||||
|
await tcsService.StartedMethod.Task.OrTimeout();
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InstanceFiltersWithIAsyncDisposableAreNotDisposed()
|
||||||
|
{
|
||||||
|
using (StartVerifiableLog())
|
||||||
|
{
|
||||||
|
var tcsService = new TcsService();
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter(new AsyncDisposableFilter(tcsService));
|
||||||
|
});
|
||||||
|
|
||||||
|
services.AddSingleton(tcsService);
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync("Echo", "Hello");
|
||||||
|
Assert.Equal("Hello", message.Result);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
|
||||||
|
Assert.False(tcsService.StartedMethod.Task.IsCompleted);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeFailsWhenFilterCallsNonExistantMethod()
|
||||||
|
{
|
||||||
|
bool ExpectedErrors(WriteContext writeContext)
|
||||||
|
{
|
||||||
|
return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
|
||||||
|
writeContext.EventId.Name == "FailedInvokingHubMethod";
|
||||||
|
}
|
||||||
|
|
||||||
|
using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors))
|
||||||
|
{
|
||||||
|
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||||
|
{
|
||||||
|
services.AddSignalR(options =>
|
||||||
|
{
|
||||||
|
options.EnableDetailedErrors = true;
|
||||||
|
options.AddFilter<ChangeMethodFilter>();
|
||||||
|
});
|
||||||
|
}, LoggerFactory);
|
||||||
|
|
||||||
|
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||||
|
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||||
|
|
||||||
|
var message = await client.InvokeAsync("Echo", "Hello");
|
||||||
|
Assert.Equal("An unexpected error occurred invoking 'Echo' on the server. HubException: Unknown hub method 'BaseMethod'", message.Error);
|
||||||
|
|
||||||
|
client.Dispose();
|
||||||
|
|
||||||
|
await connectionHandlerTask.OrTimeout();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,236 @@
|
||||||
|
// Copyright (c) .NET Foundation. All rights reserved.
|
||||||
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||||
|
|
||||||
|
using System;
|
||||||
|
using System.Diagnostics;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.AspNetCore.Internal;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public class VerifyMethodFilter : IHubFilter
|
||||||
|
{
|
||||||
|
private readonly TcsService _service;
|
||||||
|
public VerifyMethodFilter(TcsService tcsService)
|
||||||
|
{
|
||||||
|
_service = tcsService;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
_service.StartedMethod.TrySetResult(null);
|
||||||
|
await next(context);
|
||||||
|
_service.EndMethod.TrySetResult(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
_service.StartedMethod.TrySetResult(null);
|
||||||
|
var result = await next(invocationContext);
|
||||||
|
_service.EndMethod.TrySetResult(null);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
_service.StartedMethod.TrySetResult(null);
|
||||||
|
await next(context, exception);
|
||||||
|
_service.EndMethod.TrySetResult(null);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class SyncPointFilter : IHubFilter
|
||||||
|
{
|
||||||
|
private readonly SyncPoint[] _syncPoint;
|
||||||
|
public SyncPointFilter(SyncPoint[] syncPoints)
|
||||||
|
{
|
||||||
|
Debug.Assert(syncPoints.Length == 3);
|
||||||
|
_syncPoint = syncPoints;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
await _syncPoint[0].WaitToContinue();
|
||||||
|
await next(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
await _syncPoint[1].WaitToContinue();
|
||||||
|
var result = await next(invocationContext);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
await _syncPoint[2].WaitToContinue();
|
||||||
|
await next(context, exception);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class FilterCounter
|
||||||
|
{
|
||||||
|
public int OnConnectedAsyncCount;
|
||||||
|
public int InvokeMethodAsyncCount;
|
||||||
|
public int OnDisconnectedAsyncCount;
|
||||||
|
}
|
||||||
|
|
||||||
|
public class CounterFilter : IHubFilter
|
||||||
|
{
|
||||||
|
private readonly FilterCounter _counter;
|
||||||
|
public CounterFilter(FilterCounter counter)
|
||||||
|
{
|
||||||
|
_counter = counter;
|
||||||
|
_counter.OnConnectedAsyncCount = 0;
|
||||||
|
_counter.InvokeMethodAsyncCount = 0;
|
||||||
|
_counter.OnDisconnectedAsyncCount = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
_counter.OnConnectedAsyncCount++;
|
||||||
|
return next(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
_counter.OnDisconnectedAsyncCount++;
|
||||||
|
return next(context, exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
_counter.InvokeMethodAsyncCount++;
|
||||||
|
return next(invocationContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class NoExceptionFilter : IHubFilter
|
||||||
|
{
|
||||||
|
public async Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await next(context);
|
||||||
|
}
|
||||||
|
catch { }
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
await next(context, exception);
|
||||||
|
}
|
||||||
|
catch { }
|
||||||
|
}
|
||||||
|
|
||||||
|
public async ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
try
|
||||||
|
{
|
||||||
|
return await next(invocationContext);
|
||||||
|
}
|
||||||
|
catch { }
|
||||||
|
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class SkipNextFilter : IHubFilter
|
||||||
|
{
|
||||||
|
private readonly bool _skipOnConnected;
|
||||||
|
private readonly bool _skipInvoke;
|
||||||
|
private readonly bool _skipOnDisconnected;
|
||||||
|
|
||||||
|
public SkipNextFilter(bool skipOnConnected = false, bool skipInvoke = false, bool skipOnDisconnected = false)
|
||||||
|
{
|
||||||
|
_skipOnConnected = skipOnConnected;
|
||||||
|
_skipInvoke = skipInvoke;
|
||||||
|
_skipOnDisconnected = skipOnDisconnected;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task OnConnectedAsync(HubLifetimeContext context, Func<HubLifetimeContext, Task> next)
|
||||||
|
{
|
||||||
|
if (_skipOnConnected)
|
||||||
|
{
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
|
||||||
|
return next(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
public Task OnDisconnectedAsync(HubLifetimeContext context, Exception exception, Func<HubLifetimeContext, Exception, Task> next)
|
||||||
|
{
|
||||||
|
if (_skipOnDisconnected)
|
||||||
|
{
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
|
||||||
|
return next(context, exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
if (_skipInvoke)
|
||||||
|
{
|
||||||
|
return new ValueTask<object>();
|
||||||
|
}
|
||||||
|
|
||||||
|
return next(invocationContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class DisposableFilter : IHubFilter, IDisposable
|
||||||
|
{
|
||||||
|
private readonly TcsService _tcsService;
|
||||||
|
|
||||||
|
public DisposableFilter(TcsService tcsService)
|
||||||
|
{
|
||||||
|
_tcsService = tcsService;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void Dispose()
|
||||||
|
{
|
||||||
|
_tcsService.StartedMethod.SetResult(null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
return next(invocationContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class AsyncDisposableFilter : IHubFilter, IAsyncDisposable
|
||||||
|
{
|
||||||
|
private readonly TcsService _tcsService;
|
||||||
|
|
||||||
|
public AsyncDisposableFilter(TcsService tcsService)
|
||||||
|
{
|
||||||
|
_tcsService = tcsService;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueTask DisposeAsync()
|
||||||
|
{
|
||||||
|
_tcsService.StartedMethod.SetResult(null);
|
||||||
|
return default;
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
return next(invocationContext);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public class ChangeMethodFilter : IHubFilter
|
||||||
|
{
|
||||||
|
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
|
||||||
|
{
|
||||||
|
var methodInfo = typeof(BaseHub).GetMethod(nameof(BaseHub.BaseMethod));
|
||||||
|
var context = new HubInvocationContext(invocationContext.Context, invocationContext.ServiceProvider, invocationContext.Hub, methodInfo, invocationContext.HubMethodArguments);
|
||||||
|
return next(context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue