diff --git a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs index 14087c25fe..6e2f3f1f71 100644 --- a/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs +++ b/src/SignalR/server/Core/ref/Microsoft.AspNetCore.SignalR.Core.netcoreapp3.0.cs @@ -172,6 +172,13 @@ namespace Microsoft.AspNetCore.SignalR public void Reset() { } } } + public partial class HubInvocationContext + { + public HubInvocationContext(Microsoft.AspNetCore.SignalR.HubCallerContext context, string hubMethodName, object[] hubMethodArguments) { } + public Microsoft.AspNetCore.SignalR.HubCallerContext Context { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public System.Collections.Generic.IReadOnlyList HubMethodArguments { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public string HubMethodName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + } public abstract partial class HubLifetimeManager where THub : Microsoft.AspNetCore.SignalR.Hub { protected HubLifetimeManager() { } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 16ca9313bf..22400fa702 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -59,6 +59,8 @@ namespace Microsoft.AspNetCore.SignalR _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); ConnectionAborted = _connectionAbortedTokenSource.Token; + + HubCallerContext = new DefaultHubCallerContext(this); } internal StreamTracker StreamTracker @@ -75,6 +77,8 @@ namespace Microsoft.AspNetCore.SignalR } } + internal HubCallerContext HubCallerContext { get; } + /// /// Gets a that notifies when the connection is aborted. /// diff --git a/src/SignalR/server/Core/src/HubInvocationContext.cs b/src/SignalR/server/Core/src/HubInvocationContext.cs new file mode 100644 index 0000000000..9cda75da11 --- /dev/null +++ b/src/SignalR/server/Core/src/HubInvocationContext.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubInvocationContext + { + public HubInvocationContext(HubCallerContext context, string hubMethodName, object[] hubMethodArguments) + { + HubMethodName = hubMethodName; + HubMethodArguments = hubMethodArguments; + Context = context; + } + + public HubCallerContext Context { get; } + public string HubMethodName { get; } + public IReadOnlyList HubMethodArguments { get; } + } +} diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 86fe4b32ac..1a21ad47f1 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal THub hub = null; try { - if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies)) + if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection, descriptor.Policies, descriptor.MethodExecutor.MethodInfo.Name, hubMethodInvocationMessage.Arguments)) { Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, @@ -479,11 +479,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal private void InitializeHub(THub hub, HubConnectionContext connection) { hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); - hub.Context = new DefaultHubCallerContext(connection); + hub.Context = connection.HubCallerContext; hub.Groups = _hubContext.Groups; } - private Task IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList policies) + private Task IsHubMethodAuthorized(IServiceProvider provider, HubConnectionContext hubConnectionContext, IList policies, string hubMethodName, object[] hubMethodArguments) { // If there are no policies we don't need to run auth if (!policies.Any()) @@ -491,10 +491,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal return TaskCache.True; } - return IsHubMethodAuthorizedSlow(provider, principal, policies); + return IsHubMethodAuthorizedSlow(provider, hubConnectionContext.User, policies, new HubInvocationContext(hubConnectionContext.HubCallerContext, hubMethodName, hubMethodArguments)); } - private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList policies) + private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList policies, HubInvocationContext resource) { var authService = provider.GetRequiredService(); var policyProvider = provider.GetRequiredService(); @@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal // AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above Debug.Assert(authorizePolicy != null); - var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy); + var authorizationResult = await authService.AuthorizeAsync(principal, resource, authorizePolicy); // Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation return authorizationResult.Succeeded; } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index f99cb6c76e..504c2cc751 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -143,6 +143,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { } + [Authorize("test")] + public void MultiParamAuthMethod(string s1, string s2) + { + } + public Task SendToAllExcept(string message, IReadOnlyList excludedConnectionIds) { return Clients.AllExcept(excludedConnectionIds).SendAsync("Send", message); diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 56b77c0b7a..1c047973a3 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -12,9 +12,11 @@ using System.Text; using System.Threading.Tasks; using MessagePack; using MessagePack.Formatters; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Connections.Features; +using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; @@ -2198,6 +2200,69 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + private class TestAuthHandler : IAuthorizationHandler + { + public Task HandleAsync(AuthorizationHandlerContext context) + { + Assert.NotNull(context.Resource); + var resource = Assert.IsType(context.Resource); + Assert.Equal(nameof(MethodHub.MultiParamAuthMethod), resource.HubMethodName); + Assert.Equal(2, resource.HubMethodArguments?.Count); + Assert.Equal("Hello", resource.HubMethodArguments[0]); + Assert.Equal("World!", resource.HubMethodArguments[1]); + Assert.NotNull(resource.Context); + Assert.Equal(context.User, resource.Context.User); + Assert.NotNull(resource.Context.GetHttpContext()); + + return Task.CompletedTask; + } + } + + [Fact] + public async Task HubMethodWithAuthorizationProvidesResourceToAuthHandlers() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddAuthorization(options => + { + options.AddPolicy("test", policy => + { + policy.RequireClaim(ClaimTypes.NameIdentifier); + policy.AddAuthenticationSchemes("Default"); + }); + }); + + services.AddSingleton(); + }, LoggerFactory); + + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); + + // Setup a HttpContext to make sure it flows to the AuthHandler correctly + var httpConnectionContext = new HttpContextFeatureImpl(); + httpConnectionContext.HttpContext = new DefaultHttpContext(); + client.Connection.Features.Set(httpConnectionContext); + + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + + var message = await client.InvokeAsync(nameof(MethodHub.MultiParamAuthMethod), "Hello", "World!").OrTimeout(); + + Assert.Null(message.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + [Fact] public async Task HubOptionsCanUseCustomJsonSerializerSettings() { @@ -3632,5 +3697,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests public int Bar { get; } public string Foo { get; } } + + private class HttpContextFeatureImpl : IHttpContextFeature + { + public HttpContext HttpContext { get; set; } + } } }