Handle errors in Redis subscription callbacks (#1069)

This commit is contained in:
BrennanConroy 2017-11-01 10:29:02 -07:00 committed by GitHub
parent 8a7f495141
commit 2419867dfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 246 additions and 101 deletions

View File

@ -38,6 +38,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
private static readonly Action<ILogger, Exception> _connectionFailed = private static readonly Action<ILogger, Exception> _connectionFailed =
LoggerMessage.Define(LogLevel.Warning, new EventId(8, nameof(ConnectionFailed)), "Connection to Redis failed."); LoggerMessage.Define(LogLevel.Warning, new EventId(8, nameof(ConnectionFailed)), "Connection to Redis failed.");
private static readonly Action<ILogger, Exception> _failedWritingMessage =
LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(FailedWritingMessage)), "Failed writing message.");
private static readonly Action<ILogger, Exception> _internalMessageFailed =
LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(InternalMessageFailed)), "Error processing message for internal server message.");
public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints) public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints)
{ {
if (logger.IsEnabled(LogLevel.Information)) if (logger.IsEnabled(LogLevel.Information))
@ -88,5 +94,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{ {
_connectionFailed(logger, exception); _connectionFailed(logger, exception);
} }
public static void FailedWritingMessage(this ILogger logger, Exception exception)
{
_failedWritingMessage(logger, exception);
}
public static void InternalMessageFailed(this ILogger logger, Exception exception)
{
_internalMessageFailed(logger, exception);
}
} }
} }

View File

@ -88,86 +88,97 @@ namespace Microsoft.AspNetCore.SignalR.Redis
} }
_bus = _redisServerConnection.GetSubscriber(); _bus = _redisServerConnection.GetSubscriber();
var previousBroadcastTask = Task.CompletedTask;
var channelName = _channelNamePrefix; var channelName = _channelNamePrefix;
_logger.Subscribing(channelName); _logger.Subscribing(channelName);
_bus.Subscribe(channelName, async (c, data) => _bus.Subscribe(channelName, async (c, data) =>
{ {
await previousBroadcastTask; try
_logger.ReceivedFromChannel(channelName);
var message = DeserializeMessage<HubMessage>(data);
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{ {
tasks.Add(WriteAsync(connection, message)); _logger.ReceivedFromChannel(channelName);
}
previousBroadcastTask = Task.WhenAll(tasks); var message = DeserializeMessage<HubMessage>(data);
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
tasks.Add(WriteAsync(connection, message));
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
}); });
var allExceptTask = Task.CompletedTask;
channelName = _channelNamePrefix + ".AllExcept"; channelName = _channelNamePrefix + ".AllExcept";
_logger.Subscribing(channelName); _logger.Subscribing(channelName);
_bus.Subscribe(channelName, async (c, data) => _bus.Subscribe(channelName, async (c, data) =>
{ {
await allExceptTask; try
_logger.ReceivedFromChannel(channelName);
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
var excludedIds = message.ExcludedIds;
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{ {
if (!excludedIds.Contains(connection.ConnectionId)) _logger.ReceivedFromChannel(channelName);
{
tasks.Add(WriteAsync(connection, message));
}
}
allExceptTask = Task.WhenAll(tasks); var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
var excludedIds = message.ExcludedIds;
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
if (!excludedIds.Contains(connection.ConnectionId))
{
tasks.Add(WriteAsync(connection, message));
}
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
}); });
channelName = _channelNamePrefix + ".internal.group"; channelName = _channelNamePrefix + ".internal.group";
_bus.Subscribe(channelName, async (c, data) => _bus.Subscribe(channelName, async (c, data) =>
{ {
var groupMessage = DeserializeMessage<GroupMessage>(data); try
var connection = _connections[groupMessage.ConnectionId];
if (connection == null)
{ {
// user not on this server var groupMessage = DeserializeMessage<GroupMessage>(data);
return;
var connection = _connections[groupMessage.ConnectionId];
if (connection == null)
{
// user not on this server
return;
}
if (groupMessage.Action == GroupAction.Remove)
{
await RemoveGroupAsyncCore(connection, groupMessage.Group);
}
if (groupMessage.Action == GroupAction.Add)
{
await AddGroupAsyncCore(connection, groupMessage.Group);
}
// Sending ack to server that sent the original add/remove
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
{
Action = GroupAction.Ack,
ConnectionId = groupMessage.ConnectionId,
Group = groupMessage.Group,
Id = groupMessage.Id
});
} }
catch (Exception ex)
if (groupMessage.Action == GroupAction.Remove)
{ {
await RemoveGroupAsyncCore(connection, groupMessage.Group); _logger.InternalMessageFailed(ex);
} }
if (groupMessage.Action == GroupAction.Add)
{
await AddGroupAsyncCore(connection, groupMessage.Group);
}
// Sending ack to server that sent the original add/remove
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
{
Action = GroupAction.Ack,
ConnectionId = groupMessage.ConnectionId,
Group = groupMessage.Group,
Id = groupMessage.Id
});
}); });
// Create server specific channel in order to send an ack to a single server // Create server specific channel in order to send an ack to a single server
@ -264,16 +275,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId; var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel); redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = Task.CompletedTask;
_logger.Subscribing(connectionChannel); _logger.Subscribing(connectionChannel);
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) => connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{ {
await previousConnectionTask; try
{
var message = DeserializeMessage<HubMessage>(data);
var message = DeserializeMessage<HubMessage>(data); await WriteAsync(connection, message);
}
previousConnectionTask = WriteAsync(connection, message); catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
}); });
if (!string.IsNullOrEmpty(connection.UserIdentifier)) if (!string.IsNullOrEmpty(connection.UserIdentifier))
@ -281,16 +295,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier; var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
redisSubscriptions.Add(userChannel); redisSubscriptions.Add(userChannel);
var previousUserTask = Task.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name) // TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) => userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{ {
await previousUserTask; try
{
var message = DeserializeMessage<HubMessage>(data);
var message = DeserializeMessage<HubMessage>(data); await WriteAsync(connection, message);
}
previousUserTask = WriteAsync(connection, message); catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
}); });
} }
@ -383,25 +400,25 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return; return;
} }
var previousTask = Task.CompletedTask;
_logger.Subscribing(groupChannel); _logger.Subscribing(groupChannel);
await _bus.SubscribeAsync(groupChannel, async (c, data) => await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{ {
// Since this callback is async, we await the previous task then try
// before sending the current message. This is because we don't
// want to do concurrent writes to the outgoing connections
await previousTask;
var message = DeserializeMessage<HubMessage>(data);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{ {
tasks.Add(WriteAsync(groupConnection, message)); var message = DeserializeMessage<HubMessage>(data);
}
previousTask = Task.WhenAll(tasks); var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
tasks.Add(WriteAsync(groupConnection, message));
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
}); });
} }
finally finally

View File

@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using System.Threading.Tasks.Channels; using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
@ -8,6 +10,7 @@ using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Moq;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests namespace Microsoft.AspNetCore.SignalR.Redis.Tests
@ -481,6 +484,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
} }
} }
[Fact]
public async Task WritingToRemoteConnectionThatFailsDoesNotThrow()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager2.OnConnectedAsync(connection).OrTimeout();
// This doesn't throw because there is no connection.ConnectionId on this server so it has to publish to redis.
// And once that happens there is no way to know if the invocation was successful or not.
await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
}
}
[Fact]
public async Task WritingToLocalConnectionThatFailsThrowsException()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager.OnConnectedAsync(connection).OrTimeout();
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
Assert.Equal("Message", exception.Message);
}
}
[Fact]
public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var output2 = Channel.CreateUnbounded<HubMessage>();
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
var connection1 = new HubConnectionContext(output.Object, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.AddGroupAsync(connection1.ConnectionId, "group");
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.AddGroupAsync(connection2.ConnectionId, "group");
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
// connection1 will throw when receiving a group message, we are making sure other connections
// are not affected by another connection throwing
AssertMessage(output2);
// Repeat to check that group can still be sent to
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
AssertMessage(output2);
}
}
private void AssertMessage(Channel<HubMessage> channel) private void AssertMessage(Channel<HubMessage> channel)
{ {
Assert.True(channel.In.TryRead(out var item)); Assert.True(channel.In.TryRead(out var item));

View File

@ -1,7 +1,10 @@
using System.Threading.Tasks; using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels; using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Moq;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests namespace Microsoft.AspNetCore.SignalR.Tests
@ -21,10 +24,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection); var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1); await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2); await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.InvokeAllAsync("Hello", new object[] { "World" }); await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item)); Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage; var message = item as InvocationMessage;
@ -55,12 +58,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection); var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1); await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2); await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnDisconnectedAsync(connection2); await manager.OnDisconnectedAsync(connection2).OrTimeout();
await manager.InvokeAllAsync("Hello", new object[] { "World" }); await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item)); Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage; var message = item as InvocationMessage;
@ -86,12 +89,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection); var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection); var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1); await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2); await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.AddGroupAsync(connection1.ConnectionId, "gunit"); await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout();
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }); await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item)); Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage; var message = item as InvocationMessage;
@ -113,9 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var manager = new DefaultHubLifetimeManager<MyHub>(); var manager = new DefaultHubLifetimeManager<MyHub>();
var connection = new HubConnectionContext(output, client.Connection); var connection = new HubConnectionContext(output, client.Connection);
await manager.OnConnectedAsync(connection); await manager.OnConnectedAsync(connection).OrTimeout();
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }); await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output.In.TryRead(out var item)); Assert.True(output.In.TryRead(out var item));
var message = item as InvocationMessage; var message = item as InvocationMessage;
@ -126,25 +129,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
[Fact]
public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite()
{
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
var manager = new DefaultHubLifetimeManager<MyHub>();
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager.OnConnectedAsync(connection).OrTimeout();
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
Assert.Equal("Message", exception.Message);
}
}
[Fact] [Fact]
public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops() public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops()
{ {
var manager = new DefaultHubLifetimeManager<MyHub>(); var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }); await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout();
} }
[Fact] [Fact]
public async Task AddGroupOnNonExistentConnectionNoops() public async Task AddGroupOnNonExistentConnectionNoops()
{ {
var manager = new DefaultHubLifetimeManager<MyHub>(); var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup"); await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
} }
[Fact] [Fact]
public async Task RemoveGroupOnNonExistentConnectionNoops() public async Task RemoveGroupOnNonExistentConnectionNoops()
{ {
var manager = new DefaultHubLifetimeManager<MyHub>(); var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup"); await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
} }
private class MyHub : Hub private class MyHub : Hub