[SignalR] Pass a resource into IPolicyEvaluator for Hub method auth (#11070)
This commit is contained in:
parent
f9aa85a829
commit
5b31a9540a
|
|
@ -172,6 +172,13 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public void Reset() { }
|
||||
}
|
||||
}
|
||||
public partial class HubInvocationContext
|
||||
{
|
||||
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
|
||||
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
|
||||
public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
|
||||
public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
|
||||
}
|
||||
public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
|
||||
{
|
||||
protected HubLifetimeManager() { }
|
||||
|
|
|
|||
|
|
@ -59,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
_connectionContext = connectionContext;
|
||||
_logger = loggerFactory.CreateLogger<HubConnectionContext>();
|
||||
ConnectionAborted = _connectionAbortedTokenSource.Token;
|
||||
|
||||
HubCallerContext = new DefaultHubCallerContext(this);
|
||||
}
|
||||
|
||||
internal StreamTracker StreamTracker
|
||||
|
|
@ -75,6 +77,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
internal HubCallerContext HubCallerContext { get; }
|
||||
|
||||
/// <summary>
|
||||
/// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class HubInvocationContext
|
||||
{
|
||||
public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
|
||||
{
|
||||
HubMethodName = hubMethodName;
|
||||
HubMethodArguments = hubMethodArguments;
|
||||
Context = context;
|
||||
}
|
||||
|
||||
public HubCallerContext Context { get; }
|
||||
public string HubMethodName { get; }
|
||||
public IReadOnlyList<object> HubMethodArguments { get; }
|
||||
}
|
||||
}
|
||||
|
|
@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
THub hub = null;
|
||||
try
|
||||
{
|
||||
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
|
||||
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments))
|
||||
{
|
||||
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
|
||||
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
|
||||
|
|
@ -479,11 +479,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
private void InitializeHub(THub hub, HubConnectionContext connection)
|
||||
{
|
||||
hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);
|
||||
hub.Context = new DefaultHubCallerContext(connection);
|
||||
hub.Context = connection.HubCallerContext;
|
||||
hub.Groups = _hubContext.Groups;
|
||||
}
|
||||
|
||||
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
|
||||
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList<IAuthorizeData> policies, string hubMethodName, object[] hubMethodArguments)
|
||||
{
|
||||
// If there are no policies we don't need to run auth
|
||||
if (!policies.Any())
|
||||
|
|
@ -491,10 +491,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
return TaskCache.True;
|
||||
}
|
||||
|
||||
return IsHubMethodAuthorizedSlow(provider, principal, policies);
|
||||
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, hubMethodName, hubMethodArguments));
|
||||
}
|
||||
|
||||
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
|
||||
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
|
||||
{
|
||||
var authService = provider.GetRequiredService<IAuthorizationService>();
|
||||
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
|
||||
|
|
@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
|
|||
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
|
||||
Debug.Assert(authorizePolicy != null);
|
||||
|
||||
var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
|
||||
var authorizationResult = await authService.AuthorizeAsync(principal, resource, authorizePolicy);
|
||||
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
|
||||
return authorizationResult.Succeeded;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -143,6 +143,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
}
|
||||
|
||||
[Authorize("test")]
|
||||
public void MultiParamAuthMethod(string s1, string s2)
|
||||
{
|
||||
}
|
||||
|
||||
public Task SendToAllExcept(string message, IReadOnlyList<string> excludedConnectionIds)
|
||||
{
|
||||
return Clients.AllExcept(excludedConnectionIds).SendAsync("Send", message);
|
||||
|
|
|
|||
|
|
@ -12,9 +12,11 @@ using System.Text;
|
|||
using System.Threading.Tasks;
|
||||
using MessagePack;
|
||||
using MessagePack.Formatters;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Connections.Features;
|
||||
using Microsoft.AspNetCore.Http.Connections.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
|
|
@ -2198,6 +2200,69 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
private class TestAuthHandler : IAuthorizationHandler
|
||||
{
|
||||
public Task HandleAsync(AuthorizationHandlerContext context)
|
||||
{
|
||||
Assert.NotNull(context.Resource);
|
||||
var resource = Assert.IsType<HubInvocationContext>(context.Resource);
|
||||
Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
|
||||
Assert.Equal(2, resource.HubMethodArguments?.Count);
|
||||
Assert.Equal("Hello", resource.HubMethodArguments[0]);
|
||||
Assert.Equal("World!", resource.HubMethodArguments[1]);
|
||||
Assert.NotNull(resource.Context);
|
||||
Assert.Equal(context.User, resource.Context.User);
|
||||
Assert.NotNull(resource.Context.GetHttpContext());
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HubMethodWithAuthorizationProvidesResourceToAuthHandlers()
|
||||
{
|
||||
using (StartVerifiableLog())
|
||||
{
|
||||
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
|
||||
{
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
|
||||
services.AddSingleton<IAuthorizationHandler, TestAuthHandler>();
|
||||
}, LoggerFactory);
|
||||
|
||||
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// Setup a HttpContext to make sure it flows to the AuthHandler correctly
|
||||
var httpConnectionContext = new HttpContextFeatureImpl();
|
||||
httpConnectionContext.HttpContext = new DefaultHttpContext();
|
||||
client.Connection.Features.Set<IHttpContextFeature>(httpConnectionContext);
|
||||
|
||||
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
|
||||
|
||||
await client.Connected.OrTimeout();
|
||||
|
||||
var message = await client.InvokeAsync(nameof(MethodHub.MultiParamAuthMethod), "Hello", "World!").OrTimeout();
|
||||
|
||||
Assert.Null(message.Error);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await connectionHandlerTask.OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HubOptionsCanUseCustomJsonSerializerSettings()
|
||||
{
|
||||
|
|
@ -3632,5 +3697,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
public int Bar { get; }
|
||||
public string Foo { get; }
|
||||
}
|
||||
|
||||
private class HttpContextFeatureImpl : IHttpContextFeature
|
||||
{
|
||||
public HttpContext HttpContext { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue