diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index cbdae52a48..2433af1c8b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Client private readonly IHubProtocol _protocol; private readonly IServiceProvider _serviceProvider; private readonly IConnectionFactory _connectionFactory; - private readonly ConcurrentDictionary> _handlers = new ConcurrentDictionary>(StringComparer.Ordinal); + private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(StringComparer.Ordinal); private bool _disposed; // Transient state to a connection @@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.SignalR.Client // It's OK to be disposed while registering a callback, we'll just never call the callback anyway (as with all the callbacks registered before disposal). var invocationHandler = new InvocationHandler(parameterTypes, handler, state); - var invocationList = _handlers.AddOrUpdate(methodName, _ => new List { invocationHandler }, + var invocationList = _handlers.AddOrUpdate(methodName, _ => new InvocationHandlerList(invocationHandler) , (_, invocations) => { lock (invocations) @@ -438,21 +438,14 @@ namespace Microsoft.AspNetCore.SignalR.Client await AwaitableThreadPool.Yield(); // Find the handler - if (!_handlers.TryGetValue(invocation.Target, out var handlers)) + if (!_handlers.TryGetValue(invocation.Target, out var invocationHandlerList)) { Log.MissingHandler(_logger, invocation.Target); return; } - // TODO: Optimize this! - // Copying the callbacks to avoid concurrency issues - InvocationHandler[] copiedHandlers; - lock (handlers) - { - copiedHandlers = new InvocationHandler[handlers.Count]; - handlers.CopyTo(copiedHandlers); - } - + // Grabbing the current handlers + var copiedHandlers = invocationHandlerList.GetHandlers(); foreach (var handler in copiedHandlers) { try @@ -793,9 +786,9 @@ namespace Microsoft.AspNetCore.SignalR.Client private class Subscription : IDisposable { private readonly InvocationHandler _handler; - private readonly List _handlerList; + private readonly InvocationHandlerList _handlerList; - public Subscription(InvocationHandler handler, List handlerList) + public Subscription(InvocationHandler handler, InvocationHandlerList handlerList) { _handler = handler; _handlerList = handlerList; @@ -803,9 +796,57 @@ namespace Microsoft.AspNetCore.SignalR.Client public void Dispose() { - lock (_handlerList) + _handlerList.Remove(_handler); + } + } + + private class InvocationHandlerList + { + private readonly List _invocationHandlers; + // A lazy cached copy of the handlers that doesn't change for thread safety. + // Adding or removing a handler sets this to null. + private InvocationHandler[] _copiedHandlers; + + internal InvocationHandlerList(InvocationHandler handler) + { + _invocationHandlers = new List() { handler }; + } + + internal InvocationHandler[] GetHandlers() + { + var handlers = _copiedHandlers; + if (handlers == null) { - _handlerList.Remove(_handler); + lock (_invocationHandlers) + { + // Check if the handlers are set, if not we'll copy them over. + if (_copiedHandlers == null) + { + _copiedHandlers = _invocationHandlers.ToArray(); + } + handlers = _copiedHandlers; + } + } + return handlers; + } + + internal void Add(InvocationHandler handler) + { + lock (_invocationHandlers) + { + _invocationHandlers.Add(handler); + _copiedHandlers = null; + } + } + + internal void Remove(InvocationHandler handler) + { + lock (_invocationHandlers) + { + if (_invocationHandlers.Remove(handler)) + { + _copiedHandlers = null; + } } } } @@ -964,21 +1005,19 @@ namespace Microsoft.AspNetCore.SignalR.Client IReadOnlyList IInvocationBinder.GetParameterTypes(string methodName) { - if (!_hubConnection._handlers.TryGetValue(methodName, out var handlers)) + if (!_hubConnection._handlers.TryGetValue(methodName, out var invocationHandlerList)) { Log.MissingHandler(_hubConnection._logger, methodName); return Type.EmptyTypes; } // We use the parameter types of the first handler - lock (handlers) + var handlers = invocationHandlerList.GetHandlers(); + if (handlers.Length > 0) { - if (handlers.Count > 0) - { - return handlers[0].ParameterTypes; - } - throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); + return handlers[0].ParameterTypes; } + throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); } } }