diff --git a/src/Microsoft.AspNetCore.SignalR.Core/SerializedHubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Core/SerializedHubMessage.cs index 5b2d507528..be46ed61e8 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/SerializedHubMessage.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/SerializedHubMessage.cs @@ -16,11 +16,13 @@ namespace Microsoft.AspNetCore.SignalR private SerializedMessage _cachedItem1; private SerializedMessage _cachedItem2; private IList _cachedItems; + private readonly object _lock = new object(); public HubMessage Message { get; } public SerializedHubMessage(IReadOnlyList messages) { + // A lock isn't needed here because nobody has access to this type until the constructor finishes. for (var i = 0; i < messages.Count; i++) { var message = messages[i]; @@ -35,23 +37,58 @@ namespace Microsoft.AspNetCore.SignalR public ReadOnlyMemory GetSerializedMessage(IHubProtocol protocol) { - if (!TryGetCached(protocol.Name, out var serialized)) + lock (_lock) { - if (Message == null) + if (!TryGetCached(protocol.Name, out var serialized)) { - throw new InvalidOperationException( - "This message was received from another server that did not have the requested protocol available."); + if (Message == null) + { + throw new InvalidOperationException( + "This message was received from another server that did not have the requested protocol available."); + } + + serialized = protocol.GetMessageBytes(Message); + SetCache(protocol.Name, serialized); } - serialized = protocol.GetMessageBytes(Message); - SetCache(protocol.Name, serialized); + return serialized; } + } - return serialized; + // Used for unit testing. + internal IReadOnlyList GetAllSerializations() + { + // Even if this is only used in tests, let's do it right. + lock (_lock) + { + if (_cachedItem1.ProtocolName == null) + { + return Array.Empty(); + } + + var list = new List(2); + list.Add(_cachedItem1); + + if (_cachedItem2.ProtocolName != null) + { + list.Add(_cachedItem2); + + if (_cachedItems != null) + { + list.AddRange(_cachedItems); + } + } + + return list; + } } private void SetCache(string protocolName, ReadOnlyMemory serialized) { + // We set the fields before moving on to the list, if we need it to hold more than 2 items. + // We have to read/write these fields under the lock because the structs might tear and another + // thread might observe them half-assigned + if (_cachedItem1.ProtocolName == null) { _cachedItem1 = new SerializedMessage(protocolName, serialized); diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs index 3a710f42cc..b3ab182d5d 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/RedisProtocolTests.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Redis.Internal; +using Microsoft.AspNetCore.SignalR.Tests; using Xunit; namespace Microsoft.AspNetCore.SignalR.Redis.Tests @@ -150,7 +151,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests Assert.Equal( expected.Message.GetSerializedMessage(hubProtocol).ToArray(), decoded.Message.GetSerializedMessage(hubProtocol).ToArray()); - Assert.Equal(1, hubProtocol.SerializationCount); + + var writtenMessages = hubProtocol.GetWrittenMessages(); + Assert.Collection(writtenMessages, + actualMessage => + { + var invocation = Assert.IsType(actualMessage); + Assert.Same(_testMessage.Target, invocation.Target); + Assert.Same(_testMessage.Arguments, invocation.Arguments); + }); } } @@ -186,46 +195,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests Encoded = encoded; } } - - public class DummyHubProtocol : IHubProtocol - { - public int SerializationCount { get; private set; } - - public string Name { get; } - public int Version => 1; - public TransferFormat TransferFormat => TransferFormat.Text; - - public DummyHubProtocol(string name) - { - Name = name; - } - - public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) - { - throw new NotSupportedException(); - } - - public void WriteMessage(HubMessage message, IBufferWriter output) - { - output.Write(GetMessageBytes(message).Span); - } - - public ReadOnlyMemory GetMessageBytes(HubMessage message) - { - SerializationCount += 1; - - // Assert that we got the test message - var invocation = Assert.IsType(message); - Assert.Same(_testMessage.Target, invocation.Target); - Assert.Same(_testMessage.Arguments, invocation.Arguments); - - return new byte[] { 0x2A }; - } - - public bool IsVersionSupported(int version) - { - throw new NotSupportedException(); - } - } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/DummyHubProtocol.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/DummyHubProtocol.cs new file mode 100644 index 0000000000..74ca4eaf04 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/DummyHubProtocol.cs @@ -0,0 +1,61 @@ +using System; +using System.Buffers; +using System.Collections.Generic; +using Microsoft.AspNetCore.Connections; +using Microsoft.AspNetCore.SignalR.Protocol; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class DummyHubProtocol : IHubProtocol + { + private readonly Action _onWrite; + private readonly object _lock = new object(); + private readonly List _writtenMessages = new List(); + + public static readonly byte[] DummySerialization = new byte[] { 0x2A }; + + public string Name { get; } + public int Version => 1; + public TransferFormat TransferFormat => TransferFormat.Text; + + public DummyHubProtocol(string name, Action onWrite = null) + { + _onWrite = onWrite ?? (() => { }); + Name = name; + } + + public IReadOnlyList GetWrittenMessages() + { + lock (_lock) + { + return _writtenMessages.ToArray(); + } + } + + public bool TryParseMessage(ref ReadOnlySequence input, IInvocationBinder binder, out HubMessage message) + { + throw new NotSupportedException(); + } + + public void WriteMessage(HubMessage message, IBufferWriter output) + { + output.Write(GetMessageBytes(message).Span); + } + + public ReadOnlyMemory GetMessageBytes(HubMessage message) + { + _onWrite(); + lock (_lock) + { + _writtenMessages.Add(message); + } + + return DummySerialization; + } + + public bool IsVersionSupported(int version) + { + throw new NotSupportedException(); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/SyncPoint.cs similarity index 92% rename from test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs rename to test/Microsoft.AspNetCore.SignalR.Tests.Utils/SyncPoint.cs index d39d24af55..55f4a034d5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/SyncPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/SyncPoint.cs @@ -4,13 +4,12 @@ using System; using System.Threading.Tasks; -namespace Microsoft.AspNetCore.SignalR.Client.Tests +namespace Microsoft.AspNetCore.SignalR.Tests { - // Possibly useful as a general-purpose async testing helper? public class SyncPoint { - private readonly TaskCompletionSource _atSyncPoint = new TaskCompletionSource(); - private readonly TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(); + private readonly TaskCompletionSource _atSyncPoint = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + private readonly TaskCompletionSource _continueFromSyncPoint = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); /// /// Waits for the code-under-test to reach . diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/SerializedHubMessageTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/SerializedHubMessageTests.cs new file mode 100644 index 0000000000..b69952b523 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/SerializedHubMessageTests.cs @@ -0,0 +1,88 @@ +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Protocol; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public class SerializedHubMessageTests + { + [Fact] + public void GetSerializedMessageSerializesUsingHubProtocolIfNoCacheAvailable() + { + var invocation = new InvocationMessage("Foo", new object[0]); + var message = new SerializedHubMessage(invocation); + var protocol = new DummyHubProtocol("p1"); + + var serialized = message.GetSerializedMessage(protocol); + + Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray()); + Assert.Collection(protocol.GetWrittenMessages(), + actualMessage => Assert.Same(invocation, actualMessage)); + } + + [Fact] + public void GetSerializedMessageReturnsCachedSerializationIfAvailable() + { + var invocation = new InvocationMessage("Foo", new object[0]); + var message = new SerializedHubMessage(invocation); + var protocol = new DummyHubProtocol("p1"); + + // This should cache it + _ = message.GetSerializedMessage(protocol); + + // Get it again + var serialized = message.GetSerializedMessage(protocol); + + + Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray()); + + // We should still only have written one message + Assert.Collection(protocol.GetWrittenMessages(), + actualMessage => Assert.Same(invocation, actualMessage)); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(5)] + public async Task SerializingTwoMessagesFromTheSameProtocolSimultaneouslyResultsInOneCachedItemAsync(int numberOfSerializationsToPreCache) + { + var invocation = new InvocationMessage("Foo", new object[0]); + var message = new SerializedHubMessage(invocation); + + // "Pre-cache" the requested number of serializations (so we can test scenarios involving each of the fields and the fallback list) + for (var i = 0; i < numberOfSerializationsToPreCache; i++) + { + _ = message.GetSerializedMessage(new DummyHubProtocol($"p{i}")); + } + + var onWrite = SyncPoint.Create(2, out var syncPoints); + var protocol = new DummyHubProtocol("test", () => onWrite().Wait()); + + // Serialize once, but hold at the Hub Protocol + var firstSerialization = Task.Run(() => message.GetSerializedMessage(protocol)); + await syncPoints[0].WaitForSyncPoint(); + + // Serialize again, which should hit the lock + var secondSerialization = Task.Run(() => message.GetSerializedMessage(protocol)); + Assert.False(secondSerialization.IsCompleted); + + // Release both instances of the syncpoint + syncPoints[0].Continue(); + syncPoints[1].Continue(); + + // Everything should finish and only one serialization should be written + await firstSerialization.OrTimeout(); + await secondSerialization.OrTimeout(); + + Assert.Collection(message.GetAllSerializations().Skip(numberOfSerializationsToPreCache).ToArray(), + serializedMessage => + { + Assert.Equal("test", serializedMessage.ProtocolName); + Assert.Equal(DummyHubProtocol.DummySerialization, serializedMessage.Serialized.ToArray()); + }); + } + } +}