Add authorization per hub method (#577)

This commit is contained in:
BrennanConroy 2017-06-23 10:22:05 -07:00 committed by GitHub
parent a84ba8820f
commit ef273b4796
10 changed files with 178 additions and 32 deletions

View File

@ -81,22 +81,29 @@ let transportType = signalR.TransportType[getParameterByName('transport')] || si
document.getElementById('head1').innerHTML = signalR.TransportType[transportType];
let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType });
let connection = new signalR.HubConnection(http);
connection.on('Send', msg => {
addLine('message-list', msg);
});
connection.onClosed = e => {
if (e) {
addLine('message-list', 'Connection closed with error: ' + e, 'red');
}
else {
addLine('message-list', 'Disconnected', 'green');
}
}
let connectButton = document.getElementById('connect');
let disconnectButton = document.getElementById('disconnect');
disconnectButton.disabled = true;
var connection;
click('connect', event => {
connectButton.disabled = true;
disconnectButton.disabled = false;
let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType });
connection = new signalR.HubConnection(http);
connection.on('Send', msg => {
addLine('message-list', msg);
});
connection.onClosed = e => {
if (e) {
addLine('message-list', 'Connection closed with error: ' + e, 'red');
}
else {
addLine('message-list', 'Disconnected', 'green');
}
}
connection.start()
.then(() => {
isConnected = true;
@ -108,6 +115,8 @@ click('connect', event => {
});
click('disconnect', event => {
connectButton.disabled = false;
disconnectButton.disabled = true;
connection.stop()
.then(() => {
isConnected = false;

View File

@ -25,11 +25,11 @@ namespace Microsoft.AspNetCore.SignalR
public void MapHub<THub>(string path, Action<HttpSocketOptions> socketOptions) where THub : Hub<IClientProxy>
{
// find auth attributes
var authorizeAttribute = typeof(THub).GetCustomAttribute<AuthorizeAttribute>(inherit: true);
var authorizeAttributes = typeof(THub).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpSocketOptions();
if (authorizeAttribute != null)
foreach (var attribute in authorizeAttributes)
{
options.AuthorizationData.Add(authorizeAttribute);
options.AuthorizationData.Add(attribute);
}
socketOptions?.Invoke(options);

View File

@ -3,11 +3,14 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
@ -251,6 +254,13 @@ namespace Microsoft.AspNetCore.SignalR
using (var scope = _serviceScopeFactory.CreateScope())
{
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
{
_logger.LogDebug("Failed to invoke {hubMethod} because user is unauthorized", invocationMessage.Target);
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized"));
return;
}
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub, TClient>>();
var hub = hubActivator.Create();
@ -395,7 +405,8 @@ namespace Microsoft.AspNetCore.SignalR
}
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
_methods[methodName] = new HubMethodDescriptor(executor);
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
_methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes);
if (_logger.IsEnabled(LogLevel.Debug))
{
@ -404,6 +415,26 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
// If there are no policies we don't need to run auth
if (!policies.Any())
{
return true;
}
var authService = provider.GetRequiredService<IAuthorizationService>();
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
Debug.Assert(authorizePolicy != null);
var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
return authorizationResult.Succeeded;
}
Type IInvocationBinder.GetReturnType(string invocationId)
{
return typeof(object);
@ -422,15 +453,18 @@ namespace Microsoft.AspNetCore.SignalR
// REVIEW: We can decide to move this out of here if we want pluggable hub discovery
private class HubMethodDescriptor
{
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor)
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
Policies = policies.ToArray();
}
public ObjectMethodExecutor MethodExecutor { get; }
public Type[] ParameterTypes { get; }
public IList<IAuthorizeData> Policies { get; }
}
}
}

View File

@ -14,6 +14,7 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.ClosedGenericMatcher.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />

View File

@ -16,6 +16,8 @@ namespace Microsoft.Extensions.DependencyInjection
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>));
services.AddAuthorization();
return new SignalRBuilder(services);
}
}

View File

@ -13,7 +13,7 @@
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />

View File

@ -2,9 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
@ -597,7 +599,74 @@ namespace Microsoft.AspNetCore.SignalR.Tests
client.Dispose();
await endPointLifetime;
await endPointLifetime.OrTimeout();
}
}
[Fact]
public async Task UnauthorizedConnectionCannotInvokeHubMethodWithAuthorization()
{
var serviceProvider = CreateServiceProvider(services =>
{
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
});
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
await client.Connected.OrTimeout();
var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout();
Assert.NotNull(message.Error);
client.Dispose();
await endPointLifetime.OrTimeout();
}
}
[Fact]
public async Task AuthorizedConnectionCanInvokeHubMethodWithAuthorization()
{
var serviceProvider = CreateServiceProvider(services =>
{
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
});
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
await client.Connected.OrTimeout();
var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout();
Assert.Null(message.Error);
client.Dispose();
await endPointLifetime.OrTimeout();
}
}
@ -793,6 +862,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public static void StaticMethod()
{
}
[Authorize("test")]
public void AuthMethod()
{
}
}
private class InheritedHub : BaseHub

View File

@ -74,6 +74,28 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Equal(1, authCount);
}
[Fact]
public void MapHubFindsMultipleAuthAttributesOnDoubleAuthHub()
{
var authCount = 0;
var builder = new WebHostBuilder()
.UseKestrel()
.ConfigureServices(services =>
{
services.AddSignalR();
})
.Configure(app =>
{
app.UseSignalR(options => options.MapHub<DoubleAuthHub>("path", httpSocketOptions =>
{
authCount += httpSocketOptions.AuthorizationData.Count;
}));
})
.Build();
Assert.Equal(2, authCount);
}
private class InvalidHub : Hub
{
public void OverloadedMethod(int num)
@ -85,6 +107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Authorize]
private class DoubleAuthHub : AuthHub
{
}
private class InheritedAuthHub : AuthHub
{
}

View File

@ -52,16 +52,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var message = await Read();
if (!string.Equals(message.InvocationId, invocationId))
{
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
}
if (message == null)
{
throw new InvalidOperationException("Connection aborted!");
}
if (!string.Equals(message.InvocationId, invocationId))
{
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
}
switch (message)
{
case StreamItemMessage _:
@ -84,16 +84,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var message = await Read();
if (!string.Equals(message.InvocationId, invocationId))
{
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
}
if (message == null)
{
throw new InvalidOperationException("Connection aborted!");
}
if (!string.Equals(message.InvocationId, invocationId))
{
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
}
switch (message)
{
case StreamItemMessage result:

View File

@ -1071,7 +1071,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
{
public override void OnCompleted(Func<object, Task> callback, object state)
{
}
public override void OnStarting(Func<object, Task> callback, object state)