From af286c81bbca00a532ad5297a8f35c9eb1310684 Mon Sep 17 00:00:00 2001 From: gurgen Date: Sat, 7 Oct 2017 21:38:44 +0400 Subject: [PATCH] Groups collection without lock --- .../DefaultHubLifetimeManager.cs | 57 ++++------ .../HubGroupList.cs | 103 ++++++++++++++++++ 2 files changed, 122 insertions(+), 38 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR.Core/HubGroupList.cs diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 43dddfa54f..fcad6dd4a6 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -14,6 +14,7 @@ namespace Microsoft.AspNetCore.SignalR { private long _nextInvocationId = 0; private readonly HubConnectionList _connections = new HubConnectionList(); + private readonly HubGroupList _groups = new HubGroupList(); public override Task AddGroupAsync(string connectionId, string groupName) { @@ -33,13 +34,7 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var feature = connection.Features.Get(); - var groups = feature.Groups; - - lock (groups) - { - groups.Add(groupName); - } + _groups.Add(connection, groupName); return Task.CompletedTask; } @@ -62,13 +57,7 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var feature = connection.Features.Get(); - var groups = feature.Groups; - - lock (groups) - { - groups.Remove(groupName); - } + _groups.Remove(connectionId, groupName); return Task.CompletedTask; } @@ -81,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR private Task InvokeAllWhere(string methodName, object[] args, Func include) { var tasks = new List(_connections.Count); - var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); + var message = CreateInvocationMessage(methodName, args); // TODO: serialize once per format by providing a different stream? foreach (var connection in _connections) @@ -111,7 +100,7 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } - var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); + var message = CreateInvocationMessage(methodName, args); return WriteAsync(connection, message); } @@ -123,29 +112,30 @@ namespace Microsoft.AspNetCore.SignalR throw new ArgumentNullException(nameof(groupName)); } - return InvokeAllWhere(methodName, args, connection => + var group = _groups[groupName]; + if (group != null) { - var feature = connection.Features.Get(); - var groups = feature.Groups; + var message = CreateInvocationMessage(methodName, args); + var tasks = group.Values.Select(c => WriteAsync(c, message)); + return Task.WhenAll(tasks); + } - // PERF: ... - lock (groups) - { - return groups.Contains(groupName) == true; - } - }); + return Task.CompletedTask; + } + + private InvocationMessage CreateInvocationMessage(string methodName, object[] args) + { + return new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args); } public override Task InvokeUserAsync(string userId, string methodName, object[] args) { - return InvokeAllWhere(methodName, args, connection => + return InvokeAllWhere(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); } public override Task OnConnectedAsync(HubConnectionContext connection) { - // Set the hub groups feature - connection.Features.Set(new HubGroupsFeature()); _connections.Add(connection); return Task.CompletedTask; } @@ -153,6 +143,7 @@ namespace Microsoft.AspNetCore.SignalR public override Task OnDisconnectedAsync(HubConnectionContext connection) { _connections.Remove(connection); + _groups.RemoveDisconnectedConnection(connection.ConnectionId); return Task.CompletedTask; } @@ -180,15 +171,5 @@ namespace Microsoft.AspNetCore.SignalR return !excludedIds.Contains(connection.ConnectionId); }); } - - private interface IHubGroupsFeature - { - HashSet Groups { get; } - } - - private class HubGroupsFeature : IHubGroupsFeature - { - public HashSet Groups { get; } = new HashSet(StringComparer.OrdinalIgnoreCase); - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubGroupList.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubGroupList.cs new file mode 100644 index 0000000000..b8b28ce051 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubGroupList.cs @@ -0,0 +1,103 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Linq; + +namespace Microsoft.AspNetCore.SignalR +{ + public class HubGroupList : IReadOnlyCollection> + { + private readonly ConcurrentDictionary _groups = + new ConcurrentDictionary(); + + private static readonly GroupConnectionList EmptyGroupConnectionList = new GroupConnectionList(); + + public ConcurrentDictionary this[string groupName] + { + get + { + _groups.TryGetValue(groupName, out var group); + return group; + } + } + + public void Add(HubConnectionContext connection, string groupName) + { + CreateOrUpdateGroupWithConnection(groupName, connection); + } + + public void Remove(string connectionId, string groupName) + { + if (_groups.TryGetValue(groupName, out var connections)) + { + if (connections.TryRemove(connectionId, out var _) && connections.IsEmpty) + { + // If group is empty after connection remove, don't need empty group in dictionary. + // Why this way? Because ICollection.Remove implementation of dictionary checks for key and value. When we remove empty group, + // it checks if no connection added from another thread. + var groupToRemove = new KeyValuePair(groupName, EmptyGroupConnectionList); + ((ICollection>)(_groups)).Remove(groupToRemove); + } + } + } + + public void RemoveDisconnectedConnection(string connectionId) + { + var groupNames = _groups.Where(x => x.Value.Keys.Contains(connectionId)).Select(x => x.Key); + foreach (var groupName in groupNames) + { + Remove(connectionId, groupName); + } + } + + public int Count => _groups.Count; + + public IEnumerator> GetEnumerator() + { + return _groups.Values.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + + private void CreateOrUpdateGroupWithConnection(string groupName, HubConnectionContext connection) + { + _groups.AddOrUpdate(groupName, _ => AddConnectionToGroup(connection, new GroupConnectionList()), + (key, oldCollection) => + { + AddConnectionToGroup(connection, oldCollection); + return oldCollection; + }); + } + + private static GroupConnectionList AddConnectionToGroup( + HubConnectionContext connection, GroupConnectionList group) + { + group.AddOrUpdate(connection.ConnectionId, connection, (_, __) => connection); + return group; + } + } + + internal class GroupConnectionList : ConcurrentDictionary + { + public override bool Equals(object obj) + { + if (obj is ConcurrentDictionary list) + { + return list.Count == Count; + } + + return false; + } + + public override int GetHashCode() + { + return base.GetHashCode(); + } + } +} \ No newline at end of file