Clients Subset - AllExcept (#700)

This commit is contained in:
Mikael Mengistu 2017-08-22 17:33:27 -07:00 committed by GitHub
parent 5c6fb642a0
commit d469cc3151
11 changed files with 160 additions and 23 deletions

View File

@ -2,12 +2,13 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.SignalR.Redis;
using System.Linq;
namespace ChatSample
{
@ -141,6 +142,11 @@ namespace ChatSample
return _wrappedHubLifetimeManager.InvokeAllAsync(methodName, args);
}
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
return _wrappedHubLifetimeManager.InvokeAllExceptAsync(methodName, args, excludedIds);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
return _wrappedHubLifetimeManager.InvokeConnectionAsync(connectionId, methodName, args);

View File

@ -3,6 +3,7 @@
using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR;
namespace SocketsSample.Hubs

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol

View File

@ -26,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private readonly ISubscriber _bus;
private readonly ILogger _logger;
private readonly RedisOptions _options;
private readonly string _channelNamePrefix = typeof(THub).FullName;
// This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
private readonly JsonSerializer _serializer = new JsonSerializer
@ -60,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var previousBroadcastTask = Task.CompletedTask;
var channelName = typeof(THub).FullName;
var channelName = _channelNamePrefix;
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_logger.LogTrace("Received message from redis channel {channel}", channelName);
var message = DeserializeMessage(data);
var message = DeserializeMessage<HubMessage>(data);
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
@ -80,34 +81,67 @@ namespace Microsoft.AspNetCore.SignalR.Redis
previousBroadcastTask = Task.WhenAll(tasks);
});
var allExceptTask = Task.CompletedTask;
channelName = _channelNamePrefix + ".AllExcept";
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
await allExceptTask;
_logger.LogTrace("Received message from redis channel {channel}", channelName);
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
var excludedIds = message.ExcludedIds;
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
if (!excludedIds.Contains(connection.ConnectionId))
{
tasks.Add(WriteAsync(connection, message));
}
}
allExceptTask = Task.WhenAll(tasks);
});
}
public override Task InvokeAllAsync(string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName, message);
return PublishAsync(_channelNamePrefix, message);
}
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
var message = new RedisExcludeClientsMessage(GetInvocationId(), nonBlocking: true, target: methodName, excludedIds: excludedIds, arguments: args);
return PublishAsync(_channelNamePrefix + ".AllExcept", message);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
return PublishAsync(_channelNamePrefix + ".user." + userId, message);
}
private async Task PublishAsync(string channel, HubMessage hubMessage)
@ -136,7 +170,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_connections.Add(connection);
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = Task.CompletedTask;
@ -146,15 +180,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousConnectionTask;
var message = DeserializeMessage(data);
var message = DeserializeMessage<HubMessage>(data);
previousConnectionTask = WriteAsync(connection, message);
});
if (connection.User.Identity.IsAuthenticated)
{
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
redisSubscriptions.Add(userChannel);
var previousUserTask = Task.CompletedTask;
@ -164,7 +197,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousUserTask;
var message = DeserializeMessage(data);
var message = DeserializeMessage<HubMessage>(data);
previousUserTask = WriteAsync(connection, message);
});
@ -208,7 +241,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override async Task AddGroupAsync(string connectionId, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
var groupChannel = _channelNamePrefix + ".group." + groupName;
var connection = _connections[connectionId];
if (connection == null)
{
@ -246,7 +279,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
// want to do concurrent writes to the outgoing connections
await previousTask;
var message = DeserializeMessage(data);
var message = DeserializeMessage<HubMessage>(data);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
@ -265,7 +298,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
public override async Task RemoveGroupAsync(string connectionId, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
var groupChannel = _channelNamePrefix + ".group." + groupName;
GroupData group;
if (!_groups.TryGetValue(groupChannel, out group))
@ -329,15 +362,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return invocationId.ToString();
}
private HubMessage DeserializeMessage(RedisValue data)
private T DeserializeMessage<T>(RedisValue data)
{
HubMessage message;
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data))))
{
message = (HubMessage)_serializer.Deserialize(reader);
return (T)_serializer.Deserialize(reader);
}
return message;
}
private class LoggerTextWriter : TextWriter
@ -362,6 +392,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public class RedisExcludeClientsMessage : InvocationMessage
{
public IReadOnlyList<string> ExcludedIds;
public RedisExcludeClientsMessage(string invocationId, bool nonBlocking, string target, IReadOnlyList<string> excludedIds, params object[] arguments)
: base(invocationId, nonBlocking, target, arguments)
{
ExcludedIds = excludedIds;
}
}
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
@ -140,6 +141,14 @@ namespace Microsoft.AspNetCore.SignalR
return invocationId.ToString();
}
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
return InvokeAllWhere(methodName, args, connection =>
{
return !excludedIds.Contains(connection.ConnectionId);
});
}
private interface IHubGroupsFeature
{
HashSet<string> Groups { get; }

View File

@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubContext<THub> : IHubContext<THub>, IHubClients where THub : Hub
@ -20,6 +22,11 @@ namespace Microsoft.AspNetCore.SignalR
public virtual IGroupManager Groups { get; }
public IClientProxy AllExcept(IReadOnlyList<string> excludedIds)
{
return new AllClientsExceptProxy<THub>(_lifetimeManager, excludedIds);
}
public virtual IClientProxy Client(string connectionId)
{
return new SingleClientProxy<THub>(_lifetimeManager, connectionId);

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
@ -13,6 +14,8 @@ namespace Microsoft.AspNetCore.SignalR
public abstract Task InvokeAllAsync(string methodName, object[] args);
public abstract Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds);
public abstract Task InvokeConnectionAsync(string connectionId, string methodName, object[] args);
public abstract Task InvokeGroupAsync(string groupName, string methodName, object[] args);

View File

@ -1,12 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public interface IHubClients
{
IClientProxy All { get; }
IClientProxy AllExcept(IReadOnlyList<string> excludedIds);
IClientProxy Client(string connectionId);
IClientProxy Group(string groupName);

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR
@ -54,12 +55,28 @@ namespace Microsoft.AspNetCore.SignalR
}
}
public class AllClientsExceptProxy<THub> : IClientProxy
{
private readonly HubLifetimeManager<THub> _lifetimeManager;
private IReadOnlyList<string> _excludedIds;
public AllClientsExceptProxy(HubLifetimeManager<THub> lifetimeManager, IReadOnlyList<string> excludedIds)
{
_lifetimeManager = lifetimeManager;
_excludedIds = excludedIds;
}
public Task InvokeAsync(string method, params object[] args)
{
return _lifetimeManager.InvokeAllExceptAsync(method, args, _excludedIds);
}
}
public class SingleClientProxy<THub> : IClientProxy
{
private readonly string _connectionId;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public SingleClientProxy(HubLifetimeManager<THub> lifetimeManager, string connectionId)
{
_lifetimeManager = lifetimeManager;

View File

@ -511,6 +511,50 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task SendToAllExcept()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var firstClient = new TestClient())
using (var secondClient = new TestClient())
using (var thirdClient = new TestClient())
{
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);
Task thirdEndPointTask = endPoint.OnConnectedAsync(thirdClient.Connection);
await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout();
var excludeSecondClientId = new HashSet<string>();
excludeSecondClientId.Add(secondClient.Connection.ConnectionId);
var excludeThirdClientId = new HashSet<string>();
excludeThirdClientId.Add(thirdClient.Connection.ConnectionId);
await firstClient.SendInvocationAsync("SendToAllExcept", "To second", excludeThirdClientId).OrTimeout();
await firstClient.SendInvocationAsync("SendToAllExcept", "To third", excludeSecondClientId).OrTimeout();
var secondClientResult = await secondClient.ReadAsync().OrTimeout();
var invocation = Assert.IsType<InvocationMessage>(secondClientResult);
Assert.Equal("Send", invocation.Target);
Assert.Equal("To second", invocation.Arguments[0]);
var thirdClientResult = await thirdClient.ReadAsync().OrTimeout();
invocation = Assert.IsType<InvocationMessage>(thirdClientResult);
Assert.Equal("Send", invocation.Target);
Assert.Equal("To third", invocation.Arguments[0]);
// kill the connections
firstClient.Dispose();
secondClient.Dispose();
thirdClient.Dispose();
await Task.WhenAll(firstEndPointTask, secondEndPointTask, thirdEndPointTask).OrTimeout();
}
}
[Theory]
[MemberData(nameof(HubTypes))]
public async Task HubsCanAddAndSendToGroup(Type hubType)
@ -1063,6 +1107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void AuthMethod()
{
}
public Task SendToAllExcept(string message, IReadOnlyList<string> excludedIds)
{
return Clients.AllExcept(excludedIds).InvokeAsync("Send", message);
}
}
private class InheritedHub : BaseHub

View File

@ -124,7 +124,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public async Task<string> SendInvocationAsync(string methodName, bool nonBlocking, params object[] args)
{
var invocationId = GetInvocationId();
var payload = _protocolReaderWriter.WriteMessage(new InvocationMessage(invocationId, nonBlocking, methodName, args));
await Application.Out.WriteAsync(payload);