aspnetcore/src/Microsoft.AspNetCore.Signal.../RedisHubLifetimeManager.cs

307 lines
10 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using StackExchange.Redis;
namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private readonly ConnectionList _connections = new ConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly InvocationAdapterRegistry _registry;
private readonly ConnectionMultiplexer _redisServerConnection;
private readonly ISubscriber _bus;
private readonly ILoggerFactory _loggerFactory;
private readonly RedisOptions _options;
public RedisHubLifetimeManager(InvocationAdapterRegistry registry,
ILoggerFactory loggerFactory,
IOptions<RedisOptions> options)
{
_loggerFactory = loggerFactory;
_registry = registry;
_options = options.Value;
var writer = new LoggerTextWriter(loggerFactory.CreateLogger<RedisHubLifetimeManager<THub>>());
_redisServerConnection = _options.Connect(writer);
_bus = _redisServerConnection.GetSubscriber();
var previousBroadcastTask = TaskCache.CompletedTask;
_bus.Subscribe(typeof(THub).FullName, async (c, data) =>
{
await previousBroadcastTask;
var tasks = new List<Task>(_connections.Count);
foreach (var connection in _connections)
{
tasks.Add(connection.Channel.Output.WriteAsync((byte[])data));
}
previousBroadcastTask = Task.WhenAll(tasks);
});
}
public override Task InvokeAllAsync(string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return PublishAsync(typeof(THub).FullName, message);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return PublishAsync(typeof(THub).FullName + "." + connectionId, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return PublishAsync(typeof(THub).FullName + ".group." + groupName, message);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return PublishAsync(typeof(THub).FullName + ".user." + userId, message);
}
private async Task PublishAsync(string channel, InvocationDescriptor message)
{
// TODO: What format??
var invocationAdapter = _registry.GetInvocationAdapter("json");
// BAD
using (var ms = new MemoryStream())
{
await invocationAdapter.WriteMessageAsync(message, ms);
await _bus.PublishAsync(channel, ms.ToArray());
}
}
public override Task OnConnectedAsync(Connection connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
var connectionTask = TaskCache.CompletedTask;
var userTask = TaskCache.CompletedTask;
_connections.Add(connection);
var connectionChannel = typeof(THub).FullName + "." + connection.ConnectionId;
redisSubscriptions.Add(connectionChannel);
var previousConnectionTask = TaskCache.CompletedTask;
connectionTask = _bus.SubscribeAsync(connectionChannel, async (c, data) =>
{
await previousConnectionTask;
previousConnectionTask = connection.Channel.Output.WriteAsync((byte[])data);
});
if (connection.User.Identity.IsAuthenticated)
{
var userChannel = typeof(THub).FullName + ".user." + connection.User.Identity.Name;
redisSubscriptions.Add(userChannel);
var previousUserTask = TaskCache.CompletedTask;
// TODO: Look at optimizing (looping over connections checking for Name)
userTask = _bus.SubscribeAsync(userChannel, async (c, data) =>
{
await previousUserTask;
previousUserTask = connection.Channel.Output.WriteAsync((byte[])data);
});
}
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
var tasks = new List<Task>();
var redisSubscriptions = connection.Metadata.Get<HashSet<string>>("redis_subscriptions");
if (redisSubscriptions != null)
{
foreach (var subscription in redisSubscriptions)
{
tasks.Add(_bus.UnsubscribeAsync(subscription));
}
}
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
if (groupNames != null)
{
// Copy the groups to an array here because they get removed from this collection
// in RemoveGroupAsync
foreach (var group in groupNames.ToArray())
{
tasks.Add(RemoveGroupAsync(connection, group));
}
}
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
var groupNames = connection.Metadata.GetOrAdd("group", _ => new HashSet<string>());
lock (groupNames)
{
groupNames.Add(groupName);
}
var group = _groups.GetOrAdd(groupChannel, _ => new GroupData());
await group.Lock.WaitAsync();
try
{
group.Connections.Add(connection);
// Subscribe once
if (group.Connections.Count > 1)
{
return;
}
var previousTask = TaskCache.CompletedTask;
await _bus.SubscribeAsync(groupChannel, async (c, data) =>
{
// Since this callback is async, we await the previous task then
// before sending the current message. This is because we don't
// want to do concurrent writes to the outgoing connections
await previousTask;
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections)
{
tasks.Add(groupConnection.Channel.Output.WriteAsync((byte[])data));
}
previousTask = Task.WhenAll(tasks);
});
}
finally
{
group.Lock.Release();
}
}
public override async Task RemoveGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
GroupData group;
if (!_groups.TryGetValue(groupChannel, out group))
{
return;
}
var groupNames = connection.Metadata.Get<HashSet<string>>("group");
if (groupNames != null)
{
lock (groupNames)
{
groupNames.Remove(groupName);
}
}
await group.Lock.WaitAsync();
try
{
group.Connections.Remove(connection);
if (group.Connections.Count == 0)
{
await _bus.UnsubscribeAsync(groupChannel);
}
}
finally
{
group.Lock.Release();
}
}
public void Dispose()
{
_bus.UnsubscribeAll();
_redisServerConnection.Dispose();
}
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
public LoggerTextWriter(ILogger logger)
{
_logger = logger;
}
public override Encoding Encoding => Encoding.UTF8;
public override void Write(char value)
{
}
public override void WriteLine(string value)
{
_logger.LogDebug(value);
}
}
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public ConnectionList Connections = new ConnectionList();
}
}
}