// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR { public class HubEndPoint : HubEndPoint where THub : Hub { public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, ILogger> logger, IServiceScopeFactory serviceScopeFactory) : base(lifetimeManager, protocolResolver, hubContext, logger, serviceScopeFactory) { } } public class HubEndPoint : EndPoint, IInvocationBinder where THub : Hub { private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; private readonly ILogger> _logger; private readonly IServiceScopeFactory _serviceScopeFactory; private readonly IHubProtocolResolver _protocolResolver; public HubEndPoint(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IHubContext hubContext, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { _protocolResolver = protocolResolver; _lifetimeManager = lifetimeManager; _hubContext = hubContext; _logger = logger; _serviceScopeFactory = serviceScopeFactory; DiscoverHubMethods(); } public override async Task OnConnectedAsync(ConnectionContext connection) { try { // Resolve the Hub Protocol for the connection and store it in metadata // Other components, outside the Hub, may need to know what protocol is in use // for a particular connection, so we store it here. connection.Metadata[HubConnectionMetadataNames.HubProtocol] = _protocolResolver.GetProtocol(connection); await _lifetimeManager.OnConnectedAsync(connection); await RunHubAsync(connection); } finally { await _lifetimeManager.OnDisconnectedAsync(connection); } } private async Task RunHubAsync(ConnectionContext connection) { await HubOnConnectedAsync(connection); try { await DispatchMessagesAsync(connection); } catch (Exception ex) { _logger.LogError(0, ex, "Error when processing requests."); await HubOnDisconnectedAsync(connection, ex); throw; } await HubOnDisconnectedAsync(connection, null); } private async Task HubOnConnectedAsync(ConnectionContext connection) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnConnectedAsync(); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub."); throw; } } private async Task HubOnDisconnectedAsync(ConnectionContext connection, Exception exception) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnDisconnectedAsync(exception); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnDisconnectedAsync on hub."); throw; } } private async Task DispatchMessagesAsync(ConnectionContext connection) { // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The // cancellation token is used to stop reading from the channel, the tcs // is used to get the exception so we can bubble it up the stack var cts = new CancellationTokenSource(); var completion = new TaskCompletionSource(); var protocol = connection.Metadata.Get(HubConnectionMetadataNames.HubProtocol); try { while (await connection.Transport.Input.WaitToReadAsync(cts.Token)) { while (connection.Transport.Input.TryRead(out var incomingMessage)) { var hubMessage = protocol.ParseMessage(incomingMessage.Payload, this); switch (hubMessage) { case InvocationMessage invocationMessage: if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Received hub invocation: {invocation}", invocationMessage); } // Don't wait on the result of execution, continue processing other // incoming messages on this connection. var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion); break; // Other kind of message we weren't expecting default: _logger.LogError("Received unsupported message of type '{messageType}'", hubMessage.GetType().FullName); throw new NotSupportedException($"Received unsupported message: {hubMessage}"); } } } } catch (OperationCanceledException) { // Await the task so the exception bubbles up to the caller await completion.Task; } } private async Task ProcessInvocation(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage, CancellationTokenSource dispatcherCancellation, TaskCompletionSource dispatcherCompletion) { try { // If an unexpected exception occurs then we want to kill the entire connection // by ending the processing loop await Execute(connection, protocol, invocationMessage); } catch (Exception ex) { // Set the exception on the task completion source dispatcherCompletion.TrySetException(ex); // Cancel reading operation dispatcherCancellation.Cancel(); } } private async Task Execute(ConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage) { HubMethodDescriptor descriptor; if (!_methods.TryGetValue(invocationMessage.Target, out descriptor)) { // Send an error to the client. Then let the normal completion process occur _logger.LogError("Unknown hub method '{method}'", invocationMessage.Target); await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'")); } else { var result = await Invoke(descriptor, connection, invocationMessage); await SendMessageAsync(connection, protocol, result); } } private async Task SendMessageAsync(ConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage) { var payload = await protocol.WriteToArrayAsync(hubMessage); var message = new Message(payload, protocol.MessageType, endOfMessage: true); while (await connection.Transport.Output.WaitToWriteAsync()) { if (connection.Transport.Output.TryWrite(message)) { return; } } // Output is closed. Cancel this invocation completely _logger.LogWarning("Outbound channel was closed while trying to write hub message"); throw new OperationCanceledException("Outbound channel was closed while trying to write hub message"); } private async Task Invoke(HubMethodDescriptor descriptor, ConnectionContext connection, InvocationMessage invocationMessage) { var methodExecutor = descriptor.MethodExecutor; using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); object result = null; if (methodExecutor.IsMethodAsync) { if (methodExecutor.MethodReturnType == typeof(Task)) { await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments); } else { result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments); } } else { result = methodExecutor.Execute(hub, invocationMessage.Arguments); } return CompletionMessage.WithResult(invocationMessage.InvocationId, result); } catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); return CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message); } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); return CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message); } finally { hubActivator.Release(hub); } } } private void InitializeHub(THub hub, ConnectionContext connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); hub.Groups = new GroupManager(connection, _lifetimeManager); } private void DiscoverHubMethods() { var hubType = typeof(THub); foreach (var methodInfo in hubType.GetMethods().Where(m => IsHubMethod(m))) { var methodName = methodInfo.Name; if (_methods.ContainsKey(methodName)) { throw new NotSupportedException($"Duplicate definitions of '{methodName}'. Overloading is not supported."); } var executor = ObjectMethodExecutor.Create(methodInfo, hubType.GetTypeInfo()); _methods[methodName] = new HubMethodDescriptor(executor); if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Hub method '{methodName}' is bound", methodName); } } } private static bool IsHubMethod(MethodInfo methodInfo) { // TODO: Add more checks if (!methodInfo.IsPublic || methodInfo.IsSpecialName) { return false; } var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType; var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition; if (typeof(Hub<>) == baseType) { return false; } return true; } Type IInvocationBinder.GetReturnType(string invocationId) { return typeof(object); } Type[] IInvocationBinder.GetParameterTypes(string methodName) { HubMethodDescriptor descriptor; if (!_methods.TryGetValue(methodName, out descriptor)) { return Type.EmptyTypes; } return descriptor.ParameterTypes; } // REVIEW: We can decide to move this out of here if we want pluggable hub discovery private class HubMethodDescriptor { public HubMethodDescriptor(ObjectMethodExecutor methodExecutor) { MethodExecutor = methodExecutor; ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray(); } public ObjectMethodExecutor MethodExecutor { get; } public Type[] ParameterTypes { get; } } } }