Add authorization per hub method (#577)
This commit is contained in:
parent
a84ba8820f
commit
ef273b4796
|
|
@ -81,22 +81,29 @@ let transportType = signalR.TransportType[getParameterByName('transport')] || si
|
|||
|
||||
document.getElementById('head1').innerHTML = signalR.TransportType[transportType];
|
||||
|
||||
let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType });
|
||||
let connection = new signalR.HubConnection(http);
|
||||
connection.on('Send', msg => {
|
||||
addLine('message-list', msg);
|
||||
});
|
||||
|
||||
connection.onClosed = e => {
|
||||
if (e) {
|
||||
addLine('message-list', 'Connection closed with error: ' + e, 'red');
|
||||
}
|
||||
else {
|
||||
addLine('message-list', 'Disconnected', 'green');
|
||||
}
|
||||
}
|
||||
let connectButton = document.getElementById('connect');
|
||||
let disconnectButton = document.getElementById('disconnect');
|
||||
disconnectButton.disabled = true;
|
||||
var connection;
|
||||
|
||||
click('connect', event => {
|
||||
connectButton.disabled = true;
|
||||
disconnectButton.disabled = false;
|
||||
let http = new signalR.HttpConnection(`http://${document.location.host}/hubs`, { transport: transportType });
|
||||
connection = new signalR.HubConnection(http);
|
||||
connection.on('Send', msg => {
|
||||
addLine('message-list', msg);
|
||||
});
|
||||
|
||||
connection.onClosed = e => {
|
||||
if (e) {
|
||||
addLine('message-list', 'Connection closed with error: ' + e, 'red');
|
||||
}
|
||||
else {
|
||||
addLine('message-list', 'Disconnected', 'green');
|
||||
}
|
||||
}
|
||||
|
||||
connection.start()
|
||||
.then(() => {
|
||||
isConnected = true;
|
||||
|
|
@ -108,6 +115,8 @@ click('connect', event => {
|
|||
});
|
||||
|
||||
click('disconnect', event => {
|
||||
connectButton.disabled = false;
|
||||
disconnectButton.disabled = true;
|
||||
connection.stop()
|
||||
.then(() => {
|
||||
isConnected = false;
|
||||
|
|
|
|||
|
|
@ -25,11 +25,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public void MapHub<THub>(string path, Action<HttpSocketOptions> socketOptions) where THub : Hub<IClientProxy>
|
||||
{
|
||||
// find auth attributes
|
||||
var authorizeAttribute = typeof(THub).GetCustomAttribute<AuthorizeAttribute>(inherit: true);
|
||||
var authorizeAttributes = typeof(THub).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
|
||||
var options = new HttpSocketOptions();
|
||||
if (authorizeAttribute != null)
|
||||
foreach (var attribute in authorizeAttributes)
|
||||
{
|
||||
options.AuthorizationData.Add(authorizeAttribute);
|
||||
options.AuthorizationData.Add(attribute);
|
||||
}
|
||||
socketOptions?.Invoke(options);
|
||||
|
||||
|
|
|
|||
|
|
@ -3,11 +3,14 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Security.Claims;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
|
@ -251,6 +254,13 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
using (var scope = _serviceScopeFactory.CreateScope())
|
||||
{
|
||||
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
|
||||
{
|
||||
_logger.LogDebug("Failed to invoke {hubMethod} because user is unauthorized", invocationMessage.Target);
|
||||
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized"));
|
||||
return;
|
||||
}
|
||||
|
||||
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub, TClient>>();
|
||||
var hub = hubActivator.Create();
|
||||
|
||||
|
|
@ -395,7 +405,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
|
||||
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
|
||||
_methods[methodName] = new HubMethodDescriptor(executor);
|
||||
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
|
||||
_methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes);
|
||||
|
||||
if (_logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
|
|
@ -404,6 +415,26 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
private async Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
|
||||
{
|
||||
// If there are no policies we don't need to run auth
|
||||
if (!policies.Any())
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
var authService = provider.GetRequiredService<IAuthorizationService>();
|
||||
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
|
||||
|
||||
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
|
||||
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
|
||||
Debug.Assert(authorizePolicy != null);
|
||||
|
||||
var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
|
||||
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
|
||||
return authorizationResult.Succeeded;
|
||||
}
|
||||
|
||||
Type IInvocationBinder.GetReturnType(string invocationId)
|
||||
{
|
||||
return typeof(object);
|
||||
|
|
@ -422,15 +453,18 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
// REVIEW: We can decide to move this out of here if we want pluggable hub discovery
|
||||
private class HubMethodDescriptor
|
||||
{
|
||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor)
|
||||
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
|
||||
{
|
||||
MethodExecutor = methodExecutor;
|
||||
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
|
||||
Policies = policies.ToArray();
|
||||
}
|
||||
|
||||
public ObjectMethodExecutor MethodExecutor { get; }
|
||||
|
||||
public Type[] ParameterTypes { get; }
|
||||
|
||||
public IList<IAuthorizeData> Policies { get; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@
|
|||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Abstractions\Microsoft.AspNetCore.Sockets.Abstractions.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.ClosedGenericMatcher.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ namespace Microsoft.Extensions.DependencyInjection
|
|||
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
|
||||
services.AddScoped(typeof(IHubActivator<,>), typeof(DefaultHubActivator<,>));
|
||||
|
||||
services.AddAuthorization();
|
||||
|
||||
return new SignalRBuilder(services);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@
|
|||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common.Http\Microsoft.AspNetCore.Sockets.Common.Http.csproj" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Security.Claims;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
|
@ -597,7 +599,74 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
client.Dispose();
|
||||
|
||||
await endPointLifetime;
|
||||
await endPointLifetime.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnauthorizedConnectionCannotInvokeHubMethodWithAuthorization()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider(services =>
|
||||
{
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
await client.Connected.OrTimeout();
|
||||
|
||||
var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout();
|
||||
|
||||
Assert.NotNull(message.Error);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await endPointLifetime.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionCanInvokeHubMethodWithAuthorization()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider(services =>
|
||||
{
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
client.Connection.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
await client.Connected.OrTimeout();
|
||||
|
||||
var message = await client.InvokeAsync(nameof(MethodHub.AuthMethod)).OrTimeout();
|
||||
|
||||
Assert.Null(message.Error);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await endPointLifetime.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -793,6 +862,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
public static void StaticMethod()
|
||||
{
|
||||
}
|
||||
|
||||
[Authorize("test")]
|
||||
public void AuthMethod()
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
private class InheritedHub : BaseHub
|
||||
|
|
|
|||
|
|
@ -74,6 +74,28 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
Assert.Equal(1, authCount);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void MapHubFindsMultipleAuthAttributesOnDoubleAuthHub()
|
||||
{
|
||||
var authCount = 0;
|
||||
var builder = new WebHostBuilder()
|
||||
.UseKestrel()
|
||||
.ConfigureServices(services =>
|
||||
{
|
||||
services.AddSignalR();
|
||||
})
|
||||
.Configure(app =>
|
||||
{
|
||||
app.UseSignalR(options => options.MapHub<DoubleAuthHub>("path", httpSocketOptions =>
|
||||
{
|
||||
authCount += httpSocketOptions.AuthorizationData.Count;
|
||||
}));
|
||||
})
|
||||
.Build();
|
||||
|
||||
Assert.Equal(2, authCount);
|
||||
}
|
||||
|
||||
private class InvalidHub : Hub
|
||||
{
|
||||
public void OverloadedMethod(int num)
|
||||
|
|
@ -85,6 +107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Authorize]
|
||||
private class DoubleAuthHub : AuthHub
|
||||
{
|
||||
}
|
||||
|
||||
private class InheritedAuthHub : AuthHub
|
||||
{
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,16 +52,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var message = await Read();
|
||||
|
||||
if (!string.Equals(message.InvocationId, invocationId))
|
||||
{
|
||||
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
|
||||
}
|
||||
|
||||
if (message == null)
|
||||
{
|
||||
throw new InvalidOperationException("Connection aborted!");
|
||||
}
|
||||
|
||||
if (!string.Equals(message.InvocationId, invocationId))
|
||||
{
|
||||
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
|
||||
}
|
||||
|
||||
switch (message)
|
||||
{
|
||||
case StreamItemMessage _:
|
||||
|
|
@ -84,16 +84,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var message = await Read();
|
||||
|
||||
if (!string.Equals(message.InvocationId, invocationId))
|
||||
{
|
||||
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
|
||||
}
|
||||
|
||||
if (message == null)
|
||||
{
|
||||
throw new InvalidOperationException("Connection aborted!");
|
||||
}
|
||||
|
||||
if (!string.Equals(message.InvocationId, invocationId))
|
||||
{
|
||||
throw new NotSupportedException("TestClient does not support multiple outgoing invocations!");
|
||||
}
|
||||
|
||||
switch (message)
|
||||
{
|
||||
case StreamItemMessage result:
|
||||
|
|
|
|||
|
|
@ -1071,7 +1071,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
{
|
||||
public override void OnCompleted(Func<object, Task> callback, object state)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
public override void OnStarting(Func<object, Task> callback, object state)
|
||||
|
|
|
|||
Loading…
Reference in New Issue