[SignalR] Pass a resource into IPolicyEvaluator for Hub method auth (#11070)

This commit is contained in:
Brennan 2019-06-21 14:08:50 -07:00 committed by GitHub
parent f9aa85a829
commit 5b31a9540a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 113 additions and 6 deletions

View File

@ -172,6 +172,13 @@ namespace Microsoft.AspNetCore.SignalR
public void Reset() { }
}
}
public partial class HubInvocationContext
{
public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { }
public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
public System.Collections.Generic.IReadOnlyList<object> HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
}
public abstract partial class HubLifetimeManager<THub> where THub : Microsoft.AspNetCore.SignalR.Hub
{
protected HubLifetimeManager() { }

View File

@ -59,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR
_connectionContext = connectionContext;
_logger = loggerFactory.CreateLogger<HubConnectionContext>();
ConnectionAborted = _connectionAbortedTokenSource.Token;
HubCallerContext = new DefaultHubCallerContext(this);
}
internal StreamTracker StreamTracker
@ -75,6 +77,8 @@ namespace Microsoft.AspNetCore.SignalR
}
}
internal HubCallerContext HubCallerContext { get; }
/// <summary>
/// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
/// </summary>

View File

@ -0,0 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubInvocationContext
{
public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments)
{
HubMethodName = hubMethodName;
HubMethodArguments = hubMethodArguments;
Context = context;
}
public HubCallerContext Context { get; }
public string HubMethodName { get; }
public IReadOnlyList<object> HubMethodArguments { get; }
}
}

View File

@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
THub hub = null;
try
{
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments))
{
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
@ -479,11 +479,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private void InitializeHub(THub hub, HubConnectionContext connection)
{
hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);
hub.Context = new DefaultHubCallerContext(connection);
hub.Context = connection.HubCallerContext;
hub.Groups = _hubContext.Groups;
}
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList<IAuthorizeData> policies, string hubMethodName, object[] hubMethodArguments)
{
// If there are no policies we don't need to run auth
if (!policies.Any())
@ -491,10 +491,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return TaskCache.True;
}
return IsHubMethodAuthorizedSlow(provider, principal, policies);
return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, hubMethodName, hubMethodArguments));
}
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies, HubInvocationContext resource)
{
var authService = provider.GetRequiredService<IAuthorizationService>();
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
Debug.Assert(authorizePolicy != null);
var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
var authorizationResult = await authService.AuthorizeAsync(principal, resource, authorizePolicy);
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
return authorizationResult.Succeeded;
}

View File

@ -143,6 +143,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
}
[Authorize("test")]
public void MultiParamAuthMethod(string s1, string s2)
{
}
public Task SendToAllExcept(string message, IReadOnlyList<string> excludedConnectionIds)
{
return Clients.AllExcept(excludedConnectionIds).SendAsync("Send", message);

View File

@ -12,9 +12,11 @@ using System.Text;
using System.Threading.Tasks;
using MessagePack;
using MessagePack.Formatters;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections.Features;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
@ -2198,6 +2200,69 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
private class TestAuthHandler : IAuthorizationHandler
{
public Task HandleAsync(AuthorizationHandlerContext context)
{
Assert.NotNull(context.Resource);
var resource = Assert.IsType<HubInvocationContext>(context.Resource);
Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName);
Assert.Equal(2, resource.HubMethodArguments?.Count);
Assert.Equal("Hello", resource.HubMethodArguments[0]);
Assert.Equal("World!", resource.HubMethodArguments[1]);
Assert.NotNull(resource.Context);
Assert.Equal(context.User, resource.Context.User);
Assert.NotNull(resource.Context.GetHttpContext());
return Task.CompletedTask;
}
}
[Fact]
public async Task HubMethodWithAuthorizationProvidesResourceToAuthHandlers()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddSingleton<IAuthorizationHandler, TestAuthHandler>();
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// Setup a HttpContext to make sure it flows to the AuthHandler correctly
var httpConnectionContext = new HttpContextFeatureImpl();
httpConnectionContext.HttpContext = new DefaultHttpContext();
client.Connection.Features.Set<IHttpContextFeature>(httpConnectionContext);
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
var message = await client.InvokeAsync(nameof(MethodHub.MultiParamAuthMethod), "Hello", "World!").OrTimeout();
Assert.Null(message.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task HubOptionsCanUseCustomJsonSerializerSettings()
{
@ -3632,5 +3697,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public int Bar { get; }
public string Foo { get; }
}
private class HttpContextFeatureImpl : IHttpContextFeature
{
public HttpContext HttpContext { get; set; }
}
}
}