123 lines
3.9 KiB
C#
123 lines
3.9 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Threading.Tasks;
|
|
using Channels;
|
|
using Microsoft.AspNetCore.Sockets;
|
|
|
|
namespace Microsoft.AspNetCore.SignalR
|
|
{
|
|
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
|
{
|
|
private readonly ConnectionList _connections = new ConnectionList();
|
|
private readonly InvocationAdapterRegistry _registry;
|
|
|
|
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
|
|
{
|
|
_registry = registry;
|
|
}
|
|
|
|
public override Task AddGroupAsync(Connection connection, string groupName)
|
|
{
|
|
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
|
|
|
|
lock (groups)
|
|
{
|
|
groups.Add(groupName);
|
|
}
|
|
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
public override Task RemoveGroupAsync(Connection connection, string groupName)
|
|
{
|
|
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
|
|
|
lock (groups)
|
|
{
|
|
groups.Remove(groupName);
|
|
}
|
|
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
public override Task InvokeAllAsync(string methodName, object[] args)
|
|
{
|
|
return InvokeAllWhere(methodName, args, c => true);
|
|
}
|
|
|
|
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
|
|
{
|
|
var tasks = new List<Task>(_connections.Count);
|
|
var message = new InvocationDescriptor
|
|
{
|
|
Method = methodName,
|
|
Arguments = args
|
|
};
|
|
|
|
// TODO: serialize once per format by providing a different stream?
|
|
foreach (var connection in _connections)
|
|
{
|
|
if (!include(connection))
|
|
{
|
|
continue;
|
|
}
|
|
|
|
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
|
|
|
|
tasks.Add(invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream()));
|
|
}
|
|
|
|
return Task.WhenAll(tasks);
|
|
}
|
|
|
|
public override Task InvokeConnectionAsync(string connectionId, string methodName, object[] args)
|
|
{
|
|
var connection = _connections[connectionId];
|
|
|
|
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
|
|
|
|
var message = new InvocationDescriptor
|
|
{
|
|
Method = methodName,
|
|
Arguments = args
|
|
};
|
|
|
|
return invocationAdapter.WriteInvocationDescriptorAsync(message, connection.Channel.GetStream());
|
|
}
|
|
|
|
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
|
|
{
|
|
return InvokeAllWhere(methodName, args, connection =>
|
|
{
|
|
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
|
return groups?.Contains(groupName) == true;
|
|
});
|
|
}
|
|
|
|
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
|
|
{
|
|
return InvokeAllWhere(methodName, args, connection =>
|
|
{
|
|
return connection.User.Identity.Name == userId;
|
|
});
|
|
}
|
|
|
|
public override Task OnConnectedAsync(Connection connection)
|
|
{
|
|
_connections.Add(connection);
|
|
return Task.CompletedTask;
|
|
}
|
|
|
|
public override Task OnDisconnectedAsync(Connection connection)
|
|
{
|
|
_connections.Remove(connection);
|
|
return Task.CompletedTask;
|
|
}
|
|
}
|
|
|
|
}
|