aspnetcore/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs

123 lines
3.9 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Sockets;
namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly ConnectionList _connections = new ConnectionList();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
{
_registry = registry;
}
public override Task AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
lock (groups)
{
groups.Add(groupName);
}
return Task.CompletedTask;
}
public override Task RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
lock (groups)
{
groups.Remove(groupName);
}
return Task.CompletedTask;
}
public override Task InvokeAllAsync(string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _connections)
{
if (!include(connection))
{
continue;
}
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
tasks.Add(invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream()));
}
return Task.WhenAll(tasks);
}
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
{
var connection = _connections[connectionId];
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream());
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
return groups?.Contains(groupName) == true;
});
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
return connection.User.Identity.Name == userId;
});
}
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
return Task.CompletedTask;
}
}
}