Merge branch 'release/2.1' into dev
This commit is contained in:
commit
9cb683a41d
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
||||
{
|
||||
internal class RedisSubscriptionManager
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, HubConnectionStore> _subscriptions = new ConcurrentDictionary<string, HubConnectionStore>(StringComparer.Ordinal);
|
||||
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
|
||||
|
||||
public async Task AddSubscriptionAsync(string id, HubConnectionContext connection, Func<string, HubConnectionStore, Task> subscribeMethod)
|
||||
{
|
||||
await _lock.WaitAsync();
|
||||
|
||||
try
|
||||
{
|
||||
var subscription = _subscriptions.GetOrAdd(id, _ => new HubConnectionStore());
|
||||
|
||||
subscription.Add(connection);
|
||||
|
||||
// Subscribe once
|
||||
if (subscription.Count == 1)
|
||||
{
|
||||
await subscribeMethod(id, subscription);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_lock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
public async Task RemoveSubscriptionAsync(string id, HubConnectionContext connection, Func<string, Task> unsubscribeMethod)
|
||||
{
|
||||
await _lock.WaitAsync();
|
||||
|
||||
try
|
||||
{
|
||||
if (!_subscriptions.TryGetValue(id, out var subscription))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
subscription.Remove(connection);
|
||||
|
||||
if (subscription.Count == 0)
|
||||
{
|
||||
_subscriptions.TryRemove(id, out _);
|
||||
await unsubscribeMethod(id);
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
_lock.Release();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -2,7 +2,6 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
|
|
@ -20,8 +19,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable where THub : Hub
|
||||
{
|
||||
private readonly HubConnectionStore _connections = new HubConnectionStore();
|
||||
// TODO: Investigate "memory leak" entries never get removed
|
||||
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>(StringComparer.Ordinal);
|
||||
private readonly RedisSubscriptionManager _groups = new RedisSubscriptionManager();
|
||||
private readonly RedisSubscriptionManager _users = new RedisSubscriptionManager();
|
||||
private IConnectionMultiplexer _redisServerConnection;
|
||||
private ISubscriber _bus;
|
||||
private readonly ILogger _logger;
|
||||
|
|
@ -54,17 +53,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var feature = new RedisFeature();
|
||||
connection.Features.Set<IRedisFeature>(feature);
|
||||
|
||||
var redisSubscriptions = feature.Subscriptions;
|
||||
var connectionTask = Task.CompletedTask;
|
||||
var userTask = Task.CompletedTask;
|
||||
|
||||
_connections.Add(connection);
|
||||
|
||||
connectionTask = SubscribeToConnection(connection, redisSubscriptions);
|
||||
connectionTask = SubscribeToConnection(connection);
|
||||
|
||||
if (!string.IsNullOrEmpty(connection.UserIdentifier))
|
||||
{
|
||||
userTask = SubscribeToUser(connection, redisSubscriptions);
|
||||
userTask = SubscribeToUser(connection);
|
||||
}
|
||||
|
||||
await Task.WhenAll(connectionTask, userTask);
|
||||
|
|
@ -76,18 +74,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
var tasks = new List<Task>();
|
||||
|
||||
var connectionChannel = _channels.Connection(connection.ConnectionId);
|
||||
RedisLog.Unsubscribe(_logger, connectionChannel);
|
||||
tasks.Add(_bus.UnsubscribeAsync(connectionChannel));
|
||||
|
||||
var feature = connection.Features.Get<IRedisFeature>();
|
||||
|
||||
var redisSubscriptions = feature.Subscriptions;
|
||||
if (redisSubscriptions != null)
|
||||
{
|
||||
foreach (var subscription in redisSubscriptions)
|
||||
{
|
||||
RedisLog.Unsubscribe(_logger, subscription);
|
||||
tasks.Add(_bus.UnsubscribeAsync(subscription));
|
||||
}
|
||||
}
|
||||
|
||||
var groupNames = feature.Groups;
|
||||
|
||||
if (groupNames != null)
|
||||
|
|
@ -102,6 +93,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(connection.UserIdentifier))
|
||||
{
|
||||
tasks.Add(RemoveUserAsync(connection));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
}
|
||||
|
||||
|
|
@ -290,25 +286,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
|
||||
var groupChannel = _channels.Group(groupName);
|
||||
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
|
||||
|
||||
await group.Lock.WaitAsync();
|
||||
try
|
||||
{
|
||||
group.Connections.Add(connection);
|
||||
|
||||
// Subscribe once
|
||||
if (group.Connections.Count > 1)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
await SubscribeToGroup(groupChannel, group);
|
||||
}
|
||||
finally
|
||||
{
|
||||
group.Lock.Release();
|
||||
}
|
||||
await _groups.AddSubscriptionAsync(groupChannel, connection, SubscribeToGroupAsync);
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -319,10 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
var groupChannel = _channels.Group(groupName);
|
||||
|
||||
if (!_groups.TryGetValue(groupChannel, out var group))
|
||||
await _groups.RemoveSubscriptionAsync(groupChannel, connection, async channelName =>
|
||||
{
|
||||
return;
|
||||
}
|
||||
RedisLog.Unsubscribe(_logger, channelName);
|
||||
await _bus.UnsubscribeAsync(channelName);
|
||||
});
|
||||
|
||||
var feature = connection.Features.Get<IRedisFeature>();
|
||||
var groupNames = feature.Groups;
|
||||
|
|
@ -333,25 +312,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
groupNames.Remove(groupName);
|
||||
}
|
||||
}
|
||||
|
||||
await group.Lock.WaitAsync();
|
||||
try
|
||||
{
|
||||
if (group.Connections.Count > 0)
|
||||
{
|
||||
group.Connections.Remove(connection);
|
||||
|
||||
if (group.Connections.Count == 0)
|
||||
{
|
||||
RedisLog.Unsubscribe(_logger, groupChannel);
|
||||
await _bus.UnsubscribeAsync(groupChannel);
|
||||
}
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
group.Lock.Release();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action)
|
||||
|
|
@ -365,6 +325,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
await ack;
|
||||
}
|
||||
|
||||
private Task RemoveUserAsync(HubConnectionContext connection)
|
||||
{
|
||||
var userChannel = _channels.User(connection.UserIdentifier);
|
||||
|
||||
return _users.RemoveSubscriptionAsync(userChannel, connection, async channelName =>
|
||||
{
|
||||
RedisLog.Unsubscribe(_logger, channelName);
|
||||
await _bus.UnsubscribeAsync(channelName);
|
||||
});
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
_bus?.UnsubscribeAll();
|
||||
|
|
@ -448,10 +419,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToConnection(HubConnectionContext connection, HashSet<string> redisSubscriptions)
|
||||
private Task SubscribeToConnection(HubConnectionContext connection)
|
||||
{
|
||||
var connectionChannel = _channels.Connection(connection.ConnectionId);
|
||||
redisSubscriptions.Add(connectionChannel);
|
||||
|
||||
RedisLog.Subscribing(_logger, connectionChannel);
|
||||
return _bus.SubscribeAsync(connectionChannel, async (c, data) =>
|
||||
|
|
@ -461,20 +431,35 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToUser(HubConnectionContext connection, HashSet<string> redisSubscriptions)
|
||||
private Task SubscribeToUser(HubConnectionContext connection)
|
||||
{
|
||||
var userChannel = _channels.User(connection.UserIdentifier);
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
// TODO: Look at optimizing (looping over connections checking for Name)
|
||||
return _bus.SubscribeAsync(userChannel, async (c, data) =>
|
||||
return _users.AddSubscriptionAsync(userChannel, connection, async (channelName, subscriptions) =>
|
||||
{
|
||||
var invocation = _protocol.ReadInvocation((byte[])data);
|
||||
await connection.WriteAsync(invocation.Message);
|
||||
await _bus.SubscribeAsync(channelName, async (c, data) =>
|
||||
{
|
||||
try
|
||||
{
|
||||
var invocation = _protocol.ReadInvocation((byte[])data);
|
||||
|
||||
var tasks = new List<Task>();
|
||||
foreach (var userConnection in subscriptions)
|
||||
{
|
||||
tasks.Add(userConnection.WriteAsync(invocation.Message).AsTask());
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
RedisLog.FailedWritingMessage(_logger, ex);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
private Task SubscribeToGroup(string groupChannel, GroupData group)
|
||||
private Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore groupConnections)
|
||||
{
|
||||
RedisLog.Subscribing(_logger, groupChannel);
|
||||
return _bus.SubscribeAsync(groupChannel, async (c, data) =>
|
||||
|
|
@ -484,7 +469,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var invocation = _protocol.ReadInvocation((byte[])data);
|
||||
|
||||
var tasks = new List<Task>();
|
||||
foreach (var groupConnection in group.Connections)
|
||||
foreach (var groupConnection in groupConnections)
|
||||
{
|
||||
if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true)
|
||||
{
|
||||
|
|
@ -515,6 +500,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var writer = new LoggerTextWriter(_logger);
|
||||
_redisServerConnection = await _options.ConnectAsync(writer);
|
||||
_bus = _redisServerConnection.GetSubscriber();
|
||||
|
||||
_redisServerConnection.ConnectionRestored += (_, e) =>
|
||||
{
|
||||
// We use the subscription connection type
|
||||
|
|
@ -589,21 +575,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
private class GroupData
|
||||
{
|
||||
public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
|
||||
public readonly HubConnectionStore Connections = new HubConnectionStore();
|
||||
}
|
||||
|
||||
private interface IRedisFeature
|
||||
{
|
||||
HashSet<string> Subscriptions { get; }
|
||||
HashSet<string> Groups { get; }
|
||||
}
|
||||
|
||||
private class RedisFeature : IRedisFeature
|
||||
{
|
||||
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
|
||||
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
return Clients.Group(groupName).SendAsync("Echo", message);
|
||||
}
|
||||
|
||||
public Task EchoUser(string userName, string message)
|
||||
{
|
||||
return Clients.User(userName).SendAsync("Echo", message);
|
||||
}
|
||||
|
||||
public Task AddSelfToGroup(string groupName)
|
||||
{
|
||||
return Groups.AddToGroupAsync(Context.ConnectionId, groupName);
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ using Microsoft.AspNetCore.SignalR.Tests;
|
|||
using Microsoft.AspNetCore.Testing.xunit;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Logging.Testing;
|
||||
using Xunit;
|
||||
using Xunit.Abstractions;
|
||||
|
||||
|
|
@ -39,7 +38,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
_serverFixture = serverFixture;
|
||||
}
|
||||
|
||||
[ConditionalTheory()]
|
||||
[ConditionalTheory]
|
||||
[SkipIfDockerNotPresent]
|
||||
[MemberData(nameof(TransportTypesAndProtocolTypes))]
|
||||
public async Task HubConnectionCanSendAndReceiveMessages(HttpTransportType transportType, string protocolName)
|
||||
|
|
@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[ConditionalTheory()]
|
||||
[ConditionalTheory]
|
||||
[SkipIfDockerNotPresent]
|
||||
[MemberData(nameof(TransportTypesAndProtocolTypes))]
|
||||
public async Task HubConnectionCanSendAndReceiveGroupMessages(HttpTransportType transportType, string protocolName)
|
||||
|
|
@ -93,11 +92,77 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory)
|
||||
[ConditionalTheory]
|
||||
[SkipIfDockerNotPresent]
|
||||
[MemberData(nameof(TransportTypesAndProtocolTypes))]
|
||||
public async Task CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser(HttpTransportType transportType, string protocolName)
|
||||
{
|
||||
using (StartVerifiableLog(out var loggerFactory, testName:
|
||||
$"{nameof(CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser)}_{transportType.ToString()}_{protocolName}"))
|
||||
{
|
||||
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
|
||||
|
||||
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
|
||||
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
|
||||
|
||||
var tcs = new TaskCompletionSource<string>();
|
||||
connection.On<string>("Echo", message => tcs.TrySetResult(message));
|
||||
var tcs2 = new TaskCompletionSource<string>();
|
||||
secondConnection.On<string>("Echo", message => tcs2.TrySetResult(message));
|
||||
|
||||
await secondConnection.StartAsync().OrTimeout();
|
||||
await connection.StartAsync().OrTimeout();
|
||||
await connection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
|
||||
|
||||
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
|
||||
Assert.Equal("Hello, World!", await tcs2.Task.OrTimeout());
|
||||
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
await secondConnection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[ConditionalTheory]
|
||||
[SkipIfDockerNotPresent]
|
||||
[MemberData(nameof(TransportTypesAndProtocolTypes))]
|
||||
public async Task CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects(HttpTransportType transportType, string protocolName)
|
||||
{
|
||||
// Regression test:
|
||||
// When multiple connections from the same user were connected and one left, it used to unsubscribe from the user channel
|
||||
// Now we keep track of users connections and only unsubscribe when no users are listening
|
||||
using (StartVerifiableLog(out var loggerFactory, testName:
|
||||
$"{nameof(CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects)}_{transportType.ToString()}_{protocolName}"))
|
||||
{
|
||||
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
|
||||
|
||||
var firstConnection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
|
||||
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
|
||||
|
||||
var tcs = new TaskCompletionSource<string>();
|
||||
firstConnection.On<string>("Echo", message => tcs.TrySetResult(message));
|
||||
|
||||
await secondConnection.StartAsync().OrTimeout();
|
||||
await firstConnection.StartAsync().OrTimeout();
|
||||
await secondConnection.DisposeAsync().OrTimeout();
|
||||
await firstConnection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
|
||||
|
||||
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
|
||||
|
||||
await firstConnection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
|
||||
{
|
||||
var hubConnectionBuilder = new HubConnectionBuilder()
|
||||
.WithLoggerFactory(loggerFactory)
|
||||
.WithUrl(url, transportType);
|
||||
.WithUrl(url, transportType, httpConnectionOptions =>
|
||||
{
|
||||
if (!string.IsNullOrEmpty(userName))
|
||||
{
|
||||
httpConnectionOptions.Headers["UserName"] = userName;
|
||||
}
|
||||
});
|
||||
|
||||
hubConnectionBuilder.Services.AddSingleton(protocol);
|
||||
|
||||
|
|
|
|||
|
|
@ -500,6 +500,61 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeUserSendsToAllConnectionsForUser()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
using (var client3 = new TestClient())
|
||||
{
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
|
||||
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
|
||||
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection3).OrTimeout();
|
||||
|
||||
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
|
||||
await AssertMessageAsync(client1);
|
||||
await AssertMessageAsync(client2);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StillSubscribedToUserAfterOneOfMultipleConnectionsAssociatedWithUserDisconnects()
|
||||
{
|
||||
var server = new TestRedisServer();
|
||||
|
||||
var manager = CreateLifetimeManager(server);
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
using (var client3 = new TestClient())
|
||||
{
|
||||
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
|
||||
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
|
||||
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
|
||||
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection3).OrTimeout();
|
||||
|
||||
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
|
||||
await AssertMessageAsync(client1);
|
||||
await AssertMessageAsync(client2);
|
||||
|
||||
// Disconnect one connection for the user
|
||||
await manager.OnDisconnectedAsync(connection1).OrTimeout();
|
||||
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
|
||||
await AssertMessageAsync(client2);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Primitives;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||
{
|
||||
|
|
@ -21,11 +22,28 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
// We start the servers before starting redis so we want to time them out ASAP
|
||||
options.Configuration.ConnectTimeout = 1;
|
||||
});
|
||||
|
||||
services.AddSingleton<IUserIdProvider, UserNameIdProvider>();
|
||||
}
|
||||
|
||||
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
|
||||
{
|
||||
app.UseSignalR(options => options.MapHub<EchoHub>("/echo"));
|
||||
}
|
||||
|
||||
private class UserNameIdProvider : IUserIdProvider
|
||||
{
|
||||
public string GetUserId(HubConnectionContext connection)
|
||||
{
|
||||
// This is an AWFUL way to authenticate users! We're just using it for test purposes.
|
||||
var userNameHeader = connection.GetHttpContext().Request.Headers["UserName"];
|
||||
if (!StringValues.IsNullOrEmpty(userNameHeader))
|
||||
{
|
||||
return userNameHeader;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
public TransferFormat ActiveFormat { get; set; }
|
||||
|
||||
public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false)
|
||||
public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null)
|
||||
{
|
||||
var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false);
|
||||
var pair = DuplexPipe.CreateConnectionPair(options, options);
|
||||
|
|
@ -44,9 +44,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
var claimValue = Interlocked.Increment(ref _id).ToString();
|
||||
var claims = new List<Claim> { new Claim(ClaimTypes.Name, claimValue) };
|
||||
if (addClaimId)
|
||||
if (userIdentifier != null)
|
||||
{
|
||||
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
|
||||
claims.Add(new Claim(ClaimTypes.NameIdentifier, userIdentifier));
|
||||
}
|
||||
|
||||
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));
|
||||
|
|
|
|||
|
|
@ -1100,9 +1100,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
|
||||
|
||||
using (var firstClient = new TestClient(addClaimId: true))
|
||||
using (var secondClient = new TestClient(addClaimId: true))
|
||||
using (var thirdClient = new TestClient(addClaimId: true))
|
||||
using (var firstClient = new TestClient(userIdentifier: "userA"))
|
||||
using (var secondClient = new TestClient(userIdentifier: "userB"))
|
||||
using (var thirdClient = new TestClient(userIdentifier: "userC"))
|
||||
{
|
||||
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
|
||||
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
|
||||
|
|
@ -1110,10 +1110,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout();
|
||||
|
||||
var secondAndThirdClients = new HashSet<string> {secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value,
|
||||
thirdClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value };
|
||||
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), secondAndThirdClients, "Second and Third").OrTimeout();
|
||||
await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), new[] { "userB", "userC" }, "Second and Third").OrTimeout();
|
||||
|
||||
var secondClientResult = await secondClient.ReadAsync().OrTimeout();
|
||||
var invocation = Assert.IsType<InvocationMessage>(secondClientResult);
|
||||
|
|
@ -1344,15 +1341,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
|
||||
|
||||
using (var firstClient = new TestClient(addClaimId: true))
|
||||
using (var secondClient = new TestClient(addClaimId: true))
|
||||
using (var firstClient = new TestClient(userIdentifier: "userA"))
|
||||
using (var secondClient = new TestClient(userIdentifier: "userB"))
|
||||
{
|
||||
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
|
||||
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
|
||||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
|
||||
|
||||
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();
|
||||
await firstClient.SendInvocationAsync("ClientSendMethod", "userB", "test").OrTimeout();
|
||||
|
||||
// check that 'secondConnection' has received the group send
|
||||
var hubMessage = await secondClient.ReadAsync().OrTimeout();
|
||||
|
|
|
|||
Loading…
Reference in New Issue