aspnetcore/src/Microsoft.AspNetCore.Signal.../RedisHubLifetimeManager.cs

368 lines
13 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using Newtonsoft.Json;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private const string RedisSubscriptionsMetadataName = "redis_subscriptions";
private readonly HubConnectionList _connections = new HubConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly ConnectionMultiplexer _redisServerConnection;
private readonly ISubscriber _bus;
private readonly ILogger _logger;
private readonly RedisOptions _options;
// This serializer is ONLY use to transmit the data through redis, it has no connection to the serializer used on each connection.
private readonly JsonSerializer _serializer = new JsonSerializer
{
// We need to serialize objects "full-fidelity", even if it is noisy, so we preserve the original types
TypeNameAssemblyFormatHandling = TypeNameAssemblyFormatHandling.Simple,
TypeNameHandling = TypeNameHandling.All,
Formatting = Formatting.None
};
private long _nextInvocationId = 0;
public RedisHubLifetimeManager(ILogger<RedisHubLifetimeManager<THub>> logger,
IOptions<RedisOptions> options)
{
_logger = logger;
_options = options.Value;
var writer = new LoggerTextWriter(logger);
_logger.LogInformation("Connecting to redis endpoints: {endpoints}", string.Join(", ", options.Value.Options.EndPoints.Select(e => EndPointCollection.ToString(e))));
_redisServerConnection = _options.Connect(writer);
if (_redisServerConnection.IsConnected)
{
_logger.LogInformation("Connected to redis");
}
else
{
// TODO: We could support reconnecting, like old SignalR does.
throw new InvalidOperationException("Connection to redis failed.");
}
_bus = _redisServerConnection.GetSubscriber();
var previousBroadcastTask = Task.CompletedTask;
var channelName = typeof(THub).FullName;
_logger.LogInformation("Subscribing to channel: {channel}", channelName);
_bus.Subscribe(channelName, async (c, data) =>
{
await previousBroadcastTask;
_logger.LogTrace("Received message from redis channel {channel}", channelName);
var message = DeserializeMessage(data);
// TODO: This isn't going to work when we allow JsonSerializer customization or add Protobuf
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
tasks.Add(WriteAsync(connection, message));
}
previousBroadcastTask = Task.WhenAll(tasks);
});
}
public override Task InvokeAllAsync(string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName, message);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
}
private async Task PublishAsync(string channel, HubMessage hubMessage)
{
byte[] payload;
using (var stream = new MemoryStream())
using (var writer = new JsonTextWriter(new StreamWriter(stream)))
{
_serializer.Serialize(writer, hubMessage);
await writer.FlushAsync();
payload = stream.ToArray();
}
_logger.LogTrace("Publishing message to redis channel {channel}", channel);
await _bus.PublishAsync(channel, payload);
}
public override Task OnConnectedAsync(HubConnectionContext connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd(RedisSubscriptionsMetadataName, _ => new HashSet<string>());
var connectionTask = Task.CompletedTask;
var userTask = Task.CompletedTask;
_connections.Add(connection);
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = Task.CompletedTask;
_logger.LogInformation("Subscribing to connection channel: {channel}", connectionChannel);
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
await previousConnectionTask;
var message = DeserializeMessage(data);
previousConnectionTask = WriteAsync(connection, message);
});
if (connection.User.Identity.IsAuthenticated)
{
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
redisSubscriptions.Add(userChannel);
var previousUserTask = Task.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{
await previousUserTask;
var message = DeserializeMessage(data);
previousUserTask = WriteAsync(connection, message);
});
}
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(HubConnectionContext connection)
{
_connections.Remove(connection);
var tasks = new List<Task>();
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>(RedisSubscriptionsMetadataName);
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
_logger.LogInformation("Unsubscribing from channel: {channel}", subscription);
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
// Copy the groups to an array here because they get removed from this collection
// in RemoveGroupAsync
foreach (var group in groupNames.ToArray())
{
tasks.Add(RemoveGroupAsync(connection.ConnectionId, group));
}
}
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(string connectionId, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
var connection = _connections[connectionId];
if (connection == null)
{
return;
}
var groupNames = connection.Metadata.GetOrAdd(HubConnectionMetadataNames.Groups, _ => new HashSet<string>());
lock (groupNames)
{
groupNames.Add(groupName);
}
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
await group.Lock.WaitAsync();
try
{
group.Connections.Add(connection);
// Subscribe once
if (group.Connections.Count > 1)
{
return;
}
var previousTask = Task.CompletedTask;
_logger.LogInformation("Subscribing to group channel: {channel}", groupChannel);
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
// Since this callback is async, we await the previous task then
// before sending the current message. This is because we don't
// want to do concurrent writes to the outgoing connections
await previousTask;
var message = DeserializeMessage(data);
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
tasks.Add(WriteAsync(groupConnection, message));
}
previousTask = Task.WhenAll(tasks);
});
}
finally
{
group.Lock.Release();
}
}
public override async Task RemoveGroupAsync(string connectionId, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
GroupData group;
if (!_groups.TryGetValue(groupChannel, out group))
{
return;
}
var connection = _connections[connectionId];
if (connection != null)
{
return;
}
var groupNames = connection.Metadata.Get<HashSet<string>>(HubConnectionMetadataNames.Groups);
if (groupNames != null)
{
lock (groupNames)
{
groupNames.Remove(groupName);
}
}
await group.Lock.WaitAsync();
try
{
group.Connections.Remove(connection);
if (group.Connections.Count == 0)
{
_logger.LogInformation("Unsubscribing from group channel: {channel}", groupChannel);
await _bus.UnsubscribeAsync(groupChannel);
}
}
finally
{
group.Lock.Release();
}
}
public void Dispose()
{
_bus.UnsubscribeAll();
_redisServerConnection.Dispose();
}
private async Task WriteAsync(HubConnectionContext connection, HubMessage hubMessage)
{
var data = connection.Protocol.WriteToArray(hubMessage);
while (await connection.Output.WaitToWriteAsync())
{
if (connection.Output.TryWrite(data))
{
break;
}
}
}
private string GetInvocationId()
{
var invocationId = Interlocked.Increment(ref _nextInvocationId);
return invocationId.ToString();
}
private HubMessage DeserializeMessage(RedisValue data)
{
HubMessage message;
using (var reader = new JsonTextReader(new StreamReader(new MemoryStream((byte[])data))))
{
message = (HubMessage)_serializer.Deserialize(reader);
}
return message;
}
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
public LoggerTextWriter(ILogger logger)
{
_logger = logger;
}
public override Encoding Encoding => Encoding.UTF8;
public override void Write(char value)
{
}
public override void WriteLine(string value)
{
_logger.LogDebug(value);
}
}
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public HubConnectionList Connections = new HubConnectionList();
}
}
}