Merge branch 'release/2.1' into dev

This commit is contained in:
BrennanConroy 2018-05-15 16:07:57 -07:00
commit 9cb683a41d
8 changed files with 275 additions and 94 deletions

View File

@ -0,0 +1,63 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
internal class RedisSubscriptionManager
{
private readonly ConcurrentDictionary<string, HubConnectionStore> _subscriptions = new ConcurrentDictionary<string, HubConnectionStore>(StringComparer.Ordinal);
private readonly SemaphoreSlim _lock = new SemaphoreSlim(1, 1);
public async Task AddSubscriptionAsync(string id, HubConnectionContext connection, Func<string, HubConnectionStore, Task> subscribeMethod)
{
await _lock.WaitAsync();
try
{
var subscription = _subscriptions.GetOrAdd(id, _ => new HubConnectionStore());
subscription.Add(connection);
// Subscribe once
if (subscription.Count == 1)
{
await subscribeMethod(id, subscription);
}
}
finally
{
_lock.Release();
}
}
public async Task RemoveSubscriptionAsync(string id, HubConnectionContext connection, Func<string, Task> unsubscribeMethod)
{
await _lock.WaitAsync();
try
{
if (!_subscriptions.TryGetValue(id, out var subscription))
{
return;
}
subscription.Remove(connection);
if (subscription.Count == 0)
{
_subscriptions.TryRemove(id, out _);
await unsubscribeMethod(id);
}
}
finally
{
_lock.Release();
}
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
@ -20,8 +19,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable where THub : Hub
{
private readonly HubConnectionStore _connections = new HubConnectionStore();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>(StringComparer.Ordinal);
private readonly RedisSubscriptionManager _groups = new RedisSubscriptionManager();
private readonly RedisSubscriptionManager _users = new RedisSubscriptionManager();
private IConnectionMultiplexer _redisServerConnection;
private ISubscriber _bus;
private readonly ILogger _logger;
@ -54,17 +53,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var feature = new RedisFeature();
connection.Features.Set<IRedisFeature>(feature);
var redisSubscriptions = feature.Subscriptions;
var connectionTask = Task.CompletedTask;
var userTask = Task.CompletedTask;
_connections.Add(connection);
connectionTask = SubscribeToConnection(connection, redisSubscriptions);
connectionTask = SubscribeToConnection(connection);
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
userTask = SubscribeToUser(connection, redisSubscriptions);
userTask = SubscribeToUser(connection);
}
await Task.WhenAll(connectionTask, userTask);
@ -76,18 +74,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var tasks = new List<Task>();
var connectionChannel = _channels.Connection(connection.ConnectionId);
RedisLog.Unsubscribe(_logger, connectionChannel);
tasks.Add(_bus.UnsubscribeAsync(connectionChannel));
var feature = connection.Features.Get<IRedisFeature>();
var redisSubscriptions = feature.Subscriptions;
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
RedisLog.Unsubscribe(_logger, subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
var groupNames = feature.Groups;
if (groupNames != null)
@ -102,6 +93,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
if (!string.IsNullOrEmpty(connection.UserIdentifier))
{
tasks.Add(RemoveUserAsync(connection));
}
return Task.WhenAll(tasks);
}
@ -290,25 +286,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
var groupChannel = _channels.Group(groupName);
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
await group.Lock.WaitAsync();
try
{
group.Connections.Add(connection);
// Subscribe once
if (group.Connections.Count > 1)
{
return;
}
await SubscribeToGroup(groupChannel, group);
}
finally
{
group.Lock.Release();
}
await _groups.AddSubscriptionAsync(groupChannel, connection, SubscribeToGroupAsync);
}
/// <summary>
@ -319,10 +297,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
var groupChannel = _channels.Group(groupName);
if (!_groups.TryGetValue(groupChannel, out var group))
await _groups.RemoveSubscriptionAsync(groupChannel, connection, async channelName =>
{
return;
}
RedisLog.Unsubscribe(_logger, channelName);
await _bus.UnsubscribeAsync(channelName);
});
var feature = connection.Features.Get<IRedisFeature>();
var groupNames = feature.Groups;
@ -333,25 +312,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis
groupNames.Remove(groupName);
}
}
await group.Lock.WaitAsync();
try
{
if (group.Connections.Count > 0)
{
group.Connections.Remove(connection);
if (group.Connections.Count == 0)
{
RedisLog.Unsubscribe(_logger, groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
}
finally
{
group.Lock.Release();
}
}
private async Task SendGroupActionAndWaitForAck(string connectionId, string groupName, GroupAction action)
@ -365,6 +325,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await ack;
}
private Task RemoveUserAsync(HubConnectionContext connection)
{
var userChannel = _channels.User(connection.UserIdentifier);
return _users.RemoveSubscriptionAsync(userChannel, connection, async channelName =>
{
RedisLog.Unsubscribe(_logger, channelName);
await _bus.UnsubscribeAsync(channelName);
});
}
public void Dispose()
{
_bus?.UnsubscribeAll();
@ -448,10 +419,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
});
}
private Task SubscribeToConnection(HubConnectionContext connection, HashSet<string> redisSubscriptions)
private Task SubscribeToConnection(HubConnectionContext connection)
{
var connectionChannel = _channels.Connection(connection.ConnectionId);
redisSubscriptions.Add(connectionChannel);
RedisLog.Subscribing(_logger, connectionChannel);
return _bus.SubscribeAsync(connectionChannel, async (c, data) =>
@ -461,20 +431,35 @@ namespace Microsoft.AspNetCore.SignalR.Redis
});
}
private Task SubscribeToUser(HubConnectionContext connection, HashSet<string> redisSubscriptions)
private Task SubscribeToUser(HubConnectionContext connection)
{
var userChannel = _channels.User(connection.UserIdentifier);
redisSubscriptions.Add(userChannel);
// TODO: Look at optimizing (looping over connections checking for Name)
return _bus.SubscribeAsync(userChannel, async (c, data) =>
return _users.AddSubscriptionAsync(userChannel, connection, async (channelName, subscriptions) =>
{
var invocation = _protocol.ReadInvocation((byte[])data);
await connection.WriteAsync(invocation.Message);
await _bus.SubscribeAsync(channelName, async (c, data) =>
{
try
{
var invocation = _protocol.ReadInvocation((byte[])data);
var tasks = new List<Task>();
foreach (var userConnection in subscriptions)
{
tasks.Add(userConnection.WriteAsync(invocation.Message).AsTask());
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
RedisLog.FailedWritingMessage(_logger, ex);
}
});
});
}
private Task SubscribeToGroup(string groupChannel, GroupData group)
private Task SubscribeToGroupAsync(string groupChannel, HubConnectionStore groupConnections)
{
RedisLog.Subscribing(_logger, groupChannel);
return _bus.SubscribeAsync(groupChannel, async (c, data) =>
@ -484,7 +469,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var invocation = _protocol.ReadInvocation((byte[])data);
var tasks = new List<Task>();
foreach (var groupConnection in group.Connections)
foreach (var groupConnection in groupConnections)
{
if (invocation.ExcludedConnectionIds?.Contains(groupConnection.ConnectionId) == true)
{
@ -515,6 +500,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var writer = new LoggerTextWriter(_logger);
_redisServerConnection = await _options.ConnectAsync(writer);
_bus = _redisServerConnection.GetSubscriber();
_redisServerConnection.ConnectionRestored += (_, e) =>
{
// We use the subscription connection type
@ -589,21 +575,13 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
private class GroupData
{
public readonly SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public readonly HubConnectionStore Connections = new HubConnectionStore();
}
private interface IRedisFeature
{
HashSet<string> Subscriptions { get; }
HashSet<string> Groups { get; }
}
private class RedisFeature : IRedisFeature
{
public HashSet<string> Subscriptions { get; } = new HashSet<string>();
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
}

View File

@ -18,6 +18,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
return Clients.Group(groupName).SendAsync("Echo", message);
}
public Task EchoUser(string userName, string message)
{
return Clients.User(userName).SendAsync("Echo", message);
}
public Task AddSelfToGroup(string groupName)
{
return Groups.AddToGroupAsync(Context.ConnectionId, groupName);

View File

@ -11,7 +11,6 @@ using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Xunit;
using Xunit.Abstractions;
@ -39,7 +38,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
_serverFixture = serverFixture;
}
[ConditionalTheory()]
[ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task HubConnectionCanSendAndReceiveMessages(HttpTransportType transportType, string protocolName)
@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[ConditionalTheory()]
[ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task HubConnectionCanSendAndReceiveGroupMessages(HttpTransportType transportType, string protocolName)
@ -93,11 +92,77 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory)
[ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser(HttpTransportType transportType, string protocolName)
{
using (StartVerifiableLog(out var loggerFactory, testName:
$"{nameof(CanSendAndReceiveUserMessagesFromMultipleConnectionsWithSameUser)}_{transportType.ToString()}_{protocolName}"))
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
var connection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var tcs = new TaskCompletionSource<string>();
connection.On<string>("Echo", message => tcs.TrySetResult(message));
var tcs2 = new TaskCompletionSource<string>();
secondConnection.On<string>("Echo", message => tcs2.TrySetResult(message));
await secondConnection.StartAsync().OrTimeout();
await connection.StartAsync().OrTimeout();
await connection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
Assert.Equal("Hello, World!", await tcs2.Task.OrTimeout());
await connection.DisposeAsync().OrTimeout();
await secondConnection.DisposeAsync().OrTimeout();
}
}
[ConditionalTheory]
[SkipIfDockerNotPresent]
[MemberData(nameof(TransportTypesAndProtocolTypes))]
public async Task CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects(HttpTransportType transportType, string protocolName)
{
// Regression test:
// When multiple connections from the same user were connected and one left, it used to unsubscribe from the user channel
// Now we keep track of users connections and only unsubscribe when no users are listening
using (StartVerifiableLog(out var loggerFactory, testName:
$"{nameof(CanSendAndReceiveUserMessagesWhenOneConnectionWithUserDisconnects)}_{transportType.ToString()}_{protocolName}"))
{
var protocol = HubProtocolHelpers.GetHubProtocol(protocolName);
var firstConnection = CreateConnection(_serverFixture.FirstServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var secondConnection = CreateConnection(_serverFixture.SecondServer.Url + "/echo", transportType, protocol, loggerFactory, userName: "userA");
var tcs = new TaskCompletionSource<string>();
firstConnection.On<string>("Echo", message => tcs.TrySetResult(message));
await secondConnection.StartAsync().OrTimeout();
await firstConnection.StartAsync().OrTimeout();
await secondConnection.DisposeAsync().OrTimeout();
await firstConnection.InvokeAsync("EchoUser", "userA", "Hello, World!").OrTimeout();
Assert.Equal("Hello, World!", await tcs.Task.OrTimeout());
await firstConnection.DisposeAsync().OrTimeout();
}
}
private static HubConnection CreateConnection(string url, HttpTransportType transportType, IHubProtocol protocol, ILoggerFactory loggerFactory, string userName = null)
{
var hubConnectionBuilder = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
.WithUrl(url, transportType);
.WithUrl(url, transportType, httpConnectionOptions =>
{
if (!string.IsNullOrEmpty(userName))
{
httpConnectionOptions.Headers["UserName"] = userName;
}
});
hubConnectionBuilder.Services.AddSingleton(protocol);

View File

@ -500,6 +500,61 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[Fact]
public async Task InvokeUserSendsToAllConnectionsForUser()
{
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client3 = new TestClient())
{
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnConnectedAsync(connection3).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client1);
await AssertMessageAsync(client2);
}
}
[Fact]
public async Task StillSubscribedToUserAfterOneOfMultipleConnectionsAssociatedWithUserDisconnects()
{
var server = new TestRedisServer();
var manager = CreateLifetimeManager(server);
using (var client1 = new TestClient())
using (var client2 = new TestClient())
using (var client3 = new TestClient())
{
var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "userA");
var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "userA");
var connection3 = HubConnectionContextUtils.Create(client3.Connection, userIdentifier: "userB");
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnConnectedAsync(connection3).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client1);
await AssertMessageAsync(client2);
// Disconnect one connection for the user
await manager.OnDisconnectedAsync(connection1).OrTimeout();
await manager.SendUserAsync("userA", "Hello", new object[] { "World" }).OrTimeout();
await AssertMessageAsync(client2);
}
}
[Fact]
public async Task CamelCasedJsonIsPreservedAcrossRedisBoundary()
{

View File

@ -4,6 +4,7 @@
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
{
@ -21,11 +22,28 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
// We start the servers before starting redis so we want to time them out ASAP
options.Configuration.ConnectTimeout = 1;
});
services.AddSingleton<IUserIdProvider, UserNameIdProvider>();
}
public void Configure(IApplicationBuilder app, IHostingEnvironment env)
{
app.UseSignalR(options => options.MapHub<EchoHub>("/echo"));
}
private class UserNameIdProvider : IUserIdProvider
{
public string GetUserId(HubConnectionContext connection)
{
// This is an AWFUL way to authenticate users! We're just using it for test purposes.
var userNameHeader = connection.GetHttpContext().Request.Headers["UserName"];
if (!StringValues.IsNullOrEmpty(userNameHeader))
{
return userNameHeader;
}
return null;
}
}
}
}

View File

@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public TransferFormat ActiveFormat { get; set; }
public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, bool addClaimId = false)
public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null)
{
var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false);
var pair = DuplexPipe.CreateConnectionPair(options, options);
@ -44,9 +44,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var claimValue = Interlocked.Increment(ref _id).ToString();
var claims = new List<Claim> { new Claim(ClaimTypes.Name, claimValue) };
if (addClaimId)
if (userIdentifier != null)
{
claims.Add(new Claim(ClaimTypes.NameIdentifier, claimValue));
claims.Add(new Claim(ClaimTypes.NameIdentifier, userIdentifier));
}
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(claims));

View File

@ -1100,9 +1100,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
using (var firstClient = new TestClient(addClaimId: true))
using (var secondClient = new TestClient(addClaimId: true))
using (var thirdClient = new TestClient(addClaimId: true))
using (var firstClient = new TestClient(userIdentifier: "userA"))
using (var secondClient = new TestClient(userIdentifier: "userB"))
using (var thirdClient = new TestClient(userIdentifier: "userC"))
{
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
@ -1110,10 +1110,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout();
var secondAndThirdClients = new HashSet<string> {secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value,
thirdClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value };
await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), secondAndThirdClients, "Second and Third").OrTimeout();
await firstClient.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), new[] { "userB", "userC" }, "Second and Third").OrTimeout();
var secondClientResult = await secondClient.ReadAsync().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(secondClientResult);
@ -1344,15 +1341,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(hubType);
using (var firstClient = new TestClient(addClaimId: true))
using (var secondClient = new TestClient(addClaimId: true))
using (var firstClient = new TestClient(userIdentifier: "userA"))
using (var secondClient = new TestClient(userIdentifier: "userB"))
{
var firstConnectionHandlerTask = await firstClient.ConnectAsync(connectionHandler);
var secondConnectionHandlerTask = await secondClient.ConnectAsync(connectionHandler);
await Task.WhenAll(firstClient.Connected, secondClient.Connected).OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", secondClient.Connection.User.FindFirst(ClaimTypes.NameIdentifier)?.Value, "test").OrTimeout();
await firstClient.SendInvocationAsync("ClientSendMethod", "userB", "test").OrTimeout();
// check that 'secondConnection' has received the group send
var hubMessage = await secondClient.ReadAsync().OrTimeout();