Remove unnecessary state machines in DefautHubDispatcher (#2167)

* Remove unnecessary state machines in DefautHubDispatcher
- Remove state machines for async tail calls
- Added fast path for auth check when there are no policies
This commit is contained in:
David Fowler 2018-04-30 22:41:32 -07:00 committed by GitHub
parent c9746d43c9
commit f3ba1babfc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 13 deletions

View File

@ -73,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
public override async Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
// Messages are dispatched sequentially and will stop other messages from being processed until they complete.
// Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run.
@ -81,18 +81,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
switch (hubMessage)
{
case InvocationBindingFailureMessage bindingFailureMessage:
await ProcessBindingFailure(connection, bindingFailureMessage);
break;
return ProcessBindingFailure(connection, bindingFailureMessage);
case InvocationMessage invocationMessage:
Log.ReceivedHubInvocation(_logger, invocationMessage);
await ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
break;
return ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
case StreamInvocationMessage streamInvocationMessage:
Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage);
await ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
break;
return ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
@ -118,6 +115,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
}
return Task.CompletedTask;
}
private Task ProcessBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage)
@ -142,19 +141,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return descriptor.ParameterTypes;
}
private async Task ProcessInvocation(HubConnectionContext connection,
private Task ProcessInvocation(HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation)
{
if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor))
{
// Send an error to the client. Then let the normal completion process occur
Log.UnknownHubMethod(_logger, hubMethodInvocationMessage.Target);
await connection.WriteAsync(CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'"));
return connection.WriteAsync(CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask();
}
else
{
await Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation);
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation);
}
}
@ -323,14 +322,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal
hub.Groups = _hubContext.Groups;
}
private async Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
// If there are no policies we don't need to run auth
if (!policies.Any())
{
return true;
return TaskCache.True;
}
return IsHubMethodAuthorizedSlow(provider, principal, policies);
}
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
var authService = provider.GetRequiredService<IAuthorizationService>();
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal static class TaskCache
{
public static readonly Task<bool> True = Task.FromResult(true);
public static readonly Task<bool> False = Task.FromResult(false);
}
}