// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.AspNetCore.SignalR.Redis; namespace ChatSample { public class DefaultPresenceHublifetimeManager : PresenceHubLifetimeManager> where THub : HubWithPresence { public DefaultPresenceHublifetimeManager(IUserTracker userTracker, IServiceScopeFactory serviceScopeFactory, ILoggerFactory loggerFactory, IServiceProvider serviceProvider) : base(userTracker, serviceScopeFactory, loggerFactory, serviceProvider) { } } public class RedisPresenceHublifetimeManager : PresenceHubLifetimeManager> where THub : HubWithPresence { public RedisPresenceHublifetimeManager(IUserTracker userTracker, IServiceScopeFactory serviceScopeFactory, ILoggerFactory loggerFactory, IServiceProvider serviceProvider) : base(userTracker, serviceScopeFactory, loggerFactory, serviceProvider) { } } public class PresenceHubLifetimeManager : HubLifetimeManager, IDisposable where THubLifetimeManager : HubLifetimeManager where THub : HubWithPresence { private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly IUserTracker _userTracker; private readonly IServiceScopeFactory _serviceScopeFactory; private readonly ILogger _logger; private readonly IServiceProvider _serviceProvider; private readonly HubLifetimeManager _wrappedHubLifetimeManager; private IHubContext _hubContext; public PresenceHubLifetimeManager(IUserTracker userTracker, IServiceScopeFactory serviceScopeFactory, ILoggerFactory loggerFactory, IServiceProvider serviceProvider) { _userTracker = userTracker; _userTracker.UsersJoined += OnUsersJoined; _userTracker.UsersLeft += OnUsersLeft; _serviceScopeFactory = serviceScopeFactory; _serviceProvider = serviceProvider; _logger = loggerFactory.CreateLogger>(); _wrappedHubLifetimeManager = serviceProvider.GetRequiredService(); } public override async Task OnConnectedAsync(HubConnectionContext connection) { await _wrappedHubLifetimeManager.OnConnectedAsync(connection); _connections.Add(connection); await _userTracker.AddUser(connection, new UserDetails(connection.ConnectionId, connection.User.Identity.Name)); } public override async Task OnDisconnectedAsync(HubConnectionContext connection) { await _wrappedHubLifetimeManager.OnDisconnectedAsync(connection); _connections.Remove(connection); await _userTracker.RemoveUser(connection); } private async void OnUsersJoined(UserDetails[] users) { await Notify(hub => { if (users.Length == 1) { if (users[0].ConnectionId != hub.Context.ConnectionId) { return hub.OnUsersJoined(users); } } else { return hub.OnUsersJoined( users.Where(u => u.ConnectionId != hub.Context.ConnectionId).ToArray()); } return Task.CompletedTask; }); } private async void OnUsersLeft(UserDetails[] users) { await Notify(hub => hub.OnUsersLeft(users)); } private async Task Notify(Func invocation) { foreach (var connection in _connections) { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); if (_hubContext == null) { // Cannot be injected due to circular dependency _hubContext = _serviceProvider.GetRequiredService>(); } hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); hub.Context = new DefaultHubCallerContext(connection); hub.Groups = _hubContext.Groups; try { await invocation(hub); } catch (Exception ex) { _logger.LogWarning(ex, "Presence notification failed."); } finally { hubActivator.Release(hub); } } } } public void Dispose() { _userTracker.UsersJoined -= OnUsersJoined; _userTracker.UsersLeft -= OnUsersLeft; } public override Task SendAllAsync(string methodName, object[] args) { return _wrappedHubLifetimeManager.SendAllAsync(methodName, args); } public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) { return _wrappedHubLifetimeManager.SendAllExceptAsync(methodName, args, excludedIds); } public override Task SendConnectionAsync(string connectionId, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendConnectionAsync(connectionId, methodName, args); } public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendConnectionsAsync(connectionIds, methodName, args); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendGroupAsync(groupName, methodName, args); } public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendGroupsAsync(groupNames, methodName, args); } public override Task SendUserAsync(string userId, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendUserAsync(userId, methodName, args); } public override Task AddGroupAsync(string connectionId, string groupName) { return _wrappedHubLifetimeManager.AddGroupAsync(connectionId, groupName); } public override Task RemoveGroupAsync(string connectionId, string groupName) { return _wrappedHubLifetimeManager.RemoveGroupAsync(connectionId, groupName); } public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) { return _wrappedHubLifetimeManager.SendGroupExceptAsync(groupName, methodName, args, excludedIds); } public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args) { return _wrappedHubLifetimeManager.SendUsersAsync(userIds, methodName, args); } } }