From 85bde1da5e957856250985dc2ae3a8de732e86ba Mon Sep 17 00:00:00 2001 From: Brennan Date: Wed, 19 Aug 2020 14:58:24 -0700 Subject: [PATCH] Parallel hub invocations (#23535) --- .../DefaultHubDispatcherBenchmark.cs | 1 + .../server/Core/src/HubConnectionContext.cs | 9 + .../Core/src/HubConnectionContextOptions.cs | 5 + .../server/Core/src/HubConnectionHandler.cs | 13 +- src/SignalR/server/Core/src/HubOptions.cs | 20 + .../server/Core/src/HubOptionsSetup`T.cs | 1 + .../src/Internal/DefaultHubDispatcher.Log.cs | 8 + .../Core/src/Internal/DefaultHubDispatcher.cs | 164 ++++--- .../src/Internal/SemaphoreSlimExtensions.cs | 41 ++ .../server/SignalR/test/AddSignalRTests.cs | 22 + .../HubConnectionHandlerTestUtils/Hubs.cs | 35 +- .../SignalR/test/HubConnectionHandlerTests.cs | 405 +++++++++++++++++- 12 files changed, 637 insertions(+), 87 deletions(-) create mode 100644 src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs diff --git a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 6850989e3d..62ad679754 100644 --- a/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -47,6 +47,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks var contextOptions = new HubConnectionContextOptions() { KeepAliveInterval = TimeSpan.Zero, + StreamBufferCapacity = 10, }; _connectionContext = new NoErrorHubConnectionContext(connection, contextOptions, NullLoggerFactory.Instance); diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 0e98a6ee08..9f94b6a8b0 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -73,6 +73,13 @@ namespace Microsoft.AspNetCore.SignalR _systemClock = contextOptions.SystemClock ?? new SystemClock(); _lastSendTimeStamp = _systemClock.UtcNowTicks; + + // We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it + var maxInvokeLimit = contextOptions.MaximumParallelInvocations; + if (maxInvokeLimit != 1) + { + ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit); + } } internal StreamTracker StreamTracker @@ -93,6 +100,8 @@ namespace Microsoft.AspNetCore.SignalR internal Exception? CloseException { get; private set; } + internal SemaphoreSlim? ActiveInvocationLimit { get; } + /// /// Gets a that notifies when the connection is aborted. /// diff --git a/src/SignalR/server/Core/src/HubConnectionContextOptions.cs b/src/SignalR/server/Core/src/HubConnectionContextOptions.cs index 4626d195cc..54ada054cf 100644 --- a/src/SignalR/server/Core/src/HubConnectionContextOptions.cs +++ b/src/SignalR/server/Core/src/HubConnectionContextOptions.cs @@ -32,5 +32,10 @@ namespace Microsoft.AspNetCore.SignalR public long? MaximumReceiveMessageSize { get; set; } internal ISystemClock SystemClock { get; set; } = default!; + + /// + /// Gets or sets the maximum parallel hub method invocations. + /// + public int MaximumParallelInvocations { get; set; } = 1; } } diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 40745494ad..d2d77e0fe6 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; @@ -31,6 +32,7 @@ namespace Microsoft.AspNetCore.SignalR private readonly HubDispatcher _dispatcher; private readonly bool _enableDetailedErrors; private readonly long? _maximumMessageSize; + private readonly int _maxParallelInvokes; // Internal for testing internal ISystemClock SystemClock { get; set; } = new SystemClock(); @@ -70,6 +72,7 @@ namespace Microsoft.AspNetCore.SignalR { _maximumMessageSize = _hubOptions.MaximumReceiveMessageSize; _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors; + _maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient; if (_hubOptions.HubFilters != null) { @@ -80,6 +83,7 @@ namespace Microsoft.AspNetCore.SignalR { _maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize; _enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors; + _maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient; if (_globalHubOptions.HubFilters != null) { @@ -116,6 +120,7 @@ namespace Microsoft.AspNetCore.SignalR StreamBufferCapacity = _hubOptions.StreamBufferCapacity ?? _globalHubOptions.StreamBufferCapacity ?? HubOptionsSetup.DefaultStreamBufferCapacity, MaximumReceiveMessageSize = _maximumMessageSize, SystemClock = SystemClock, + MaximumParallelInvocations = _maxParallelInvokes, }; Log.ConnectedStarting(_logger); @@ -235,7 +240,6 @@ namespace Microsoft.AspNetCore.SignalR var protocol = connection.Protocol; connection.BeginClientTimeout(); - var binder = new HubConnectionBinder(_dispatcher, connection); while (true) @@ -258,8 +262,9 @@ namespace Microsoft.AspNetCore.SignalR { while (protocol.TryParseMessage(ref buffer, binder, out var message)) { - messageReceived = true; connection.StopClientTimeout(); + // This lets us know the timeout has stopped and we need to re-enable it after dispatching the message + messageReceived = true; await _dispatcher.DispatchMessageAsync(connection, message); } @@ -286,9 +291,9 @@ namespace Microsoft.AspNetCore.SignalR if (protocol.TryParseMessage(ref segment, binder, out var message)) { - messageReceived = true; connection.StopClientTimeout(); - + // This lets us know the timeout has stopped and we need to re-enable it after dispatching the message + messageReceived = true; await _dispatcher.DispatchMessageAsync(connection, message); } else if (overLength) diff --git a/src/SignalR/server/Core/src/HubOptions.cs b/src/SignalR/server/Core/src/HubOptions.cs index a9a889909d..684b8c343b 100644 --- a/src/SignalR/server/Core/src/HubOptions.cs +++ b/src/SignalR/server/Core/src/HubOptions.cs @@ -11,6 +11,8 @@ namespace Microsoft.AspNetCore.SignalR /// public class HubOptions { + private int _maximumParallelInvocationsPerClient = 1; + // HandshakeTimeout and KeepAliveInterval are set to null here to help identify when // local hub options have been set. Global default values are set in HubOptionsSetup. // SupportedProtocols being null is the true default value, and it represents support @@ -53,5 +55,23 @@ namespace Microsoft.AspNetCore.SignalR public int? StreamBufferCapacity { get; set; } = null; internal List? HubFilters { get; set; } + + /// + /// By default a client is only allowed to invoke a single Hub method at a time. + /// Changing this property will allow clients to invoke multiple methods at the same time before queueing. + /// + public int MaximumParallelInvocationsPerClient + { + get => _maximumParallelInvocationsPerClient; + set + { + if (value < 1) + { + throw new ArgumentOutOfRangeException(nameof(MaximumParallelInvocationsPerClient)); + } + + _maximumParallelInvocationsPerClient = value; + } + } } } diff --git a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs index a935980e09..1dfae3de0c 100644 --- a/src/SignalR/server/Core/src/HubOptionsSetup`T.cs +++ b/src/SignalR/server/Core/src/HubOptionsSetup`T.cs @@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.SignalR options.EnableDetailedErrors = _hubOptions.EnableDetailedErrors; options.MaximumReceiveMessageSize = _hubOptions.MaximumReceiveMessageSize; options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity; + options.MaximumParallelInvocationsPerClient = _hubOptions.MaximumParallelInvocationsPerClient; options.UserHasSetValues = true; diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.Log.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.Log.cs index 359c78f2db..e6710f7e31 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.Log.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.Log.cs @@ -79,6 +79,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal private static readonly Action _invalidHubParameters = LoggerMessage.Define(LogLevel.Debug, new EventId(22, "InvalidHubParameters"), "Parameters to hub method '{HubMethod}' are incorrect."); + private static readonly Action _invocationIdInUse = + LoggerMessage.Define(LogLevel.Debug, new EventId(23, "InvocationIdInUse"), "Invocation ID '{InvocationId}' is already in use."); + public static void ReceivedHubInvocation(ILogger logger, InvocationMessage invocationMessage) { _receivedHubInvocation(logger, invocationMessage, null); @@ -188,6 +191,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal { _invalidHubParameters(logger, hubMethod, exception); } + + public static void InvocationIdInUse(ILogger logger, string InvocationId) + { + _invocationIdInUse(logger, InvocationId, null); + } } } } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 889ee5dfc8..ee6be9f57f 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -147,6 +147,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal // Messages are dispatched sequentially and will stop other messages from being processed until they complete. // Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run. + // With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running. + switch (hubMessage) { case InvocationBindingFailureMessage bindingFailureMessage: @@ -229,7 +231,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal connection.StreamTracker.TryComplete(message); // TODO: Send stream completion message to client when we add it - return Task.CompletedTask; } @@ -258,7 +259,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal else { bool isStreamCall = descriptor.StreamingParameters != null; - return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall); + if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse) + { + return connection.ActiveInvocationLimit.RunAsync(state => + { + var (dispatcher, descriptor, connection, invocationMessage) = state; + return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false); + }, (this, descriptor, connection, hubMethodInvocationMessage)); + } + else + { + return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall); + } } } @@ -305,68 +317,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal InitializeHub(hub, connection); Task invocation = null; - CancellationTokenSource cts = null; var arguments = hubMethodInvocationMessage.Arguments; + CancellationTokenSource cts = null; if (descriptor.HasSyntheticArguments) { - // In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments) - arguments = new object[descriptor.OriginalParameterTypes.Count]; - - var streamPointer = 0; - var hubInvocationArgumentPointer = 0; - for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) - { - if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer && - (hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null || - descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType()))) - { - // The types match so it isn't a synthetic argument, just copy it into the arguments array - arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]; - hubInvocationArgumentPointer++; - } - else - { - if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) - { - cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - arguments[parameterPointer] = cts.Token; - } - else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) - { - Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]); - var itemType = descriptor.StreamingParameters[streamPointer]; - arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], - itemType, descriptor.OriginalParameterTypes[parameterPointer]); - - streamPointer++; - } - else - { - // This should never happen - Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'."); - } - } - } + ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts); } if (isStreamResponse) { - var result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider); - - if (result == null) - { - Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); - await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, - $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>."); - return; - } - - cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); - connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts); - var enumerable = descriptor.FromReturnedStream(result, cts.Token); - - Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage); + _ = StreamAsync(hubMethodInvocationMessage.InvocationId, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor); } else { @@ -456,13 +416,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal return scope.DisposeAsync(); } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable enumerable, IServiceScope scope, - IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) + private async Task StreamAsync(string invocationId, HubConnectionContext connection, object[] arguments, IServiceScope scope, + IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor) { string error = null; + streamCts = streamCts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + try { + if (!connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts)) + { + Log.InvocationIdInUse(_logger, invocationId); + error = $"Invocation ID '{invocationId}' is already in use."; + return; + } + + object result; + try + { + result = await ExecuteHubMethod(descriptor.MethodExecutor, hub, arguments, connection, scope.ServiceProvider); + } + catch (Exception ex) + { + Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); + error = ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors); + return; + } + + if (result == null) + { + Log.InvalidReturnValueFromStreamingMethod(_logger, descriptor.MethodExecutor.MethodInfo.Name); + error = $"The value returned by the streaming method '{descriptor.MethodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>."; + return; + } + + var enumerable = descriptor.FromReturnedStream(result, streamCts.Token); + + Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, descriptor.MethodExecutor); + await foreach (var streamItem in enumerable) { // Send the stream item @@ -477,8 +469,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal catch (Exception ex) { // If the streaming method was canceled we don't want to send a HubException message - this is not an error case - if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) - && cts.IsCancellationRequested)) + if (!(ex is OperationCanceledException && streamCts.IsCancellationRequested)) { error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); } @@ -487,15 +478,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal { await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); - // Dispose the linked CTS for the stream. streamCts.Dispose(); + connection.ActiveRequestCancellationSources.TryRemove(invocationId, out _); await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); - - if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) - { - cts.Dispose(); - } } } @@ -612,6 +598,50 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } + private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall, + HubConnectionContext connection, ref object[] arguments, out CancellationTokenSource cts) + { + cts = null; + // In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments) + arguments = new object[descriptor.OriginalParameterTypes.Count]; + + var streamPointer = 0; + var hubInvocationArgumentPointer = 0; + for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) + { + if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer && + (hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null || + descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType()))) + { + // The types match so it isn't a synthetic argument, just copy it into the arguments array + arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]; + hubInvocationArgumentPointer++; + } + else + { + if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) + { + cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + arguments[parameterPointer] = cts.Token; + } + else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) + { + Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]); + var itemType = descriptor.StreamingParameters[streamPointer]; + arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], + itemType, descriptor.OriginalParameterTypes[parameterPointer]); + + streamPointer++; + } + else + { + // This should never happen + Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{descriptor.MethodExecutor.MethodInfo.Name}'."); + } + } + } + } + private void DiscoverHubMethods() { var hubType = typeof(THub); diff --git a/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs b/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs new file mode 100644 index 0000000000..4650fb11e8 --- /dev/null +++ b/src/SignalR/server/Core/src/Internal/SemaphoreSlimExtensions.cs @@ -0,0 +1,41 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + internal static class SemaphoreSlimExtensions + { + public static Task RunAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) + { + if (semaphoreSlim.Wait(0)) + { + _ = RunTask(callback, semaphoreSlim, state); + return Task.CompletedTask; + } + + return RunSlowAsync(semaphoreSlim, callback, state); + } + + private static async Task RunSlowAsync(this SemaphoreSlim semaphoreSlim, Func callback, TState state) + { + await semaphoreSlim.WaitAsync(); + return RunTask(callback, semaphoreSlim, state); + } + + static async Task RunTask(Func callback, SemaphoreSlim semaphoreSlim, TState state) + { + try + { + await callback(state); + } + finally + { + semaphoreSlim.Release(); + } + } + } +} diff --git a/src/SignalR/server/SignalR/test/AddSignalRTests.cs b/src/SignalR/server/SignalR/test/AddSignalRTests.cs index a8cd5a9342..1eca6c63ed 100644 --- a/src/SignalR/server/SignalR/test/AddSignalRTests.cs +++ b/src/SignalR/server/SignalR/test/AddSignalRTests.cs @@ -7,6 +7,7 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; @@ -78,11 +79,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests serviceCollection.AddSignalR().AddHubOptions(options => { options.SupportedProtocols.Clear(); + options.AddFilter(new CustomHubFilter()); }); var serviceProvider = serviceCollection.BuildServiceProvider(); Assert.Equal(1, serviceProvider.GetRequiredService>().Value.SupportedProtocols.Count); Assert.Equal(0, serviceProvider.GetRequiredService>>().Value.SupportedProtocols.Count); + + Assert.Null(serviceProvider.GetRequiredService>().Value.HubFilters); + Assert.Single(serviceProvider.GetRequiredService>>().Value.HubFilters); } [Fact] @@ -105,6 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(globalHubOptions.HandshakeTimeout, hubOptions.HandshakeTimeout); Assert.Equal(globalHubOptions.SupportedProtocols, hubOptions.SupportedProtocols); Assert.Equal(globalHubOptions.ClientTimeoutInterval, hubOptions.ClientTimeoutInterval); + Assert.Equal(globalHubOptions.MaximumParallelInvocationsPerClient, hubOptions.MaximumParallelInvocationsPerClient); Assert.True(hubOptions.UserHasSetValues); } @@ -138,6 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests options.HandshakeTimeout = null; options.SupportedProtocols = null; options.ClientTimeoutInterval = TimeSpan.FromSeconds(1); + options.MaximumParallelInvocationsPerClient = 3; }); var serviceProvider = serviceCollection.BuildServiceProvider(); @@ -149,6 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Null(globalOptions.KeepAliveInterval); Assert.Null(globalOptions.HandshakeTimeout); Assert.Null(globalOptions.SupportedProtocols); + Assert.Equal(3, globalOptions.MaximumParallelInvocationsPerClient); Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval); } @@ -175,6 +183,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal("messagepack", p); }); } + + [Fact] + public void ThrowsIfSetInvalidValueForMaxInvokes() + { + Assert.Throws(() => new HubOptions() { MaximumParallelInvocationsPerClient = 0 }); + } } public class CustomHub : Hub @@ -333,6 +347,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests throw new NotImplementedException(); } } + + internal class CustomHubFilter : IHubFilter + { + public ValueTask InvokeMethodAsync(HubInvocationContext invocationContext, Func> next) + { + throw new NotImplementedException(); + } + } } namespace Microsoft.AspNetCore.SignalR.Internal diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index c2a4893fd4..498ce608e4 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -239,6 +239,22 @@ namespace Microsoft.AspNetCore.SignalR.Tests return results; } + [Authorize("test")] + public async Task> UploadArrayAuth(ChannelReader source) + { + var results = new List(); + + while (await source.WaitToReadAsync()) + { + while (source.TryRead(out var item)) + { + results.Add(item); + } + } + + return results; + } + public async Task TestTypeCastingErrors(ChannelReader source) { try @@ -684,13 +700,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Channel.CreateUnbounded().Reader; } - public ChannelReader ThrowStream() + public ChannelReader ExceptionStream() { var channel = Channel.CreateUnbounded(); channel.Writer.TryComplete(new Exception("Exception from channel")); return channel.Reader; } + public ChannelReader ThrowStream() + { + throw new Exception("Throw from hub method"); + } + + public ChannelReader NullStream() + { + return null; + } + public int NonStream() { return 42; @@ -1010,6 +1036,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests return 21; } + public async Task Upload(ChannelReader stream) + { + _tcsService.StartedMethod.SetResult(null); + _ = await stream.ReadAndCollectAllAsync(); + _tcsService.EndMethod.SetResult(null); + } + private class CustomAsyncEnumerable : IAsyncEnumerable { private readonly TcsService _tcsService; diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 32d198d2fd..862cde22d4 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -400,9 +400,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await client.Connection.Application.Output.WriteAsync(part3); - Assert.True(task.IsCompleted); - - var completionMessage = await task as CompletionMessage; + var completionMessage = await task.OrTimeout() as CompletionMessage; Assert.NotNull(completionMessage); Assert.Equal("hello", completionMessage.Result); Assert.Equal("1", completionMessage.InvocationId); @@ -2089,7 +2087,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await client.Connected.OrTimeout(); - var messages = await client.StreamAsync(nameof(StreamingHub.ThrowStream)); + var messages = await client.StreamAsync(nameof(StreamingHub.ExceptionStream)); Assert.Equal(1, messages.Count); var completion = messages[0] as CompletionMessage; @@ -2923,7 +2921,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => { services.Configure(options => - options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0)); + { + options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0); + options.MaximumParallelInvocationsPerClient = 1; + }); services.AddSingleton(tcsService); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); @@ -2963,6 +2964,42 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task HubMethodInvokeCountsTowardsClientTimeoutIfParallelNotMaxed() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.Configure(options => + { + options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0); + options.MaximumParallelInvocationsPerClient = 2; + }); + services.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient(new JsonHubProtocol())) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + // This starts the timeout logic + await client.SendHubMessageAsync(PingMessage.Instance); + + // Call long running hub method + var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); + await tcsService.StartedMethod.Task.OrTimeout(); + + // Tick heartbeat while hub method is running + client.TickHeartbeat(); + + // Connection is closed + await connectionHandlerTask.OrTimeout(); + } + } + } + [Fact] public async Task EndingConnectionSendsCloseMessageWithNoError() { @@ -3040,7 +3077,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifiableLog()) { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => + { + options.MaximumParallelInvocationsPerClient = 1; + }); + }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); using (var client = new TestClient(new NewtonsoftJsonHubProtocol())) @@ -3062,7 +3105,119 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task InvocationsRunInOrder() + public async Task StreamMethodThatThrowsWillCleanup() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + + var messages = await client.StreamAsync(nameof(StreamingHub.ThrowStream)); + + Assert.Equal(1, messages.Count); + var completion = messages[0] as CompletionMessage; + Assert.NotNull(completion); + + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; + + // OnConnectedAsync and ThrowStream hubs have been disposed + Assert.Equal(2, hubActivator.ReleaseCount); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task StreamMethodThatReturnsNullWillCleanup() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + + var messages = await client.StreamAsync(nameof(StreamingHub.NullStream)); + + Assert.Equal(1, messages.Count); + var completion = messages[0] as CompletionMessage; + Assert.NotNull(completion); + + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; + + // OnConnectedAsync and NullStream hubs have been disposed + Assert.Equal(2, hubActivator.ReleaseCount); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task StreamMethodWithDuplicateIdFails() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + await client.Connected.OrTimeout(); + + await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty())).OrTimeout(); + + await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty())).OrTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().OrTimeout()); + Assert.Equal("Invocation ID '123' is already in use.", completion.Error); + + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; + + // OnConnectedAsync and BlockingStream hubs have been disposed + Assert.Equal(2, hubActivator.ReleaseCount); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task InvocationsRunInOrderWithNoParallelism() { using (StartVerifiableLog()) { @@ -3070,6 +3225,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { builder.AddSingleton(tcsService); + + builder.AddSignalR(options => + { + options.MaximumParallelInvocationsPerClient = 1; + }); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); @@ -3112,7 +3272,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests } [Fact] - public async Task StreamInvocationsBlockOtherInvocationsUntilTheyStartStreaming() + public async Task InvocationsCanRunOutOfOrderWithParallelism() { using (StartVerifiableLog()) { @@ -3120,7 +3280,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => { builder.AddSingleton(tcsService); - builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + + builder.AddSignalR(options => + { + options.MaximumParallelInvocationsPerClient = 2; + }); }, LoggerFactory); var connectionHandler = serviceProvider.GetService>(); @@ -3130,7 +3294,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); // Long running hub invocation to test that other invocations will not run until it is completed - var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream), null).OrTimeout(); + await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout(); + // Wait for the long running method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + for (var i = 0; i < 5; i++) + { + // Invoke another hub method which will wait for the first method to finish + await client.SendInvocationAsync(nameof(LongRunningHub.SimpleMethod), nonBlocking: false).OrTimeout(); + + // simple hub method result + var secondResult = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(secondResult); + Assert.Equal(21L, simpleCompletion.Result); + } + + // Release the long running hub method + tcsService.EndMethod.TrySetResult(null); + + // Long running hub method result + var firstResult = await client.ReadAsync().OrTimeout(); + + var longRunningCompletion = Assert.IsType(firstResult); + Assert.Equal(12L, longRunningCompletion.Result); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task PendingInvocationUnblockedWhenBlockingMethodCompletesWithParallelism() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + + builder.AddSignalR(options => + { + options.MaximumParallelInvocationsPerClient = 2; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + // Because we use PipeScheduler.Inline the hub invocations will run inline until they wait, which happens inside the LongRunningMethod call + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + // Long running hub invocation to test that other invocations will not run until it is completed + await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout(); + // Wait for the long running method to start + await tcsService.StartedMethod.Task.OrTimeout(); + // Grab the tcs before resetting to use in the second long running method + var endTcs = tcsService.EndMethod; + tcsService.Reset(); + + // Long running hub invocation to test that other invocations will not run until it is completed + await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout(); // Wait for the long running method to start await tcsService.StartedMethod.Task.OrTimeout(); @@ -3139,21 +3367,79 @@ namespace Microsoft.AspNetCore.SignalR.Tests // Both invocations should be waiting now Assert.Null(client.TryRead()); - // Release the long running hub method + // Release the second long running hub method tcsService.EndMethod.TrySetResult(null); - // simple hub method result - var result = await client.ReadAsync().OrTimeout(); + // Long running hub method result + var firstResult = await client.ReadAsync().OrTimeout(); - var simpleCompletion = Assert.IsType(result); + var longRunningCompletion = Assert.IsType(firstResult); + Assert.Equal(12L, longRunningCompletion.Result); + + // simple hub method result + var secondResult = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(secondResult); Assert.Equal(21L, simpleCompletion.Result); + // Release the first long running hub method + endTcs.TrySetResult(null); + + firstResult = await client.ReadAsync().OrTimeout(); + longRunningCompletion = Assert.IsType(firstResult); + Assert.Equal(12L, longRunningCompletion.Result); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task StreamInvocationsDoNotBlockOtherInvocations() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + + builder.AddSignalR(options => + { + options.MaximumParallelInvocationsPerClient = 1; + }); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + // Because we use PipeScheduler.Inline the hub invocations will run inline until they go async + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + // Long running stream invocation to test that other invocations can still run before it is completed + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream), null).OrTimeout(); + // Wait for the long running method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + // Invoke another hub method which will be able to run even though a streaming method is still running + var completion = await client.InvokeAsync(nameof(LongRunningHub.SimpleMethod)).OrTimeout(); + Assert.Null(completion.Error); + Assert.Equal(21L, completion.Result); + + // Release the long running hub method + tcsService.EndMethod.TrySetResult(null); + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout(); // Completion message for canceled Stream - await client.ReadAsync().OrTimeout(); + completion = Assert.IsType(await client.ReadAsync().OrTimeout()); + Assert.Equal(streamInvocationId, completion.InvocationId); // Shut down client.Dispose(); @@ -3319,6 +3605,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + private class DelayRequirement : AuthorizationHandler, IAuthorizationRequirement + { + private readonly TcsService _tcsService; + public DelayRequirement(TcsService tcsService) + { + _tcsService = tcsService; + } + + protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, DelayRequirement requirement, HubInvocationContext resource) + { + _tcsService.StartedMethod.SetResult(null); + await _tcsService.EndMethod.Task; + context.Succeed(requirement); + } + } + + [Fact] + // Test to check if StreamItems can be processed before the Stream from the invocation is properly registered internally + public async Task UploadStreamStreamItemsSentAsSoonAsPossible() + { + // Use Auth as the delay injection point because it is one of the first things to run after the invocation message has been parsed + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddAuthorization(options => + { + options.AddPolicy("test", policy => + { + policy.Requirements.Add(new DelayRequirement(tcsService)); + }); + }); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadArrayAuth), new[] { "id" }, Array.Empty()); + await tcsService.StartedMethod.Task.OrTimeout(); + + var objects = new[] { new SampleObject("solo", 322), new SampleObject("ggez", 3145) }; + foreach (var thing in objects) + { + await client.SendHubMessageAsync(new StreamItemMessage("id", thing)).OrTimeout(); + } + + tcsService.EndMethod.SetResult(null); + + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + var result = ((JArray)response.Result).ToArray(); + + Assert.Equal(objects[0].Foo, ((JContainer)result[0])["foo"]); + Assert.Equal(objects[0].Bar, ((JContainer)result[0])["bar"]); + Assert.Equal(objects[1].Foo, ((JContainer)result[1])["foo"]); + Assert.Equal(objects[1].Bar, ((JContainer)result[1])["bar"]); + } + } + + [Fact] + public async Task UploadStreamDoesNotCountTowardsMaxInvocationLimit() + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => options.MaximumParallelInvocationsPerClient = 1); + services.AddSingleton(tcsService); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(LongRunningHub.Upload), new[] { "id" }, Array.Empty()); + await tcsService.StartedMethod.Task.OrTimeout(); + + var completion = await client.InvokeAsync(nameof(LongRunningHub.SimpleMethod)).OrTimeout(); + Assert.Null(completion.Error); + Assert.Equal(21L, completion.Result); + + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); + + await tcsService.EndMethod.Task.OrTimeout(); + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + Assert.Null(response.Result); + Assert.Null(response.Error); + } + } + [Fact] public async Task ConnectionAbortedIfSendFailsWithProtocolError() {