// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Client { public class HubConnection { private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly IConnection _connection; private readonly IHubProtocol _protocol; private readonly HubBinder _binder; private HttpClient _httpClient; private readonly object _pendingCallsLock = new object(); private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(); private int _nextId = 0; public event Action Connected { add { _connection.Connected += value; } remove { _connection.Connected -= value; } } public event Action Closed { add { _connection.Closed += value; } remove { _connection.Closed -= value; } } public HubConnection(Uri url) : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), null) { } public HubConnection(Uri url, ILoggerFactory loggerFactory) : this(new Connection(url), new JsonHubProtocol(new JsonSerializer()), loggerFactory) { } // These are only really needed for tests now... public HubConnection(IConnection connection, ILoggerFactory loggerFactory) : this(connection, new JsonHubProtocol(new JsonSerializer()), loggerFactory) { } public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) { if (connection == null) { throw new ArgumentNullException(nameof(connection)); } if (protocol == null) { throw new ArgumentNullException(nameof(protocol)); } _connection = connection; _binder = new HubBinder(this); _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _connection.Received += OnDataReceived; _connection.Closed += Shutdown; } public Task StartAsync() => StartAsync(TransportType.All, httpClient: null); public Task StartAsync(HttpClient httpClient) => StartAsync(TransportType.All, httpClient: httpClient); public Task StartAsync(TransportType transportType) => StartAsync(transportType, httpClient: null); public async Task StartAsync(TransportType transportType, HttpClient httpClient) { if (httpClient == null) { // We are creating the client so store it to be able to dispose _httpClient = httpClient = new HttpClient(); } await _connection.StartAsync(new DefaultTransportFactory(transportType, _loggerFactory, httpClient), httpClient); } public async Task StartAsync(ITransportFactory transportFactory, HttpClient httpClient) { await _connection.StartAsync(transportFactory, httpClient); } public async Task DisposeAsync() { await _connection.DisposeAsync(); _httpClient?.Dispose(); } // TODO: Client return values/tasks? public void On(string methodName, Type[] parameterTypes, Action handler) { var invocationHandler = new InvocationHandler(parameterTypes, handler); _handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler); } public async Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { ThrowIfConnectionTerminated(); _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, returnType.AssemblyQualifiedName, args.Length); // Create an invocation descriptor. Client invocations are always blocking var invocationMessage = new InvocationMessage(GetNextId(), nonBlocking: false, target: methodName, arguments: args); // I just want an excuse to use 'irq' as a variable name... _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); var irq = new InvocationRequest(cancellationToken, returnType, invocationMessage.InvocationId, _loggerFactory); AddInvocation(irq); // Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow) if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", args.Select(a => a.GetType().FullName)); _logger.LogTrace("Issuing Invocation '{invocationId}': {returnType} {methodName}({args})", invocationMessage.InvocationId, returnType.FullName, methodName, argsList); } try { var payload = await _protocol.WriteToArrayAsync(invocationMessage); _logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId); await _connection.SendAsync(payload, _protocol.MessageType, cancellationToken); _logger.LogInformation("Sending Invocation '{invocationId}' complete", invocationMessage.InvocationId); } catch (Exception ex) { _logger.LogError(0, ex, "Sending Invocation '{invocationId}' failed", invocationMessage.InvocationId); irq.Fail(ex); TryRemoveInvocation(invocationMessage.InvocationId, out _); } // Return the completion task. It will be completed by ReceiveMessages when the response is received. return await irq.Completion; } private void OnDataReceived(byte[] data, MessageType messageType) { var message = _protocol.ParseMessage(data, _binder); InvocationRequest irq; switch (message) { case InvocationMessage invocation: if (_logger.IsEnabled(LogLevel.Trace)) { var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName)); _logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList); } DispatchInvocation(invocation, _connectionActive.Token); break; case CompletionMessage completion: if (!TryRemoveInvocation(completion.InvocationId, out irq)) { _logger.LogWarning("Dropped unsolicited Completion message for invocation '{invocationId}'", completion.InvocationId); return; } DispatchInvocationCompletion(completion, irq); irq.Dispose(); break; case StreamItemMessage streamItem: // Complete the invocation with an error, we don't support streaming (yet) if (!TryRemoveInvocation(streamItem.InvocationId, out irq)) { _logger.LogWarning("Dropped unsolicited Stream Item message for invocation '{invocationId}'", streamItem.InvocationId); return; } irq.Fail(new NotSupportedException("Streaming method results are not supported")); break; default: throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); } } private void Shutdown(Exception ex = null) { _logger.LogTrace("Shutting down connection"); if (ex != null) { _logger.LogError(ex, "Connection is shutting down due to an error"); } lock (_pendingCallsLock) { // We cancel inside the lock to make sure everyone who was part-way through registering an invocation // completes. This also ensures that nobody will add things to _pendingCalls after we leave this block // because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock) _connectionActive.Cancel(); foreach (var outstandingCall in _pendingCalls.Values) { _logger.LogTrace("Removing pending call {invocationId}", outstandingCall.InvocationId); if (ex != null) { outstandingCall.Fail(ex); } outstandingCall.Dispose(); } _pendingCalls.Clear(); } } private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler)) { _logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target); return; } // TODO: Return values // TODO: Dispatch to a sync context to ensure we aren't blocking this loop. handler.Handler(invocation.Arguments); } private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq) { _logger.LogTrace("Received Completion for Invocation #{invocationId}", completion.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.LogTrace("Cancelling dispatch of Completion message for Invocation {invocationId}. The invocation was cancelled.", irq.InvocationId); } else { if (!string.IsNullOrEmpty(completion.Error)) { irq.Fail(new HubException(completion.Error)); } else { irq.Complete(completion.Result); } } } private void ThrowIfConnectionTerminated() { if (_connectionActive.Token.IsCancellationRequested) { _logger.LogError("Invoke was called after the connection was terminated"); throw new InvalidOperationException("Connection has been terminated."); } } private string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); private void AddInvocation(InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); if (_pendingCalls.ContainsKey(irq.InvocationId)) { _logger.LogCritical("Invocation ID '{invocationId}' is already in use.", irq.InvocationId); throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); } else { _pendingCalls.Add(irq.InvocationId, irq); } } } private bool TryGetInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); return _pendingCalls.TryGetValue(invocationId, out irq); } } private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(); if (_pendingCalls.TryGetValue(invocationId, out irq)) { _pendingCalls.Remove(invocationId); return true; } else { return false; } } } private class HubBinder : IInvocationBinder { private HubConnection _connection; public HubBinder(HubConnection connection) { _connection = connection; } public Type GetReturnType(string invocationId) { if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq)) { _connection._logger.LogError("Unsolicited response received for invocation '{invocationId}'", invocationId); return null; } return irq.ResultType; } public Type[] GetParameterTypes(string methodName) { if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler)) { _connection._logger.LogWarning("Failed to find handler for '{target}' method", methodName); return Type.EmptyTypes; } return handler.ParameterTypes; } } private struct InvocationHandler { public Action Handler { get; } public Type[] ParameterTypes { get; } public InvocationHandler(Type[] parameterTypes, Action handler) { Handler = handler; ParameterTypes = parameterTypes; } } private class InvocationRequest : IDisposable { private readonly TaskCompletionSource _completionSource = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly CancellationTokenRegistration _cancellationTokenRegistration; private readonly ILogger _logger; public Type ResultType { get; } public CancellationToken CancellationToken { get; } public string InvocationId { get; } public Task Completion => _completionSource.Task; public InvocationRequest(CancellationToken cancellationToken, Type resultType, string invocationId, ILoggerFactory loggerFactory) { _logger = loggerFactory.CreateLogger(); _cancellationTokenRegistration = cancellationToken.Register(() => _completionSource.TrySetCanceled()); InvocationId = invocationId; CancellationToken = cancellationToken; ResultType = resultType; _logger.LogTrace("Invocation {invocationId} created", InvocationId); } public void Fail(Exception exception) { _logger.LogTrace("Invocation {invocationId} marked as failed", InvocationId); _completionSource.TrySetException(exception); } public void Complete(object result) { _logger.LogTrace("Invocation {invocationId} marked as completed", InvocationId); _completionSource.TrySetResult(result); } public void Dispose() { _logger.LogTrace("Invocation {invocationId} disposed", InvocationId); // Just in case it hasn't already been completed _completionSource.TrySetCanceled(); _cancellationTokenRegistration.Dispose(); } } } }