aspnetcore/src/Microsoft.AspNetCore.Signal.../DefaultHubLifetimeManager.cs

290 lines
9.7 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub> where THub : Hub
{
private readonly HubConnectionStore _connections = new HubConnectionStore();
private readonly HubGroupList _groups = new HubGroupList();
private readonly ILogger _logger;
public DefaultHubLifetimeManager(ILogger<DefaultHubLifetimeManager<THub>> logger)
{
_logger = logger;
}
public override Task AddToGroupAsync(string connectionId, string groupName)
{
if (connectionId == null)
{
throw new ArgumentNullException(nameof(connectionId));
}
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var connection = _connections[connectionId];
if (connection == null)
{
return Task.CompletedTask;
}
_groups.Add(connection, groupName);
return Task.CompletedTask;
}
public override Task RemoveFromGroupAsync(string connectionId, string groupName)
{
if (connectionId == null)
{
throw new ArgumentNullException(nameof(connectionId));
}
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var connection = _connections[connectionId];
if (connection == null)
{
return Task.CompletedTask;
}
_groups.Remove(connectionId, groupName);
return Task.CompletedTask;
}
public override Task SendAllAsync(string methodName, object[] args)
{
return SendToAllConnections(methodName, args, null);
}
private Task SendToAllConnections(string methodName, object[] args, Func<HubConnectionContext, bool> include)
{
List<Task> tasks = null;
SerializedHubMessage message = null;
// foreach over HubConnectionStore avoids allocating an enumerator
foreach (var connection in _connections)
{
if (include != null && !include(connection))
{
continue;
}
if (message == null)
{
message = CreateSerializedInvocationMessage(methodName, args);
}
var task = connection.WriteAsync(message);
if (!task.IsCompletedSuccessfully)
{
if (tasks == null)
{
tasks = new List<Task>();
}
tasks.Add(task.AsTask());
}
}
if (tasks == null)
{
return Task.CompletedTask;
}
// Some connections are slow
return Task.WhenAll(tasks);
}
// Tasks and message are passed by ref so they can be lazily created inside the method post-filtering,
// while still being re-usable when sending to multiple groups
private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary<string, HubConnectionContext> connections, Func<HubConnectionContext, bool> include, ref List<Task> tasks, ref SerializedHubMessage message)
{
// foreach over ConcurrentDictionary avoids allocating an enumerator
foreach (var connection in connections)
{
if (include != null && !include(connection.Value))
{
continue;
}
if (message == null)
{
message = CreateSerializedInvocationMessage(methodName, args);
}
var task = connection.Value.WriteAsync(message);
if (!task.IsCompletedSuccessfully)
{
if (tasks == null)
{
tasks = new List<Task>();
}
tasks.Add(task.AsTask());
}
}
}
public override Task SendConnectionAsync(string connectionId, string methodName, object[] args)
{
if (connectionId == null)
{
throw new ArgumentNullException(nameof(connectionId));
}
var connection = _connections[connectionId];
if (connection == null)
{
return Task.CompletedTask;
}
// We're sending to a single connection
// Write message directly to connection without caching it in memory
var message = CreateInvocationMessage(methodName, args);
return connection.WriteAsync(message).AsTask();
}
public override Task SendGroupAsync(string groupName, string methodName, object[] args)
{
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var group = _groups[groupName];
if (group != null)
{
// Can't optimize for sending to a single connection in a group because
// group might be modified inbetween checking and sending
List<Task> tasks = null;
SerializedHubMessage message = null;
SendToGroupConnections(methodName, args, group, null, ref tasks, ref message);
if (tasks != null)
{
return Task.WhenAll(tasks);
}
}
return Task.CompletedTask;
}
public override Task SendGroupsAsync(IReadOnlyList<string> groupNames, string methodName, object[] args)
{
// Each task represents the list of tasks for each of the writes within a group
List<Task> tasks = null;
SerializedHubMessage message = null;
foreach (var groupName in groupNames)
{
if (string.IsNullOrEmpty(groupName))
{
throw new InvalidOperationException("Cannot send to an empty group name.");
}
var group = _groups[groupName];
if (group != null)
{
SendToGroupConnections(methodName, args, group, null, ref tasks, ref message);
}
}
if (tasks != null)
{
return Task.WhenAll(tasks);
}
return Task.CompletedTask;
}
public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
if (groupName == null)
{
throw new ArgumentNullException(nameof(groupName));
}
var group = _groups[groupName];
if (group != null)
{
List<Task> tasks = null;
SerializedHubMessage message = null;
SendToGroupConnections(methodName, args, group, connection => !excludedIds.Contains(connection.ConnectionId), ref tasks, ref message);
if (tasks != null)
{
return Task.WhenAll(tasks);
}
}
return Task.CompletedTask;
}
private SerializedHubMessage CreateSerializedInvocationMessage(string methodName, object[] args)
{
return new SerializedHubMessage(CreateInvocationMessage(methodName, args));
}
private HubMessage CreateInvocationMessage(string methodName, object[] args)
{
return new InvocationMessage(methodName, null, args);
}
public override Task SendUserAsync(string userId, string methodName, object[] args)
{
return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
}
public override Task OnConnectedAsync(HubConnectionContext connection)
{
_connections.Add(connection);
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(HubConnectionContext connection)
{
_connections.Remove(connection);
_groups.RemoveDisconnectedConnection(connection.ConnectionId);
return Task.CompletedTask;
}
public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList<string> excludedIds)
{
return SendToAllConnections(methodName, args, connection => !excludedIds.Contains(connection.ConnectionId));
}
public override Task SendConnectionsAsync(IReadOnlyList<string> connectionIds, string methodName, object[] args)
{
return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId));
}
public override Task SendUsersAsync(IReadOnlyList<string> userIds, string methodName, object[] args)
{
return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier));
}
}
}