Groups collection without lock

This commit is contained in:
gurgen 2017-10-07 21:38:44 +04:00 committed by Pawel Kadluczka
parent 0aea1e851b
commit af286c81bb
2 changed files with 122 additions and 38 deletions

View File

@ -14,6 +14,7 @@ namespace Microsoft.AspNetCore.SignalR
{
private long _nextInvocationId = 0;
private readonly HubConnectionList _connections = new HubConnectionList();
private readonly HubGroupList _groups = new HubGroupList();
public override Task AddGroupAsync(string connectionId, string groupName)
{
@ -33,13 +34,7 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
lock (groups)
{
groups.Add(groupName);
}
_groups.Add(connection, groupName);
return Task.CompletedTask;
}
@ -62,13 +57,7 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
lock (groups)
{
groups.Remove(groupName);
}
_groups.Remove(connectionId, groupName);
return Task.CompletedTask;
}
@ -81,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR
private Task InvokeAllWhere(string methodName, object[] args, Func<HubConnectionContext, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
var message = CreateInvocationMessage(methodName, args);
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _connections)
@ -111,7 +100,7 @@ namespace Microsoft.AspNetCore.SignalR
return Task.CompletedTask;
}
var message = new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
var message = CreateInvocationMessage(methodName, args);
return WriteAsync(connection, message);
}
@ -123,29 +112,30 @@ namespace Microsoft.AspNetCore.SignalR
throw new ArgumentNullException(nameof(groupName));
}
return InvokeAllWhere(methodName, args, connection =>
var group = _groups[groupName];
if (group != null)
{
var feature = connection.Features.Get<IHubGroupsFeature>();
var groups = feature.Groups;
var message = CreateInvocationMessage(methodName, args);
var tasks = group.Values.Select(c => WriteAsync(c, message));
return Task.WhenAll(tasks);
}
// PERF: ...
lock (groups)
{
return groups.Contains(groupName) == true;
}
});
return Task.CompletedTask;
}
private InvocationMessage CreateInvocationMessage(string methodName, object[] args)
{
return new InvocationMessage(GetInvocationId(), nonBlocking: true, target: methodName, arguments: args);
}
public override Task InvokeUserAsync(string userId, string methodName, object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
return InvokeAllWhere(methodName, args, connection =>
string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal));
}
public override Task OnConnectedAsync(HubConnectionContext connection)
{
// Set the hub groups feature
connection.Features.Set<IHubGroupsFeature>(new HubGroupsFeature());
_connections.Add(connection);
return Task.CompletedTask;
}
@ -153,6 +143,7 @@ namespace Microsoft.AspNetCore.SignalR
public override Task OnDisconnectedAsync(HubConnectionContext connection)
{
_connections.Remove(connection);
_groups.RemoveDisconnectedConnection(connection.ConnectionId);
return Task.CompletedTask;
}
@ -180,15 +171,5 @@ namespace Microsoft.AspNetCore.SignalR
return !excludedIds.Contains(connection.ConnectionId);
});
}
private interface IHubGroupsFeature
{
HashSet<string> Groups { get; }
}
private class HubGroupsFeature : IHubGroupsFeature
{
public HashSet<string> Groups { get; } = new HashSet<string>(StringComparer.OrdinalIgnoreCase);
}
}
}

View File

@ -0,0 +1,103 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.AspNetCore.SignalR
{
public class HubGroupList : IReadOnlyCollection<ConcurrentDictionary<string, HubConnectionContext>>
{
private readonly ConcurrentDictionary<string, GroupConnectionList> _groups =
new ConcurrentDictionary<string, GroupConnectionList>();
private static readonly GroupConnectionList EmptyGroupConnectionList = new GroupConnectionList();
public ConcurrentDictionary<string, HubConnectionContext> this[string groupName]
{
get
{
_groups.TryGetValue(groupName, out var group);
return group;
}
}
public void Add(HubConnectionContext connection, string groupName)
{
CreateOrUpdateGroupWithConnection(groupName, connection);
}
public void Remove(string connectionId, string groupName)
{
if (_groups.TryGetValue(groupName, out var connections))
{
if (connections.TryRemove(connectionId, out var _) && connections.IsEmpty)
{
// If group is empty after connection remove, don't need empty group in dictionary.
// Why this way? Because ICollection.Remove implementation of dictionary checks for key and value. When we remove empty group,
// it checks if no connection added from another thread.
var groupToRemove = new KeyValuePair<string, GroupConnectionList>(groupName, EmptyGroupConnectionList);
((ICollection<KeyValuePair<string, GroupConnectionList>>)(_groups)).Remove(groupToRemove);
}
}
}
public void RemoveDisconnectedConnection(string connectionId)
{
var groupNames = _groups.Where(x => x.Value.Keys.Contains(connectionId)).Select(x => x.Key);
foreach (var groupName in groupNames)
{
Remove(connectionId, groupName);
}
}
public int Count => _groups.Count;
public IEnumerator<ConcurrentDictionary<string, HubConnectionContext>> GetEnumerator()
{
return _groups.Values.GetEnumerator();
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
private void CreateOrUpdateGroupWithConnection(string groupName, HubConnectionContext connection)
{
_groups.AddOrUpdate(groupName, _ => AddConnectionToGroup(connection, new GroupConnectionList()),
(key, oldCollection) =>
{
AddConnectionToGroup(connection, oldCollection);
return oldCollection;
});
}
private static GroupConnectionList AddConnectionToGroup(
HubConnectionContext connection, GroupConnectionList group)
{
group.AddOrUpdate(connection.ConnectionId, connection, (_, __) => connection);
return group;
}
}
internal class GroupConnectionList : ConcurrentDictionary<string, HubConnectionContext>
{
public override bool Equals(object obj)
{
if (obj is ConcurrentDictionary<string, HubConnectionContext> list)
{
return list.Count == Count;
}
return false;
}
public override int GetHashCode()
{
return base.GetHashCode();
}
}
}