// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO.Pipelines; using System.Linq; using System.Reflection; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { public class HubEndPoint : HubEndPoint where THub : Hub { public HubEndPoint(HubLifetimeManager lifetimeManager, IHubContext hubContext, InvocationAdapterRegistry registry, ILogger> logger, IServiceScopeFactory serviceScopeFactory) : base(lifetimeManager, hubContext, registry, logger, serviceScopeFactory) { } } public class HubEndPoint : StreamingEndPoint, IInvocationBinder where THub : Hub { private readonly Dictionary>> _callbacks = new Dictionary>>(StringComparer.OrdinalIgnoreCase); private readonly Dictionary _paramTypes = new Dictionary(); private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; private readonly ILogger> _logger; private readonly InvocationAdapterRegistry _registry; private readonly IServiceScopeFactory _serviceScopeFactory; public HubEndPoint(HubLifetimeManager lifetimeManager, IHubContext hubContext, InvocationAdapterRegistry registry, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { _lifetimeManager = lifetimeManager; _hubContext = hubContext; _registry = registry; _logger = logger; _serviceScopeFactory = serviceScopeFactory; DiscoverHubMethods(); } public override async Task OnConnectedAsync(StreamingConnection connection) { // TODO: Dispatch from the caller await Task.Yield(); try { await _lifetimeManager.OnConnectedAsync(connection); await RunHubAsync(connection); } finally { await _lifetimeManager.OnDisconnectedAsync(connection); } } private async Task RunHubAsync(StreamingConnection connection) { await HubOnConnectedAsync(connection); try { await DispatchMessagesAsync(connection); } catch (Exception ex) { _logger.LogError(0, ex, "Error when processing requests."); await HubOnDisconnectedAsync(connection, ex); throw; } await HubOnDisconnectedAsync(connection, null); } private async Task HubOnConnectedAsync(StreamingConnection connection) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnConnectedAsync(); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub."); throw; } } private async Task HubOnDisconnectedAsync(StreamingConnection connection, Exception exception) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnDisconnectedAsync(exception); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnDisconnectedAsync on hub."); throw; } } private async Task DispatchMessagesAsync(StreamingConnection connection) { var stream = connection.Transport.GetStream(); var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); while (true) { // TODO: Handle receiving InvocationResultDescriptor var invocationDescriptor = await invocationAdapter.ReadMessageAsync(stream, this) as InvocationDescriptor; // Is there a better way of detecting that a connection was closed? if (invocationDescriptor == null) { break; } if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor); } InvocationResultDescriptor result; Func> callback; if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback)) { result = await callback(connection, invocationDescriptor); } else { // If there's no method then return a failed response for this request result = new InvocationResultDescriptor { Id = invocationDescriptor.Id, Error = $"Unknown hub method '{invocationDescriptor.Method}'" }; _logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method); } await invocationAdapter.WriteMessageAsync(result, stream); } } private void InitializeHub(THub hub, StreamingConnection connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); hub.Groups = new GroupManager(connection, _lifetimeManager); } private void DiscoverHubMethods() { var type = typeof(THub); foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => IsHubMethod(m))) { var methodName = methodInfo.Name; if (_callbacks.ContainsKey(methodName)) { throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported."); } var parameters = methodInfo.GetParameters(); _paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray(); if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Hub method '{methodName}' is bound", methodName); } _callbacks[methodName] = async (connection, invocationDescriptor) => { var invocationResult = new InvocationResultDescriptor() { Id = invocationDescriptor.Id }; using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments); var resultTask = result as Task; if (resultTask != null) { await resultTask; if (methodInfo.ReturnType.GetTypeInfo().IsGenericType) { var property = resultTask.GetType().GetProperty("Result"); invocationResult.Result = property?.GetValue(resultTask); } } else { invocationResult.Result = result; } } catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); invocationResult.Error = ex.InnerException.Message; } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); invocationResult.Error = ex.Message; } finally { hubActivator.Release(hub); } } return invocationResult; }; }; } private static bool IsHubMethod(MethodInfo m) { // TODO: Add more checks return m.IsPublic && !m.IsSpecialName; } Type IInvocationBinder.GetReturnType(string invocationId) { return typeof(object); } Type[] IInvocationBinder.GetParameterTypes(string methodName) { Type[] types; if (!_paramTypes.TryGetValue(methodName, out types)) { throw new InvalidOperationException($"The hub method '{methodName}' could not be resolved."); } return types; } } }