From ef273b47969f24ed9b11eb251ad3de7a0b249a2e Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 23 Jun 2017 10:22:05 -0700 Subject: [PATCH] Add authorization per hub method (#577) --- samples/SocketsSample/wwwroot/hubs.html | 37 +++++---- .../HubRouteBuilder.cs | 6 +- .../HubEndPoint.cs | 38 +++++++++- .../Microsoft.AspNetCore.SignalR.csproj | 1 + .../SignalRDependencyInjectionExtensions.cs | 2 + .../Microsoft.AspNetCore.Sockets.Http.csproj | 2 +- .../HubEndpointTests.cs | 76 ++++++++++++++++++- .../MapSignalRTests.cs | 27 +++++++ .../TestClient.cs | 20 ++--- .../HttpConnectionDispatcherTests.cs | 1 - 10 files changed, 178 insertions(+), 32 deletions(-) diff --git a/samples/SocketsSample/wwwroot/hubs.html b/samples/SocketsSample/wwwroot/hubs.html index 8049de291a..c7fcc7e4e8 100644 --- a/samples/SocketsSample/wwwroot/hubs.html +++ b/samples/SocketsSample/wwwroot/hubs.html @@ -81,22 +81,29 @@ let transportType = signalR.TransportType[getParameterByName('transport')] || si document.getElementById('head1').innerHTML = signalR.TransportType[transportType]; -let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType }); -let connection = new signalR.HubConnection(http); -connection.on('Send', msg => { - addLine('message-list', msg); -}); - -connection.onClosed = e => { - if (e) { - addLine('message-list', 'Connection closed with error: ' + e, 'red'); - } - else { - addLine('message-list', 'Disconnected', 'green'); - } -} +let connectButton = document.getElementById('connect'); +let disconnectButton = document.getElementById('disconnect'); +disconnectButton.disabled = true; +var connection; click('connect', event => { + connectButton.disabled = true; + disconnectButton.disabled = false; + let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType }); + connection = new signalR.HubConnection(http); + connection.on('Send', msg => { + addLine('message-list', msg); + }); + + connection.onClosed = e => { + if (e) { + addLine('message-list', 'Connection closed with error: ' + e, 'red'); + } + else { + addLine('message-list', 'Disconnected', 'green'); + } + } + connection.start() .then(() => { isConnected = true; @@ -108,6 +115,8 @@ click('connect', event => { }); click('disconnect', event => { + connectButton.disabled = false; + disconnectButton.disabled = true; connection.stop() .then(() => { isConnected = false; diff --git a/src/Microsoft.AspNetCore.SignalR.Http/HubRouteBuilder.cs b/src/Microsoft.AspNetCore.SignalR.Http/HubRouteBuilder.cs index 7e720b2fbd..d50dd2b17f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Http/HubRouteBuilder.cs +++ b/src/Microsoft.AspNetCore.SignalR.Http/HubRouteBuilder.cs @@ -25,11 +25,11 @@ namespace Microsoft.AspNetCore.SignalR public void MapHub(string path, Action socketOptions) where THub : Hub { // find auth attributes - var authorizeAttribute = typeof(THub).GetCustomAttribute(inherit: true); + var authorizeAttributes = typeof(THub).GetCustomAttributes(inherit: true); var options = new HttpSocketOptions(); - if (authorizeAttribute != null) + foreach (var attribute in authorizeAttributes) { - options.AuthorizationData.Add(authorizeAttribute); + options.AuthorizationData.Add(attribute); } socketOptions?.Invoke(options); diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 696dc7408f..b865abdc5c 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -3,11 +3,14 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; using System.Reflection; +using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; @@ -251,6 +254,13 @@ namespace Microsoft.AspNetCore.SignalR using (var scope = _serviceScopeFactory.CreateScope()) { + if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies)) + { + _logger.LogDebug("Failed to invoke {hubMethod} because user is unauthorized", invocationMessage.Target); + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized")); + return; + } + var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); @@ -395,7 +405,8 @@ namespace Microsoft.AspNetCore.SignalR } var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo); - _methods[methodName] = new HubMethodDescriptor(executor); + var authorizeAttributes = methodInfo.GetCustomAttributes(inherit: true); + _methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes); if (_logger.IsEnabled(LogLevel.Debug)) { @@ -404,6 +415,26 @@ namespace Microsoft.AspNetCore.SignalR } } + private async Task IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList policies) + { + // If there are no policies we don't need to run auth + if (!policies.Any()) + { + return true; + } + + var authService = provider.GetRequiredService(); + var policyProvider = provider.GetRequiredService(); + + var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies); + // AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above + Debug.Assert(authorizePolicy != null); + + var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy); + // Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation + return authorizationResult.Succeeded; + } + Type IInvocationBinder.GetReturnType(string invocationId) { return typeof(object); @@ -422,15 +453,18 @@ namespace Microsoft.AspNetCore.SignalR // REVIEW: We can decide to move this out of here if we want pluggable hub discovery private class HubMethodDescriptor { - public HubMethodDescriptor(ObjectMethodExecutor methodExecutor) + public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { MethodExecutor = methodExecutor; ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray(); + Policies = policies.ToArray(); } public ObjectMethodExecutor MethodExecutor { get; } public Type[] ParameterTypes { get; } + + public IList Policies { get; } } } } diff --git a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj index 1a5d416215..64ac863d88 100644 --- a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj +++ b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj @@ -14,6 +14,7 @@ + diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index 8fe8166b85..5961e221fe 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -16,6 +16,8 @@ namespace Microsoft.Extensions.DependencyInjection services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>)); + services.AddAuthorization(); + return new SignalRBuilder(services); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index 560672734b..085abf658f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -13,7 +13,7 @@ - + diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index c7a43e1030..dccf6de789 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; @@ -597,7 +599,74 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); - await endPointLifetime; + await endPointLifetime.OrTimeout(); + } + } + + [Fact] + public async Task UnauthorizedConnectionCannotInvokeHubMethodWithAuthorization() + { + var serviceProvider = CreateServiceProvider(services => + { + services.AddAuthorization(options => + { + options.AddPolicy("test", policy => + { + policy.RequireClaim(ClaimTypes.NameIdentifier); + policy.AddAuthenticationSchemes("Default"); + }); + }); + }); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); + + await client.Connected.OrTimeout(); + + var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout(); + + Assert.NotNull(message.Error); + + client.Dispose(); + + await endPointLifetime.OrTimeout(); + } + } + + [Fact] + public async Task AuthorizedConnectionCanInvokeHubMethodWithAuthorization() + { + var serviceProvider = CreateServiceProvider(services => + { + services.AddAuthorization(options => + { + options.AddPolicy("test", policy => + { + policy.RequireClaim(ClaimTypes.NameIdentifier); + policy.AddAuthenticationSchemes("Default"); + }); + }); + }); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); + + await client.Connected.OrTimeout(); + + var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout(); + + Assert.Null(message.Error); + + client.Dispose(); + + await endPointLifetime.OrTimeout(); } } @@ -793,6 +862,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests public static void StaticMethod() { } + + [Authorize("test")] + public void AuthMethod() + { + } } private class InheritedHub : BaseHub diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/MapSignalRTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/MapSignalRTests.cs index 0b097e0101..559eeb477a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/MapSignalRTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/MapSignalRTests.cs @@ -74,6 +74,28 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(1, authCount); } + [Fact] + public void MapHubFindsMultipleAuthAttributesOnDoubleAuthHub() + { + var authCount = 0; + var builder = new WebHostBuilder() + .UseKestrel() + .ConfigureServices(services => + { + services.AddSignalR(); + }) + .Configure(app => + { + app.UseSignalR(options => options.MapHub("path", httpSocketOptions => + { + authCount += httpSocketOptions.AuthorizationData.Count; + })); + }) + .Build(); + + Assert.Equal(2, authCount); + } + private class InvalidHub : Hub { public void OverloadedMethod(int num) @@ -85,6 +107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Authorize] + private class DoubleAuthHub : AuthHub + { + } + private class InheritedAuthHub : AuthHub { } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 37d01ea6ec..847b8f3a0e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -52,16 +52,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var message = await Read(); - if (!string.Equals(message.InvocationId, invocationId)) - { - throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); - } - if (message == null) { throw new InvalidOperationException("Connection aborted!"); } + if (!string.Equals(message.InvocationId, invocationId)) + { + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + switch (message) { case StreamItemMessage _: @@ -84,16 +84,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var message = await Read(); - if (!string.Equals(message.InvocationId, invocationId)) - { - throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); - } - if (message == null) { throw new InvalidOperationException("Connection aborted!"); } + if (!string.Equals(message.InvocationId, invocationId)) + { + throw new NotSupportedException("TestClient does not support multiple outgoing invocations!"); + } + switch (message) { case StreamItemMessage result: diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 297549971b..602db4d1d5 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -1071,7 +1071,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests { public override void OnCompleted(Func callback, object state) { - } public override void OnStarting(Func callback, object state)