diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 9dadfc8986..ab5cbd926b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -17,6 +17,16 @@ namespace Microsoft.AspNetCore.SignalR public override Task AddGroupAsync(string connectionId, string groupName) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + var connection = _connections[connectionId]; if (connection == null) { @@ -36,6 +46,16 @@ namespace Microsoft.AspNetCore.SignalR public override Task RemoveGroupAsync(string connectionId, string groupName) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + var connection = _connections[connectionId]; if (connection == null) { @@ -79,8 +99,18 @@ namespace Microsoft.AspNetCore.SignalR public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + var connection = _connections[connectionId]; + if (connection == null) + { + return Task.CompletedTask; + } + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return WriteAsync(connection, message); @@ -88,6 +118,11 @@ namespace Microsoft.AspNetCore.SignalR public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + return InvokeAllWhere(methodName, args, connection => { var feature = connection.Features.Get(); diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index d6197dc5ea..b6185bbcab 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -175,6 +175,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(_channelNamePrefix + "." + connectionId, message); @@ -182,6 +187,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) { + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); return PublishAsync(_channelNamePrefix + ".group." + groupName, message); @@ -291,6 +301,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override async Task AddGroupAsync(string connectionId, string groupName) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + if (await AddGroupAsyncCore(connectionId, groupName)) { // short circuit if connection is on this server @@ -361,6 +381,16 @@ namespace Microsoft.AspNetCore.SignalR.Redis public override async Task RemoveGroupAsync(string connectionId, string groupName) { + if (connectionId == null) + { + throw new ArgumentNullException(nameof(connectionId)); + } + + if (groupName == null) + { + throw new ArgumentNullException(nameof(groupName)); + } + if (await RemoveGroupAsyncCore(connectionId, groupName)) { // short circuit if connection is on this server diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs new file mode 100644 index 0000000000..68acba4366 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -0,0 +1,154 @@ +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class DefaultHubLifetimeManagerTests + { + [Fact] + public async Task InvokeAllAsyncWritesToAllConnectionsOutput() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var output1 = Channel.CreateUnbounded(); + var output2 = Channel.CreateUnbounded(); + + var manager = new DefaultHubLifetimeManager(); + var connection1 = new HubConnectionContext(output1, client1.Connection); + var connection2 = new HubConnectionContext(output2, client2.Connection); + + await manager.OnConnectedAsync(connection1); + await manager.OnConnectedAsync(connection2); + + await manager.InvokeAllAsync("Hello", new object[] { "World" }); + + Assert.True(output1.In.TryRead(out var item)); + var message = item as InvocationMessage; + Assert.NotNull(message); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + Assert.True(output2.In.TryRead(out item)); + message = item as InvocationMessage; + Assert.NotNull(message); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + } + } + + [Fact] + public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var output1 = Channel.CreateUnbounded(); + var output2 = Channel.CreateUnbounded(); + + var manager = new DefaultHubLifetimeManager(); + var connection1 = new HubConnectionContext(output1, client1.Connection); + var connection2 = new HubConnectionContext(output2, client2.Connection); + + await manager.OnConnectedAsync(connection1); + await manager.OnConnectedAsync(connection2); + + await manager.OnDisconnectedAsync(connection2); + + await manager.InvokeAllAsync("Hello", new object[] { "World" }); + + Assert.True(output1.In.TryRead(out var item)); + var message = item as InvocationMessage; + Assert.NotNull(message); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + Assert.False(output2.In.TryRead(out item)); + } + } + + [Fact] + public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var output1 = Channel.CreateUnbounded(); + var output2 = Channel.CreateUnbounded(); + + var manager = new DefaultHubLifetimeManager(); + var connection1 = new HubConnectionContext(output1, client1.Connection); + var connection2 = new HubConnectionContext(output2, client2.Connection); + + await manager.OnConnectedAsync(connection1); + await manager.OnConnectedAsync(connection2); + + await manager.AddGroupAsync(connection1.ConnectionId, "gunit"); + + await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }); + + Assert.True(output1.In.TryRead(out var item)); + var message = item as InvocationMessage; + Assert.NotNull(message); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + Assert.False(output2.In.TryRead(out item)); + } + } + + [Fact] + public async Task InvokeConnectionAsyncWritesToConnectionOutput() + { + using (var client = new TestClient()) + { + var output = Channel.CreateUnbounded(); + var manager = new DefaultHubLifetimeManager(); + var connection = new HubConnectionContext(output, client.Connection); + + await manager.OnConnectedAsync(connection); + + await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }); + + Assert.True(output.In.TryRead(out var item)); + var message = item as InvocationMessage; + Assert.NotNull(message); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + } + } + + [Fact] + public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops() + { + var manager = new DefaultHubLifetimeManager(); + await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }); + } + + [Fact] + public async Task AddGroupOnNonExistentConnectionNoops() + { + var manager = new DefaultHubLifetimeManager(); + await manager.AddGroupAsync("NotARealConnectionId", "MyGroup"); + } + + [Fact] + public async Task RemoveGroupOnNonExistentConnectionNoops() + { + var manager = new DefaultHubLifetimeManager(); + await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup"); + } + + private class MyHub : Hub + { + + } + } +}