Check for null before sending to a specific connection id (#935)
* Check for null before sending to a specific connection id - Added some tests for the DefaultHubLifetimeManager #905
This commit is contained in:
parent
0267695656
commit
26255cc29c
|
|
@ -17,6 +17,16 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public override Task AddGroupAsync(string connectionId, string groupName)
|
public override Task AddGroupAsync(string connectionId, string groupName)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
var connection = _connections[connectionId];
|
var connection = _connections[connectionId];
|
||||||
if (connection == null)
|
if (connection == null)
|
||||||
{
|
{
|
||||||
|
|
@ -36,6 +46,16 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public override Task RemoveGroupAsync(string connectionId, string groupName)
|
public override Task RemoveGroupAsync(string connectionId, string groupName)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
var connection = _connections[connectionId];
|
var connection = _connections[connectionId];
|
||||||
if (connection == null)
|
if (connection == null)
|
||||||
{
|
{
|
||||||
|
|
@ -79,8 +99,18 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
var connection = _connections[connectionId];
|
var connection = _connections[connectionId];
|
||||||
|
|
||||||
|
if (connection == null)
|
||||||
|
{
|
||||||
|
return Task.CompletedTask;
|
||||||
|
}
|
||||||
|
|
||||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||||
|
|
||||||
return WriteAsync(connection, message);
|
return WriteAsync(connection, message);
|
||||||
|
|
@ -88,6 +118,11 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
||||||
{
|
{
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
return InvokeAllWhere(methodName, args, connection =>
|
return InvokeAllWhere(methodName, args, connection =>
|
||||||
{
|
{
|
||||||
var feature = connection.Features.Get<IHubGroupsFeature>();
|
var feature = connection.Features.Get<IHubGroupsFeature>();
|
||||||
|
|
|
||||||
|
|
@ -175,6 +175,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
|
|
||||||
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||||
|
|
||||||
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
|
return PublishAsync(_channelNamePrefix + "." + connectionId, message);
|
||||||
|
|
@ -182,6 +187,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
|
|
||||||
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
||||||
{
|
{
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
|
||||||
|
|
||||||
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
|
return PublishAsync(_channelNamePrefix + ".group." + groupName, message);
|
||||||
|
|
@ -291,6 +301,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
|
|
||||||
public override async Task AddGroupAsync(string connectionId, string groupName)
|
public override async Task AddGroupAsync(string connectionId, string groupName)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
if (await AddGroupAsyncCore(connectionId, groupName))
|
if (await AddGroupAsyncCore(connectionId, groupName))
|
||||||
{
|
{
|
||||||
// short circuit if connection is on this server
|
// short circuit if connection is on this server
|
||||||
|
|
@ -361,6 +381,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
||||||
|
|
||||||
public override async Task RemoveGroupAsync(string connectionId, string groupName)
|
public override async Task RemoveGroupAsync(string connectionId, string groupName)
|
||||||
{
|
{
|
||||||
|
if (connectionId == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(connectionId));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (groupName == null)
|
||||||
|
{
|
||||||
|
throw new ArgumentNullException(nameof(groupName));
|
||||||
|
}
|
||||||
|
|
||||||
if (await RemoveGroupAsyncCore(connectionId, groupName))
|
if (await RemoveGroupAsyncCore(connectionId, groupName))
|
||||||
{
|
{
|
||||||
// short circuit if connection is on this server
|
// short circuit if connection is on this server
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using System.Threading.Tasks.Channels;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public class DefaultHubLifetimeManagerTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeAllAsyncWritesToAllConnectionsOutput()
|
||||||
|
{
|
||||||
|
using (var client1 = new TestClient())
|
||||||
|
using (var client2 = new TestClient())
|
||||||
|
{
|
||||||
|
var output1 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
var output2 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||||
|
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||||
|
|
||||||
|
await manager.OnConnectedAsync(connection1);
|
||||||
|
await manager.OnConnectedAsync(connection2);
|
||||||
|
|
||||||
|
await manager.InvokeAllAsync("Hello", new object[] { "World" });
|
||||||
|
|
||||||
|
Assert.True(output1.In.TryRead(out var item));
|
||||||
|
var message = item as InvocationMessage;
|
||||||
|
Assert.NotNull(message);
|
||||||
|
Assert.Equal("Hello", message.Target);
|
||||||
|
Assert.Single(message.Arguments);
|
||||||
|
Assert.Equal("World", (string)message.Arguments[0]);
|
||||||
|
|
||||||
|
Assert.True(output2.In.TryRead(out item));
|
||||||
|
message = item as InvocationMessage;
|
||||||
|
Assert.NotNull(message);
|
||||||
|
Assert.Equal("Hello", message.Target);
|
||||||
|
Assert.Single(message.Arguments);
|
||||||
|
Assert.Equal("World", (string)message.Arguments[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput()
|
||||||
|
{
|
||||||
|
using (var client1 = new TestClient())
|
||||||
|
using (var client2 = new TestClient())
|
||||||
|
{
|
||||||
|
var output1 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
var output2 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||||
|
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||||
|
|
||||||
|
await manager.OnConnectedAsync(connection1);
|
||||||
|
await manager.OnConnectedAsync(connection2);
|
||||||
|
|
||||||
|
await manager.OnDisconnectedAsync(connection2);
|
||||||
|
|
||||||
|
await manager.InvokeAllAsync("Hello", new object[] { "World" });
|
||||||
|
|
||||||
|
Assert.True(output1.In.TryRead(out var item));
|
||||||
|
var message = item as InvocationMessage;
|
||||||
|
Assert.NotNull(message);
|
||||||
|
Assert.Equal("Hello", message.Target);
|
||||||
|
Assert.Single(message.Arguments);
|
||||||
|
Assert.Equal("World", (string)message.Arguments[0]);
|
||||||
|
|
||||||
|
Assert.False(output2.In.TryRead(out item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput()
|
||||||
|
{
|
||||||
|
using (var client1 = new TestClient())
|
||||||
|
using (var client2 = new TestClient())
|
||||||
|
{
|
||||||
|
var output1 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
var output2 = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||||
|
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||||
|
|
||||||
|
await manager.OnConnectedAsync(connection1);
|
||||||
|
await manager.OnConnectedAsync(connection2);
|
||||||
|
|
||||||
|
await manager.AddGroupAsync(connection1.ConnectionId, "gunit");
|
||||||
|
|
||||||
|
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" });
|
||||||
|
|
||||||
|
Assert.True(output1.In.TryRead(out var item));
|
||||||
|
var message = item as InvocationMessage;
|
||||||
|
Assert.NotNull(message);
|
||||||
|
Assert.Equal("Hello", message.Target);
|
||||||
|
Assert.Single(message.Arguments);
|
||||||
|
Assert.Equal("World", (string)message.Arguments[0]);
|
||||||
|
|
||||||
|
Assert.False(output2.In.TryRead(out item));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeConnectionAsyncWritesToConnectionOutput()
|
||||||
|
{
|
||||||
|
using (var client = new TestClient())
|
||||||
|
{
|
||||||
|
var output = Channel.CreateUnbounded<HubMessage>();
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
var connection = new HubConnectionContext(output, client.Connection);
|
||||||
|
|
||||||
|
await manager.OnConnectedAsync(connection);
|
||||||
|
|
||||||
|
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" });
|
||||||
|
|
||||||
|
Assert.True(output.In.TryRead(out var item));
|
||||||
|
var message = item as InvocationMessage;
|
||||||
|
Assert.NotNull(message);
|
||||||
|
Assert.Equal("Hello", message.Target);
|
||||||
|
Assert.Single(message.Arguments);
|
||||||
|
Assert.Equal("World", (string)message.Arguments[0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops()
|
||||||
|
{
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" });
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task AddGroupOnNonExistentConnectionNoops()
|
||||||
|
{
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup");
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task RemoveGroupOnNonExistentConnectionNoops()
|
||||||
|
{
|
||||||
|
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||||
|
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup");
|
||||||
|
}
|
||||||
|
|
||||||
|
private class MyHub : Hub
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue