From d469cc3151454fdf4913670e6775af9e0a60e8e8 Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Tue, 22 Aug 2017 17:33:27 -0700 Subject: [PATCH] Clients Subset - AllExcept (#700) --- .../ChatSample/PresenceHubLifetimeManager.cs | 8 +- samples/SocketsSample/Hubs/Chat.cs | 1 + .../Internal/Protocol/InvocationMessage.cs | 1 + .../RedisHubLifetimeManager.cs | 81 ++++++++++++++----- .../DefaultHubLifetimeManager.cs | 9 +++ .../HubContext.cs | 7 ++ .../HubLifetimeManager.cs | 3 + .../IHubClients.cs | 4 + src/Microsoft.AspNetCore.SignalR/Proxies.cs | 19 ++++- .../HubEndpointTests.cs | 49 +++++++++++ .../TestClient.cs | 1 - 11 files changed, 160 insertions(+), 23 deletions(-) diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index 841eeaa907..ac3c61f66a 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -2,12 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.AspNetCore.SignalR.Redis; -using System.Linq; namespace ChatSample { @@ -141,6 +142,11 @@ namespace ChatSample return _wrappedHubLifetimeManager.InvokeAllAsync(methodName, args); } + public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) + { + return _wrappedHubLifetimeManager.InvokeAllExceptAsync(methodName, args, excludedIds); + } + public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args) { return _wrappedHubLifetimeManager.InvokeConnectionAsync(connectionId, methodName, args); diff --git a/samples/SocketsSample/Hubs/Chat.cs b/samples/SocketsSample/Hubs/Chat.cs index 29f09f1ca8..a01e1dec00 100644 --- a/samples/SocketsSample/Hubs/Chat.cs +++ b/samples/SocketsSample/Hubs/Chat.cs @@ -3,6 +3,7 @@ using System; using System.Threading.Tasks; +using System.Collections.Generic; using Microsoft.AspNetCore.SignalR; namespace SocketsSample.Hubs diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs index 4f8a4c738f..fa941b3fe4 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/InvocationMessage.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; namespace Microsoft.AspNetCore.SignalR.Internal.Protocol diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 1771594675..fcc40935be 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -26,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis private readonly ISubscriber _bus; private readonly ILogger _logger; private readonly RedisOptions _options; + private readonly string _channelNamePrefix = typeof(THub).FullName; // This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection. private readonly JsonSerializer _serializer = new JsonSerializer @@ -60,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis var previousBroadcastTask = Task.CompletedTask; - var channelName = typeof(THub).FullName; + var channelName = _channelNamePrefix; _logger.LogInformation("Subscribing to channel: {channel}", channelName); _bus.Subscribe(channelName, async (c, data) => { @@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis _logger.LogTrace("Received message from redis channel {channel}", channelName); - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf var tasks = new List(_connections.Count); @@ -80,34 +81,67 @@ namespace Microsoft.AspNetCore.SignalR.Redis previousBroadcastTask = Task.WhenAll(tasks); }); + + var allExceptTask = Task.CompletedTask; + channelName = _channelNamePrefix + ".AllExcept"; + _logger.LogInformation("Subscribing to channel: {channel}", channelName); + _bus.Subscribe(channelName, async (c, data) => + { + await allExceptTask; + + _logger.LogTrace("Received message from redis channel {channel}", channelName); + + var message = DeserializeMessage(data); + var excludedIds = message.ExcludedIds; + + // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + if (!excludedIds.Contains(connection.ConnectionId)) + { + tasks.Add(WriteAsync(connection, message)); + } + } + + allExceptTask = Task.WhenAll(tasks); + }); } public override Task InvokeAllAsync(string methodName, object[] args) { var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); - return PublishAsync(typeof(THub).FullName, message); + return PublishAsync(_channelNamePrefix, message); + } + + public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) + { + var message = new RedisExcludeClientsMessage(GetInvocationId(), nonBlocking: true, target: methodName, excludedIds: excludedIds, arguments: args); + return PublishAsync(_channelNamePrefix + ".AllExcept", message); } public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args) { var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); - return PublishAsync(typeof(THub).FullName + "." + connectionId, message); + return PublishAsync(_channelNamePrefix + "." + connectionId, message); } public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) { var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); - return PublishAsync(typeof(THub).FullName + ".group." + groupName, message); + return PublishAsync(_channelNamePrefix + ".group." + groupName, message); } public override Task InvokeUserAsync(string userId, string methodName, object[] args) { var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); - return PublishAsync(typeof(THub).FullName + ".user." + userId, message); + return PublishAsync(_channelNamePrefix + ".user." + userId, message); } private async Task PublishAsync(string channel, HubMessage hubMessage) @@ -136,7 +170,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis _connections.Add(connection); - var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId; + var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; redisSubscriptions.Add(connectionChannel); var previousConnectionTask = Task.CompletedTask; @@ -146,15 +180,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis { await previousConnectionTask; - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); previousConnectionTask = WriteAsync(connection, message); }); - if (connection.User.Identity.IsAuthenticated) { - var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name; + var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name; redisSubscriptions.Add(userChannel); var previousUserTask = Task.CompletedTask; @@ -164,7 +197,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { await previousUserTask; - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); previousUserTask = WriteAsync(connection, message); }); @@ -208,7 +241,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override async Task AddGroupAsync(string connectionId, string groupName) { - var groupChannel = typeof(THub).FullName + ".group." + groupName; + var groupChannel = _channelNamePrefix + ".group." + groupName; var connection = _connections[connectionId]; if (connection == null) { @@ -246,7 +279,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // want to do concurrent writes to the outgoing connections await previousTask; - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); var tasks = new List(group.Connections.Count); foreach (var groupConnection in group.Connections) @@ -265,7 +298,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override async Task RemoveGroupAsync(string connectionId, string groupName) { - var groupChannel = typeof(THub).FullName + ".group." + groupName; + var groupChannel = _channelNamePrefix + ".group." + groupName; GroupData group; if (!_groups.TryGetValue(groupChannel, out group)) @@ -329,15 +362,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis return invocationId.ToString(); } - private HubMessage DeserializeMessage(RedisValue data) + private T DeserializeMessage(RedisValue data) { - HubMessage message; - using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data)))) + using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data)))) { - message = (HubMessage)_serializer.Deserialize(reader); + return (T)_serializer.Deserialize(reader); } - - return message; } private class LoggerTextWriter : TextWriter @@ -362,6 +392,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } + public class RedisExcludeClientsMessage : InvocationMessage + { + public IReadOnlyList ExcludedIds; + + public RedisExcludeClientsMessage(string invocationId, bool nonBlocking, string target, IReadOnlyList excludedIds, params object[] arguments) + : base(invocationId, nonBlocking, target, arguments) + { + ExcludedIds = excludedIds; + } + } + private class GroupData { public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 18c9c80c9c..9dadfc8986 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; @@ -140,6 +141,14 @@ namespace Microsoft.AspNetCore.SignalR return invocationId.ToString(); } + public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) + { + return InvokeAllWhere(methodName, args, connection => + { + return !excludedIds.Contains(connection.ConnectionId); + }); + } + private interface IHubGroupsFeature { HashSet Groups { get; } diff --git a/src/Microsoft.AspNetCore.SignalR/HubContext.cs b/src/Microsoft.AspNetCore.SignalR/HubContext.cs index c527174b6d..8fbcf2d0bc 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubContext.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; + namespace Microsoft.AspNetCore.SignalR { public class HubContext : IHubContext, IHubClients where THub : Hub @@ -20,6 +22,11 @@ namespace Microsoft.AspNetCore.SignalR public virtual IGroupManager Groups { get; } + public IClientProxy AllExcept(IReadOnlyList excludedIds) + { + return new AllClientsExceptProxy(_lifetimeManager, excludedIds); + } + public virtual IClientProxy Client(string connectionId) { return new SingleClientProxy(_lifetimeManager, connectionId); diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 0879fa3227..9e57825636 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR @@ -13,6 +14,8 @@ namespace Microsoft.AspNetCore.SignalR public abstract Task InvokeAllAsync(string methodName, object[] args); + public abstract Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds); + public abstract Task InvokeConnectionAsync(string connectionId, string methodName, object[] args); public abstract Task InvokeGroupAsync(string groupName, string methodName, object[] args); diff --git a/src/Microsoft.AspNetCore.SignalR/IHubClients.cs b/src/Microsoft.AspNetCore.SignalR/IHubClients.cs index 7721dc1c88..823a9db9d2 100644 --- a/src/Microsoft.AspNetCore.SignalR/IHubClients.cs +++ b/src/Microsoft.AspNetCore.SignalR/IHubClients.cs @@ -1,12 +1,16 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; + namespace Microsoft.AspNetCore.SignalR { public interface IHubClients { IClientProxy All { get; } + IClientProxy AllExcept(IReadOnlyList excludedIds); + IClientProxy Client(string connectionId); IClientProxy Group(string groupName); diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index 048ee43ba9..5d4672fc53 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR @@ -54,12 +55,28 @@ namespace Microsoft.AspNetCore.SignalR } } + public class AllClientsExceptProxy : IClientProxy + { + private readonly HubLifetimeManager _lifetimeManager; + private IReadOnlyList _excludedIds; + + public AllClientsExceptProxy(HubLifetimeManager lifetimeManager, IReadOnlyList excludedIds) + { + _lifetimeManager = lifetimeManager; + _excludedIds = excludedIds; + } + + public Task InvokeAsync(string method, params object[] args) + { + return _lifetimeManager.InvokeAllExceptAsync(method, args, _excludedIds); + } + } + public class SingleClientProxy : IClientProxy { private readonly string _connectionId; private readonly HubLifetimeManager _lifetimeManager; - public SingleClientProxy(HubLifetimeManager lifetimeManager, string connectionId) { _lifetimeManager = lifetimeManager; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 8dee3a565a..4fbaf6d2a1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -511,6 +511,50 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task SendToAllExcept() + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var firstClient = new TestClient()) + using (var secondClient = new TestClient()) + using (var thirdClient = new TestClient()) + { + Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection); + Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection); + Task thirdEndPointTask = endPoint.OnConnectedAsync(thirdClient.Connection); + + await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout(); + + var excludeSecondClientId = new HashSet(); + excludeSecondClientId.Add(secondClient.Connection.ConnectionId); + var excludeThirdClientId = new HashSet(); + excludeThirdClientId.Add(thirdClient.Connection.ConnectionId); + + await firstClient.SendInvocationAsync("SendToAllExcept", "To second", excludeThirdClientId).OrTimeout(); + await firstClient.SendInvocationAsync("SendToAllExcept", "To third", excludeSecondClientId).OrTimeout(); + + var secondClientResult = await secondClient.ReadAsync().OrTimeout(); + var invocation = Assert.IsType(secondClientResult); + Assert.Equal("Send", invocation.Target); + Assert.Equal("To second", invocation.Arguments[0]); + + var thirdClientResult = await thirdClient.ReadAsync().OrTimeout(); + invocation = Assert.IsType(thirdClientResult); + Assert.Equal("Send", invocation.Target); + Assert.Equal("To third", invocation.Arguments[0]); + + // kill the connections + firstClient.Dispose(); + secondClient.Dispose(); + thirdClient.Dispose(); + + await Task.WhenAll(firstEndPointTask, secondEndPointTask, thirdEndPointTask).OrTimeout(); + } + } + [Theory] [MemberData(nameof(HubTypes))] public async Task HubsCanAddAndSendToGroup(Type hubType) @@ -1063,6 +1107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void AuthMethod() { } + + public Task SendToAllExcept(string message, IReadOnlyList excludedIds) + { + return Clients.AllExcept(excludedIds).InvokeAsync("Send", message); + } } private class InheritedHub : BaseHub diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index 6e80b42c5a..e589fefff1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -124,7 +124,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) { var invocationId = GetInvocationId(); - var payload = _protocolReaderWriter.WriteMessage(new InvocationMessage(invocationId, nonBlocking, methodName, args)); await Application.Out.WriteAsync(payload);