From db0dc0f96019d6658b6758cfd748281fd809c317 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Fri, 16 Mar 2018 16:48:05 -0700 Subject: [PATCH] Ignore writeasync failures when sending to multiple connections (#1589) --- .../BroadcastBenchmark.cs | 3 +- .../DefaultHubDispatcherBenchmark.cs | 2 +- .../DefaultHubLifetimeManager.cs | 42 ++++++++++-- .../Internal/RedisLoggerExtensions.cs | 22 +++---- .../RedisHubLifetimeManager.cs | 66 +++++++------------ .../RedisHubLifetimeManagerTests.cs | 5 +- .../DefaultHubLifetimeManagerTests.cs | 65 +++++++++++++----- 7 files changed, 123 insertions(+), 82 deletions(-) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 042b5034ed..8133b10e55 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -7,6 +7,7 @@ using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks @@ -25,7 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [GlobalSetup] public void GlobalSetup() { - _hubLifetimeManager = new DefaultHubLifetimeManager(); + _hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); IHubProtocol protocol; diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 5799673af7..ca1ef0f128 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _dispatcher = new DefaultHubDispatcher( serviceScopeFactory, - new HubContext(new DefaultHubLifetimeManager()), + new HubContext(new DefaultHubLifetimeManager(NullLogger>.Instance)), new Logger>(NullLoggerFactory.Instance)); var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 67cf01085f..03f07a5b6d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { @@ -13,6 +14,12 @@ namespace Microsoft.AspNetCore.SignalR { private readonly HubConnectionList _connections = new HubConnectionList(); private readonly HubGroupList _groups = new HubGroupList(); + private readonly ILogger _logger; + + public DefaultHubLifetimeManager(ILogger> logger) + { + _logger = logger; + } public override Task AddGroupAsync(string connectionId, string groupName) { @@ -83,7 +90,7 @@ namespace Microsoft.AspNetCore.SignalR continue; } - tasks.Add(connection.WriteAsync(message)); + tasks.Add(SafeWriteAsync(connection, message)); } return Task.WhenAll(tasks); @@ -105,7 +112,7 @@ namespace Microsoft.AspNetCore.SignalR var message = CreateInvocationMessage(methodName, args); - return connection.WriteAsync(message); + return SafeWriteAsync(connection, message); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) @@ -119,7 +126,7 @@ namespace Microsoft.AspNetCore.SignalR if (group != null) { var message = CreateInvocationMessage(methodName, args); - var tasks = group.Values.Select(c => c.WriteAsync(message)); + var tasks = group.Values.Select(c => SafeWriteAsync(c, message)); return Task.WhenAll(tasks); } @@ -142,7 +149,7 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - tasks.Add(Task.WhenAll(group.Values.Select(c => c.WriteAsync(message)))); + tasks.Add(Task.WhenAll(group.Values.Select(c => SafeWriteAsync(c, message)))); } } @@ -161,7 +168,7 @@ namespace Microsoft.AspNetCore.SignalR { var message = CreateInvocationMessage(methodName, args); var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId)) - .Select(c => c.WriteAsync(message)); + .Select(c => SafeWriteAsync(c, message)); return Task.WhenAll(tasks); } @@ -215,5 +222,30 @@ namespace Microsoft.AspNetCore.SignalR return userIds.Contains(connection.UserIdentifier); }); } + + // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to + private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) + { + try + { + await connection.WriteAsync(message); + } + // This exception isn't interesting to users + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); + } + } + + private static class Log + { + private static readonly Action _failedWritingMessage = + LoggerMessage.Define(LogLevel.Warning, new EventId(1, "FailedWritingMessage"), "Failed writing message."); + + public static void FailedWritingMessage(ILogger logger, Exception exception) + { + _failedWritingMessage(logger, exception); + } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs index 19e9791072..e8862cf0e1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/Internal/RedisLoggerExtensions.cs @@ -12,37 +12,37 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal { // Category: RedisHubLifetimeManager private static readonly Action _connectingToEndpoints = - LoggerMessage.Define(LogLevel.Information, new EventId(1, nameof(ConnectingToEndpoints)), "Connecting to Redis endpoints: {Endpoints}."); + LoggerMessage.Define(LogLevel.Information, new EventId(1, "ConnectingToEndpoints"), "Connecting to Redis endpoints: {Endpoints}."); private static readonly Action _connected = - LoggerMessage.Define(LogLevel.Information, new EventId(2, nameof(Connected)), "Connected to Redis."); + LoggerMessage.Define(LogLevel.Information, new EventId(2, "Connected"), "Connected to Redis."); private static readonly Action _subscribing = - LoggerMessage.Define(LogLevel.Trace, new EventId(3, nameof(Subscribing)), "Subscribing to channel: {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "Subscribing"), "Subscribing to channel: {Channel}."); private static readonly Action _receivedFromChannel = - LoggerMessage.Define(LogLevel.Trace, new EventId(4, nameof(ReceivedFromChannel)), "Received message from Redis channel {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "ReceivedFromChannel"), "Received message from Redis channel {Channel}."); private static readonly Action _publishToChannel = - LoggerMessage.Define(LogLevel.Trace, new EventId(5, nameof(PublishToChannel)), "Publishing message to Redis channel {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "PublishToChannel"), "Publishing message to Redis channel {Channel}."); private static readonly Action _unsubscribe = - LoggerMessage.Define(LogLevel.Trace, new EventId(6, nameof(Unsubscribe)), "Unsubscribing from channel: {Channel}."); + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "Unsubscribe"), "Unsubscribing from channel: {Channel}."); private static readonly Action _notConnected = - LoggerMessage.Define(LogLevel.Warning, new EventId(7, nameof(Connected)), "Not connected to Redis."); + LoggerMessage.Define(LogLevel.Warning, new EventId(7, "Connected"), "Not connected to Redis."); private static readonly Action _connectionRestored = - LoggerMessage.Define(LogLevel.Information, new EventId(8, nameof(ConnectionRestored)), "Connection to Redis restored."); + LoggerMessage.Define(LogLevel.Information, new EventId(8, "ConnectionRestored"), "Connection to Redis restored."); private static readonly Action _connectionFailed = - LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(ConnectionFailed)), "Connection to Redis failed."); + LoggerMessage.Define(LogLevel.Warning, new EventId(9, "ConnectionFailed"), "Connection to Redis failed."); private static readonly Action _failedWritingMessage = - LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(FailedWritingMessage)), "Failed writing message."); + LoggerMessage.Define(LogLevel.Warning, new EventId(10, "FailedWritingMessage"), "Failed writing message."); private static readonly Action _internalMessageFailed = - LoggerMessage.Define(LogLevel.Warning, new EventId(11, nameof(InternalMessageFailed)), "Error processing message for internal server message."); + LoggerMessage.Define(LogLevel.Warning, new EventId(11, "InternalMessageFailed"), "Error processing message for internal server message."); public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints) { diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 13d0d4d819..125924eeb0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -173,7 +173,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis var connection = _connections[connectionId]; if (connection != null) { - return connection.WriteAsync(message.CreateInvocation()); + return SafeWriteAsync(connection, message.CreateInvocation()); } return PublishAsync(_channelNamePrefix + "." + connectionId, message); @@ -402,14 +402,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis var invocation = message.CreateInvocation(); foreach (var connection in _connections) { - try - { - tasks.Add(connection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(connection, invocation)); } await Task.WhenAll(tasks); @@ -441,14 +434,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { if (!excludedIds.Contains(connection.ConnectionId)) { - try - { - tasks.Add(connection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(connection, invocation)); } } @@ -524,16 +510,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis _logger.Subscribing(connectionChannel); return _bus.SubscribeAsync(connectionChannel, async (c, data) => { - try - { - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); - await connection.WriteAsync(message.CreateInvocation()); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + await SafeWriteAsync(connection, message.CreateInvocation()); }); } @@ -545,16 +524,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis // TODO: Look at optimizing (looping over connections checking for Name) return _bus.SubscribeAsync(userChannel, async (c, data) => { - try - { - var message = DeserializeMessage(data); + var message = DeserializeMessage(data); - await connection.WriteAsync(message.CreateInvocation()); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + await SafeWriteAsync(connection, message.CreateInvocation()); }); } @@ -576,14 +548,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis continue; } - try - { - tasks.Add(groupConnection.WriteAsync(invocation)); - } - catch (Exception ex) - { - _logger.FailedWritingMessage(ex); - } + tasks.Add(SafeWriteAsync(groupConnection, invocation)); } await Task.WhenAll(tasks); @@ -611,7 +576,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis // This also saves serializing and deserializing the message! if (connection != null) { - publishTasks.Add(connection.WriteAsync(message.CreateInvocation())); + publishTasks.Add(SafeWriteAsync(connection, message.CreateInvocation())); } else { @@ -662,6 +627,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.CompletedTask; } + // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to + private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) + { + try + { + await connection.WriteAsync(message); + } + catch (Exception ex) + { + _logger.FailedWritingMessage(ex); + } + } + private class LoggerTextWriter : TextWriter { private readonly ILogger _logger; diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs index 7f42a38095..7a6aeca22b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisHubLifetimeManagerTests.cs @@ -503,7 +503,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests } [Fact] - public async Task WritingToLocalConnectionThatFailsThrowsException() + public async Task WritingToLocalConnectionThatFailsDoesNotThrowException() { var manager = new RedisHubLifetimeManager(new LoggerFactory().CreateLogger>(), Options.Create(new RedisOptions() { @@ -519,8 +519,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests await manager.OnConnectedAsync(connection).OrTimeout(); - var exception = await Assert.ThrowsAsync(() => manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); - Assert.Equal("Message", exception.Message); + await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index f298fc24eb..dd87b3527f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -3,6 +3,8 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Moq; using Xunit; @@ -11,12 +13,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests public class DefaultHubLifetimeManagerTests { [Fact] - public async Task InvokeAllAsyncWritesToAllConnectionsOutput() + public async Task SendAllAsyncWritesToAllConnectionsOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -38,12 +40,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() + public async Task SendAllAsyncDoesNotWriteToDisconnectedConnectionsOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -64,12 +66,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeGroupAsyncWritesToAllConnectionsInGroupOutput() + public async Task SendGroupAsyncWritesToAllConnectionsInGroupOutput() { using (var client1 = new TestClient()) using (var client2 = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection1 = HubConnectionContextUtils.Create(client1.Connection); var connection2 = HubConnectionContextUtils.Create(client2.Connection); @@ -90,11 +92,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeConnectionAsyncWritesToConnectionOutput() + public async Task SendConnectionAsyncWritesToConnectionOutput() { using (var client = new TestClient()) { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connection = HubConnectionContextUtils.Create(client.Connection); await manager.OnConnectedAsync(connection).OrTimeout(); @@ -109,42 +111,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite() + public async Task SendConnectionAsyncDoesNotThrowIfConnectionFailsToWrite() { using (var client = new TestClient()) { - // Force an exception when writing to connection - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + // Force an exception when writing to connection connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); var connection = connectionMock.Object; await manager.OnConnectedAsync(connection).OrTimeout(); - var exception = await Assert.ThrowsAsync(() => manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout()); - Assert.Equal("Message", exception.Message); + await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); } } [Fact] - public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops() + public async Task SendAllAsyncSendsToAllConnectionsEvenWhenSomeFailToSend() { - var manager = new DefaultHubLifetimeManager(); + using (var client = new TestClient()) + using (var client2 = new TestClient()) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + + var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); + var connectionMock2 = HubConnectionContextUtils.CreateMock(client2.Connection); + + var tcs = new TaskCompletionSource(); + var tcs2 = new TaskCompletionSource(); + // Force an exception when writing to connection + connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs.TrySetResult(null)).Throws(new Exception("Message")); + connectionMock2.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs2.TrySetResult(null)).Throws(new Exception("Message")); + var connection = connectionMock.Object; + var connection2 = connectionMock2.Object; + + await manager.OnConnectedAsync(connection).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout(); + + // Check that all connections were "written" to + await tcs.Task.OrTimeout(); + await tcs2.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendConnectionAsyncOnNonExistentConnectionNoops() + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.SendConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout(); } [Fact] public async Task AddGroupOnNonExistentConnectionNoops() { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); } [Fact] public async Task RemoveGroupOnNonExistentConnectionNoops() { - var manager = new DefaultHubLifetimeManager(); + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout(); }