Clients Subset - AllExcept (#700)
This commit is contained in:
parent
5c6fb642a0
commit
d469cc3151
|
|
@ -2,12 +2,13 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.AspNetCore.SignalR.Redis;
|
||||
using System.Linq;
|
||||
|
||||
namespace ChatSample
|
||||
{
|
||||
|
|
@ -141,6 +142,11 @@ namespace ChatSample
|
|||
return _wrappedHubLifetimeManager.InvokeAllAsync(methodName, args);
|
||||
}
|
||||
|
||||
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
return _wrappedHubLifetimeManager.InvokeAllExceptAsync(methodName, args, excludedIds);
|
||||
}
|
||||
|
||||
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
||||
{
|
||||
return _wrappedHubLifetimeManager.InvokeConnectionAsync(connectionId, methodName, args);
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.AspNetCore.SignalR;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
private readonly ISubscriber _bus;
|
||||
private readonly ILogger _logger;
|
||||
private readonly RedisOptions _options;
|
||||
private readonly string _channelNamePrefix = typeof(THub).FullName;
|
||||
|
||||
// This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
|
||||
private readonly JsonSerializer _serializer = new JsonSerializer
|
||||
|
|
@ -60,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
var previousBroadcastTask = Task.CompletedTask;
|
||||
|
||||
var channelName = typeof(THub).FullName;
|
||||
var channelName = _channelNamePrefix;
|
||||
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
|
|
@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
_logger.LogTrace("Received message from redis channel {channel}", channelName);
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
|
@ -80,34 +81,67 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
previousBroadcastTask = Task.WhenAll(tasks);
|
||||
});
|
||||
|
||||
var allExceptTask = Task.CompletedTask;
|
||||
channelName = _channelNamePrefix + ".AllExcept";
|
||||
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
await allExceptTask;
|
||||
|
||||
_logger.LogTrace("Received message from redis channel {channel}", channelName);
|
||||
|
||||
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
|
||||
var excludedIds = message.ExcludedIds;
|
||||
|
||||
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
if (!excludedIds.Contains(connection.ConnectionId))
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
}
|
||||
|
||||
allExceptTask = Task.WhenAll(tasks);
|
||||
});
|
||||
}
|
||||
|
||||
public override Task InvokeAllAsync(string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName, message);
|
||||
return PublishAsync(_channelNamePrefix, message);
|
||||
}
|
||||
|
||||
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
var message = new RedisExcludeClientsMessage(GetInvocationId(), nonBlocking: true, target: methodName, excludedIds: excludedIds, arguments: args);
|
||||
return PublishAsync(_channelNamePrefix + ".AllExcept", message);
|
||||
}
|
||||
|
||||
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
|
||||
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
|
||||
}
|
||||
|
||||
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
|
||||
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
|
||||
}
|
||||
|
||||
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
|
||||
{
|
||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||
|
||||
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
|
||||
return PublishAsync(_channelNamePrefix + ".user." + userId, message);
|
||||
}
|
||||
|
||||
private async Task PublishAsync(string channel, HubMessage hubMessage)
|
||||
|
|
@ -136,7 +170,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
_connections.Add(connection);
|
||||
|
||||
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
|
||||
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
|
||||
redisSubscriptions.Add(connectionChannel);
|
||||
|
||||
var previousConnectionTask = Task.CompletedTask;
|
||||
|
|
@ -146,15 +180,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
await previousConnectionTask;
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
previousConnectionTask = WriteAsync(connection, message);
|
||||
});
|
||||
|
||||
|
||||
if (connection.User.Identity.IsAuthenticated)
|
||||
{
|
||||
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
|
||||
var userChannel = _channelNamePrefix + ".user." + connection.User.Identity.Name;
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
var previousUserTask = Task.CompletedTask;
|
||||
|
|
@ -164,7 +197,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
{
|
||||
await previousUserTask;
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
previousUserTask = WriteAsync(connection, message);
|
||||
});
|
||||
|
|
@ -208,7 +241,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
public override async Task AddGroupAsync(string connectionId, string groupName)
|
||||
{
|
||||
var groupChannel = typeof(THub).FullName + ".group." + groupName;
|
||||
var groupChannel = _channelNamePrefix + ".group." + groupName;
|
||||
var connection = _connections[connectionId];
|
||||
if (connection == null)
|
||||
{
|
||||
|
|
@ -246,7 +279,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
// want to do concurrent writes to the outgoing connections
|
||||
await previousTask;
|
||||
|
||||
var message = DeserializeMessage(data);
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
var tasks = new List<Task>(group.Connections.Count);
|
||||
foreach (var groupConnection in group.Connections)
|
||||
|
|
@ -265,7 +298,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
|
||||
public override async Task RemoveGroupAsync(string connectionId, string groupName)
|
||||
{
|
||||
var groupChannel = typeof(THub).FullName + ".group." + groupName;
|
||||
var groupChannel = _channelNamePrefix + ".group." + groupName;
|
||||
|
||||
GroupData group;
|
||||
if (!_groups.TryGetValue(groupChannel, out group))
|
||||
|
|
@ -329,15 +362,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return invocationId.ToString();
|
||||
}
|
||||
|
||||
private HubMessage DeserializeMessage(RedisValue data)
|
||||
private T DeserializeMessage<T>(RedisValue data)
|
||||
{
|
||||
HubMessage message;
|
||||
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
|
||||
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream(data))))
|
||||
{
|
||||
message = (HubMessage)_serializer.Deserialize(reader);
|
||||
return (T)_serializer.Deserialize(reader);
|
||||
}
|
||||
|
||||
return message;
|
||||
}
|
||||
|
||||
private class LoggerTextWriter : TextWriter
|
||||
|
|
@ -362,6 +392,17 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
}
|
||||
|
||||
public class RedisExcludeClientsMessage : InvocationMessage
|
||||
{
|
||||
public IReadOnlyList<string> ExcludedIds;
|
||||
|
||||
public RedisExcludeClientsMessage(string invocationId, bool nonBlocking, string target, IReadOnlyList<string> excludedIds, params object[] arguments)
|
||||
: base(invocationId, nonBlocking, target, arguments)
|
||||
{
|
||||
ExcludedIds = excludedIds;
|
||||
}
|
||||
}
|
||||
|
||||
private class GroupData
|
||||
{
|
||||
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
|
@ -140,6 +141,14 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
return invocationId.ToString();
|
||||
}
|
||||
|
||||
public override Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
{
|
||||
return !excludedIds.Contains(connection.ConnectionId);
|
||||
});
|
||||
}
|
||||
|
||||
private interface IHubGroupsFeature
|
||||
{
|
||||
HashSet<string> Groups { get; }
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public class HubContext<THub> : IHubContext<THub>, IHubClients where THub : Hub
|
||||
|
|
@ -20,6 +22,11 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public virtual IGroupManager Groups { get; }
|
||||
|
||||
public IClientProxy AllExcept(IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
return new AllClientsExceptProxy<THub>(_lifetimeManager, excludedIds);
|
||||
}
|
||||
|
||||
public virtual IClientProxy Client(string connectionId)
|
||||
{
|
||||
return new SingleClientProxy<THub>(_lifetimeManager, connectionId);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
|
|
@ -13,6 +14,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public abstract Task InvokeAllAsync(string methodName, object[] args);
|
||||
|
||||
public abstract Task InvokeAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds);
|
||||
|
||||
public abstract Task InvokeConnectionAsync(string connectionId, string methodName, object[] args);
|
||||
|
||||
public abstract Task InvokeGroupAsync(string groupName, string methodName, object[] args);
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
{
|
||||
public interface IHubClients
|
||||
{
|
||||
IClientProxy All { get; }
|
||||
|
||||
IClientProxy AllExcept(IReadOnlyList<string> excludedIds);
|
||||
|
||||
IClientProxy Client(string connectionId);
|
||||
|
||||
IClientProxy Group(string groupName);
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
|
|
@ -54,12 +55,28 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
}
|
||||
}
|
||||
|
||||
public class AllClientsExceptProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
private IReadOnlyList<string> _excludedIds;
|
||||
|
||||
public AllClientsExceptProxy(HubLifetimeManager<THub> lifetimeManager, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_excludedIds = excludedIds;
|
||||
}
|
||||
|
||||
public Task InvokeAsync(string method, params object[] args)
|
||||
{
|
||||
return _lifetimeManager.InvokeAllExceptAsync(method, args, _excludedIds);
|
||||
}
|
||||
}
|
||||
|
||||
public class SingleClientProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly string _connectionId;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
|
||||
public SingleClientProxy(HubLifetimeManager<THub> lifetimeManager, string connectionId)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
|
|
|
|||
|
|
@ -511,6 +511,50 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task SendToAllExcept()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
|
||||
|
||||
using (var firstClient = new TestClient())
|
||||
using (var secondClient = new TestClient())
|
||||
using (var thirdClient = new TestClient())
|
||||
{
|
||||
Task firstEndPointTask = endPoint.OnConnectedAsync(firstClient.Connection);
|
||||
Task secondEndPointTask = endPoint.OnConnectedAsync(secondClient.Connection);
|
||||
Task thirdEndPointTask = endPoint.OnConnectedAsync(thirdClient.Connection);
|
||||
|
||||
await Task.WhenAll(firstClient.Connected, secondClient.Connected, thirdClient.Connected).OrTimeout();
|
||||
|
||||
var excludeSecondClientId = new HashSet<string>();
|
||||
excludeSecondClientId.Add(secondClient.Connection.ConnectionId);
|
||||
var excludeThirdClientId = new HashSet<string>();
|
||||
excludeThirdClientId.Add(thirdClient.Connection.ConnectionId);
|
||||
|
||||
await firstClient.SendInvocationAsync("SendToAllExcept", "To second", excludeThirdClientId).OrTimeout();
|
||||
await firstClient.SendInvocationAsync("SendToAllExcept", "To third", excludeSecondClientId).OrTimeout();
|
||||
|
||||
var secondClientResult = await secondClient.ReadAsync().OrTimeout();
|
||||
var invocation = Assert.IsType<InvocationMessage>(secondClientResult);
|
||||
Assert.Equal("Send", invocation.Target);
|
||||
Assert.Equal("To second", invocation.Arguments[0]);
|
||||
|
||||
var thirdClientResult = await thirdClient.ReadAsync().OrTimeout();
|
||||
invocation = Assert.IsType<InvocationMessage>(thirdClientResult);
|
||||
Assert.Equal("Send", invocation.Target);
|
||||
Assert.Equal("To third", invocation.Arguments[0]);
|
||||
|
||||
// kill the connections
|
||||
firstClient.Dispose();
|
||||
secondClient.Dispose();
|
||||
thirdClient.Dispose();
|
||||
|
||||
await Task.WhenAll(firstEndPointTask, secondEndPointTask, thirdEndPointTask).OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(HubTypes))]
|
||||
public async Task HubsCanAddAndSendToGroup(Type hubType)
|
||||
|
|
@ -1063,6 +1107,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
public void AuthMethod()
|
||||
{
|
||||
}
|
||||
|
||||
public Task SendToAllExcept(string message, IReadOnlyList<string> excludedIds)
|
||||
{
|
||||
return Clients.AllExcept(excludedIds).InvokeAsync("Send", message);
|
||||
}
|
||||
}
|
||||
|
||||
private class InheritedHub : BaseHub
|
||||
|
|
|
|||
|
|
@ -124,7 +124,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
public async Task<string> SendInvocationAsync(string methodName, bool nonBlocking, params object[] args)
|
||||
{
|
||||
var invocationId = GetInvocationId();
|
||||
|
||||
var payload = _protocolReaderWriter.WriteMessage(new InvocationMessage(invocationId, nonBlocking, methodName, args));
|
||||
await Application.Out.WriteAsync(payload);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue