Handle errors in Redis subscription callbacks (#1069)
This commit is contained in:
parent
8a7f495141
commit
2419867dfc
|
|
@ -38,6 +38,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
|||
private static readonly Action<ILogger, Exception> _connectionFailed =
|
||||
LoggerMessage.Define(LogLevel.Warning, new EventId(8, nameof(ConnectionFailed)), "Connection to Redis failed.");
|
||||
|
||||
private static readonly Action<ILogger, Exception> _failedWritingMessage =
|
||||
LoggerMessage.Define(LogLevel.Warning, new EventId(9, nameof(FailedWritingMessage)), "Failed writing message.");
|
||||
|
||||
private static readonly Action<ILogger, Exception> _internalMessageFailed =
|
||||
LoggerMessage.Define(LogLevel.Warning, new EventId(10, nameof(InternalMessageFailed)), "Error processing message for internal server message.");
|
||||
|
||||
public static void ConnectingToEndpoints(this ILogger logger, EndPointCollection endpoints)
|
||||
{
|
||||
if (logger.IsEnabled(LogLevel.Information))
|
||||
|
|
@ -88,5 +94,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Internal
|
|||
{
|
||||
_connectionFailed(logger, exception);
|
||||
}
|
||||
|
||||
public static void FailedWritingMessage(this ILogger logger, Exception exception)
|
||||
{
|
||||
_failedWritingMessage(logger, exception);
|
||||
}
|
||||
|
||||
public static void InternalMessageFailed(this ILogger logger, Exception exception)
|
||||
{
|
||||
_internalMessageFailed(logger, exception);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -88,86 +88,97 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
}
|
||||
_bus = _redisServerConnection.GetSubscriber();
|
||||
|
||||
var previousBroadcastTask = Task.CompletedTask;
|
||||
|
||||
var channelName = _channelNamePrefix;
|
||||
_logger.Subscribing(channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
await previousBroadcastTask;
|
||||
|
||||
_logger.ReceivedFromChannel(channelName);
|
||||
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
try
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
_logger.ReceivedFromChannel(channelName);
|
||||
|
||||
previousBroadcastTask = Task.WhenAll(tasks);
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
|
||||
var allExceptTask = Task.CompletedTask;
|
||||
channelName = _channelNamePrefix + ".AllExcept";
|
||||
_logger.Subscribing(channelName);
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
await allExceptTask;
|
||||
|
||||
_logger.ReceivedFromChannel(channelName);
|
||||
|
||||
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
|
||||
var excludedIds = message.ExcludedIds;
|
||||
|
||||
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
try
|
||||
{
|
||||
if (!excludedIds.Contains(connection.ConnectionId))
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
}
|
||||
_logger.ReceivedFromChannel(channelName);
|
||||
|
||||
allExceptTask = Task.WhenAll(tasks);
|
||||
var message = DeserializeMessage<RedisExcludeClientsMessage>(data);
|
||||
var excludedIds = message.ExcludedIds;
|
||||
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
if (!excludedIds.Contains(connection.ConnectionId))
|
||||
{
|
||||
tasks.Add(WriteAsync(connection, message));
|
||||
}
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
|
||||
channelName = _channelNamePrefix + ".internal.group";
|
||||
_bus.Subscribe(channelName, async (c, data) =>
|
||||
{
|
||||
var groupMessage = DeserializeMessage<GroupMessage>(data);
|
||||
|
||||
var connection = _connections[groupMessage.ConnectionId];
|
||||
if (connection == null)
|
||||
try
|
||||
{
|
||||
// user not on this server
|
||||
return;
|
||||
var groupMessage = DeserializeMessage<GroupMessage>(data);
|
||||
|
||||
var connection = _connections[groupMessage.ConnectionId];
|
||||
if (connection == null)
|
||||
{
|
||||
// user not on this server
|
||||
return;
|
||||
}
|
||||
|
||||
if (groupMessage.Action == GroupAction.Remove)
|
||||
{
|
||||
await RemoveGroupAsyncCore(connection, groupMessage.Group);
|
||||
}
|
||||
|
||||
if (groupMessage.Action == GroupAction.Add)
|
||||
{
|
||||
await AddGroupAsyncCore(connection, groupMessage.Group);
|
||||
}
|
||||
|
||||
// Sending ack to server that sent the original add/remove
|
||||
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
|
||||
{
|
||||
Action = GroupAction.Ack,
|
||||
ConnectionId = groupMessage.ConnectionId,
|
||||
Group = groupMessage.Group,
|
||||
Id = groupMessage.Id
|
||||
});
|
||||
}
|
||||
|
||||
if (groupMessage.Action == GroupAction.Remove)
|
||||
catch (Exception ex)
|
||||
{
|
||||
await RemoveGroupAsyncCore(connection, groupMessage.Group);
|
||||
_logger.InternalMessageFailed(ex);
|
||||
}
|
||||
|
||||
if (groupMessage.Action == GroupAction.Add)
|
||||
{
|
||||
await AddGroupAsyncCore(connection, groupMessage.Group);
|
||||
}
|
||||
|
||||
// Sending ack to server that sent the original add/remove
|
||||
await PublishAsync($"{_channelNamePrefix}.internal.{groupMessage.Server}", new GroupMessage
|
||||
{
|
||||
Action = GroupAction.Ack,
|
||||
ConnectionId = groupMessage.ConnectionId,
|
||||
Group = groupMessage.Group,
|
||||
Id = groupMessage.Id
|
||||
});
|
||||
});
|
||||
|
||||
// Create server specific channel in order to send an ack to a single server
|
||||
|
|
@ -264,16 +275,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var connectionChannel = _channelNamePrefix + "." + connection.ConnectionId;
|
||||
redisSubscriptions.Add(connectionChannel);
|
||||
|
||||
var previousConnectionTask = Task.CompletedTask;
|
||||
|
||||
_logger.Subscribing(connectionChannel);
|
||||
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
|
||||
{
|
||||
await previousConnectionTask;
|
||||
try
|
||||
{
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
previousConnectionTask = WriteAsync(connection, message);
|
||||
await WriteAsync(connection, message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
|
||||
if (!string.IsNullOrEmpty(connection.UserIdentifier))
|
||||
|
|
@ -281,16 +295,19 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
var userChannel = _channelNamePrefix + ".user." + connection.UserIdentifier;
|
||||
redisSubscriptions.Add(userChannel);
|
||||
|
||||
var previousUserTask = Task.CompletedTask;
|
||||
|
||||
// TODO: Look at optimizing (looping over connections checking for Name)
|
||||
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
|
||||
{
|
||||
await previousUserTask;
|
||||
try
|
||||
{
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
previousUserTask = WriteAsync(connection, message);
|
||||
await WriteAsync(connection, message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -383,25 +400,25 @@ namespace Microsoft.AspNetCore.SignalR.Redis
|
|||
return;
|
||||
}
|
||||
|
||||
var previousTask = Task.CompletedTask;
|
||||
|
||||
_logger.Subscribing(groupChannel);
|
||||
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
|
||||
{
|
||||
// Since this callback is async, we await the previous task then
|
||||
// before sending the current message. This is because we don't
|
||||
// want to do concurrent writes to the outgoing connections
|
||||
await previousTask;
|
||||
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
var tasks = new List<Task>(group.Connections.Count);
|
||||
foreach (var groupConnection in group.Connections)
|
||||
try
|
||||
{
|
||||
tasks.Add(WriteAsync(groupConnection, message));
|
||||
}
|
||||
var message = DeserializeMessage<HubMessage>(data);
|
||||
|
||||
previousTask = Task.WhenAll(tasks);
|
||||
var tasks = new List<Task>(group.Connections.Count);
|
||||
foreach (var groupConnection in group.Connections)
|
||||
{
|
||||
tasks.Add(WriteAsync(groupConnection, message));
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.FailedWritingMessage(ex);
|
||||
}
|
||||
});
|
||||
}
|
||||
finally
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
|
|
@ -8,6 +10,7 @@ using Microsoft.AspNetCore.SignalR.Tests;
|
|||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||
|
|
@ -481,6 +484,93 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritingToRemoteConnectionThatFailsDoesNotThrow()
|
||||
{
|
||||
var manager1 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
var manager2 = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
// Force an exception when writing to connection
|
||||
var output = new Mock<Channel<HubMessage>>();
|
||||
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
|
||||
|
||||
var connection = new HubConnectionContext(output.Object, client.Connection);
|
||||
|
||||
await manager2.OnConnectedAsync(connection).OrTimeout();
|
||||
|
||||
// This doesn't throw because there is no connection.ConnectionId on this server so it has to publish to redis.
|
||||
// And once that happens there is no way to know if the invocation was successful or not.
|
||||
await manager1.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritingToLocalConnectionThatFailsThrowsException()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
// Force an exception when writing to connection
|
||||
var output = new Mock<Channel<HubMessage>>();
|
||||
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
|
||||
|
||||
var connection = new HubConnectionContext(output.Object, client.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection).OrTimeout();
|
||||
|
||||
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
|
||||
Assert.Equal("Message", exception.Message);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task WritingToGroupWithOneConnectionFailingSecondConnectionStillReceivesMessage()
|
||||
{
|
||||
var manager = new RedisHubLifetimeManager<MyHub>(new LoggerFactory().CreateLogger<RedisHubLifetimeManager<MyHub>>(), Options.Create(new RedisOptions()
|
||||
{
|
||||
Factory = t => new TestConnectionMultiplexer()
|
||||
}));
|
||||
|
||||
using (var client1 = new TestClient())
|
||||
using (var client2 = new TestClient())
|
||||
{
|
||||
var output2 = Channel.CreateUnbounded<HubMessage>();
|
||||
|
||||
// Force an exception when writing to connection
|
||||
var output = new Mock<Channel<HubMessage>>();
|
||||
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception());
|
||||
|
||||
var connection1 = new HubConnectionContext(output.Object, client1.Connection);
|
||||
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.AddGroupAsync(connection1.ConnectionId, "group");
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
await manager.AddGroupAsync(connection2.ConnectionId, "group");
|
||||
|
||||
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
|
||||
// connection1 will throw when receiving a group message, we are making sure other connections
|
||||
// are not affected by another connection throwing
|
||||
AssertMessage(output2);
|
||||
|
||||
// Repeat to check that group can still be sent to
|
||||
await manager.InvokeGroupAsync("group", "Hello", new object[] { "World" }).OrTimeout();
|
||||
AssertMessage(output2);
|
||||
}
|
||||
}
|
||||
|
||||
private void AssertMessage(Channel<HubMessage> channel)
|
||||
{
|
||||
Assert.True(channel.In.TryRead(out var item));
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
using System.Threading.Tasks;
|
||||
using System;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Core;
|
||||
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
|
|
@ -21,10 +24,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection1);
|
||||
await manager.OnConnectedAsync(connection2);
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
|
||||
await manager.InvokeAllAsync("Hello", new object[] { "World" });
|
||||
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
|
||||
|
||||
Assert.True(output1.In.TryRead(out var item));
|
||||
var message = item as InvocationMessage;
|
||||
|
|
@ -55,12 +58,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection1);
|
||||
await manager.OnConnectedAsync(connection2);
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
|
||||
await manager.OnDisconnectedAsync(connection2);
|
||||
await manager.OnDisconnectedAsync(connection2).OrTimeout();
|
||||
|
||||
await manager.InvokeAllAsync("Hello", new object[] { "World" });
|
||||
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
|
||||
|
||||
Assert.True(output1.In.TryRead(out var item));
|
||||
var message = item as InvocationMessage;
|
||||
|
|
@ -86,12 +89,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var connection1 = new HubConnectionContext(output1, client1.Connection);
|
||||
var connection2 = new HubConnectionContext(output2, client2.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection1);
|
||||
await manager.OnConnectedAsync(connection2);
|
||||
await manager.OnConnectedAsync(connection1).OrTimeout();
|
||||
await manager.OnConnectedAsync(connection2).OrTimeout();
|
||||
|
||||
await manager.AddGroupAsync(connection1.ConnectionId, "gunit");
|
||||
await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout();
|
||||
|
||||
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" });
|
||||
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout();
|
||||
|
||||
Assert.True(output1.In.TryRead(out var item));
|
||||
var message = item as InvocationMessage;
|
||||
|
|
@ -113,9 +116,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||
var connection = new HubConnectionContext(output, client.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection);
|
||||
await manager.OnConnectedAsync(connection).OrTimeout();
|
||||
|
||||
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" });
|
||||
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
|
||||
|
||||
Assert.True(output.In.TryRead(out var item));
|
||||
var message = item as InvocationMessage;
|
||||
|
|
@ -126,25 +129,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeConnectionAsyncThrowsIfConnectionFailsToWrite()
|
||||
{
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
// Force an exception when writing to connection
|
||||
var output = new Mock<Channel<HubMessage>>();
|
||||
output.Setup(o => o.Out.WaitToWriteAsync(It.IsAny<CancellationToken>())).Throws(new Exception("Message"));
|
||||
|
||||
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||
var connection = new HubConnectionContext(output.Object, client.Connection);
|
||||
|
||||
await manager.OnConnectedAsync(connection).OrTimeout();
|
||||
|
||||
var exception = await Assert.ThrowsAsync<Exception>(() => manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout());
|
||||
Assert.Equal("Message", exception.Message);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task InvokeConnectionAsyncOnNonExistentConnectionNoops()
|
||||
{
|
||||
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" });
|
||||
await manager.InvokeConnectionAsync("NotARealConnectionId", "Hello", new object[] { "World" }).OrTimeout();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AddGroupOnNonExistentConnectionNoops()
|
||||
{
|
||||
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup");
|
||||
await manager.AddGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RemoveGroupOnNonExistentConnectionNoops()
|
||||
{
|
||||
var manager = new DefaultHubLifetimeManager<MyHub>();
|
||||
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup");
|
||||
await manager.RemoveGroupAsync("NotARealConnectionId", "MyGroup").OrTimeout();
|
||||
}
|
||||
|
||||
private class MyHub : Hub
|
||||
|
|
|
|||
Loading…
Reference in New Issue