// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Client { public class HubConnection { private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly IConnection _connection; private readonly IHubProtocol _protocol; private readonly HubBinder _binder; private readonly object _pendingCallsLock = new object(); private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(); private int _nextId = 0; public event Action Connected { add { _connection.Connected += value; } remove { _connection.Connected -= value; } } public event Action Closed { add { _connection.Closed += value; } remove { _connection.Closed -= value; } } public HubConnection(IConnection connection) : this(connection, new JsonHubProtocol(new JsonSerializer()), null) { } // These are only really needed for tests now... public HubConnection(IConnection connection, ILoggerFactory loggerFactory) : this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory) { } public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) { if (connection == null) { throw new ArgumentNullException(nameof(connection)); } if (protocol == null) { throw new ArgumentNullException(nameof(protocol)); } _connection = connection; _binder = new HubBinder(this); _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _connection.Received += OnDataReceived; _connection.Closed += Shutdown; } public async Task StartAsync() { await _connection.StartAsync(); } public async Task DisposeAsync() { await _connection.DisposeAsync(); } // TODO: Client return values/tasks? public void On(string methodName, Type[] parameterTypes, Action handler) { var invocationHandler = new InvocationHandler(parameterTypes, handler); _handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler); } public ReadableChannel Stream(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Stream(cancellationToken, returnType, GetNextId(), _loggerFactory, out var channel); InvokeCore(methodName, irq, args); return channel; } public Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task); InvokeCore(methodName, irq, args); return task; } private void InvokeCore(string methodName, InvocationRequest irq, object[] args) { ThrowIfConnectionTerminated(); _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, irq.ResultType.AssemblyQualifiedName, args.Length); // Create an invocation descriptor. Client invocations are always blocking var invocationMessage = new InvocationMessage(irq.InvocationId, nonBlocking: false, target: methodName, arguments: args); // I just want an excuse to use 'irq' as a variable name... _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); AddInvocation(irq); // Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow) if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", args.Select(a => a.GetType().FullName)); _logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, irq.ResultType.FullName, methodName, argsList); } // We don't need to wait for this to complete. It will signal back to the invocation request. _ = SendInvocation(invocationMessage, irq); } private async Task SendInvocation(InvocationMessage invocationMessage, InvocationRequest irq) { try { var payload = _protocol.WriteToArray(invocationMessage); _logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId); await _connection.SendAsync(payload, irq.CancellationToken); _logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId); } catch (Exception ex) { _logger.LogError(0, ex, "Sending Invocation '{invocationId}' failed", invocationMessage.InvocationId); irq.Fail(ex); TryRemoveInvocation(invocationMessage.InvocationId, out _); } } private void OnDataReceived(byte[] data) { if (_protocol.TryParseMessages(data, _binder, out var messages)) { foreach (var message in messages) { InvocationRequest irq; switch (message) { case InvocationMessage invocation: if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName)); _logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList); } DispatchInvocation(invocation, _connectionActive.Token); break; case CompletionMessage completion: if (!TryRemoveInvocation(completion.InvocationId, out irq)) { _logger.LogWarning("Dropped unsolicited Completion message for invocation '{invocationId}'", completion.InvocationId); return; } DispatchInvocationCompletion(completion, irq); irq.Dispose(); break; case StreamItemMessage streamItem: // Complete the invocation with an error, we don't support streaming (yet) if (!TryGetInvocation(streamItem.InvocationId, out irq)) { _logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId); return; } DispatchInvocationStreamItemAsync(streamItem, irq); break; default: throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); } } } } private void Shutdown(Exception ex = null) { _logger.LogTrace("Shutting down connection"); if (ex != null) { _logger.LogError(ex, "Connection is shutting down due to an error"); } lock (_pendingCallsLock) { // We cancel inside the lock to make sure everyone who was part-way through registering an invocation // completes. This also ensures that nobody will add things to _pendingCalls after we leave this block // because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock) _connectionActive.Cancel(); foreach (var outstandingCall in _pendingCalls.Values) { _logger.LogTrace("Removing pending call {invocationId}", outstandingCall.InvocationId); if (ex != null) { outstandingCall.Fail(ex); } outstandingCall.Dispose(); } _pendingCalls.Clear(); } } private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler)) { _logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target); return; } // TODO: Return values // TODO: Dispatch to a sync context to ensure we aren't blocking this loop. handler.Handler(invocation.Arguments); } // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel // and there's nobody to actually wait for us to finish. private async void DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) { _logger.LogTrace("Received StreamItem for Invocation #{invocationId}", streamItem.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.LogTrace("Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId); } else if (!await irq.StreamItem(streamItem.Item)) { _logger.LogWarning("Invocation {invocationId} received stream item after channel was closed.", irq.InvocationId); } } private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq) { _logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.LogTrace("Cancelling dispatch of Completion message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId); } else { if (!string.IsNullOrEmpty(completion.Error)) { irq.Fail(new HubException(completion.Error)); } else { irq.Complete(completion.Result); } } } private void ThrowIfConnectionTerminated() { if (_connectionActive.Token.IsCancellationRequested) { _logger.LogError("Invoke was called after the connection was terminated"); throw new InvalidOperationException("Connection has been terminated."); } } private string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); private void AddInvocation(InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); if (_pendingCalls.ContainsKey(irq.InvocationId)) { _logger.LogCritical("Invocation ID '{invocationId}' is already in use.", irq.InvocationId); throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); } else { _pendingCalls.Add(irq.InvocationId, irq); } } } private bool TryGetInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); return _pendingCalls.TryGetValue(invocationId, out irq); } } private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); if (_pendingCalls.TryGetValue(invocationId, out irq)) { _pendingCalls.Remove(invocationId); return true; } else { return false; } } } private class HubBinder : IInvocationBinder { private HubConnection _connection; public HubBinder(HubConnection connection) { _connection = connection; } public Type GetReturnType(string invocationId) { if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq)) { _connection._logger.LogError("Unsolicited response received for invocation '{invocationId}'", invocationId); return null; } return irq.ResultType; } public Type[] GetParameterTypes(string methodName) { if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler)) { _connection._logger.LogWarning("Failed to find handler for '{target}' method", methodName); return Type.EmptyTypes; } return handler.ParameterTypes; } } private struct InvocationHandler { public Action Handler { get; } public Type[] ParameterTypes { get; } public InvocationHandler(Type[] parameterTypes, Action handler) { Handler = handler; ParameterTypes = parameterTypes; } } } }