Handle errors in Redis subscription callbacks (#1069)

This commit is contained in:
BrennanConroy 2017-11-01 10:29:02 -07:00 committed by GitHub
parent 8a7f495141
commit 2419867dfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 246 additions and 101 deletions

View File

@ -38,6 +38,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
private static readonly Action<ILogger, Exception> _connectionFailed =
LoggerMessage.Define(LogLevel.Warning, new EventId(8, nameof(ConnectionFailed)), "Connection to Redis failed.");
private static readonly Action<ILogger, Exception> _failedWritingMessage =
LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(FailedWritingMessage)), "Failed writing message.");
private static readonly Action<ILogger, Exception> _internalMessageFailed =
LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(InternalMessageFailed)), "Error processing message for internal server message.");
public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints)
{
if (logger.IsEnabled(LogLevel.Information))
@ -88,5 +94,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
{
_connectionFailed(logger, exception);
}
public static void FailedWritingMessage(this ILogger logger, Exception exception)
{
_failedWritingMessage(logger, exception);
}
public static void InternalMessageFailed(this ILogger logger, Exception exception)
{
_internalMessageFailed(logger, exception);
}
}
}

View File

@ -88,86 +88,97 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
_bus = _redisServerConnection.GetSubscriber();
var previousBroadcastTask = Task.CompletedTask;
var channelName = _channelNamePrefix;
_logger.Subscribing(channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
await previousBroadcastTask;
_logger.ReceivedFromChannel(channelName);
var message = DeserializeMessage<HubMessage>(data);
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
try
{
tasks.Add(WriteAsync(connection, message));
}
_logger.ReceivedFromChannel(channelName);
previousBroadcastTask = Task.WhenAll(tasks);
var message = DeserializeMessage<HubMessage>(data);
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
tasks.Add(WriteAsync(connection, message));
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
var allExceptTask = Task.CompletedTask;
channelName = _channelNamePrefix + ".AllExcept";
_logger.Subscribing(channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
await allExceptTask;
_logger.ReceivedFromChannel(channelName);
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
var excludedIds = message.ExcludedIds;
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
try
{
if (!excludedIds.Contains(connection.ConnectionId))
{
tasks.Add(WriteAsync(connection, message));
}
}
_logger.ReceivedFromChannel(channelName);
allExceptTask = Task.WhenAll(tasks);
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
var excludedIds = message.ExcludedIds;
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
if (!excludedIds.Contains(connection.ConnectionId))
{
tasks.Add(WriteAsync(connection, message));
}
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
channelName = _channelNamePrefix + ".internal.group";
_bus.Subscribe(channelName, async (c, data) =>
{
var groupMessage = DeserializeMessage<GroupMessage>(data);
var connection = _connections[groupMessage.ConnectionId];
if (connection == null)
try
{
// user not on this server
return;
var groupMessage = DeserializeMessage<GroupMessage>(data);
var connection = _connections[groupMessage.ConnectionId];
if (connection == null)
{
// user not on this server
return;
}
if (groupMessage.Action == GroupAction.Remove)
{
await RemoveGroupAsyncCore(connection, groupMessage.Group);
}
if (groupMessage.Action == GroupAction.Add)
{
await AddGroupAsyncCore(connection, groupMessage.Group);
}
// Sending ack to server that sent the original add/remove
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
{
Action = GroupAction.Ack,
ConnectionId = groupMessage.ConnectionId,
Group = groupMessage.Group,
Id = groupMessage.Id
});
}
if (groupMessage.Action == GroupAction.Remove)
catch (Exception ex)
{
await RemoveGroupAsyncCore(connection, groupMessage.Group);
_logger.InternalMessageFailed(ex);
}
if (groupMessage.Action == GroupAction.Add)
{
await AddGroupAsyncCore(connection, groupMessage.Group);
}
// Sending ack to server that sent the original add/remove
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
{
Action = GroupAction.Ack,
ConnectionId = groupMessage.ConnectionId,
Group = groupMessage.Group,
Id = groupMessage.Id
});
});
// Create server specific channel in order to send an ack to a single server
@ -264,16 +275,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = Task.CompletedTask;
_logger.Subscribing(connectionChannel);
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
await previousConnectionTask;
try
{
var message = DeserializeMessage<HubMessage>(data);
var message = DeserializeMessage<HubMessage>(data);
previousConnectionTask = WriteAsync(connection, message);
await WriteAsync(connection, message);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
if (!string.IsNullOrEmpty(connection.UserIdentifier))
@ -281,16 +295,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
redisSubscriptions.Add(userChannel);
var previousUserTask = Task.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{
await previousUserTask;
try
{
var message = DeserializeMessage<HubMessage>(data);
var message = DeserializeMessage<HubMessage>(data);
previousUserTask = WriteAsync(connection, message);
await WriteAsync(connection, message);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
}
@ -383,25 +400,25 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return;
}
var previousTask = Task.CompletedTask;
_logger.Subscribing(groupChannel);
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
// Since this callback is async, we await the previous task then
// before sending the current message. This is because we don't
// want to do concurrent writes to the outgoing connections
await previousTask;
var message = DeserializeMessage<HubMessage>(data);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
try
{
tasks.Add(WriteAsync(groupConnection, message));
}
var message = DeserializeMessage<HubMessage>(data);
previousTask = Task.WhenAll(tasks);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
tasks.Add(WriteAsync(groupConnection, message));
}
await Task.WhenAll(tasks);
}
catch (Exception ex)
{
_logger.FailedWritingMessage(ex);
}
});
}
finally

View File

@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
@ -8,6 +10,7 @@ using Microsoft.AspNetCore.SignalR.Tests;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Moq;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
@ -481,6 +484,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
}
}
[Fact]
public async Task WritingToRemoteConnectionThatFailsDoesNotThrow()
{
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager2.OnConnectedAsync(connection).OrTimeout();
// This doesn't throw because there is no connection.ConnectionId on this server so it has to publish to redis.
// And once that happens there is no way to know if the invocation was successful or not.
await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
}
}
[Fact]
public async Task WritingToLocalConnectionThatFailsThrowsException()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager.OnConnectedAsync(connection).OrTimeout();
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
Assert.Equal("Message", exception.Message);
}
}
[Fact]
public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage()
{
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
{
Factory = t => new TestConnectionMultiplexer()
}));
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var output2 = Channel.CreateUnbounded<HubMessage>();
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
var connection1 = new HubConnectionContext(output.Object, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.AddGroupAsync(connection1.ConnectionId, "group");
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.AddGroupAsync(connection2.ConnectionId, "group");
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
// connection1 will throw when receiving a group message, we are making sure other connections
// are not affected by another connection throwing
AssertMessage(output2);
// Repeat to check that group can still be sent to
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
AssertMessage(output2);
}
}
private void AssertMessage(Channel<HubMessage> channel)
{
Assert.True(channel.In.TryRead(out var item));

View File

@ -1,7 +1,10 @@
using System.Threading.Tasks;
using System;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Core;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Moq;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
@ -21,10 +24,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.InvokeAllAsync("Hello", new object[] { "World" });
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
@ -55,12 +58,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.OnDisconnectedAsync(connection2);
await manager.OnDisconnectedAsync(connection2).OrTimeout();
await manager.InvokeAllAsync("Hello", new object[] { "World" });
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
@ -86,12 +89,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connection1 = new HubConnectionContext(output1, client1.Connection);
var connection2 = new HubConnectionContext(output2, client2.Connection);
await manager.OnConnectedAsync(connection1);
await manager.OnConnectedAsync(connection2);
await manager.OnConnectedAsync(connection1).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.AddGroupAsync(connection1.ConnectionId, "gunit");
await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout();
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" });
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
@ -113,9 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var manager = new DefaultHubLifetimeManager<MyHub>();
var connection = new HubConnectionContext(output, client.Connection);
await manager.OnConnectedAsync(connection);
await manager.OnConnectedAsync(connection).OrTimeout();
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" });
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output.In.TryRead(out var item));
var message = item as InvocationMessage;
@ -126,25 +129,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite()
{
using (var client = new TestClient())
{
// Force an exception when writing to connection
var output = new Mock<Channel<HubMessage>>();
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
var manager = new DefaultHubLifetimeManager<MyHub>();
var connection = new HubConnectionContext(output.Object, client.Connection);
await manager.OnConnectedAsync(connection).OrTimeout();
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
Assert.Equal("Message", exception.Message);
}
}
[Fact]
public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops()
{
var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" });
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout();
}
[Fact]
public async Task AddGroupOnNonExistentConnectionNoops()
{
var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup");
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
}
[Fact]
public async Task RemoveGroupOnNonExistentConnectionNoops()
{
var manager = new DefaultHubLifetimeManager<MyHub>();
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup");
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
}
private class MyHub : Hub