aspnetcore/src/Microsoft.AspNetCore.Signal.../Internal/DefaultHubDispatcher.cs

506 lines
23 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Reflection;
using System.Security.Claims;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public partial class DefaultHubDispatcher<THub> : HubDispatcher<THub> where THub : Hub
{
private readonly Dictionary<string, HubMethodDescriptor> _methods = new Dictionary<string, HubMethodDescriptor>(StringComparer.OrdinalIgnoreCase);
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly IHubContext<THub> _hubContext;
private readonly ILogger<HubDispatcher<THub>> _logger;
private readonly bool _enableDetailedErrors;
public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext<THub> hubContext, IOptions<HubOptions<THub>> hubOptions,
IOptions<HubOptions> globalHubOptions, ILogger<DefaultHubDispatcher<THub>> logger)
{
_serviceScopeFactory = serviceScopeFactory;
_hubContext = hubContext;
_enableDetailedErrors = hubOptions.Value.EnableDetailedErrors ?? globalHubOptions.Value.EnableDetailedErrors ?? false;
_logger = logger;
DiscoverHubMethods();
}
public override async Task OnConnectedAsync(HubConnectionContext connection)
{
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
var hub = hubActivator.Create();
try
{
InitializeHub(hub, connection);
await hub.OnConnectedAsync();
}
finally
{
hubActivator.Release(hub);
}
}
}
public override async Task OnDisconnectedAsync(HubConnectionContext connection, Exception exception)
{
using (var scope = _serviceScopeFactory.CreateScope())
{
var hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
var hub = hubActivator.Create();
try
{
InitializeHub(hub, connection);
await hub.OnDisconnectedAsync(exception);
}
finally
{
hubActivator.Release(hub);
}
}
}
public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
// Messages are dispatched sequentially and will stop other messages from being processed until they complete.
// Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run.
switch (hubMessage)
{
case InvocationBindingFailureMessage bindingFailureMessage:
return ProcessInvocationBindingFailure(connection, bindingFailureMessage);
case StreamBindingFailureMessage bindingFailureMessage:
return ProcessStreamBindingFailure(connection, bindingFailureMessage);
case InvocationMessage invocationMessage:
Log.ReceivedHubInvocation(_logger, invocationMessage);
return ProcessInvocation(connection, invocationMessage, isStreamResponse: false);
case StreamInvocationMessage streamInvocationMessage:
Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage);
return ProcessInvocation(connection, streamInvocationMessage, isStreamResponse: true);
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
// The cts will be removed when the streaming method completes executing
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts))
{
Log.CancelStream(_logger, cancelInvocationMessage.InvocationId);
cts.Cancel();
}
else
{
// Stream can be canceled on the server while client is canceling stream.
Log.UnexpectedCancel(_logger);
}
break;
case PingMessage _:
connection.StartClientTimeout();
break;
case StreamDataMessage streamItem:
Log.ReceivedStreamItem(_logger, streamItem);
return ProcessStreamItem(connection, streamItem);
case StreamCompleteMessage streamCompleteMessage:
// closes channels, removes from Lookup dict
// user's method can see the channel is complete and begin wrapping up
Log.CompletingStream(_logger, streamCompleteMessage);
connection.StreamTracker.Complete(streamCompleteMessage);
break;
// Other kind of message we weren't expecting
default:
Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName);
throw new NotSupportedException($"Received unsupported message: {hubMessage}");
}
return Task.CompletedTask;
}
private Task ProcessInvocationBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage)
{
Log.FailedInvokingHubMethod(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException);
var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage);
}
private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage)
{
var errorString = ErrorMessageHelper.BuildErrorMessage(
$"Failed to bind Stream Item arguments to proper type.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
var message = new StreamCompleteMessage(bindingFailureMessage.Id, errorString);
Log.ClosingStreamWithBindingError(_logger, message);
connection.StreamTracker.Complete(message);
return Task.CompletedTask;
}
private Task ProcessStreamItem(HubConnectionContext connection, StreamDataMessage message)
{
Log.ReceivedStreamItem(_logger, message);
return connection.StreamTracker.ProcessItem(message);
}
private Task ProcessInvocation(HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse)
{
if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor))
{
// Send an error to the client. Then let the normal completion process occur
Log.UnknownHubMethod(_logger, hubMethodInvocationMessage.Target);
return connection.WriteAsync(CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask();
}
else
{
bool isStreamCall = descriptor.HasStreamingParameters;
if (isStreamResponse && isStreamCall)
{
throw new NotSupportedException("Streaming responses for streaming uploads are not supported.");
}
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
}
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall)
{
var methodExecutor = descriptor.MethodExecutor;
var disposeScope = true;
var scope = _serviceScopeFactory.CreateScope();
IHubActivator<THub> hubActivator = null;
THub hub = null;
try
{
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
{
Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized");
return;
}
if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection))
{
return;
}
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
hub = hubActivator.Create();
if (isStreamCall)
{
// swap out placeholders for channels
var args = hubMethodInvocationMessage.Arguments;
for (int i = 0; i < args.Length; i++)
{
var placeholder = args[i] as StreamPlaceholder;
if (placeholder == null)
{
continue;
}
Log.StartingParameterStream(_logger, placeholder.StreamId);
var itemType = methodExecutor.MethodParameters[i].ParameterType.GetGenericArguments()[0];
args[i] = connection.StreamTracker.AddStream(placeholder.StreamId, itemType);
}
}
try
{
InitializeHub(hub, connection);
Task invocation = null;
if (isStreamResponse)
{
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
return;
}
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
}
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
// Send Async, no response expected
invocation = ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
}
else
{
// Invoke Async, one reponse expected
async Task ExecuteInvocation()
{
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
}
invocation = ExecuteInvocation();
}
if (isStreamCall || isStreamResponse)
{
// don't await streaming invocations
// leave them running in the background, allowing dispatcher to process other messages between streaming items
disposeScope = false;
}
else
{
// complete the non-streaming calls now
await invocation;
}
}
catch (TargetInvocationException ex)
{
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex.InnerException, _enableDetailedErrors));
}
catch (Exception ex)
{
Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors));
}
}
finally
{
if (disposeScope)
{
hubActivator?.Release(hub);
scope.Dispose();
}
}
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts)
{
string error = null;
using (scope)
{
try
{
while (await enumerator.MoveNextAsync())
{
// Send the stream item
await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current));
}
}
catch (ChannelClosedException ex)
{
// If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method
error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors);
}
catch (Exception ex)
{
// If the streaming method was canceled we don't want to send a HubException message - this is not an error case
if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts)
&& cts.IsCancellationRequested))
{
error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors);
}
}
finally
{
(enumerator as IDisposable)?.Dispose();
hubActivator.Release(hub);
// Dispose the linked CTS for the stream.
streamCts.Dispose();
await connection.WriteAsync(CompletionMessage.WithError(invocationId, error));
if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts))
{
cts.Dispose();
}
}
}
}
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
{
if (methodExecutor.IsMethodAsync)
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
await (Task)methodExecutor.Execute(hub, arguments);
}
else
{
return await methodExecutor.ExecuteAsync(hub, arguments);
}
}
else
{
return methodExecutor.Execute(hub, arguments);
}
return null;
}
private async Task SendInvocationError(string invocationId,
HubConnectionContext connection, string errorMessage)
{
if (string.IsNullOrEmpty(invocationId))
{
return;
}
await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage));
}
private void InitializeHub(THub hub, HubConnectionContext connection)
{
hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId);
hub.Context = new DefaultHubCallerContext(connection);
hub.Groups = _hubContext.Groups;
}
private Task<bool> IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
// If there are no policies we don't need to run auth
if (!policies.Any())
{
return TaskCache.True;
}
return IsHubMethodAuthorizedSlow(provider, principal, policies);
}
private static async Task<bool> IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList<IAuthorizeData> policies)
{
var authService = provider.GetRequiredService<IAuthorizationService>();
var policyProvider = provider.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
// AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above
Debug.Assert(authorizePolicy != null);
var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy);
// Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation
return authorizationResult.Succeeded;
}
private async Task<bool> ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation)
{
// Non-null/empty InvocationId? Blocking
if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation."));
}
return false;
}
if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation)
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation."));
return false;
}
return true;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsChannel)
{
streamCts = CreateCancellation();
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}
streamCts = null;
enumerator = null;
return false;
CancellationTokenSource CreateCancellation()
{
var userCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
}
}
private void DiscoverHubMethods()
{
var hubType = typeof(THub);
var hubTypeInfo = hubType.GetTypeInfo();
var hubName = hubType.Name;
foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType))
{
var methodName =
methodInfo.GetCustomAttribute<HubMethodNameAttribute>()?.Name ??
methodInfo.Name;
if (_methods.ContainsKey(methodName))
{
throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported.");
}
var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo);
var authorizeAttributes = methodInfo.GetCustomAttributes<AuthorizeAttribute>(inherit: true);
_methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes);
Log.HubMethodBound(_logger, hubName, methodName);
}
}
public override IReadOnlyList<Type> GetParameterTypes(string methodName)
{
if (!_methods.TryGetValue(methodName, out var descriptor))
{
throw new HubException("Method does not exist.");
}
return descriptor.ParameterTypes;
}
}
}