fix issue with incorrect user detection when Invoking for User (#747)

* fix issue with incorrect user detection when Invoking for User

* fix failed testcases

* use proper extension method to avoid potential null reference exception

* fix for channel name in redis version + follow SignalR team recommendations

* remove unncessary freespace

* remove whitespaces

* introduce IUserIdProvider to resolve user id

* Move IUserIdProvider from HubLifetimeManager to HubConnectionContext

* setting user id to connection context in hubendpoint
This commit is contained in:
ivankarpey 2017-10-06 21:58:18 +03:00 committed by David Fowler
parent 3c5d283689
commit 665f166d67
10 changed files with 55 additions and 14 deletions

View File

@ -138,17 +138,14 @@ namespace Microsoft.AspNetCore.SignalR
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal);
});
return InvokeAllWhere(methodName, args, connection =>
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
}
public override Task OnConnectedAsync(HubConnectionContext connection)
{
// Set the hub groups feature
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());
_connections.Add(connection);
return Task.CompletedTask;
}

View File

@ -0,0 +1,15 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Security.Claims;
namespace Microsoft.AspNetCore.SignalR.Core
{
public class DefaultUserIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value;
}
}
}

View File

@ -67,6 +67,8 @@ namespace Microsoft.AspNetCore.SignalR
Task.Factory.StartNew(_abortedCallback, this);
}
public string UserIdentifier { get; internal set; }
internal void Abort(Exception exception)
{
AbortException = ExceptionDispatchInfo.Capture(exception);

View File

@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Core.Internal;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
@ -39,13 +40,15 @@ namespace Microsoft.AspNetCore.SignalR
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubProtocolResolver _protocolResolver;
private readonly IOptions<HubOptions> _hubOptions;
private readonly IUserIdProvider _userIdProvider;
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
IHubProtocolResolver protocolResolver,
IHubContext<THub> hubContext,
IOptions<HubOptions> hubOptions,
ILogger<HubEndPoint<THub>> logger,
IServiceScopeFactory serviceScopeFactory)
IServiceScopeFactory serviceScopeFactory,
IUserIdProvider userIdProvider)
{
_protocolResolver = protocolResolver;
_lifetimeManager = lifetimeManager;
@ -53,6 +56,7 @@ namespace Microsoft.AspNetCore.SignalR
_hubOptions = hubOptions;
_logger = logger;
_serviceScopeFactory = serviceScopeFactory;
_userIdProvider = userIdProvider;
DiscoverHubMethods();
}
@ -72,6 +76,8 @@ namespace Microsoft.AspNetCore.SignalR
return;
}
connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext);
// Hubs support multiple producers so we set up this loop to copy
// data written to the HubConnectionContext's channel to the transport channel
var protocolReaderWriter = connectionContext.ProtocolReaderWriter;

View File

@ -0,0 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Core
{
public interface IUserIdProvider
{
string GetUserId(HubConnectionContext connection);
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.Extensions.DependencyInjection
@ -15,6 +16,7 @@ namespace Microsoft.Extensions.DependencyInjection
services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>));
services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider));
services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>));
services.AddAuthorization();

View File

@ -245,9 +245,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
previousConnectionTask = WriteAsync(connection, message);
});
if (connection.User.Identity.IsAuthenticated)
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
redisSubscriptions.Add(userChannel);
var previousUserTask = Task.CompletedTask;

View File

@ -1,5 +1,6 @@
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Xunit;

View File

@ -850,15 +850,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType));
using (var firstClient = new TestClient())
using (var secondClient = new TestClient())
using (var firstClient = new TestClient(addClaimId: true))
using (var secondClient = new TestClient(addClaimId: true))
{
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();
// check that 'secondConnection' has received the group send
var hubMessage = await secondClient.ReadAsync().OrTimeout();

View File

@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public Channel<byte[]> Application { get; }
public Task Connected => ((TaskCompletionSource<bool>)Connection.Metadata["ConnectedTask"]).Task;
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null)
public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false)
{
var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks };
var transportToApplication = Channel.CreateUnbounded<byte[]>(options);
@ -38,7 +38,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
_transport = ChannelConnection.Create<byte[]>(input: transportToApplication, output: applicationToTransport);
Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application);
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim>{ new Claim(ClaimTypes.Name, claimValue) };
if (addClaimId)
{
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
}
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
protocol = protocol ?? new JsonHubProtocol();
@ -182,4 +190,4 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return typeof(object);
}
}
}
}