// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Client.Internal; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; namespace Microsoft.AspNetCore.SignalR.Client { public class HubConnection { private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; private readonly IConnection _connection; private readonly IHubProtocol _protocol; private readonly HubBinder _binder; private HubProtocolReaderWriter _protocolReaderWriter; private readonly object _pendingCallsLock = new object(); private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); private readonly ConcurrentDictionary> _handlers = new ConcurrentDictionary>(); private int _nextId = 0; public event Func Closed { add { _connection.Closed += value; } remove { _connection.Closed -= value; } } public HubConnection(IConnection connection, IHubProtocol protocol, ILoggerFactory loggerFactory) { if (connection == null) { throw new ArgumentNullException(nameof(connection)); } if (protocol == null) { throw new ArgumentNullException(nameof(protocol)); } _connection = connection; _binder = new HubBinder(this); _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this); _connection.Closed += Shutdown; } public async Task StartAsync() => await StartAsyncCore().ForceAsync(); private async Task StartAsyncCore() { var transferModeFeature = _connection.Features.Get(); if (transferModeFeature == null) { transferModeFeature = new TransferModeFeature(); _connection.Features.Set(transferModeFeature); } var requestedTransferMode = _protocol.Type == ProtocolType.Binary ? TransferMode.Binary : TransferMode.Text; transferModeFeature.TransferMode = requestedTransferMode; await _connection.StartAsync(); var actualTransferMode = transferModeFeature.TransferMode; _protocolReaderWriter = new HubProtocolReaderWriter(_protocol, GetDataEncoder(requestedTransferMode, actualTransferMode)); _logger.HubProtocol(_protocol.Name); using (var memoryStream = new MemoryStream()) { NegotiationProtocol.WriteMessage(new NegotiationMessage(_protocol.Name), memoryStream); await _connection.SendAsync(memoryStream.ToArray(), _connectionActive.Token); } } private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, TransferMode actualTransferMode) { if (requestedTransferMode == TransferMode.Binary && actualTransferMode == TransferMode.Text) { // This is for instance for SSE which is a Text protocol and the user wants to use a binary // protocol so we need to encode messages. return new Base64Encoder(); } Debug.Assert(requestedTransferMode == actualTransferMode, "All transports besides SSE are expected to support binary mode."); return new PassThroughEncoder(); } public async Task DisposeAsync() => await DisposeAsyncCore().ForceAsync(); private async Task DisposeAsyncCore() { await _connection.DisposeAsync(); } // TODO: Client return values/tasks? public IDisposable On(string methodName, Type[] parameterTypes, Func handler, object state) { var invocationHandler = new InvocationHandler(parameterTypes, handler, state); var invocationList = _handlers.AddOrUpdate(methodName, _ => new List { invocationHandler }, (_, invocations) => { lock (invocations) { invocations.Add(invocationHandler); } return invocations; }); return new Subscription(invocationHandler, invocationList); } public async Task> StreamAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) { return await StreamAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); } private async Task> StreamAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { var invokeCts = new CancellationTokenSource(); var irq = InvocationRequest.Stream(invokeCts.Token, returnType, GetNextId(), _loggerFactory, this, out var channel); // After InvokeCore we don't want the irq cancellation token to be triggered. // The stream invocation will be canceled by the CancelInvocationMessage, connection closing, or channel finishing. using (cancellationToken.Register(token => ((CancellationTokenSource)token).Cancel(), invokeCts)) { await InvokeCore(methodName, irq, args, nonBlocking: false); } if (cancellationToken.CanBeCanceled) { cancellationToken.Register(state => { var invocationReq = (InvocationRequest)state; if (!invocationReq.HubConnection._connectionActive.IsCancellationRequested) { // Fire and forget, if it fails that means we aren't connected anymore. _ = invocationReq.HubConnection.SendHubMessage(new CancelInvocationMessage(invocationReq.InvocationId), invocationReq); if (invocationReq.HubConnection.TryRemoveInvocation(invocationReq.InvocationId, out _)) { invocationReq.Complete(new StreamCompletionMessage(irq.InvocationId, error: null)); } invocationReq.Dispose(); } }, irq); } return channel; } public async Task InvokeAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default) => await InvokeAsyncCore(methodName, returnType, args, cancellationToken).ForceAsync(); private async Task InvokeAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, this, out var task); await InvokeCore(methodName, irq, args, nonBlocking: false); return await task; } public async Task SendAsync(string methodName, object[] args, CancellationToken cancellationToken = default) => await SendAsyncCore(methodName, args, cancellationToken).ForceAsync(); private Task SendAsyncCore(string methodName, object[] args, CancellationToken cancellationToken) { var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, this, out _); return InvokeCore(methodName, irq, args, nonBlocking: true); } private Task InvokeCore(string methodName, InvocationRequest irq, object[] args, bool nonBlocking) { ThrowIfConnectionTerminated(irq.InvocationId); if (nonBlocking) { _logger.PreparingNonBlockingInvocation(irq.InvocationId, methodName, args.Length); } else { _logger.PreparingBlockingInvocation(irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); } // Create an invocation descriptor. Client invocations are always blocking var invocationMessage = new InvocationMessage(irq.InvocationId, nonBlocking, methodName, argumentBindingException: null, arguments: args); // We don't need to track invocations for fire an forget calls if (!nonBlocking) { // I just want an excuse to use 'irq' as a variable name... _logger.RegisterInvocation(invocationMessage.InvocationId); AddInvocation(irq); } // Trace the full invocation _logger.IssueInvocation(invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args); // We don't need to wait for this to complete. It will signal back to the invocation request. return SendHubMessage(invocationMessage, irq); } private async Task SendHubMessage(HubMessage hubMessage, InvocationRequest irq) { try { var payload = _protocolReaderWriter.WriteMessage(hubMessage); _logger.SendInvocation(hubMessage.InvocationId); await _connection.SendAsync(payload, irq.CancellationToken); _logger.SendInvocationCompleted(hubMessage.InvocationId); } catch (Exception ex) { _logger.SendInvocationFailed(hubMessage.InvocationId, ex); irq.Fail(ex); TryRemoveInvocation(hubMessage.InvocationId, out _); } } private async Task OnDataReceivedAsync(byte[] data) { if (_protocolReaderWriter.ReadMessages(data, _binder, out var messages)) { foreach (var message in messages) { InvocationRequest irq; switch (message) { case InvocationMessage invocation: _logger.ReceivedInvocation(invocation.InvocationId, invocation.Target, invocation.ArgumentBindingException != null ? null : invocation.Arguments); await DispatchInvocationAsync(invocation, _connectionActive.Token); break; case CompletionMessage completion: if (!TryRemoveInvocation(completion.InvocationId, out irq)) { _logger.DropCompletionMessage(completion.InvocationId); return; } DispatchInvocationCompletion(completion, irq); irq.Dispose(); break; case StreamItemMessage streamItem: // Complete the invocation with an error, we don't support streaming (yet) if (!TryGetInvocation(streamItem.InvocationId, out irq)) { _logger.DropStreamMessage(streamItem.InvocationId); return; } DispatchInvocationStreamItemAsync(streamItem, irq); break; case StreamCompletionMessage streamCompletion: if (!TryRemoveInvocation(streamCompletion.InvocationId, out irq)) { _logger.DropStreamCompletionMessage(streamCompletion.InvocationId); return; } DispatchStreamCompletion(streamCompletion, irq); irq.Dispose(); break; default: throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}"); } } } } private Task Shutdown(Exception ex = null) { _logger.ShutdownConnection(); if (ex != null) { _logger.ShutdownWithError(ex); } lock (_pendingCallsLock) { // We cancel inside the lock to make sure everyone who was part-way through registering an invocation // completes. This also ensures that nobody will add things to _pendingCalls after we leave this block // because everything that adds to _pendingCalls checks _connectionActive first (inside the _pendingCallsLock) _connectionActive.Cancel(); foreach (var outstandingCall in _pendingCalls.Values) { _logger.RemoveInvocation(outstandingCall.InvocationId); if (ex != null) { outstandingCall.Fail(ex); } outstandingCall.Dispose(); } _pendingCalls.Clear(); } return Task.CompletedTask; } private async Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler if (!_handlers.TryGetValue(invocation.Target, out var handlers)) { _logger.MissingHandler(invocation.Target); return; } //TODO: Optimize this! // Copying the callbacks to avoid concurrency issues InvocationHandler[] copiedHandlers; lock (handlers) { copiedHandlers = new InvocationHandler[handlers.Count]; handlers.CopyTo(copiedHandlers); } foreach (var handler in copiedHandlers) { try { await handler.InvokeAsync(invocation.Arguments); } catch (Exception ex) { _logger.ErrorInvokingClientSideMethod(invocation.Target, ex); } } } // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel // and there's nobody to actually wait for us to finish. private async void DispatchInvocationStreamItemAsync(StreamItemMessage streamItem, InvocationRequest irq) { _logger.ReceivedStreamItem(streamItem.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.CancelingStreamItem(irq.InvocationId); } else if (!await irq.StreamItem(streamItem.Item)) { _logger.ReceivedStreamItemAfterClose(irq.InvocationId); } } private void DispatchInvocationCompletion(CompletionMessage completion, InvocationRequest irq) { _logger.ReceivedInvocationCompletion(completion.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.CancelingInvocationCompletion(irq.InvocationId); } else { irq.Complete(completion); } } private void DispatchStreamCompletion(StreamCompletionMessage completion, InvocationRequest irq) { _logger.ReceivedStreamCompletion(completion.InvocationId); if (irq.CancellationToken.IsCancellationRequested) { _logger.CancelingStreamCompletion(irq.InvocationId); } else { irq.Complete(completion); } } private void ThrowIfConnectionTerminated(string invocationId) { if (_connectionActive.Token.IsCancellationRequested) { _logger.InvokeAfterTermination(invocationId); throw new InvalidOperationException("Connection has been terminated."); } } private string GetNextId() => Interlocked.Increment(ref _nextId).ToString(); private void AddInvocation(InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(irq.InvocationId); if (_pendingCalls.ContainsKey(irq.InvocationId)) { _logger.InvocationAlreadyInUse(irq.InvocationId); throw new InvalidOperationException($"Invocation ID '{irq.InvocationId}' is already in use."); } else { _pendingCalls.Add(irq.InvocationId, irq); } } } private bool TryGetInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(invocationId); return _pendingCalls.TryGetValue(invocationId, out irq); } } private bool TryRemoveInvocation(string invocationId, out InvocationRequest irq) { lock (_pendingCallsLock) { ThrowIfConnectionTerminated(invocationId); if (_pendingCalls.TryGetValue(invocationId, out irq)) { _pendingCalls.Remove(invocationId); return true; } else { return false; } } } private class Subscription : IDisposable { private readonly InvocationHandler _handler; private readonly List _handlerList; public Subscription(InvocationHandler handler, List handlerList) { _handler = handler; _handlerList = handlerList; } public void Dispose() { lock (_handlerList) { _handlerList.Remove(_handler); } } } private class HubBinder : IInvocationBinder { private HubConnection _connection; public HubBinder(HubConnection connection) { _connection = connection; } public Type GetReturnType(string invocationId) { if (!_connection._pendingCalls.TryGetValue(invocationId, out var irq)) { _connection._logger.ReceivedUnexpectedResponse(invocationId); return null; } return irq.ResultType; } public Type[] GetParameterTypes(string methodName) { if (!_connection._handlers.TryGetValue(methodName, out var handlers)) { _connection._logger.MissingHandler(methodName); return Type.EmptyTypes; } // We use the parameter types of the first handler lock (handlers) { if (handlers.Count > 0) { return handlers[0].ParameterTypes; } throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); } } } private struct InvocationHandler { public Type[] ParameterTypes { get; } private readonly Func _callback; private readonly object _state; public InvocationHandler(Type[] parameterTypes, Func callback, object state) { _callback = callback; ParameterTypes = parameterTypes; _state = state; } public Task InvokeAsync(object[] parameters) { return _callback(parameters, _state); } } private class TransferModeFeature : ITransferModeFeature { public TransferMode TransferMode { get; set; } } } }