fix issue with incorrect user detection when Invoking for User (#747)
* fix issue with incorrect user detection when Invoking for User * fix failed testcases * use proper extension method to avoid potential null reference exception * fix for channel name in redis version + follow SignalR team recommendations * remove unncessary freespace * remove whitespaces * introduce IUserIdProvider to resolve user id * Move IUserIdProvider from HubLifetimeManager to HubConnectionContext * setting user id to connection context in hubendpoint
This commit is contained in:
parent
3c5d283689
commit
665f166d67
|
|
@ -138,17 +138,14 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
{
|
||||
return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal);
|
||||
});
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(HubConnectionContext connection)
|
||||
{
|
||||
// Set the hub groups feature
|
||||
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());
|
||||
|
||||
_connections.Add(connection);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Security.Claims;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Core
|
||||
{
|
||||
public class DefaultUserIdProvider : IUserIdProvider
|
||||
{
|
||||
public string GetUserId(HubConnectionContext connection)
|
||||
{
|
||||
return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -67,6 +67,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
Task.Factory.StartNew(_abortedCallback, this);
|
||||
}
|
||||
|
||||
public string UserIdentifier { get; internal set; }
|
||||
|
||||
internal void Abort(Exception exception)
|
||||
{
|
||||
AbortException = ExceptionDispatchInfo.Capture(exception);
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ using System.Threading;
|
|||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.SignalR.Core;
|
||||
using Microsoft.AspNetCore.SignalR.Core.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Features;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
|
|
@ -39,13 +40,15 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private readonly IServiceScopeFactory _serviceScopeFactory;
|
||||
private readonly IHubProtocolResolver _protocolResolver;
|
||||
private readonly IOptions<HubOptions> _hubOptions;
|
||||
private readonly IUserIdProvider _userIdProvider;
|
||||
|
||||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
IHubProtocolResolver protocolResolver,
|
||||
IHubContext<THub> hubContext,
|
||||
IOptions<HubOptions> hubOptions,
|
||||
ILogger<HubEndPoint<THub>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
IServiceScopeFactory serviceScopeFactory,
|
||||
IUserIdProvider userIdProvider)
|
||||
{
|
||||
_protocolResolver = protocolResolver;
|
||||
_lifetimeManager = lifetimeManager;
|
||||
|
|
@ -53,6 +56,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
_hubOptions = hubOptions;
|
||||
_logger = logger;
|
||||
_serviceScopeFactory = serviceScopeFactory;
|
||||
_userIdProvider = userIdProvider;
|
||||
|
||||
DiscoverHubMethods();
|
||||
}
|
||||
|
|
@ -72,6 +76,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return;
|
||||
}
|
||||
|
||||
connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext);
|
||||
|
||||
// Hubs support multiple producers so we set up this loop to copy
|
||||
// data written to the HubConnectionContext's channel to the transport channel
|
||||
var protocolReaderWriter = connectionContext.ProtocolReaderWriter;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Core
|
||||
{
|
||||
public interface IUserIdProvider
|
||||
{
|
||||
string GetUserId(HubConnectionContext connection);
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.AspNetCore.SignalR.Core;
|
||||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
|
||||
namespace Microsoft.Extensions.DependencyInjection
|
||||
|
|
@ -15,6 +16,7 @@ namespace Microsoft.Extensions.DependencyInjection
|
|||
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
|
||||
services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>));
|
||||
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
|
||||
services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider));
|
||||
services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>));
|
||||
|
||||
services.AddAuthorization();
|
||||
|
|
|
|||
|
|
@ -245,9 +245,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
previousConnectionTask = WriteAsync(connection, message);
|
||||
});
|
||||
|
||||
if (connection.User.Identity.IsAuthenticated)
|
||||
if (!string.IsNullOrEmpty(connection.UserIdentifier))
|
||||
{
|
||||
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
|
||||
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
var previousUserTask = Task.CompletedTask;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Core;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Xunit;
|
||||
|
||||
|
|
|
|||
|
|
@ -850,15 +850,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType));
|
||||
|
||||
using (var firstClient = new TestClient())
|
||||
using (var secondClient = new TestClient())
|
||||
using (var firstClient = new TestClient(addClaimId: true))
|
||||
using (var secondClient = new TestClient(addClaimId: true))
|
||||
{
|
||||
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
|
||||
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);
|
||||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();
|
||||
|
||||
// check that 'secondConnection' has received the group send
|
||||
var hubMessage = await secondClient.ReadAsync().OrTimeout();
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
public Channel<byte[]> Application { get; }
|
||||
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;
|
||||
|
||||
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null)
|
||||
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false)
|
||||
{
|
||||
var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks };
|
||||
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
|
||||
|
|
@ -38,7 +38,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);
|
||||
|
||||
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
|
||||
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
|
||||
|
||||
var claimValue = Interlocked.Increment(ref _id).ToString();
|
||||
var claims = new List<Claim>{ new Claim(ClaimTypes.Name, claimValue) };
|
||||
if (addClaimId)
|
||||
{
|
||||
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
|
||||
}
|
||||
|
||||
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
|
||||
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
|
||||
|
||||
protocol = protocol ?? new JsonHubProtocol();
|
||||
|
|
@ -182,4 +190,4 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return typeof(object);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue