diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 61a3aceb55..576e69326f 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -539,7 +539,7 @@ namespace Microsoft.AspNetCore.SignalR.Client /// public IAsyncEnumerable StreamAsyncCore(string methodName, object[] args, CancellationToken cancellationToken = default) { - var cts = cancellationToken.CanBeCanceled ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken) : new CancellationTokenSource(); + var cts = cancellationToken.CanBeCanceled ? CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, default) : new CancellationTokenSource(); var stream = CastIAsyncEnumerable(methodName, args, cts); var cancelableStream = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, cts); return cancelableStream; diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 3c835ab933..a8a9fe2095 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -85,7 +85,7 @@ namespace Microsoft.AspNetCore.SignalR return SendToAllConnections(methodName, args, null); } - private Task SendToAllConnections(string methodName, object[] args, Func include) + private Task SendToAllConnections(string methodName, object[] args, Func include, object state = null) { List tasks = null; SerializedHubMessage message = null; @@ -93,7 +93,7 @@ namespace Microsoft.AspNetCore.SignalR // foreach over HubConnectionStore avoids allocating an enumerator foreach (var connection in _connections) { - if (include != null && !include(connection)) + if (include != null && !include(connection, state)) { continue; } @@ -127,12 +127,12 @@ namespace Microsoft.AspNetCore.SignalR // Tasks and message are passed by ref so they can be lazily created inside the method post-filtering, // while still being re-usable when sending to multiple groups - private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, ref List tasks, ref SerializedHubMessage message) + private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, object state, ref List tasks, ref SerializedHubMessage message) { // foreach over ConcurrentDictionary avoids allocating an enumerator foreach (var connection in connections) { - if (include != null && !include(connection.Value)) + if (include != null && !include(connection.Value, state)) { continue; } @@ -193,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR // group might be modified inbetween checking and sending List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, null, null, ref tasks, ref message); if (tasks != null) { @@ -221,7 +221,7 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, null, null, ref tasks, ref message); } } @@ -247,7 +247,7 @@ namespace Microsoft.AspNetCore.SignalR List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, connection => !excludedConnectionIds.Contains(connection.ConnectionId), ref tasks, ref message); + SendToGroupConnections(methodName, args, group, (connection, state) => !((IReadOnlyList)state).Contains(connection.ConnectionId), excludedConnectionIds, ref tasks, ref message); if (tasks != null) { @@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR /// public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); + return SendToAllConnections(methodName, args, (connection, state) => string.Equals(connection.UserIdentifier, (string)state, StringComparison.Ordinal), userId); } /// @@ -292,19 +292,19 @@ namespace Microsoft.AspNetCore.SignalR /// public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => !excludedConnectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, (connection, state) => !((IReadOnlyList)state).Contains(connection.ConnectionId), excludedConnectionIds); } /// public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, (connection, state) => ((IReadOnlyList)state).Contains(connection.ConnectionId), connectionIds); } /// public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier)); + return SendToAllConnections(methodName, args, (connection, state) => ((IReadOnlyList)state).Contains(connection.UserIdentifier), userIds); } } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index cead347976..29ae26319c 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -275,7 +275,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) { - cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, default); arguments[parameterPointer] = cts.Token; } else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) @@ -308,7 +308,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal return; } - cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, default); connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts); var enumerable = descriptor.FromReturnedStream(result, cts.Token);