diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs index 2c1e512e2d..63f04221a8 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs @@ -38,6 +38,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal private static readonly Action _connectionFailed = LoggerMessage.Define(LogLevel.Warning, new EventId(8, nameof(ConnectionFailed)), "Connection to Redis failed."); + private static readonly Action _failedWritingMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(FailedWritingMessage)), "Failed writing message."); + + private static readonly Action _internalMessageFailed = + LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(InternalMessageFailed)), "Error processing message for internal server message."); + public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints) { if (logger.IsEnabled(LogLevel.Information)) @@ -88,5 +94,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal { _connectionFailed(logger, exception); } + + public static void FailedWritingMessage(this ILogger logger, Exception exception) + { + _failedWritingMessage(logger, exception); + } + + public static void InternalMessageFailed(this ILogger logger, Exception exception) + { + _internalMessageFailed(logger, exception); + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index e0c6c6ae43..a498305f52 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -88,86 +88,97 @@ namespace Microsoft.AspNetCore.SignalR.Redis } _bus = _redisServerConnection.GetSubscriber(); - var previousBroadcastTask = Task.CompletedTask; - var channelName = _channelNamePrefix; _logger.Subscribing(channelName); _bus.Subscribe(channelName, async (c, data) => { - await previousBroadcastTask; - - _logger.ReceivedFromChannel(channelName); - - var message = DeserializeMessage(data); - - // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf - var tasks = new List(_connections.Count); - - foreach (var connection in _connections) + try { - tasks.Add(WriteAsync(connection, message)); - } + _logger.ReceivedFromChannel(channelName); - previousBroadcastTask = Task.WhenAll(tasks); + var message = DeserializeMessage(data); + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + tasks.Add(WriteAsync(connection, message)); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } }); - var allExceptTask = Task.CompletedTask; channelName = _channelNamePrefix + ".AllExcept"; _logger.Subscribing(channelName); _bus.Subscribe(channelName, async (c, data) => { - await allExceptTask; - - _logger.ReceivedFromChannel(channelName); - - var message = DeserializeMessage(data); - var excludedIds = message.ExcludedIds; - - // TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf - - var tasks = new List(_connections.Count); - - foreach (var connection in _connections) + try { - if (!excludedIds.Contains(connection.ConnectionId)) - { - tasks.Add(WriteAsync(connection, message)); - } - } + _logger.ReceivedFromChannel(channelName); - allExceptTask = Task.WhenAll(tasks); + var message = DeserializeMessage(data); + var excludedIds = message.ExcludedIds; + + var tasks = new List(_connections.Count); + + foreach (var connection in _connections) + { + if (!excludedIds.Contains(connection.ConnectionId)) + { + tasks.Add(WriteAsync(connection, message)); + } + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } }); channelName = _channelNamePrefix + ".internal.group"; _bus.Subscribe(channelName, async (c, data) => { - var groupMessage = DeserializeMessage(data); - - var connection = _connections[groupMessage.ConnectionId]; - if (connection == null) + try { - // user not on this server - return; + var groupMessage = DeserializeMessage(data); + + var connection = _connections[groupMessage.ConnectionId]; + if (connection == null) + { + // user not on this server + return; + } + + if (groupMessage.Action == GroupAction.Remove) + { + await RemoveGroupAsyncCore(connection, groupMessage.Group); + } + + if (groupMessage.Action == GroupAction.Add) + { + await AddGroupAsyncCore(connection, groupMessage.Group); + } + + // Sending ack to server that sent the original add/remove + await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage + { + Action = GroupAction.Ack, + ConnectionId = groupMessage.ConnectionId, + Group = groupMessage.Group, + Id = groupMessage.Id + }); } - - if (groupMessage.Action == GroupAction.Remove) + catch (Exception ex) { - await RemoveGroupAsyncCore(connection, groupMessage.Group); + _logger.InternalMessageFailed(ex); } - - if (groupMessage.Action == GroupAction.Add) - { - await AddGroupAsyncCore(connection, groupMessage.Group); - } - - // Sending ack to server that sent the original add/remove - await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage - { - Action = GroupAction.Ack, - ConnectionId = groupMessage.ConnectionId, - Group = groupMessage.Group, - Id = groupMessage.Id - }); }); // Create server specific channel in order to send an ack to a single server @@ -264,16 +275,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; redisSubscriptions.Add(connectionChannel); - var previousConnectionTask = Task.CompletedTask; - _logger.Subscribing(connectionChannel); connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) => { - await previousConnectionTask; + try + { + var message = DeserializeMessage(data); - var message = DeserializeMessage(data); - - previousConnectionTask = WriteAsync(connection, message); + await WriteAsync(connection, message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } }); if (!string.IsNullOrEmpty(connection.UserIdentifier)) @@ -281,16 +295,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; redisSubscriptions.Add(userChannel); - var previousUserTask = Task.CompletedTask; - // TODO: Look at optimizing (looping over connections checking for Name) userTask = _bus.SubscribeAsync(userChannel, async (c, data) => { - await previousUserTask; + try + { + var message = DeserializeMessage(data); - var message = DeserializeMessage(data); - - previousUserTask = WriteAsync(connection, message); + await WriteAsync(connection, message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } }); } @@ -383,25 +400,25 @@ namespace Microsoft.AspNetCore.SignalR.Redis return; } - var previousTask = Task.CompletedTask; - _logger.Subscribing(groupChannel); await _bus.SubscribeAsync(groupChannel, async (c, data) => { - // Since this callback is async, we await the previous task then - // before sending the current message. This is because we don't - // want to do concurrent writes to the outgoing connections - await previousTask; - - var message = DeserializeMessage(data); - - var tasks = new List(group.Connections.Count); - foreach (var groupConnection in group.Connections) + try { - tasks.Add(WriteAsync(groupConnection, message)); - } + var message = DeserializeMessage(data); - previousTask = Task.WhenAll(tasks); + var tasks = new List(group.Connections.Count); + foreach (var groupConnection in group.Connections) + { + tasks.Add(WriteAsync(groupConnection, message)); + } + + await Task.WhenAll(tasks); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } }); } finally diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index b8c038946c..246ea50fe1 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; @@ -8,6 +10,7 @@ using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; +using Moq; using Xunit; namespace Microsoft.AspNetCore.SignalR.Redis.Tests @@ -481,6 +484,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } } + [Fact] + public async Task WritingToRemoteConnectionThatFailsDoesNotThrow() + { + var manager1 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer() + })); + var manager2 = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer() + })); + + using (var client = new TestClient()) + { + // Force an exception when writing to connection + var output = new Mock>(); + output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); + + var connection = new HubConnectionContext(output.Object, client.Connection); + + await manager2.OnConnectedAsync(connection).OrTimeout(); + + // This doesn't throw because there is no connection.ConnectionId on this server so it has to publish to redis. + // And once that happens there is no way to know if the invocation was successful or not. + await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); + } + } + + [Fact] + public async Task WritingToLocalConnectionThatFailsThrowsException() + { + var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer() + })); + + using (var client = new TestClient()) + { + // Force an exception when writing to connection + var output = new Mock>(); + output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); + + var connection = new HubConnectionContext(output.Object, client.Connection); + + await manager.OnConnectedAsync(connection).OrTimeout(); + + var exception = await Assert.ThrowsAsync(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); + Assert.Equal("Message", exception.Message); + } + } + + [Fact] + public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage() + { + var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() + { + Factory = t => new TestConnectionMultiplexer() + })); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var output2 = Channel.CreateUnbounded(); + + // Force an exception when writing to connection + var output = new Mock>(); + output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception()); + + var connection1 = new HubConnectionContext(output.Object, client1.Connection); + var connection2 = new HubConnectionContext(output2, client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.AddGroupAsync(connection1.ConnectionId, "group"); + await manager.OnConnectedAsync(connection2).OrTimeout(); + await manager.AddGroupAsync(connection2.ConnectionId, "group"); + + await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout(); + // connection1 will throw when receiving a group message, we are making sure other connections + // are not affected by another connection throwing + AssertMessage(output2); + + // Repeat to check that group can still be sent to + await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout(); + AssertMessage(output2); + } + } + private void AssertMessage(Channel channel) { Assert.True(channel.In.TryRead(out var item)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index ecb02cb3c8..ca07869684 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -1,7 +1,10 @@ -using System.Threading.Tasks; +using System; +using System.Threading; +using System.Threading.Tasks; using System.Threading.Tasks.Channels; -using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Moq; using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests @@ -21,10 +24,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connection1 = new HubConnectionContext(output1, client1.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection); - await manager.OnConnectedAsync(connection1); - await manager.OnConnectedAsync(connection2); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); - await manager.InvokeAllAsync("Hello", new object[] { "World" }); + await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); Assert.True(output1.In.TryRead(out var item)); var message = item as InvocationMessage; @@ -55,12 +58,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connection1 = new HubConnectionContext(output1, client1.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection); - await manager.OnConnectedAsync(connection1); - await manager.OnConnectedAsync(connection2); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); - await manager.OnDisconnectedAsync(connection2); + await manager.OnDisconnectedAsync(connection2).OrTimeout(); - await manager.InvokeAllAsync("Hello", new object[] { "World" }); + await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout(); Assert.True(output1.In.TryRead(out var item)); var message = item as InvocationMessage; @@ -86,12 +89,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connection1 = new HubConnectionContext(output1, client1.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection); - await manager.OnConnectedAsync(connection1); - await manager.OnConnectedAsync(connection2); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); - await manager.AddGroupAsync(connection1.ConnectionId, "gunit"); + await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout(); - await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }); + await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout(); Assert.True(output1.In.TryRead(out var item)); var message = item as InvocationMessage; @@ -113,9 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests var manager = new DefaultHubLifetimeManager(); var connection = new HubConnectionContext(output, client.Connection); - await manager.OnConnectedAsync(connection); + await manager.OnConnectedAsync(connection).OrTimeout(); - await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }); + await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); Assert.True(output.In.TryRead(out var item)); var message = item as InvocationMessage; @@ -126,25 +129,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite() + { + using (var client = new TestClient()) + { + // Force an exception when writing to connection + var output = new Mock>(); + output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny())).Throws(new Exception("Message")); + + var manager = new DefaultHubLifetimeManager(); + var connection = new HubConnectionContext(output.Object, client.Connection); + + await manager.OnConnectedAsync(connection).OrTimeout(); + + var exception = await Assert.ThrowsAsync(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); + Assert.Equal("Message", exception.Message); + } + } + [Fact] public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops() { var manager = new DefaultHubLifetimeManager(); - await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }); + await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout(); } [Fact] public async Task AddGroupOnNonExistentConnectionNoops() { var manager = new DefaultHubLifetimeManager(); - await manager.AddGroupAsync("NotARealConnectionId", "MyGroup"); + await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); } [Fact] public async Task RemoveGroupOnNonExistentConnectionNoops() { var manager = new DefaultHubLifetimeManager(); - await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup"); + await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); } private class MyHub : Hub