diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index ab5cbd926b..43dddfa54f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -138,17 +138,14 @@ namespace Microsoft.AspNetCore.SignalR public override Task InvokeUserAsync(string userId, string methodName, object[] args) { - return InvokeAllWhere(methodName, args, connection => - { - return string.Equals(connection.User.Identity.Name, userId, StringComparison.Ordinal); - }); + return InvokeAllWhere(methodName, args, connection => + string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); } public override Task OnConnectedAsync(HubConnectionContext connection) { // Set the hub groups feature connection.Features.Set(new HubGroupsFeature()); - _connections.Add(connection); return Task.CompletedTask; } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultUserIdProvider.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultUserIdProvider.cs new file mode 100644 index 0000000000..b297431276 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultUserIdProvider.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Security.Claims; + +namespace Microsoft.AspNetCore.SignalR.Core +{ + public class DefaultUserIdProvider : IUserIdProvider + { + public string GetUserId(HubConnectionContext connection) + { + return connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index c95660c199..f016986d34 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -67,6 +67,8 @@ namespace Microsoft.AspNetCore.SignalR Task.Factory.StartNew(_abortedCallback, this); } + public string UserIdentifier { get; internal set; } + internal void Abort(Exception exception) { AbortException = ExceptionDispatchInfo.Capture(exception); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index 2ed66603f4..fbf9151e41 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -11,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Core.Internal; using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; @@ -39,13 +40,15 @@ namespace Microsoft.AspNetCore.SignalR private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IHubProtocolResolver _protocolResolver; private readonly IOptions _hubOptions; + private readonly IUserIdProvider _userIdProvider; public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, IOptions hubOptions, ILogger> logger, - IServiceScopeFactory serviceScopeFactory) + IServiceScopeFactory serviceScopeFactory, + IUserIdProvider userIdProvider) { _protocolResolver = protocolResolver; _lifetimeManager = lifetimeManager; @@ -53,6 +56,7 @@ namespace Microsoft.AspNetCore.SignalR _hubOptions = hubOptions; _logger = logger; _serviceScopeFactory = serviceScopeFactory; + _userIdProvider = userIdProvider; DiscoverHubMethods(); } @@ -72,6 +76,8 @@ namespace Microsoft.AspNetCore.SignalR return; } + connectionContext.UserIdentifier = _userIdProvider.GetUserId(connectionContext); + // Hubs support multiple producers so we set up this loop to copy // data written to the HubConnectionContext's channel to the transport channel var protocolReaderWriter = connectionContext.ProtocolReaderWriter; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/IUserIdProvider.cs b/src/Microsoft.AspNetCore.SignalR.Core/IUserIdProvider.cs new file mode 100644 index 0000000000..227047fbf8 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/IUserIdProvider.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Core +{ + public interface IUserIdProvider + { + string GetUserId(HubConnectionContext connection); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs index ba48638ad6..cad16b740b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.SignalR; +using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; namespace Microsoft.Extensions.DependencyInjection @@ -15,6 +16,7 @@ namespace Microsoft.Extensions.DependencyInjection services.AddSingleton(typeof(IHubContext<>), typeof(HubContext<>)); services.AddSingleton(typeof(IHubContext<,>), typeof(HubContext<,>)); services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>)); + services.AddSingleton(typeof(IUserIdProvider), typeof(DefaultUserIdProvider)); services.AddScoped(typeof(IHubActivator<>), typeof(DefaultHubActivator<>)); services.AddAuthorization(); diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index b6185bbcab..b9463caeab 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -245,9 +245,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis previousConnectionTask = WriteAsync(connection, message); }); - if (connection.User.Identity.IsAuthenticated) + if (!string.IsNullOrEmpty(connection.UserIdentifier)) { - var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name; + var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; redisSubscriptions.Add(userChannel); var previousUserTask = Task.CompletedTask; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index 68acba4366..ecb02cb3c8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -1,5 +1,6 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Xunit; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index d9123e448c..04ef10118b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -850,15 +850,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests dynamic endPoint = serviceProvider.GetService(GetEndPointType(hubType)); - using (var firstClient = new TestClient()) - using (var secondClient = new TestClient()) + using (var firstClient = new TestClient(addClaimId: true)) + using (var secondClient = new TestClient(addClaimId: true)) { Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); - await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.Identity.Name, "test").OrTimeout(); + await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout(); // check that 'secondConnection' has received the group send var hubMessage = await secondClient.ReadAsync().OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 516a38cab4..7835863279 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -28,7 +28,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public Channel Application { get; } public Task Connected => ((TaskCompletionSource)Connection.Metadata["ConnectedTask"]).Task; - public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null) + public TestClient(bool synchronousCallbacks = false, IHubProtocol protocol = null, bool addClaimId = false) { var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks }; var transportToApplication = Channel.CreateUnbounded(options); @@ -38,7 +38,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests _transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), _transport, Application); - Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); + + var claimValue = Interlocked.Increment(ref _id).ToString(); + var claims = new List{ new Claim(ClaimTypes.Name, claimValue) }; + if (addClaimId) + { + claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue)); + } + + Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims)); Connection.Metadata["ConnectedTask"] = new TaskCompletionSource(); protocol = protocol ?? new JsonHubProtocol(); @@ -182,4 +190,4 @@ namespace Microsoft.AspNetCore.SignalR.Tests return typeof(object); } } -} +} \ No newline at end of file