diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubLifetimeManagerBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubLifetimeManagerBenchmark.cs new file mode 100644 index 0000000000..d6e5cc8860 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubLifetimeManagerBenchmark.cs @@ -0,0 +1,125 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks +{ + public class DefaultHubLifetimeManagerBenchmark + { + private DefaultHubLifetimeManager _hubLifetimeManager; + private List _connectionIds; + private List _subsetConnectionIds; + private List _groupNames; + private List _userIdentifiers; + + [Params(true, false)] + public bool ForceAsync { get; set; } + + [GlobalSetup] + public void GlobalSetup() + { + _hubLifetimeManager = new DefaultHubLifetimeManager(NullLogger>.Instance); + _connectionIds = new List(); + _subsetConnectionIds = new List(); + _groupNames = new List(); + _userIdentifiers = new List(); + + var jsonHubProtocol = new JsonHubProtocol(); + + for (int i = 0; i < 100; i++) + { + string connectionId = "connection-" + i; + string groupName = "group-" + i % 10; + string userIdentifier = "user-" + i % 20; + AddUnique(_connectionIds, connectionId); + AddUnique(_groupNames, groupName); + AddUnique(_userIdentifiers, userIdentifier); + if (i % 3 == 0) + { + _subsetConnectionIds.Add(connectionId); + } + + var connectionContext = new TestConnectionContext + { + ConnectionId = connectionId, + Transport = new TestDuplexPipe(ForceAsync) + }; + var hubConnectionContext = new HubConnectionContext(connectionContext, TimeSpan.Zero, NullLoggerFactory.Instance); + hubConnectionContext.UserIdentifier = userIdentifier; + hubConnectionContext.Protocol = jsonHubProtocol; + + _hubLifetimeManager.OnConnectedAsync(hubConnectionContext).GetAwaiter().GetResult(); + _hubLifetimeManager.AddGroupAsync(connectionId, groupName); + } + } + + private void AddUnique(List list, string connectionId) + { + if (!list.Contains(connectionId)) + { + list.Add(connectionId); + } + } + + [Benchmark] + public Task SendAllAsync() + { + return _hubLifetimeManager.SendAllAsync("MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendGroupAsync() + { + return _hubLifetimeManager.SendGroupAsync(_groupNames[0], "MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendGroupsAsync() + { + return _hubLifetimeManager.SendGroupsAsync(_groupNames, "MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendGroupExceptAsync() + { + return _hubLifetimeManager.SendGroupExceptAsync(_groupNames[0], "MethodName", Array.Empty(), _subsetConnectionIds); + } + + [Benchmark] + public Task SendAllExceptAsync() + { + return _hubLifetimeManager.SendAllExceptAsync("MethodName", Array.Empty(), _subsetConnectionIds); + } + + [Benchmark] + public Task SendConnectionAsync() + { + return _hubLifetimeManager.SendConnectionAsync(_connectionIds[0], "MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendConnectionsAsync() + { + return _hubLifetimeManager.SendConnectionsAsync(_subsetConnectionIds, "MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendUserAsync() + { + return _hubLifetimeManager.SendUserAsync(_userIdentifiers[0], "MethodName", Array.Empty()); + } + + [Benchmark] + public Task SendUsersAsync() + { + return _hubLifetimeManager.SendUsersAsync(_userIdentifiers, "MethodName", Array.Empty()); + } + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionContext.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionContext.cs new file mode 100644 index 0000000000..4225313743 --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionContext.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.IO.Pipelines; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.Http.Features; + +namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared +{ + public class TestConnectionContext : ConnectionContext + { + public override string ConnectionId { get; set; } + public override IFeatureCollection Features { get; } = new FeatureCollection(); + public override IDictionary Items { get; set; } + public override IDuplexPipe Transport { get; set; } + } +} \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionInherentKeepAliveFeature.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionInherentKeepAliveFeature.cs index 2b9fc07039..d91b3f8580 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionInherentKeepAliveFeature.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestConnectionInherentKeepAliveFeature.cs @@ -1,3 +1,6 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; using Microsoft.AspNetCore.Connections.Features; diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs index 91b8e96d3b..3e41887840 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestDuplexPipe.cs @@ -14,10 +14,13 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared public PipeWriter Output { get; } - public TestDuplexPipe() + public TestDuplexPipe(bool writerForceAsync = false) { _input = new TestPipeReader(); - Output = new TestPipeWriter(); + Output = new TestPipeWriter + { + ForceAsync = writerForceAsync + }; } public void AddReadResult(ValueTask readResult) diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs index 6bc98def01..0d3f89b749 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Shared/TestPipeWriter.cs @@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared // huge buffer that should be large enough for writing any content private readonly byte[] _buffer = new byte[10000]; + public bool ForceAsync { get; set; } + public override void Advance(int bytes) { } @@ -44,7 +46,17 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks.Shared public override ValueTask FlushAsync(CancellationToken cancellationToken = new CancellationToken()) { - return default; + if (!ForceAsync) + { + return default; + } + + return new ValueTask(ForceAsyncResult()); + } + + public async Task ForceAsyncResult() + { + return await Task.FromResult(default).ForceAsync(); } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs index de9c37256e..0971639f38 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs @@ -147,8 +147,8 @@ namespace Microsoft.AspNetCore.Http.Connections // Scan the registered connections looking for ones that have timed out foreach (var c in _connections) { - var status = HttpConnectionContext.ConnectionStatus.Inactive; - var lastSeenUtc = DateTimeOffset.UtcNow; + HttpConnectionContext.ConnectionStatus status; + DateTimeOffset lastSeenUtc; var connection = c.Value.Connection; try diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index ebfa3c312d..bbcd6b451b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; @@ -70,11 +71,27 @@ namespace Microsoft.AspNetCore.SignalR public override Task SendAllAsync(string methodName, object[] args) { - List tasks = null; - var message = CreateInvocationMessage(methodName, args); + return SendToAllConnections(methodName, args, null); + } + private Task SendToAllConnections(string methodName, object[] args, Func include) + { + List tasks = null; + SerializedHubMessage message = null; + + // foreach over HubConnectionStore avoids allocating an enumerator foreach (var connection in _connections) { + if (include != null && !include(connection)) + { + continue; + } + + if (message == null) + { + message = CreateSerializedInvocationMessage(methodName, args); + } + var task = connection.WriteAsync(message); if (!task.IsCompletedSuccessfully) @@ -88,7 +105,6 @@ namespace Microsoft.AspNetCore.SignalR } } - // No async if (tasks == null) { return Task.CompletedTask; @@ -98,19 +114,24 @@ namespace Microsoft.AspNetCore.SignalR return Task.WhenAll(tasks); } - private Task SendAllWhere(string methodName, object[] args, Func include) + // Tasks and message are passed by ref so they can be lazily created inside the method post-filtering, + // while still being re-usable when sending to multiple groups + private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, ref List tasks, ref SerializedHubMessage message) { - List tasks = null; - var message = CreateInvocationMessage(methodName, args); - - foreach (var connection in _connections) + // foreach over ConcurrentDictionary avoids allocating an enumerator + foreach (var connection in connections) { - if (!include(connection)) + if (include != null && !include(connection.Value)) { continue; } - var task = connection.WriteAsync(message); + if (message == null) + { + message = CreateSerializedInvocationMessage(methodName, args); + } + + var task = connection.Value.WriteAsync(message); if (!task.IsCompletedSuccessfully) { @@ -122,14 +143,6 @@ namespace Microsoft.AspNetCore.SignalR tasks.Add(task.AsTask()); } } - - if (tasks == null) - { - return Task.CompletedTask; - } - - // Some connections are slow - return Task.WhenAll(tasks); } public override Task SendConnectionAsync(string connectionId, string methodName, object[] args) @@ -146,6 +159,8 @@ namespace Microsoft.AspNetCore.SignalR return Task.CompletedTask; } + // We're sending to a single connection + // Write message directly to connection without caching it in memory var message = CreateInvocationMessage(methodName, args); return connection.WriteAsync(message).AsTask(); @@ -161,9 +176,16 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - var message = CreateInvocationMessage(methodName, args); - var tasks = group.Values.Select(c => c.WriteAsync(message).AsTask()); - return Task.WhenAll(tasks); + // Can't optimize for sending to a single connection in a group because + // group might be modified inbetween checking and sending + List tasks = null; + SerializedHubMessage message = null; + SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + + if (tasks != null) + { + return Task.WhenAll(tasks); + } } return Task.CompletedTask; @@ -172,24 +194,29 @@ namespace Microsoft.AspNetCore.SignalR public override Task SendGroupsAsync(IReadOnlyList groupNames, string methodName, object[] args) { // Each task represents the list of tasks for each of the writes within a group - var tasks = new List(); - var message = CreateInvocationMessage(methodName, args); + List tasks = null; + SerializedHubMessage message = null; foreach (var groupName in groupNames) { if (string.IsNullOrEmpty(groupName)) { - throw new ArgumentException(nameof(groupName)); + throw new InvalidOperationException("Cannot send to an empty group name."); } var group = _groups[groupName]; if (group != null) { - tasks.Add(Task.WhenAll(group.Values.Select(c => c.WriteAsync(message).AsTask()))); + SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); } } - return Task.WhenAll(tasks); + if (tasks != null) + { + return Task.WhenAll(tasks); + } + + return Task.CompletedTask; } public override Task SendGroupExceptAsync(string groupName, string methodName, object[] args, IReadOnlyList excludedIds) @@ -202,24 +229,33 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - var message = CreateInvocationMessage(methodName, args); - var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId)) - .Select(c => c.WriteAsync(message).AsTask()); - return Task.WhenAll(tasks); + List tasks = null; + SerializedHubMessage message = null; + + SendToGroupConnections(methodName, args, group, connection => !excludedIds.Contains(connection.ConnectionId), ref tasks, ref message); + + if (tasks != null) + { + return Task.WhenAll(tasks); + } } return Task.CompletedTask; } - private SerializedHubMessage CreateInvocationMessage(string methodName, object[] args) + private SerializedHubMessage CreateSerializedInvocationMessage(string methodName, object[] args) { - return new SerializedHubMessage(new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args)); + return new SerializedHubMessage(CreateInvocationMessage(methodName, args)); + } + + private HubMessage CreateInvocationMessage(string methodName, object[] args) + { + return new InvocationMessage(target: methodName, argumentBindingException: null, arguments: args); } public override Task SendUserAsync(string userId, string methodName, object[] args) { - return SendAllWhere(methodName, args, connection => - string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); + return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); } public override Task OnConnectedAsync(HubConnectionContext connection) @@ -237,26 +273,17 @@ namespace Microsoft.AspNetCore.SignalR public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedIds) { - return SendAllWhere(methodName, args, connection => - { - return !excludedIds.Contains(connection.ConnectionId); - }); + return SendToAllConnections(methodName, args, connection => !excludedIds.Contains(connection.ConnectionId)); } public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args) { - return SendAllWhere(methodName, args, connection => - { - return connectionIds.Contains(connection.ConnectionId); - }); + return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId)); } public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args) { - return SendAllWhere(methodName, args, connection => - { - return userIds.Contains(connection.UserIdentifier); - }); + return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier)); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index 02ccd24d62..57b21f1c97 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -91,6 +91,33 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task SendGroupExceptAsyncDoesNotWriteToExcludedConnections() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + + await manager.AddGroupAsync(connection1.ConnectionId, "gunit").OrTimeout(); + await manager.AddGroupAsync(connection2.ConnectionId, "gunit").OrTimeout(); + + await manager.SendGroupExceptAsync("gunit", "Hello", new object[] { "World" }, new []{ connection2.ConnectionId }).OrTimeout(); + + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + + Assert.Null(client2.TryRead()); + } + } + [Fact] public async Task SendConnectionAsyncWritesToConnectionOutput() {