// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Linq; using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { public class HubEndPoint : HubEndPoint where THub : Hub { public HubEndPoint(HubLifetimeManager lifetimeManager, IHubContext hubContext, InvocationAdapterRegistry registry, ILogger> logger, IServiceScopeFactory serviceScopeFactory) : base(lifetimeManager, hubContext, registry, logger, serviceScopeFactory) { } } public class HubEndPoint : EndPoint, IInvocationBinder where THub : Hub { private readonly Dictionary>> _callbacks = new Dictionary>>(StringComparer.OrdinalIgnoreCase); private readonly Dictionary _paramTypes = new Dictionary(); private readonly HubLifetimeManager _lifetimeManager; private readonly IHubContext _hubContext; private readonly ILogger> _logger; private readonly InvocationAdapterRegistry _registry; private readonly IServiceScopeFactory _serviceScopeFactory; public HubEndPoint(HubLifetimeManager lifetimeManager, IHubContext hubContext, InvocationAdapterRegistry registry, ILogger> logger, IServiceScopeFactory serviceScopeFactory) { _lifetimeManager = lifetimeManager; _hubContext = hubContext; _registry = registry; _logger = logger; _serviceScopeFactory = serviceScopeFactory; DiscoverHubMethods(); } public override async Task OnConnectedAsync(Connection connection) { try { await _lifetimeManager.OnConnectedAsync(connection); await RunHubAsync(connection); } finally { await _lifetimeManager.OnDisconnectedAsync(connection); } } private async Task RunHubAsync(Connection connection) { await HubOnConnectedAsync(connection); try { await DispatchMessagesAsync(connection); } catch (Exception ex) { _logger.LogError(0, ex, "Error when processing requests."); await HubOnDisconnectedAsync(connection, ex); throw; } await HubOnDisconnectedAsync(connection, null); } private async Task HubOnConnectedAsync(Connection connection) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnConnectedAsync(); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnConnectedAsync on hub."); throw; } } private async Task HubOnDisconnectedAsync(Connection connection, Exception exception) { try { using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); await hub.OnDisconnectedAsync(exception); } finally { hubActivator.Release(hub); } } } catch (Exception ex) { _logger.LogError(0, ex, "Error when invoking OnDisconnectedAsync on hub."); throw; } } private async Task DispatchMessagesAsync(Connection connection) { var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get("formatType")); // We use these for error handling. Since we dispatch multiple hub invocations // in parallel, we need a way to communicate failure back to the main processing loop. The // cancellation token is used to stop reading from the channel, the tcs // is used to get the exception so we can bubble it up the stack var cts = new CancellationTokenSource(); var tcs = new TaskCompletionSource(); try { while (await connection.Transport.Input.WaitToReadAsync(cts.Token)) { Message incomingMessage; while (connection.Transport.Input.TryRead(out incomingMessage)) { InvocationDescriptor invocationDescriptor; using (incomingMessage) { var inputStream = new MemoryStream(incomingMessage.Payload.Buffer.ToArray()); // TODO: Handle receiving InvocationResultDescriptor invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor; } // Is there a better way of detecting that a connection was closed? if (invocationDescriptor == null) { break; } if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Received hub invocation: {invocation}", invocationDescriptor); } // Don't wait on the result of execution, continue processing other // incoming messages on this connection. var ignore = ProcessInvocation(connection, invocationAdapter, invocationDescriptor, cts, tcs); } } } catch (OperationCanceledException) { // Await the task so the exception bubbles up to the caller await tcs.Task; } } private async Task ProcessInvocation(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor, CancellationTokenSource cts, TaskCompletionSource tcs) { try { // If an unexpected exception occurs then we want to kill the entire connection // by ending the processing loop await Execute(connection, invocationAdapter, invocationDescriptor); } catch (Exception ex) { // Set the exception on the task completion source tcs.TrySetException(ex); // Cancel reading operation cts.Cancel(); } } private async Task Execute(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor invocationDescriptor) { InvocationResultDescriptor result; Func> callback; if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback)) { result = await callback(connection, invocationDescriptor); } else { // If there's no method then return a failed response for this request result = new InvocationResultDescriptor { Id = invocationDescriptor.Id, Error = $"Unknown hub method '{invocationDescriptor.Method}'" }; _logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method); } // TODO: Pool memory var outStream = new MemoryStream(); await invocationAdapter.WriteMessageAsync(result, outStream); var buffer = ReadableBuffer.Create(outStream.ToArray()).Preserve(); var outMessage = new Message(buffer, Format.Text, endOfMessage: true); while (await connection.Transport.Output.WaitToWriteAsync()) { if (connection.Transport.Output.TryWrite(outMessage)) { break; } } } private void InitializeHub(THub hub, Connection connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); hub.Groups = new GroupManager(connection, _lifetimeManager); } private void DiscoverHubMethods() { var type = typeof(THub); foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => IsHubMethod(m))) { var methodName = methodInfo.Name; if (_callbacks.ContainsKey(methodName)) { throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported."); } var parameters = methodInfo.GetParameters(); _paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray(); if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Hub method '{methodName}' is bound", methodName); } _callbacks[methodName] = async (connection, invocationDescriptor) => { var invocationResult = new InvocationResultDescriptor() { Id = invocationDescriptor.Id }; using (var scope = _serviceScopeFactory.CreateScope()) { var hubActivator = scope.ServiceProvider.GetRequiredService>(); var hub = hubActivator.Create(); try { InitializeHub(hub, connection); var result = methodInfo.Invoke(hub, invocationDescriptor.Arguments); var resultTask = result as Task; if (resultTask != null) { await resultTask; if (methodInfo.ReturnType.GetTypeInfo().IsGenericType) { var property = resultTask.GetType().GetProperty("Result"); invocationResult.Result = property?.GetValue(resultTask); } } else { invocationResult.Result = result; } } catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); invocationResult.Error = ex.InnerException.Message; } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); invocationResult.Error = ex.Message; } finally { hubActivator.Release(hub); } } return invocationResult; }; }; } private static bool IsHubMethod(MethodInfo methodInfo) { // TODO: Add more checks if (!methodInfo.IsPublic || methodInfo.IsSpecialName) { return false; } var baseDefinition = methodInfo.GetBaseDefinition().DeclaringType; var baseType = baseDefinition.GetTypeInfo().IsGenericType ? baseDefinition.GetGenericTypeDefinition() : baseDefinition; if (typeof(Hub<>) == baseType) { return false; } return true; } Type IInvocationBinder.GetReturnType(string invocationId) { return typeof(object); } Type[] IInvocationBinder.GetParameterTypes(string methodName) { Type[] types; if (!_paramTypes.TryGetValue(methodName, out types)) { return Type.EmptyTypes; } return types; } } }