// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using Microsoft.AspNetCore.SignalR.Protocol; namespace Microsoft.AspNetCore.SignalR { /// /// Represents a serialization cache for a single message. /// public class SerializedHubMessage { private SerializedMessage _cachedItem1; private SerializedMessage _cachedItem2; private IList _cachedItems; private readonly object _lock = new object(); /// /// Gets the hub message for the serialization cache. /// public HubMessage Message { get; } /// /// Initializes a new instance of the class. /// /// A collection of already serialized messages to cache. public SerializedHubMessage(IReadOnlyList messages) { // A lock isn't needed here because nobody has access to this type until the constructor finishes. for (var i = 0; i < messages.Count; i++) { var message = messages[i]; SetCache(message.ProtocolName, message.Serialized); } } /// /// Initializes a new instance of the class. /// /// The hub message for the cache. This will be serialized with an in to get the message's serialized representation. public SerializedHubMessage(HubMessage message) { Message = message; } /// /// Gets the serialized representation of the using the specified . /// /// The protocol used to create the serialized representation. /// The serialized representation of the . public ReadOnlyMemory GetSerializedMessage(IHubProtocol protocol) { lock (_lock) { if (!TryGetCached(protocol.Name, out var serialized)) { if (Message == null) { throw new InvalidOperationException( "This message was received from another server that did not have the requested protocol available."); } serialized = protocol.GetMessageBytes(Message); SetCache(protocol.Name, serialized); } return serialized; } } // Used for unit testing. internal IReadOnlyList GetAllSerializations() { // Even if this is only used in tests, let's do it right. lock (_lock) { if (_cachedItem1.ProtocolName == null) { return Array.Empty(); } var list = new List(2); list.Add(_cachedItem1); if (_cachedItem2.ProtocolName != null) { list.Add(_cachedItem2); if (_cachedItems != null) { list.AddRange(_cachedItems); } } return list; } } private void SetCache(string protocolName, ReadOnlyMemory serialized) { // We set the fields before moving on to the list, if we need it to hold more than 2 items. // We have to read/write these fields under the lock because the structs might tear and another // thread might observe them half-assigned if (_cachedItem1.ProtocolName == null) { _cachedItem1 = new SerializedMessage(protocolName, serialized); } else if (_cachedItem2.ProtocolName == null) { _cachedItem2 = new SerializedMessage(protocolName, serialized); } else { if (_cachedItems == null) { _cachedItems = new List(); } foreach (var item in _cachedItems) { if (string.Equals(item.ProtocolName, protocolName, StringComparison.Ordinal)) { // No need to add return; } } _cachedItems.Add(new SerializedMessage(protocolName, serialized)); } } private bool TryGetCached(string protocolName, out ReadOnlyMemory result) { if (string.Equals(_cachedItem1.ProtocolName, protocolName, StringComparison.Ordinal)) { result = _cachedItem1.Serialized; return true; } if (string.Equals(_cachedItem2.ProtocolName, protocolName, StringComparison.Ordinal)) { result = _cachedItem2.Serialized; return true; } if (_cachedItems != null) { foreach (var serializedMessage in _cachedItems) { if (string.Equals(serializedMessage.ProtocolName, protocolName, StringComparison.Ordinal)) { result = serializedMessage.Serialized; return true; } } } result = default; return false; } } }