Merge branch 'release/2.1' into dev

This commit is contained in:
BrennanConroy 2018-05-15 16:07:57 -07:00
commit 9cb683a41d
8 changed files with 275 additions and 94 deletions

View File

@ -0,0 +1,63 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
internal class RedisSubscriptionManager
{
private readonly ConcurrentDictionary<string, HubConnectionStore> _subscriptions = new ConcurrentDictionary<string, HubConnectionStore>(StringComparer.Ordinal);
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
public async Task AddSubscriptionAsync(string id, HubConnectionContext connection, Func<string, HubConnectionStore, Task> subscribeMethod)
{
await _lock.WaitAsync();
try
{
var subscription = _subscriptions.GetOrAdd(id, _ => new HubConnectionStore());
subscription.Add(connection);
// Subscribe once
if (subscription.Count == 1)
{
await subscribeMethod(id, subscription);
}
}
finally
{
_lock.Release();
}
}
public async Task RemoveSubscriptionAsync(string id, HubConnectionContext connection, Func<string, Task> unsubscribeMethod)
{
await _lock.WaitAsync();
try
{
if (!_subscriptions.TryGetValue(id, out var subscription))
{
return;
}
subscription.Remove(connection);
if (subscription.Count == 0)
{
_subscriptions.TryRemove(id, out _);
await unsubscribeMethod(id);
}
}
finally
{
_lock.Release();
}
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Collections.Concurrent;
using System.Collections.Generic; using System.Collections.Generic;
using System.IO; using System.IO;
using System.Linq; using System.Linq;
@ -20,8 +19,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable where THub : Hub public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable where THub : Hub
{ {
private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly HubConnectionStore _connections = new HubConnectionStore();
// TODO: Investigate "memory leak" entries never get removed private readonly RedisSubscriptionManager _groups = new RedisSubscriptionManager();
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>(StringComparer.Ordinal); private readonly RedisSubscriptionManager _users = new RedisSubscriptionManager();
private IConnectionMultiplexer _redisServerConnection; private IConnectionMultiplexer _redisServerConnection;
private ISubscriber _bus; private ISubscriber _bus;
private readonly ILogger _logger; private readonly ILogger _logger;
@ -54,17 +53,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var feature = new RedisFeature(); var feature = new RedisFeature();
connection.Features.Set<IRedisFeature>(feature); connection.Features.Set<IRedisFeature>(feature);
var redisSubscriptions = feature.Subscriptions;
var connectionTask = Task.CompletedTask; var connectionTask = Task.CompletedTask;
var userTask = Task.CompletedTask; var userTask = Task.CompletedTask;
_connections.Add(connection); _connections.Add(connection);
connectionTask = SubscribeToConnection(connection, redisSubscriptions); connectionTask = SubscribeToConnection(connection);
if (!string.IsNullOrEmpty(connection.UserIdentifier)) if (!string.IsNullOrEmpty(connection.UserIdentifier))
{ {
userTask = SubscribeToUser(connection, redisSubscriptions); userTask = SubscribeToUser(connection);
} }
await Task.WhenAll(connectionTask, userTask); await Task.WhenAll(connectionTask, userTask);
@ -76,18 +74,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var tasks = new List<Task>(); var tasks = new List<Task>();
var connectionChannel = _channels.Connection(connection.ConnectionId);
RedisLog.Unsubscribe(_logger, connectionChannel);
tasks.Add(_bus.UnsubscribeAsync(connectionChannel));
var feature = connection.Features.Get<IRedisFeature>(); var feature = connection.Features.Get<IRedisFeature>();
var redisSubscriptions = feature.Subscriptions;
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
RedisLog.Unsubscribe(_logger, subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
var groupNames = feature.Groups; var groupNames = feature.Groups;
if (groupNames != null) if (groupNames != null)
@ -102,6 +93,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
} }
} }
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
tasks.Add(RemoveUserAsync(connection));
}
return Task.WhenAll(tasks); return Task.WhenAll(tasks);
} }
@ -290,25 +286,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
} }
var groupChannel = _channels.Group(groupName); var groupChannel = _channels.Group(groupName);
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData()); await _groups.AddSubscriptionAsync(groupChannel, connection, SubscribeToGroupAsync);
await group.Lock.WaitAsync();
try
{
group.Connections.Add(connection);
// Subscribe once
if (group.Connections.Count > 1)
{
return;
}
await SubscribeToGroup(groupChannel, group);
}
finally
{
group.Lock.Release();
}
} }
/// <summary> /// <summary>
@ -319,10 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{ {
var groupChannel = _channels.Group(groupName); var groupChannel = _channels.Group(groupName);
if (!_groups.TryGetValue(groupChannel, out var group)) await _groups.RemoveSubscriptionAsync(groupChannel, connection, async channelName =>
{ {
return; RedisLog.Unsubscribe(_logger, channelName);
} await _bus.UnsubscribeAsync(channelName);
});
var feature = connection.Features.Get<IRedisFeature>(); var feature = connection.Features.Get<IRedisFeature>();
var groupNames = feature.Groups; var groupNames = feature.Groups;
@ -333,25 +312,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
groupNames.Remove(groupName); groupNames.Remove(groupName);
} }
} }
await group.Lock.WaitAsync();
try
{
if (group.Connections.Count > 0)
{
group.Connections.Remove(connection);
if (group.Connections.Count == 0)
{
RedisLog.Unsubscribe(_logger, groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
}
finally
{
group.Lock.Release();
}
} }
private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action) private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action)
@ -365,6 +325,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await ack; await ack;
} }
private Task RemoveUserAsync(HubConnectionContext connection)
{
var userChannel = _channels.User(connection.UserIdentifier);
return _users.RemoveSubscriptionAsync(userChannel, connection, async channelName =>
{
RedisLog.Unsubscribe(_logger, channelName);
await _bus.UnsubscribeAsync(channelName);
});
}
public void Dispose() public void Dispose()
{ {
_bus?.UnsubscribeAll(); _bus?.UnsubscribeAll();
@ -448,10 +419,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}); });
} }
private Task SubscribeToConnection(HubConnectionContext connection, HashSet<string> redisSubscriptions) private Task SubscribeToConnection(HubConnectionContext connection)
{ {
var connectionChannel = _channels.Connection(connection.ConnectionId); var connectionChannel = _channels.Connection(connection.ConnectionId);
redisSubscriptions.Add(connectionChannel);
RedisLog.Subscribing(_logger, connectionChannel); RedisLog.Subscribing(_logger, connectionChannel);
return _bus.SubscribeAsync(connectionChannel, async (c, data) => return _bus.SubscribeAsync(connectionChannel, async (c, data) =>
@ -461,20 +431,35 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}); });
} }
private Task SubscribeToUser(HubConnectionContext connection, HashSet<string> redisSubscriptions) private Task SubscribeToUser(HubConnectionContext connection)
{ {
var userChannel = _channels.User(connection.UserIdentifier); var userChannel = _channels.User(connection.UserIdentifier);
redisSubscriptions.Add(userChannel);
// TODO: Look at optimizing (looping over connections checking for Name) return _users.AddSubscriptionAsync(userChannel, connection, async (channelName, subscriptions) =>
return _bus.SubscribeAsync(userChannel, async (c, data) =>
{ {
var invocation = _protocol.ReadInvocation((byte[])data); await _bus.SubscribeAsync(channelName, async (c, data) =>
await connection.WriteAsync(invocation.Message); {
try
{
var invocation = _protocol.ReadInvocation((byte[])data);
var tasks = new List<Task>();
foreach (var userConnection in subscriptions)
{
tasks.Add(userConnection.WriteAsync(invocation.Message).AsTask());
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
RedisLog.FailedWritingMessage(_logger, ex);
}
});
}); });
} }
private Task SubscribeToGroup(string groupChannel, GroupData group) private Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore groupConnections)
{ {
RedisLog.Subscribing(_logger, groupChannel); RedisLog.Subscribing(_logger, groupChannel);
return _bus.SubscribeAsync(groupChannel, async (c, data) => return _bus.SubscribeAsync(groupChannel, async (c, data) =>
@ -484,7 +469,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var invocation = _protocol.ReadInvocation((byte[])data); var invocation = _protocol.ReadInvocation((byte[])data);
var tasks = new List<Task>(); var tasks = new List<Task>();
foreach (var groupConnection in group.Connections) foreach (var groupConnection in groupConnections)
{ {
if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true) if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true)
{ {
@ -515,6 +500,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var writer = new LoggerTextWriter(_logger); var writer = new LoggerTextWriter(_logger);
_redisServerConnection = await _options.ConnectAsync(writer); _redisServerConnection = await _options.ConnectAsync(writer);
_bus = _redisServerConnection.GetSubscriber(); _bus = _redisServerConnection.GetSubscriber();
_redisServerConnection.ConnectionRestored += (_, e) => _redisServerConnection.ConnectionRestored += (_, e) =>
{ {
// We use the subscription connection type // We use the subscription connection type
@ -589,21 +575,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
} }
} }
private class GroupData
{
public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public readonly HubConnectionStore Connections = new HubConnectionStore();
}
private interface IRedisFeature private interface IRedisFeature
{ {
HashSet<string> Subscriptions { get; }
HashSet<string> Groups { get; } HashSet<string> Groups { get; }
} }
private class RedisFeature : IRedisFeature private class RedisFeature : IRedisFeature
{ {
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase); public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
} }
} }

View File

@ -18,6 +18,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
return Clients.Group(groupName).SendAsync("Echo", message); return Clients.Group(groupName).SendAsync("Echo", message);
} }
public Task EchoUser(string userName, string message)
{
return Clients.User(userName).SendAsync("Echo", message);
}
public Task AddSelfToGroup(string groupName) public Task AddSelfToGroup(string groupName)
{ {
return Groups.AddToGroupAsync(Context.ConnectionId, groupName); return Groups.AddToGroupAsync(Context.ConnectionId, groupName);

View File

@ -11,7 +11,6 @@ using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.Testing.xunit; using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Xunit; using Xunit;
using Xunit.Abstractions; using Xunit.Abstractions;
@ -39,7 +38,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
_serverFixture = serverFixture; _serverFixture = serverFixture;
} }
[ConditionalTheory()] [ConditionalTheory]
[SkipIfDockerNotPresent] [SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))] [MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task HubConnectionCanSendAndReceiveMessages(HttpTransportType transportType, string protocolName) public async Task HubConnectionCanSendAndReceiveMessages(HttpTransportType transportType, string protocolName)
@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
} }
} }
[ConditionalTheory()] [ConditionalTheory]
[SkipIfDockerNotPresent] [SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))] [MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task HubConnectionCanSendAndReceiveGroupMessages(HttpTransportType transportType, string protocolName) public async Task HubConnectionCanSendAndReceiveGroupMessages(HttpTransportType transportType, string protocolName)
@ -93,11 +92,77 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
} }
} }
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory) [ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser(HttpTransportType transportType, string protocolName)
{
using (StartVerifiableLog(out var loggerFactory, testName:
$"{nameof(CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser)}_{transportType.ToString()}_{protocolName}"))
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var tcs = new TaskCompletionSource<string>();
connection.On<string>("Echo", message => tcs.TrySetResult(message));
var tcs2 = new TaskCompletionSource<string>();
secondConnection.On<string>("Echo", message => tcs2.TrySetResult(message));
await secondConnection.StartAsync().OrTimeout();
await connection.StartAsync().OrTimeout();
await connection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
Assert.Equal("Hello, World!", await tcs2.Task.OrTimeout());
await connection.DisposeAsync().OrTimeout();
await secondConnection.DisposeAsync().OrTimeout();
}
}
[ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects(HttpTransportType transportType, string protocolName)
{
// Regression test:
// When multiple connections from the same user were connected and one left, it used to unsubscribe from the user channel
// Now we keep track of users connections and only unsubscribe when no users are listening
using (StartVerifiableLog(out var loggerFactory, testName:
$"{nameof(CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects)}_{transportType.ToString()}_{protocolName}"))
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
var firstConnection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var tcs = new TaskCompletionSource<string>();
firstConnection.On<string>("Echo", message => tcs.TrySetResult(message));
await secondConnection.StartAsync().OrTimeout();
await firstConnection.StartAsync().OrTimeout();
await secondConnection.DisposeAsync().OrTimeout();
await firstConnection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
await firstConnection.DisposeAsync().OrTimeout();
}
}
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
{ {
var hubConnectionBuilder = new HubConnectionBuilder() var hubConnectionBuilder = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory) .WithLoggerFactory(loggerFactory)
.WithUrl(url, transportType); .WithUrl(url, transportType, httpConnectionOptions =>
{
if (!string.IsNullOrEmpty(userName))
{
httpConnectionOptions.Headers["UserName"] = userName;
}
});
hubConnectionBuilder.Services.AddSingleton(protocol); hubConnectionBuilder.Services.AddSingleton(protocol);

View File

@ -500,6 +500,61 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
} }
} }
[Fact]
public async Task InvokeUserSendsToAllConnectionsForUser()
{
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client3 = new TestClient())
{
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnConnectedAsync(connection3).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client1);
await AssertMessageAsync(client2);
}
}
[Fact]
public async Task StillSubscribedToUserAfterOneOfMultipleConnectionsAssociatedWithUserDisconnects()
{
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client3 = new TestClient())
{
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnConnectedAsync(connection3).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client1);
await AssertMessageAsync(client2);
// Disconnect one connection for the user
await manager.OnDisconnectedAsync(connection1).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client2);
}
}
[Fact] [Fact]
public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary() public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
{ {

View File

@ -4,6 +4,7 @@
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests namespace Microsoft.AspNetCore.SignalR.Redis.Tests
{ {
@ -21,11 +22,28 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
// We start the servers before starting redis so we want to time them out ASAP // We start the servers before starting redis so we want to time them out ASAP
options.Configuration.ConnectTimeout = 1; options.Configuration.ConnectTimeout = 1;
}); });
services.AddSingleton<IUserIdProvider, UserNameIdProvider>();
} }
public void Configure(IApplicationBuilder app, IHostingEnvironment env) public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{ {
app.UseSignalR(options => options.MapHub<EchoHub>("/echo")); app.UseSignalR(options => options.MapHub<EchoHub>("/echo"));
} }
private class UserNameIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
// This is an AWFUL way to authenticate users! We're just using it for test purposes.
var userNameHeader = connection.GetHttpContext().Request.Headers["UserName"];
if (!StringValues.IsNullOrEmpty(userNameHeader))
{
return userNameHeader;
}
return null;
}
}
} }
} }

View File

@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public TransferFormat ActiveFormat { get; set; } public TransferFormat ActiveFormat { get; set; }
public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false) public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null)
{ {
var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false);
var pair = DuplexPipe.CreateConnectionPair(options, options); var pair = DuplexPipe.CreateConnectionPair(options, options);
@ -44,9 +44,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var claimValue = Interlocked.Increment(ref _id).ToString(); var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim> { new Claim(ClaimTypes.Name, claimValue) }; var claims = new List<Claim> { new Claim(ClaimTypes.Name, claimValue) };
if (addClaimId) if (userIdentifier != null)
{ {
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue)); claims.Add(new Claim(ClaimTypes.NameIdentifier, userIdentifier));
} }
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims)); Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));

View File

@ -1100,9 +1100,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{ {
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType); var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
using (var firstClient = new TestClient(addClaimId: true)) using (var firstClient = new TestClient(userIdentifier: "userA"))
using (var secondClient = new TestClient(addClaimId: true)) using (var secondClient = new TestClient(userIdentifier: "userB"))
using (var thirdClient = new TestClient(addClaimId: true)) using (var thirdClient = new TestClient(userIdentifier: "userC"))
{ {
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler); var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler); var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
@ -1110,10 +1110,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout(); await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout();
var secondAndThirdClients = new HashSet<string> {secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), new[] { "userB", "userC" }, "Second and Third").OrTimeout();
thirdClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value };
await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), secondAndThirdClients, "Second and Third").OrTimeout();
var secondClientResult = await secondClient.ReadAsync().OrTimeout(); var secondClientResult = await secondClient.ReadAsync().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(secondClientResult); var invocation = Assert.IsType<InvocationMessage>(secondClientResult);
@ -1344,15 +1341,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{ {
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType); var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
using (var firstClient = new TestClient(addClaimId: true)) using (var firstClient = new TestClient(userIdentifier: "userA"))
using (var secondClient = new TestClient(addClaimId: true)) using (var secondClient = new TestClient(userIdentifier: "userB"))
{ {
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler); var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler); var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout(); await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout(); await firstClient.SendInvocationAsync("ClientSendMethod", "userB", "test").OrTimeout();
// check that 'secondConnection' has received the group send // check that 'secondConnection' has received the group send
var hubMessage = await secondClient.ReadAsync().OrTimeout(); var hubMessage = await secondClient.ReadAsync().OrTimeout();