// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.SignalR { /// /// Handles incoming connections and implements the SignalR Hub Protocol. /// public class HubConnectionHandler : ConnectionHandler where THub : Hub { private readonly HubLifetimeManager _lifetimeManager; private readonly ILoggerFactory _loggerFactory; private readonly ILogger> _logger; private readonly IHubProtocolResolver _protocolResolver; private readonly HubOptions _hubOptions; private readonly HubOptions _globalHubOptions; private readonly IUserIdProvider _userIdProvider; private readonly HubDispatcher _dispatcher; private readonly bool _enableDetailedErrors; /// /// Initializes a new instance of the class. /// /// The hub lifetime manager. /// The protocol resolver used to resolve the protocols between client and server. /// The global options used to initialize hubs. /// Hub specific options used to initialize hubs. These options override the global options. /// The logger factory. /// The user ID provider used to get the user ID from a hub connection. /// The hub dispatcher used to dispatch incoming messages to hubs. /// This class is typically created via dependency injection. public HubConnectionHandler(HubLifetimeManager lifetimeManager, IHubProtocolResolver protocolResolver, IOptions globalHubOptions, IOptions> hubOptions, ILoggerFactory loggerFactory, IUserIdProvider userIdProvider, #pragma warning disable PUB0001 // Pubternal type in public API HubDispatcher dispatcher #pragma warning restore PUB0001 ) { _protocolResolver = protocolResolver; _lifetimeManager = lifetimeManager; _loggerFactory = loggerFactory; _hubOptions = hubOptions.Value; _globalHubOptions = globalHubOptions.Value; _logger = loggerFactory.CreateLogger>(); _userIdProvider = userIdProvider; _dispatcher = dispatcher; _enableDetailedErrors = _hubOptions.EnableDetailedErrors ?? _globalHubOptions.EnableDetailedErrors ?? false; } /// public override async Task OnConnectedAsync(ConnectionContext connection) { // We check to see if HubOptions are set because those take precedence over global hub options. // Then set the keepAlive and handshakeTimeout values to the defaults in HubOptionsSetup incase they were explicitly set to null. var keepAlive = _hubOptions.KeepAliveInterval ?? _globalHubOptions.KeepAliveInterval ?? HubOptionsSetup.DefaultKeepAliveInterval; var handshakeTimeout = _hubOptions.HandshakeTimeout ?? _globalHubOptions.HandshakeTimeout ?? HubOptionsSetup.DefaultHandshakeTimeout; var supportedProtocols = _hubOptions.SupportedProtocols ?? _globalHubOptions.SupportedProtocols; if (supportedProtocols != null && supportedProtocols.Count == 0) { throw new InvalidOperationException("There are no supported protocols"); } Log.ConnectedStarting(_logger); var connectionContext = new HubConnectionContext(connection, keepAlive, _loggerFactory); var resolvedSupportedProtocols = (supportedProtocols as IReadOnlyList) ?? supportedProtocols.ToList(); if (!await connectionContext.HandshakeAsync(handshakeTimeout, resolvedSupportedProtocols, _protocolResolver, _userIdProvider, _enableDetailedErrors)) { return; } try { await _lifetimeManager.OnConnectedAsync(connectionContext); await RunHubAsync(connectionContext); } finally { Log.ConnectedEnding(_logger); await _lifetimeManager.OnDisconnectedAsync(connectionContext); } } private async Task RunHubAsync(HubConnectionContext connection) { try { await _dispatcher.OnConnectedAsync(connection); } catch (Exception ex) { Log.ErrorDispatchingHubEvent(_logger, "OnConnectedAsync", ex); await SendCloseAsync(connection, ex); // return instead of throw to let close message send successfully return; } try { await DispatchMessagesAsync(connection); } catch (OperationCanceledException) { // Don't treat OperationCanceledException as an error, it's basically a "control flow" // exception to stop things from running } catch (Exception ex) { Log.ErrorProcessingRequest(_logger, ex); await HubOnDisconnectedAsync(connection, ex); // return instead of throw to let close message send successfully return; } await HubOnDisconnectedAsync(connection, null); } private async Task HubOnDisconnectedAsync(HubConnectionContext connection, Exception exception) { // send close message before aborting the connection await SendCloseAsync(connection, exception); // We wait on abort to complete, this is so that we can guarantee that all callbacks have fired // before OnDisconnectedAsync // Ensure the connection is aborted before firing disconnect await connection.AbortAsync(); try { await _dispatcher.OnDisconnectedAsync(connection, exception); } catch (Exception ex) { Log.ErrorDispatchingHubEvent(_logger, "OnDisconnectedAsync", ex); throw; } } private async Task SendCloseAsync(HubConnectionContext connection, Exception exception) { var closeMessage = CloseMessage.Empty; if (exception != null) { var errorMessage = ErrorMessageHelper.BuildErrorMessage("Connection closed with an error.", exception, _enableDetailedErrors); closeMessage = new CloseMessage(errorMessage); } try { await connection.WriteAsync(closeMessage); } catch (Exception ex) { Log.ErrorSendingClose(_logger, ex); } } private async Task DispatchMessagesAsync(HubConnectionContext connection) { var input = connection.Input; var protocol = connection.Protocol; while (true) { var result = await input.ReadAsync(); var buffer = result.Buffer; try { if (result.IsCanceled) { break; } if (!buffer.IsEmpty) { while (protocol.TryParseMessage(ref buffer, _dispatcher, out var message)) { await _dispatcher.DispatchMessageAsync(connection, message); } } if (result.IsCompleted) { if (!buffer.IsEmpty) { throw new InvalidDataException("Connection terminated while reading a message."); } break; } } finally { // The buffer was sliced up to where it was consumed, so we can just advance to the start. // We mark examined as buffer.End so that if we didn't receive a full frame, we'll wait for more data // before yielding the read again. input.AdvanceTo(buffer.Start, buffer.End); } } } private static class Log { private static readonly Action _errorDispatchingHubEvent = LoggerMessage.Define(LogLevel.Error, new EventId(1, "ErrorDispatchingHubEvent"), "Error when dispatching '{HubMethod}' on hub."); private static readonly Action _errorProcessingRequest = LoggerMessage.Define(LogLevel.Error, new EventId(2, "ErrorProcessingRequest"), "Error when processing requests."); private static readonly Action _abortFailed = LoggerMessage.Define(LogLevel.Trace, new EventId(3, "AbortFailed"), "Abort callback failed."); private static readonly Action _errorSendingClose = LoggerMessage.Define(LogLevel.Debug, new EventId(4, "ErrorSendingClose"), "Error when sending Close message."); private static readonly Action _connectedStarting = LoggerMessage.Define(LogLevel.Debug, new EventId(5, "ConnectedStarting"), "OnConnectedAsync started."); private static readonly Action _connectedEnding = LoggerMessage.Define(LogLevel.Debug, new EventId(6, "ConnectedEnding"), "OnConnectedAsync ending."); public static void ErrorDispatchingHubEvent(ILogger logger, string hubMethod, Exception exception) { _errorDispatchingHubEvent(logger, hubMethod, exception); } public static void ErrorProcessingRequest(ILogger logger, Exception exception) { _errorProcessingRequest(logger, exception); } public static void AbortFailed(ILogger logger, Exception exception) { _abortFailed(logger, exception); } public static void ErrorSendingClose(ILogger logger, Exception exception) { _errorSendingClose(logger, exception); } public static void ConnectedStarting(ILogger logger) { _connectedStarting(logger, null); } public static void ConnectedEnding(ILogger logger) { _connectedEnding(logger, null); } } } }