// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Diagnostics; using System.Linq; using System.Reflection; using System.Security.Claims; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR.Internal { public partial class DefaultHubDispatcher : HubDispatcher where THub : Hub { private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IHubContext _hubContext; private readonly ILogger> _logger; private readonly bool _enableDetailedErrors; public DefaultHubDispatcher(IServiceScopeFactory serviceScopeFactory, IHubContext hubContext, IOptions> hubOptions, IOptions globalHubOptions, ILogger> logger) { _serviceScopeFactory = serviceScopeFactory; _hubContext = hubContext; _enableDetailedErrors = hubOptions.Value.EnableDetailedErrors ?? globalHubOptions.Value.EnableDetailedErrors ?? false; _logger = logger; DiscoverHubMethods(); } public override async Task OnConnectedAsync(HubConnectionContext connection) { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnConnectedAsync(); } finally { hubActivator.Release(hub); } } } public override async Task OnDisconnectedAsync(HubConnectionContext connection, Exception exception) { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnDisconnectedAsync(exception); } finally { hubActivator.Release(hub); } } } public override Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage) { // Messages are dispatched sequentially and will stop other messages from being processed until they complete. // Streaming methods will run sequentially until they start streaming, then they will fire-and-forget allowing other messages to run. switch (hubMessage) { case InvocationBindingFailureMessage bindingFailureMessage: return ProcessInvocationBindingFailure(connection, bindingFailureMessage); case StreamBindingFailureMessage bindingFailureMessage: return ProcessStreamBindingFailure(connection, bindingFailureMessage); case InvocationMessage invocationMessage: Log.ReceivedHubInvocation(_logger, invocationMessage); return ProcessInvocation(connection, invocationMessage, isStreamResponse: false); case StreamInvocationMessage streamInvocationMessage: Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage); return ProcessInvocation(connection, streamInvocationMessage, isStreamResponse: true); case CancelInvocationMessage cancelInvocationMessage: // Check if there is an associated active stream and cancel it if it exists. // The cts will be removed when the streaming method completes executing if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts)) { Log.CancelStream(_logger, cancelInvocationMessage.InvocationId); cts.Cancel(); } else { // Stream can be canceled on the server while client is canceling stream. Log.UnexpectedCancel(_logger); } break; case PingMessage _: connection.StartClientTimeout(); break; case StreamDataMessage streamItem: Log.ReceivedStreamItem(_logger, streamItem); return ProcessStreamItem(connection, streamItem); case StreamCompleteMessage streamCompleteMessage: // closes channels, removes from Lookup dict // user's method can see the channel is complete and begin wrapping up Log.CompletingStream(_logger, streamCompleteMessage); connection.StreamTracker.Complete(streamCompleteMessage); break; // Other kind of message we weren't expecting default: Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName); throw new NotSupportedException($"Received unsupported message: {hubMessage}"); } return Task.CompletedTask; } private Task ProcessInvocationBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage) { Log.FailedInvokingHubMethod(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException); var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage); } private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage) { var errorString = ErrorMessageHelper.BuildErrorMessage( $"Failed to bind Stream Item arguments to proper type.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); var message = new StreamCompleteMessage(bindingFailureMessage.Id, errorString); Log.ClosingStreamWithBindingError(_logger, message); connection.StreamTracker.Complete(message); return Task.CompletedTask; } private Task ProcessStreamItem(HubConnectionContext connection, StreamDataMessage message) { Log.ReceivedStreamItem(_logger, message); return connection.StreamTracker.ProcessItem(message); } private Task ProcessInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse) { if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor)) { // Send an error to the client. Then let the normal completion process occur Log.UnknownHubMethod(_logger, hubMethodInvocationMessage.Target); return connection.WriteAsync(CompletionMessage.WithError( hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'")).AsTask(); } else { bool isStreamCall = descriptor.HasStreamingParameters; if (isStreamResponse && isStreamCall) { throw new NotSupportedException("Streaming responses for streaming uploads are not supported."); } return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall); } } private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall) { var methodExecutor = descriptor.MethodExecutor; var disposeScope = true; var scope = _serviceScopeFactory.CreateScope(); IHubActivator hubActivator = null; THub hub = null; try { if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies)) { Log.HubMethodNotAuthorized(_logger, hubMethodInvocationMessage.Target); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized"); return; } if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection)) { return; } hubActivator = scope.ServiceProvider.GetRequiredService>(); hub = hubActivator.Create(); if (isStreamCall) { // swap out placeholders for channels var args = hubMethodInvocationMessage.Arguments; for (int i = 0; i < args.Length; i++) { var placeholder = args[i] as StreamPlaceholder; if (placeholder == null) { continue; } Log.StartingParameterStream(_logger, placeholder.StreamId); var itemType = methodExecutor.MethodParameters[i].ParameterType.GetGenericArguments()[0]; args[i] = connection.StreamTracker.AddStream(placeholder.StreamId, itemType); } } try { InitializeHub(hub, connection); Task invocation = null; if (isStreamResponse) { var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts)) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>."); return; } Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts); } else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { // Send Async, no response expected invocation = ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); } else { // Invoke Async, one reponse expected async Task ExecuteInvocation() { var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } invocation = ExecuteInvocation(); } if (isStreamCall || isStreamResponse) { // don't await streaming invocations // leave them running in the background, allowing dispatcher to process other messages between streaming items disposeScope = false; } else { // complete the non-streaming calls now await invocation; } } catch (TargetInvocationException ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex.InnerException, _enableDetailedErrors)); } catch (Exception ex) { Log.FailedInvokingHubMethod(_logger, hubMethodInvocationMessage.Target, ex); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); } } finally { if (disposeScope) { hubActivator?.Release(hub); scope.Dispose(); } } } private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, IServiceScope scope, IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts) { string error = null; using (scope) { try { while (await enumerator.MoveNextAsync()) { // Send the stream item await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); } } catch (ChannelClosedException ex) { // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors); } catch (Exception ex) { // If the streaming method was canceled we don't want to send a HubException message - this is not an error case if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) && cts.IsCancellationRequested)) { error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); } } finally { (enumerator as IDisposable)?.Dispose(); hubActivator.Release(hub); // Dispose the linked CTS for the stream. streamCts.Dispose(); await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) { cts.Dispose(); } } } } private static async Task ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments) { if (methodExecutor.IsMethodAsync) { if (methodExecutor.MethodReturnType == typeof(Task)) { await (Task)methodExecutor.Execute(hub, arguments); } else { return await methodExecutor.ExecuteAsync(hub, arguments); } } else { return methodExecutor.Execute(hub, arguments); } return null; } private async Task SendInvocationError(string invocationId, HubConnectionContext connection, string errorMessage) { if (string.IsNullOrEmpty(invocationId)) { return; } await connection.WriteAsync(CompletionMessage.WithError(invocationId, errorMessage)); } private void InitializeHub(THub hub, HubConnectionContext connection) { hub.Clients = new HubCallerClients(_hubContext.Clients, connection.ConnectionId); hub.Context = new DefaultHubCallerContext(connection); hub.Groups = _hubContext.Groups; } private Task IsHubMethodAuthorized(IServiceProvider provider, ClaimsPrincipal principal, IList policies) { // If there are no policies we don't need to run auth if (!policies.Any()) { return TaskCache.True; } return IsHubMethodAuthorizedSlow(provider, principal, policies); } private static async Task IsHubMethodAuthorizedSlow(IServiceProvider provider, ClaimsPrincipal principal, IList policies) { var authService = provider.GetRequiredService(); var policyProvider = provider.GetRequiredService(); var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies); // AuthorizationPolicy.CombineAsync only returns null if there are no policies and we check that above Debug.Assert(authorizePolicy != null); var authorizationResult = await authService.AuthorizeAsync(principal, authorizePolicy); // Only check authorization success, challenge or forbid wouldn't make sense from a hub method invocation return authorizationResult.Succeeded; } private async Task ValidateInvocationMode(HubMethodDescriptor hubMethodDescriptor, bool isStreamedInvocation, HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection) { if (hubMethodDescriptor.IsStreamable && !isStreamedInvocation) { // Non-null/empty InvocationId? Blocking if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.")); } return false; } if (!hubMethodDescriptor.IsStreamable && isStreamedInvocation) { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.")); return false; } return true; } private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, out CancellationTokenSource streamCts) { if (result != null) { if (hubMethodDescriptor.IsChannel) { streamCts = CreateCancellation(); enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); return true; } } streamCts = null; enumerator = null; return false; CancellationTokenSource CreateCancellation() { var userCts = new CancellationTokenSource(); connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts); return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token); } } private void DiscoverHubMethods() { var hubType = typeof(THub); var hubTypeInfo = hubType.GetTypeInfo(); var hubName = hubType.Name; foreach (var methodInfo in HubReflectionHelper.GetHubMethods(hubType)) { var methodName = methodInfo.GetCustomAttribute()?.Name ?? methodInfo.Name; if (_methods.ContainsKey(methodName)) { throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported."); } var executor = ObjectMethodExecutor.Create(methodInfo, hubTypeInfo); var authorizeAttributes = methodInfo.GetCustomAttributes(inherit: true); _methods[methodName] = new HubMethodDescriptor(executor, authorizeAttributes); Log.HubMethodBound(_logger, hubName, methodName); } } public override IReadOnlyList GetParameterTypes(string methodName) { if (!_methods.TryGetValue(methodName, out var descriptor)) { throw new HubException("Method does not exist."); } return descriptor.ParameterTypes; } } }