From f3ba1babfcee37a96a25579f5b7e208e78b4658c Mon Sep 17 00:00:00 2001 From: David Fowler Date: Mon, 30 Apr 2018 22:41:32 -0700 Subject: [PATCH] Remove unnecessary state machines in DefautHubDispatcher (#2167) * Remove unnecessary state machines in DefautHubDispatcher - Remove state machines for async tail calls - Added fast path for auth check when there are no policies --- .../Internal/DefaultHubDispatcher.cs | 30 +++++++++++-------- .../Internal/TaskCache.cs | 13 ++++++++ 2 files changed, 30 insertions(+), 13 deletions(-) create mode 100644 src/Microsoft.AspNetCore.SignalR.Core/Internal/TaskCache.cs diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index ec60f108e5..51468349ea 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -73,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal } } - public override async Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage) + public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage) { // Messages are dispatched sequentially and will stop other messages from being processed until they complete. // Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run. @@ -81,18 +81,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal switch (hubMessage) { case InvocationBindingFailureMessage bindingFailureMessage: - await ProcessBindingFailure(connection, bindingFailureMessage); - break; + return ProcessBindingFailure(connection, bindingFailureMessage); case InvocationMessage invocationMessage: Log.ReceivedHubInvocation(_logger, invocationMessage); - await ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false); - break; + return ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false); case StreamInvocationMessage streamInvocationMessage: Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage); - await ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true); - break; + return ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true); case CancelInvocationMessage cancelInvocationMessage: // Check if there is an associated active stream and cancel it if it exists. @@ -118,6 +115,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName); throw new NotSupportedException($"Received unsupported message: {hubMessage}"); } + + return Task.CompletedTask; } private Task ProcessBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage) @@ -142,19 +141,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal return descriptor.ParameterTypes; } - private async Task ProcessInvocation(HubConnectionContext connection, + private Task ProcessInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation) { if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor)) { // Send an error to the client. Then let the normal completion process occur Log.UnknownHubMethod(_logger, hubMethodInvocationMessage.Target); - await connection.WriteAsync(CompletionMessage.WithError( - hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")); + return connection.WriteAsync(CompletionMessage.WithError( + hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); } else { - await Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation); + return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation); } } @@ -323,14 +322,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal hub.Groups = _hubContext.Groups; } - private async Task IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList policies) + private Task IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList policies) { // If there are no policies we don't need to run auth if (!policies.Any()) { - return true; + return TaskCache.True; } + return IsHubMethodAuthorizedSlow(provider, principal, policies); + } + + private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList policies) + { var authService = provider.GetRequiredService(); var policyProvider = provider.GetRequiredService(); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/TaskCache.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/TaskCache.cs new file mode 100644 index 0000000000..e11f53506e --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/TaskCache.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Internal +{ + internal static class TaskCache + { + public static readonly Task True = Task.FromResult(true); + public static readonly Task False = Task.FromResult(false); + } +} \ No newline at end of file