Parallel hub invocations (#23535)

This commit is contained in:
Brennan 2020-08-19 14:58:24 -07:00 committed by GitHub
parent df04381411
commit 85bde1da5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 637 additions and 87 deletions

View File

@ -47,6 +47,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
var contextOptions = new HubConnectionContextOptions()
{
KeepAliveInterval = TimeSpan.Zero,
StreamBufferCapacity = 10,
};
_connectionContext = new NoErrorHubConnectionContext(connection, contextOptions, NullLoggerFactory.Instance);

View File

@ -73,6 +73,13 @@ namespace Microsoft.AspNetCore.SignalR
_systemClock = contextOptions.SystemClock ?? new SystemClock();
_lastSendTimeStamp = _systemClock.UtcNowTicks;
// We'll be avoiding using the semaphore when the limit is set to 1, so no need to allocate it
var maxInvokeLimit = contextOptions.MaximumParallelInvocations;
if (maxInvokeLimit != 1)
{
ActiveInvocationLimit = new SemaphoreSlim(maxInvokeLimit, maxInvokeLimit);
}
}
internal StreamTracker StreamTracker
@ -93,6 +100,8 @@ namespace Microsoft.AspNetCore.SignalR
internal Exception? CloseException { get; private set; }
internal SemaphoreSlim? ActiveInvocationLimit { get; }
/// <summary>
/// Gets a <see cref="CancellationToken"/> that notifies when the connection is aborted.
/// </summary>

View File

@ -32,5 +32,10 @@ namespace Microsoft.AspNetCore.SignalR
public long? MaximumReceiveMessageSize { get; set; }
internal ISystemClock SystemClock { get; set; } = default!;
/// <summary>
/// Gets or sets the maximum parallel hub method invocations.
/// </summary>
public int MaximumParallelInvocations { get; set; } = 1;
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
@ -31,6 +32,7 @@ namespace Microsoft.AspNetCore.SignalR
private readonly HubDispatcher<THub> _dispatcher;
private readonly bool _enableDetailedErrors;
private readonly long? _maximumMessageSize;
private readonly int _maxParallelInvokes;
// Internal for testing
internal ISystemClock SystemClock { get; set; } = new SystemClock();
@ -70,6 +72,7 @@ namespace Microsoft.AspNetCore.SignalR
{
_maximumMessageSize = _hubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
_maxParallelInvokes = _hubOptions.MaximumParallelInvocationsPerClient;
if (_hubOptions.HubFilters != null)
{
@ -80,6 +83,7 @@ namespace Microsoft.AspNetCore.SignalR
{
_maximumMessageSize = _globalHubOptions.MaximumReceiveMessageSize;
_enableDetailedErrors = _globalHubOptions.EnableDetailedErrors ?? _enableDetailedErrors;
_maxParallelInvokes = _globalHubOptions.MaximumParallelInvocationsPerClient;
if (_globalHubOptions.HubFilters != null)
{
@ -116,6 +120,7 @@ namespace Microsoft.AspNetCore.SignalR
StreamBufferCapacity = _hubOptions.StreamBufferCapacity ?? _globalHubOptions.StreamBufferCapacity ?? HubOptionsSetup.DefaultStreamBufferCapacity,
MaximumReceiveMessageSize = _maximumMessageSize,
SystemClock = SystemClock,
MaximumParallelInvocations = _maxParallelInvokes,
};
Log.ConnectedStarting(_logger);
@ -235,7 +240,6 @@ namespace Microsoft.AspNetCore.SignalR
var protocol = connection.Protocol;
connection.BeginClientTimeout();
var binder = new HubConnectionBinder<THub>(_dispatcher, connection);
while (true)
@ -258,8 +262,9 @@ namespace Microsoft.AspNetCore.SignalR
{
while (protocol.TryParseMessage(ref buffer, binder, out var message))
{
messageReceived = true;
connection.StopClientTimeout();
// This lets us know the timeout has stopped and we need to re-enable it after dispatching the message
messageReceived = true;
await _dispatcher.DispatchMessageAsync(connection, message);
}
@ -286,9 +291,9 @@ namespace Microsoft.AspNetCore.SignalR
if (protocol.TryParseMessage(ref segment, binder, out var message))
{
messageReceived = true;
connection.StopClientTimeout();
// This lets us know the timeout has stopped and we need to re-enable it after dispatching the message
messageReceived = true;
await _dispatcher.DispatchMessageAsync(connection, message);
}
else if (overLength)

View File

@ -11,6 +11,8 @@ namespace Microsoft.AspNetCore.SignalR
/// </summary>
public class HubOptions
{
private int _maximumParallelInvocationsPerClient = 1;
// HandshakeTimeout and KeepAliveInterval are set to null here to help identify when
// local hub options have been set. Global default values are set in HubOptionsSetup.
// SupportedProtocols being null is the true default value, and it represents support
@ -53,5 +55,23 @@ namespace Microsoft.AspNetCore.SignalR
public int? StreamBufferCapacity { get; set; } = null;
internal List<IHubFilter>? HubFilters { get; set; }
/// <summary>
/// By default a client is only allowed to invoke a single Hub method at a time.
/// Changing this property will allow clients to invoke multiple methods at the same time before queueing.
/// </summary>
public int MaximumParallelInvocationsPerClient
{
get => _maximumParallelInvocationsPerClient;
set
{
if (value < 1)
{
throw new ArgumentOutOfRangeException(nameof(MaximumParallelInvocationsPerClient));
}
_maximumParallelInvocationsPerClient = value;
}
}
}
}

View File

@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.SignalR
options.EnableDetailedErrors = _hubOptions.EnableDetailedErrors;
options.MaximumReceiveMessageSize = _hubOptions.MaximumReceiveMessageSize;
options.StreamBufferCapacity = _hubOptions.StreamBufferCapacity;
options.MaximumParallelInvocationsPerClient = _hubOptions.MaximumParallelInvocationsPerClient;
options.UserHasSetValues = true;

View File

@ -79,6 +79,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private static readonly Action<ILogger, string, Exception> _invalidHubParameters =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(22, "InvalidHubParameters"), "Parameters to hub method '{HubMethod}' are incorrect.");
private static readonly Action<ILogger, string, Exception> _invocationIdInUse =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(23, "InvocationIdInUse"), "Invocation ID '{InvocationId}' is already in use.");
public static void ReceivedHubInvocation(ILogger logger, InvocationMessage invocationMessage)
{
_receivedHubInvocation(logger, invocationMessage, null);
@ -188,6 +191,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
_invalidHubParameters(logger, hubMethod, exception);
}
public static void InvocationIdInUse(ILogger logger, string InvocationId)
{
_invocationIdInUse(logger, InvocationId, null);
}
}
}
}

View File

@ -147,6 +147,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// Messages are dispatched sequentially and will stop other messages from being processed until they complete.
// Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run.
// With parallel invokes enabled, messages run sequentially until they go async and then the next message will be allowed to start running.
switch (hubMessage)
{
case InvocationBindingFailureMessage bindingFailureMessage:
@ -229,7 +231,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
connection.StreamTracker.TryComplete(message);
// TODO: Send stream completion message to client when we add it
return Task.CompletedTask;
}
@ -258,7 +259,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal
else
{
bool isStreamCall = descriptor.StreamingParameters != null;
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
if (connection.ActiveInvocationLimit != null && !isStreamCall && !isStreamResponse)
{
return connection.ActiveInvocationLimit.RunAsync(state =>
{
var (dispatcher, descriptor, connection, invocationMessage) = state;
return dispatcher.Invoke(descriptor, connection, invocationMessage, isStreamResponse: false, isStreamCall: false);
}, (this, descriptor, connection, hubMethodInvocationMessage));
}
else
{
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
}
}
}
@ -305,68 +317,16 @@ namespace Microsoft.AspNetCore.SignalR.Internal
InitializeHub(hub, connection);
Task invocation = null;
CancellationTokenSource cts = null;
var arguments = hubMethodInvocationMessage.Arguments;
CancellationTokenSource cts = null;
if (descriptor.HasSyntheticArguments)
{
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object[descriptor.OriginalParameterTypes.Count];
var streamPointer = 0;
var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null ||
descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType())))
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]);
var itemType = descriptor.StreamingParameters[streamPointer];
arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer],
itemType, descriptor.OriginalParameterTypes[parameterPointer]);
streamPointer++;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
}
}
}
ReplaceArguments(descriptor, hubMethodInvocationMessage, isStreamCall, connection, ref arguments, out cts);
}
if (isStreamResponse)
{
var result = await ExecuteHubMethod(methodExecutor, hub, arguments, connection, scope.ServiceProvider);
if (result == null)
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.");
return;
}
cts = cts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
connection.ActiveRequestCancellationSources.TryAdd(hubMethodInvocationMessage.InvocationId, cts);
var enumerable = descriptor.FromReturnedStream(result, cts.Token);
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerable, scope, hubActivator, hub, cts, hubMethodInvocationMessage);
_ = StreamAsync(hubMethodInvocationMessage.InvocationId, connection, arguments, scope, hubActivator, hub, cts, hubMethodInvocationMessage, descriptor);
}
else
{
@ -456,13 +416,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return scope.DisposeAsync();
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerable<object> enumerable, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
private async Task StreamAsync(string invocationId, HubConnectionContext connection, object[] arguments, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage, HubMethodDescriptor descriptor)
{
string error = null;
streamCts = streamCts ?? CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
try
{
if (!connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts))
{
Log.InvocationIdInUse(_logger, invocationId);
error = $"Invocation ID '{invocationId}' is already in use.";
return;
}
object result;
try
{
result = await ExecuteHubMethod(descriptor.MethodExecutor, hub, arguments, connection, scope.ServiceProvider);
}
catch (Exception ex)
{
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
error = ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors);
return;
}
if (result == null)
{
Log.InvalidReturnValueFromStreamingMethod(_logger, descriptor.MethodExecutor.MethodInfo.Name);
error = $"The value returned by the streaming method '{descriptor.MethodExecutor.MethodInfo.Name}' is not a ChannelReader<> or IAsyncEnumerable<>.";
return;
}
var enumerable = descriptor.FromReturnedStream(result, streamCts.Token);
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, descriptor.MethodExecutor);
await foreach (var streamItem in enumerable)
{
// Send the stream item
@ -477,8 +469,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
catch (Exception ex)
{
// If the streaming method was canceled we don't want to send a HubException message - this is not an error case
if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts)
&& cts.IsCancellationRequested))
if (!(ex is OperationCanceledException && streamCts.IsCancellationRequested))
{
error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors);
}
@ -487,15 +478,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
await CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope);
// Dispose the linked CTS for the stream.
streamCts.Dispose();
connection.ActiveRequestCancellationSources.TryRemove(invocationId, out _);
await connection.WriteAsync(CompletionMessage.WithError(invocationId, error));
if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts))
{
cts.Dispose();
}
}
}
@ -612,6 +598,50 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private void ReplaceArguments(HubMethodDescriptor descriptor, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamCall,
HubConnectionContext connection, ref object[] arguments, out CancellationTokenSource cts)
{
cts = null;
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object[descriptor.OriginalParameterTypes.Count];
var streamPointer = 0;
var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer] == null ||
descriptor.OriginalParameterTypes[parameterPointer].IsAssignableFrom(hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType())))
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true))
{
Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]);
var itemType = descriptor.StreamingParameters[streamPointer];
arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer],
itemType, descriptor.OriginalParameterTypes[parameterPointer]);
streamPointer++;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{descriptor.MethodExecutor.MethodInfo.Name}'.");
}
}
}
}
private void DiscoverHubMethods()
{
var hubType = typeof(THub);

View File

@ -0,0 +1,41 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal static class SemaphoreSlimExtensions
{
public static Task RunAsync<TState>(this SemaphoreSlim semaphoreSlim, Func<TState, Task> callback, TState state)
{
if (semaphoreSlim.Wait(0))
{
_ = RunTask(callback, semaphoreSlim, state);
return Task.CompletedTask;
}
return RunSlowAsync(semaphoreSlim, callback, state);
}
private static async Task<Task> RunSlowAsync<TState>(this SemaphoreSlim semaphoreSlim, Func<TState, Task> callback, TState state)
{
await semaphoreSlim.WaitAsync();
return RunTask(callback, semaphoreSlim, state);
}
static async Task RunTask<TState>(Func<TState, Task> callback, SemaphoreSlim semaphoreSlim, TState state)
{
try
{
await callback(state);
}
finally
{
semaphoreSlim.Release();
}
}
}
}

View File

@ -7,6 +7,7 @@ using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
@ -78,11 +79,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
serviceCollection.AddSignalR().AddHubOptions<CustomHub>(options =>
{
options.SupportedProtocols.Clear();
options.AddFilter(new CustomHubFilter());
});
var serviceProvider = serviceCollection.BuildServiceProvider();
Assert.Equal(1, serviceProvider.GetRequiredService<IOptions<HubOptions>>().Value.SupportedProtocols.Count);
Assert.Equal(0, serviceProvider.GetRequiredService<IOptions<HubOptions<CustomHub>>>().Value.SupportedProtocols.Count);
Assert.Null(serviceProvider.GetRequiredService<IOptions<HubOptions>>().Value.HubFilters);
Assert.Single(serviceProvider.GetRequiredService<IOptions<HubOptions<CustomHub>>>().Value.HubFilters);
}
[Fact]
@ -105,6 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Equal(globalHubOptions.HandshakeTimeout, hubOptions.HandshakeTimeout);
Assert.Equal(globalHubOptions.SupportedProtocols, hubOptions.SupportedProtocols);
Assert.Equal(globalHubOptions.ClientTimeoutInterval, hubOptions.ClientTimeoutInterval);
Assert.Equal(globalHubOptions.MaximumParallelInvocationsPerClient, hubOptions.MaximumParallelInvocationsPerClient);
Assert.True(hubOptions.UserHasSetValues);
}
@ -138,6 +144,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
options.HandshakeTimeout = null;
options.SupportedProtocols = null;
options.ClientTimeoutInterval = TimeSpan.FromSeconds(1);
options.MaximumParallelInvocationsPerClient = 3;
});
var serviceProvider = serviceCollection.BuildServiceProvider();
@ -149,6 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Null(globalOptions.KeepAliveInterval);
Assert.Null(globalOptions.HandshakeTimeout);
Assert.Null(globalOptions.SupportedProtocols);
Assert.Equal(3, globalOptions.MaximumParallelInvocationsPerClient);
Assert.Equal(TimeSpan.FromSeconds(1), globalOptions.ClientTimeoutInterval);
}
@ -175,6 +183,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Equal("messagepack", p);
});
}
[Fact]
public void ThrowsIfSetInvalidValueForMaxInvokes()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new HubOptions() { MaximumParallelInvocationsPerClient = 0 });
}
}
public class CustomHub : Hub
@ -333,6 +347,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
throw new NotImplementedException();
}
}
internal class CustomHubFilter : IHubFilter
{
public ValueTask<object> InvokeMethodAsync(HubInvocationContext invocationContext, Func<HubInvocationContext, ValueTask<object>> next)
{
throw new NotImplementedException();
}
}
}
namespace Microsoft.AspNetCore.SignalR.Internal

View File

@ -239,6 +239,22 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return results;
}
[Authorize("test")]
public async Task<List<object>> UploadArrayAuth(ChannelReader<object> source)
{
var results = new List<object>();
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
results.Add(item);
}
}
return results;
}
public async Task<string> TestTypeCastingErrors(ChannelReader<int> source)
{
try
@ -684,13 +700,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Channel.CreateUnbounded<string>().Reader;
}
public ChannelReader<int> ThrowStream()
public ChannelReader<int> ExceptionStream()
{
var channel = Channel.CreateUnbounded<int>();
channel.Writer.TryComplete(new Exception("Exception from channel"));
return channel.Reader;
}
public ChannelReader<int> ThrowStream()
{
throw new Exception("Throw from hub method");
}
public ChannelReader<int> NullStream()
{
return null;
}
public int NonStream()
{
return 42;
@ -1010,6 +1036,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return 21;
}
public async Task Upload(ChannelReader<string> stream)
{
_tcsService.StartedMethod.SetResult(null);
_ = await stream.ReadAndCollectAllAsync();
_tcsService.EndMethod.SetResult(null);
}
private class CustomAsyncEnumerable : IAsyncEnumerable<int>
{
private readonly TcsService _tcsService;

View File

@ -400,9 +400,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await client.Connection.Application.Output.WriteAsync(part3);
Assert.True(task.IsCompleted);
var completionMessage = await task as CompletionMessage;
var completionMessage = await task.OrTimeout() as CompletionMessage;
Assert.NotNull(completionMessage);
Assert.Equal("hello", completionMessage.Result);
Assert.Equal("1", completionMessage.InvocationId);
@ -2089,7 +2087,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await client.Connected.OrTimeout();
var messages = await client.StreamAsync(nameof(StreamingHub.ThrowStream));
var messages = await client.StreamAsync(nameof(StreamingHub.ExceptionStream));
Assert.Equal(1, messages.Count);
var completion = messages[0] as CompletionMessage;
@ -2923,7 +2921,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.Configure<HubOptions>(options =>
options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0));
{
options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0);
options.MaximumParallelInvocationsPerClient = 1;
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
@ -2963,6 +2964,42 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task HubMethodInvokeCountsTowardsClientTimeoutIfParallelNotMaxed()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.Configure<HubOptions>(options =>
{
options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0);
options.MaximumParallelInvocationsPerClient = 2;
});
services.AddSingleton(tcsService);
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient(new JsonHubProtocol()))
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
// This starts the timeout logic
await client.SendHubMessageAsync(PingMessage.Instance);
// Call long running hub method
var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod));
await tcsService.StartedMethod.Task.OrTimeout();
// Tick heartbeat while hub method is running
client.TickHeartbeat();
// Connection is closed
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task EndingConnectionSendsCloseMessageWithNoError()
{
@ -3040,7 +3077,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(null, LoggerFactory);
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options =>
{
options.MaximumParallelInvocationsPerClient = 1;
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
using (var client = new TestClient(new NewtonsoftJsonHubProtocol()))
@ -3062,7 +3105,119 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
[Fact]
public async Task InvocationsRunInOrder()
public async Task StreamMethodThatThrowsWillCleanup()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" &&
writeContext.EventId.Name == "FailedInvokingHubMethod";
}
using (StartVerifiableLog(ExpectedErrors))
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
var messages = await client.StreamAsync(nameof(StreamingHub.ThrowStream));
Assert.Equal(1, messages.Count);
var completion = messages[0] as CompletionMessage;
Assert.NotNull(completion);
var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>;
// OnConnectedAsync and ThrowStream hubs have been disposed
Assert.Equal(2, hubActivator.ReleaseCount);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task StreamMethodThatReturnsNullWillCleanup()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
var messages = await client.StreamAsync(nameof(StreamingHub.NullStream));
Assert.Equal(1, messages.Count);
var completion = messages[0] as CompletionMessage;
Assert.NotNull(completion);
var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>;
// OnConnectedAsync and NullStream hubs have been disposed
Assert.Equal(2, hubActivator.ReleaseCount);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task StreamMethodWithDuplicateIdFails()
{
using (StartVerifiableLog())
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
await client.Connected.OrTimeout();
await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).OrTimeout();
await client.SendHubMessageAsync(new StreamInvocationMessage("123", nameof(StreamingHub.BlockingStream), Array.Empty<object>())).OrTimeout();
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
Assert.Equal("Invocation ID '123' is already in use.", completion.Error);
var hubActivator = serviceProvider.GetService<IHubActivator<StreamingHub>>() as CustomHubActivator<StreamingHub>;
// OnConnectedAsync and BlockingStream hubs have been disposed
Assert.Equal(2, hubActivator.ReleaseCount);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task InvocationsRunInOrderWithNoParallelism()
{
using (StartVerifiableLog())
{
@ -3070,6 +3225,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
builder.AddSignalR(options =>
{
options.MaximumParallelInvocationsPerClient = 1;
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
@ -3112,7 +3272,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
[Fact]
public async Task StreamInvocationsBlockOtherInvocationsUntilTheyStartStreaming()
public async Task InvocationsCanRunOutOfOrderWithParallelism()
{
using (StartVerifiableLog())
{
@ -3120,7 +3280,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
builder.AddSignalR(options =>
{
options.MaximumParallelInvocationsPerClient = 2;
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
@ -3130,7 +3294,71 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
// Long running hub invocation to test that other invocations will not run until it is completed
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream), null).OrTimeout();
await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout();
// Wait for the long running method to start
await tcsService.StartedMethod.Task.OrTimeout();
for (var i = 0; i < 5; i++)
{
// Invoke another hub method which will wait for the first method to finish
await client.SendInvocationAsync(nameof(LongRunningHub.SimpleMethod), nonBlocking: false).OrTimeout();
// simple hub method result
var secondResult = await client.ReadAsync().OrTimeout();
var simpleCompletion = Assert.IsType<CompletionMessage>(secondResult);
Assert.Equal(21L, simpleCompletion.Result);
}
// Release the long running hub method
tcsService.EndMethod.TrySetResult(null);
// Long running hub method result
var firstResult = await client.ReadAsync().OrTimeout();
var longRunningCompletion = Assert.IsType<CompletionMessage>(firstResult);
Assert.Equal(12L, longRunningCompletion.Result);
// Shut down
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task PendingInvocationUnblockedWhenBlockingMethodCompletesWithParallelism()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
builder.AddSignalR(options =>
{
options.MaximumParallelInvocationsPerClient = 2;
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
// Because we use PipeScheduler.Inline the hub invocations will run inline until they wait, which happens inside the LongRunningMethod call
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
// Long running hub invocation to test that other invocations will not run until it is completed
await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout();
// Wait for the long running method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Grab the tcs before resetting to use in the second long running method
var endTcs = tcsService.EndMethod;
tcsService.Reset();
// Long running hub invocation to test that other invocations will not run until it is completed
await client.SendInvocationAsync(nameof(LongRunningHub.LongRunningMethod), nonBlocking: false).OrTimeout();
// Wait for the long running method to start
await tcsService.StartedMethod.Task.OrTimeout();
@ -3139,21 +3367,79 @@ namespace Microsoft.AspNetCore.SignalR.Tests
// Both invocations should be waiting now
Assert.Null(client.TryRead());
// Release the long running hub method
// Release the second long running hub method
tcsService.EndMethod.TrySetResult(null);
// simple hub method result
var result = await client.ReadAsync().OrTimeout();
// Long running hub method result
var firstResult = await client.ReadAsync().OrTimeout();
var simpleCompletion = Assert.IsType<CompletionMessage>(result);
var longRunningCompletion = Assert.IsType<CompletionMessage>(firstResult);
Assert.Equal(12L, longRunningCompletion.Result);
// simple hub method result
var secondResult = await client.ReadAsync().OrTimeout();
var simpleCompletion = Assert.IsType<CompletionMessage>(secondResult);
Assert.Equal(21L, simpleCompletion.Result);
// Release the first long running hub method
endTcs.TrySetResult(null);
firstResult = await client.ReadAsync().OrTimeout();
longRunningCompletion = Assert.IsType<CompletionMessage>(firstResult);
Assert.Equal(12L, longRunningCompletion.Result);
// Shut down
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
}
[Fact]
public async Task StreamInvocationsDoNotBlockOtherInvocations()
{
using (StartVerifiableLog())
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
builder.AddSignalR(options =>
{
options.MaximumParallelInvocationsPerClient = 1;
});
}, LoggerFactory);
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
// Because we use PipeScheduler.Inline the hub invocations will run inline until they go async
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
// Long running stream invocation to test that other invocations can still run before it is completed
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream), null).OrTimeout();
// Wait for the long running method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Invoke another hub method which will be able to run even though a streaming method is still running
var completion = await client.InvokeAsync(nameof(LongRunningHub.SimpleMethod)).OrTimeout();
Assert.Null(completion.Error);
Assert.Equal(21L, completion.Result);
// Release the long running hub method
tcsService.EndMethod.TrySetResult(null);
var hubActivator = serviceProvider.GetService<IHubActivator<LongRunningHub>>() as CustomHubActivator<LongRunningHub>;
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
// Completion message for canceled Stream
await client.ReadAsync().OrTimeout();
completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
Assert.Equal(streamInvocationId, completion.InvocationId);
// Shut down
client.Dispose();
@ -3319,6 +3605,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
private class DelayRequirement : AuthorizationHandler<DelayRequirement, HubInvocationContext>, IAuthorizationRequirement
{
private readonly TcsService _tcsService;
public DelayRequirement(TcsService tcsService)
{
_tcsService = tcsService;
}
protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, DelayRequirement requirement, HubInvocationContext resource)
{
_tcsService.StartedMethod.SetResult(null);
await _tcsService.EndMethod.Task;
context.Succeed(requirement);
}
}
[Fact]
// Test to check if StreamItems can be processed before the Stream from the invocation is properly registered internally
public async Task UploadStreamStreamItemsSentAsSoonAsPossible()
{
// Use Auth as the delay injection point because it is one of the first things to run after the invocation message has been parsed
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.Requirements.Add(new DelayRequirement(tcsService));
});
});
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadArrayAuth), new[] { "id" }, Array.Empty<object>());
await tcsService.StartedMethod.Task.OrTimeout();
var objects = new[] { new SampleObject("solo", 322), new SampleObject("ggez", 3145) };
foreach (var thing in objects)
{
await client.SendHubMessageAsync(new StreamItemMessage("id", thing)).OrTimeout();
}
tcsService.EndMethod.SetResult(null);
await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout();
var response = (CompletionMessage)await client.ReadAsync().OrTimeout();
var result = ((JArray)response.Result).ToArray<object>();
Assert.Equal(objects[0].Foo, ((JContainer)result[0])["foo"]);
Assert.Equal(objects[0].Bar, ((JContainer)result[0])["bar"]);
Assert.Equal(objects[1].Foo, ((JContainer)result[1])["foo"]);
Assert.Equal(objects[1].Bar, ((JContainer)result[1])["bar"]);
}
}
[Fact]
public async Task UploadStreamDoesNotCountTowardsMaxInvocationLimit()
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options => options.MaximumParallelInvocationsPerClient = 1);
services.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(LongRunningHub.Upload), new[] { "id" }, Array.Empty<object>());
await tcsService.StartedMethod.Task.OrTimeout();
var completion = await client.InvokeAsync(nameof(LongRunningHub.SimpleMethod)).OrTimeout();
Assert.Null(completion.Error);
Assert.Equal(21L, completion.Result);
await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout();
await tcsService.EndMethod.Task.OrTimeout();
var response = (CompletionMessage)await client.ReadAsync().OrTimeout();
Assert.Null(response.Result);
Assert.Null(response.Error);
}
}
[Fact]
public async Task ConnectionAbortedIfSendFailsWithProtocolError()
{