// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Buffers; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.IO.Pipelines; using System.Net; using System.Runtime.ExceptionServices; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Protocols.Features; using Microsoft.AspNetCore.SignalR.Core; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.SignalR { public class HubConnectionContext { private static Action _abortedCallback = AbortConnection; private readonly ConnectionContext _connectionContext; private readonly ILogger _logger; private readonly CancellationTokenSource _connectionAbortedTokenSource = new CancellationTokenSource(); private readonly TaskCompletionSource _abortCompletedTcs = new TaskCompletionSource(); private readonly long _keepAliveDuration; private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); private long _lastSendTimestamp = Stopwatch.GetTimestamp(); private byte[] _cachedPingMessage; public HubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) { _connectionContext = connectionContext; _logger = loggerFactory.CreateLogger(); ConnectionAborted = _connectionAbortedTokenSource.Token; _keepAliveDuration = (int)keepAliveInterval.TotalMilliseconds * (Stopwatch.Frequency / 1000); } public virtual CancellationToken ConnectionAborted { get; } public virtual string ConnectionId => _connectionContext.ConnectionId; public virtual ClaimsPrincipal User => Features.Get()?.User; public virtual IFeatureCollection Features => _connectionContext.Features; public virtual IDictionary Items => _connectionContext.Items; public virtual PipeReader Input => _connectionContext.Transport.Input; public string UserIdentifier { get; private set; } internal virtual IHubProtocol Protocol { get; set; } internal ExceptionDispatchInfo AbortException { get; private set; } // Currently used only for streaming methods internal ConcurrentDictionary ActiveRequestCancellationSources { get; } = new ConcurrentDictionary(); public virtual ValueTask WriteAsync(HubMessage message) { // We were unable to get the lock so take the slow async path of waiting for the semaphore if (!_writeLock.Wait(0)) { return new ValueTask(WriteSlowAsync(message)); } // This method should never throw synchronously var task = WriteCore(message); // The write didn't complete synchronously so await completion if (!task.IsCompletedSuccessfully) { return new ValueTask(CompleteWriteAsync(task)); } // Otherwise, release the lock acquired when entering WriteAsync _writeLock.Release(); return default; } private ValueTask WriteCore(HubMessage message) { try { // This will internally cache the buffer for each unique HubProtocol // So that we don't serialize the HubMessage for every single connection var buffer = message.WriteMessage(Protocol); _connectionContext.Transport.Output.Write(buffer); return _connectionContext.Transport.Output.FlushAsync(); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); } } private async Task CompleteWriteAsync(ValueTask task) { try { await task; } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); } finally { // Release the lock acquired when entering WriteAsync _writeLock.Release(); } } private async Task WriteSlowAsync(HubMessage message) { try { // Failed to get the lock immediately when entering WriteAsync so await until it is available await _writeLock.WaitAsync(); await WriteCore(message); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); } finally { _writeLock.Release(); } } private ValueTask TryWritePingAsync() { // Don't wait for the lock, if it returns false that means someone wrote to the connection // and we don't need to send a ping anymore if (!_writeLock.Wait(0)) { return default; } return new ValueTask(TryWritePingSlowAsync()); } private async Task TryWritePingSlowAsync() { try { Debug.Assert(_cachedPingMessage != null); _connectionContext.Transport.Output.Write(_cachedPingMessage); await _connectionContext.Transport.Output.FlushAsync(); Log.SentPing(_logger); } catch (Exception ex) { Log.FailedWritingMessage(_logger, ex); } finally { _writeLock.Release(); } } private async Task WriteHandshakeResponseAsync(HandshakeResponseMessage message) { await _writeLock.WaitAsync(); try { var ms = new MemoryStream(); HandshakeProtocol.WriteResponseMessage(message, ms); await _connectionContext.Transport.Output.WriteAsync(ms.ToArray()); } finally { _writeLock.Release(); } } public virtual void Abort() { // If we already triggered the token then noop, this isn't thread safe but it's good enough // to avoid spawning a new task in the most common cases if (_connectionAbortedTokenSource.IsCancellationRequested) { return; } // We fire and forget since this can trigger user code to run Task.Factory.StartNew(_abortedCallback, this); } internal async Task HandshakeAsync(TimeSpan timeout, IList supportedProtocols, IHubProtocolResolver protocolResolver, IUserIdProvider userIdProvider) { try { using (var cts = new CancellationTokenSource()) { cts.CancelAfter(timeout); while (true) { var result = await _connectionContext.Transport.Input.ReadAsync(cts.Token); var buffer = result.Buffer; var consumed = buffer.End; var examined = buffer.End; try { if (!buffer.IsEmpty) { if (HandshakeProtocol.TryParseRequestMessage(buffer, out var handshakeRequestMessage, out consumed, out examined)) { Protocol = protocolResolver.GetProtocol(handshakeRequestMessage.Protocol, supportedProtocols, this); if (Protocol == null) { Log.HandshakeFailed(_logger, null); await WriteHandshakeResponseAsync(new HandshakeResponseMessage($"The protocol '{handshakeRequestMessage.Protocol}' is not supported.")); return false; } // If there's a transfer format feature, we need to check if we're compatible and set the active format. // If there isn't a feature, it means that the transport supports binary data and doesn't need us to tell them // what format we're writing. var transferFormatFeature = Features.Get(); if (transferFormatFeature != null) { if ((transferFormatFeature.SupportedFormats & Protocol.TransferFormat) == 0) { Log.HandshakeFailed(_logger, null); await WriteHandshakeResponseAsync(new HandshakeResponseMessage($"Cannot use the '{Protocol.Name}' protocol on the current transport. The transport does not support '{Protocol.TransferFormat}' transfer format.")); return false; } transferFormatFeature.ActiveFormat = Protocol.TransferFormat; } _cachedPingMessage = Protocol.WriteToArray(PingMessage.Instance); UserIdentifier = userIdProvider.GetUserId(this); if (Features.Get() == null) { // Only register KeepAlive after protocol handshake otherwise KeepAliveTick could try to write without having a ProtocolReaderWriter Features.Get()?.OnHeartbeat(state => ((HubConnectionContext)state).KeepAliveTick(), this); } Log.HandshakeComplete(_logger, Protocol.Name); await WriteHandshakeResponseAsync(HandshakeResponseMessage.Empty); return true; } } else if (result.IsCompleted) { // connection was closed before we ever received a response // can't send a handshake response because there is no longer a connection Log.HandshakeFailed(_logger, null); return false; } } finally { _connectionContext.Transport.Input.AdvanceTo(consumed, examined); } } } } catch (OperationCanceledException) { Log.HandshakeCanceled(_logger); await WriteHandshakeResponseAsync(new HandshakeResponseMessage("Handshake was canceled.")); return false; } catch (Exception ex) { Log.HandshakeFailed(_logger, ex); await WriteHandshakeResponseAsync(new HandshakeResponseMessage($"An unexpected error occurred during connection handshake. {ex.GetType().Name}: {ex.Message}")); return false; } } internal void Abort(Exception exception) { AbortException = ExceptionDispatchInfo.Capture(exception); Abort(); } // Used by the HubEndPoint only internal Task AbortAsync() { Abort(); return _abortCompletedTcs.Task; } private void KeepAliveTick() { var timestamp = Stopwatch.GetTimestamp(); // Implements the keep-alive tick behavior // Each tick, we check if the time since the last send is larger than the keep alive duration (in ticks). // If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the // true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval), // because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick. if (timestamp - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration) { // Haven't sent a message for the entire keep-alive duration, so send a ping. // If the transport channel is full, this will fail, but that's OK because // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. _ = TryWritePingAsync(); Interlocked.Exchange(ref _lastSendTimestamp, timestamp); } } private static void AbortConnection(object state) { var connection = (HubConnectionContext)state; try { connection._connectionAbortedTokenSource.Cancel(); // Communicate the fact that we're finished triggering abort callbacks connection._abortCompletedTcs.TrySetResult(null); } catch (Exception ex) { // TODO: Should we log if the cancellation callback fails? This is more preventative to make sure // we don't end up with an unobserved task connection._abortCompletedTcs.TrySetException(ex); } } private static class Log { // Category: HubConnectionContext private static readonly Action _handshakeComplete = LoggerMessage.Define(LogLevel.Information, new EventId(1, "HandshakeComplete"), "Completed connection handshake. Using HubProtocol '{Protocol}'."); private static readonly Action _handshakeCanceled = LoggerMessage.Define(LogLevel.Debug, new EventId(2, "HandshakeCanceled"), "Handshake was canceled."); private static readonly Action _sentPing = LoggerMessage.Define(LogLevel.Trace, new EventId(3, "SentPing"), "Sent a ping message to the client."); private static readonly Action _transportBufferFull = LoggerMessage.Define(LogLevel.Debug, new EventId(4, "TransportBufferFull"), "Unable to send Ping message to client, the transport buffer is full."); private static readonly Action _handshakeFailed = LoggerMessage.Define(LogLevel.Error, new EventId(5, "HandshakeFailed"), "Failed connection handshake."); private static readonly Action _failedWritingMessage = LoggerMessage.Define(LogLevel.Debug, new EventId(6, "FailedWritingMessage"), "Failed writing message."); public static void HandshakeComplete(ILogger logger, string hubProtocol) { _handshakeComplete(logger, hubProtocol, null); } public static void HandshakeCanceled(ILogger logger) { _handshakeCanceled(logger, null); } public static void SentPing(ILogger logger) { _sentPing(logger, null); } public static void TransportBufferFull(ILogger logger) { _transportBufferFull(logger, null); } public static void HandshakeFailed(ILogger logger, Exception exception) { _handshakeFailed(logger, exception); } public static void FailedWritingMessage(ILogger logger, Exception exception) { _failedWritingMessage(logger, exception); } } } }