Merge pull request #2104 from aspnet/release/2.1
fix #2078 by adding locking (#2079)
This commit is contained in:
commit
bea09f5f94
|
|
@ -16,11 +16,13 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
private SerializedMessage _cachedItem1;
|
private SerializedMessage _cachedItem1;
|
||||||
private SerializedMessage _cachedItem2;
|
private SerializedMessage _cachedItem2;
|
||||||
private IList<SerializedMessage> _cachedItems;
|
private IList<SerializedMessage> _cachedItems;
|
||||||
|
private readonly object _lock = new object();
|
||||||
|
|
||||||
public HubMessage Message { get; }
|
public HubMessage Message { get; }
|
||||||
|
|
||||||
public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
|
public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
|
||||||
{
|
{
|
||||||
|
// A lock isn't needed here because nobody has access to this type until the constructor finishes.
|
||||||
for (var i = 0; i < messages.Count; i++)
|
for (var i = 0; i < messages.Count; i++)
|
||||||
{
|
{
|
||||||
var message = messages[i];
|
var message = messages[i];
|
||||||
|
|
@ -35,23 +37,58 @@ namespace Microsoft.AspNetCore.SignalR
|
||||||
|
|
||||||
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
|
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
|
||||||
{
|
{
|
||||||
if (!TryGetCached(protocol.Name, out var serialized))
|
lock (_lock)
|
||||||
{
|
{
|
||||||
if (Message == null)
|
if (!TryGetCached(protocol.Name, out var serialized))
|
||||||
{
|
{
|
||||||
throw new InvalidOperationException(
|
if (Message == null)
|
||||||
"This message was received from another server that did not have the requested protocol available.");
|
{
|
||||||
|
throw new InvalidOperationException(
|
||||||
|
"This message was received from another server that did not have the requested protocol available.");
|
||||||
|
}
|
||||||
|
|
||||||
|
serialized = protocol.GetMessageBytes(Message);
|
||||||
|
SetCache(protocol.Name, serialized);
|
||||||
}
|
}
|
||||||
|
|
||||||
serialized = protocol.GetMessageBytes(Message);
|
return serialized;
|
||||||
SetCache(protocol.Name, serialized);
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return serialized;
|
// Used for unit testing.
|
||||||
|
internal IReadOnlyList<SerializedMessage> GetAllSerializations()
|
||||||
|
{
|
||||||
|
// Even if this is only used in tests, let's do it right.
|
||||||
|
lock (_lock)
|
||||||
|
{
|
||||||
|
if (_cachedItem1.ProtocolName == null)
|
||||||
|
{
|
||||||
|
return Array.Empty<SerializedMessage>();
|
||||||
|
}
|
||||||
|
|
||||||
|
var list = new List<SerializedMessage>(2);
|
||||||
|
list.Add(_cachedItem1);
|
||||||
|
|
||||||
|
if (_cachedItem2.ProtocolName != null)
|
||||||
|
{
|
||||||
|
list.Add(_cachedItem2);
|
||||||
|
|
||||||
|
if (_cachedItems != null)
|
||||||
|
{
|
||||||
|
list.AddRange(_cachedItems);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return list;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
|
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
|
||||||
{
|
{
|
||||||
|
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
|
||||||
|
// We have to read/write these fields under the lock because the structs might tear and another
|
||||||
|
// thread might observe them half-assigned
|
||||||
|
|
||||||
if (_cachedItem1.ProtocolName == null)
|
if (_cachedItem1.ProtocolName == null)
|
||||||
{
|
{
|
||||||
_cachedItem1 = new SerializedMessage(protocolName, serialized);
|
_cachedItem1 = new SerializedMessage(protocolName, serialized);
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Connections;
|
||||||
using Microsoft.AspNetCore.SignalR.Internal;
|
using Microsoft.AspNetCore.SignalR.Internal;
|
||||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||||
using Microsoft.AspNetCore.SignalR.Redis.Internal;
|
using Microsoft.AspNetCore.SignalR.Redis.Internal;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Tests;
|
||||||
using Xunit;
|
using Xunit;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||||
|
|
@ -150,7 +151,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||||
Assert.Equal(
|
Assert.Equal(
|
||||||
expected.Message.GetSerializedMessage(hubProtocol).ToArray(),
|
expected.Message.GetSerializedMessage(hubProtocol).ToArray(),
|
||||||
decoded.Message.GetSerializedMessage(hubProtocol).ToArray());
|
decoded.Message.GetSerializedMessage(hubProtocol).ToArray());
|
||||||
Assert.Equal(1, hubProtocol.SerializationCount);
|
|
||||||
|
var writtenMessages = hubProtocol.GetWrittenMessages();
|
||||||
|
Assert.Collection(writtenMessages,
|
||||||
|
actualMessage =>
|
||||||
|
{
|
||||||
|
var invocation = Assert.IsType<InvocationMessage>(actualMessage);
|
||||||
|
Assert.Same(_testMessage.Target, invocation.Target);
|
||||||
|
Assert.Same(_testMessage.Arguments, invocation.Arguments);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -186,46 +195,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||||
Encoded = encoded;
|
Encoded = encoded;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public class DummyHubProtocol : IHubProtocol
|
|
||||||
{
|
|
||||||
public int SerializationCount { get; private set; }
|
|
||||||
|
|
||||||
public string Name { get; }
|
|
||||||
public int Version => 1;
|
|
||||||
public TransferFormat TransferFormat => TransferFormat.Text;
|
|
||||||
|
|
||||||
public DummyHubProtocol(string name)
|
|
||||||
{
|
|
||||||
Name = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
|
|
||||||
{
|
|
||||||
throw new NotSupportedException();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
|
|
||||||
{
|
|
||||||
output.Write(GetMessageBytes(message).Span);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
|
|
||||||
{
|
|
||||||
SerializationCount += 1;
|
|
||||||
|
|
||||||
// Assert that we got the test message
|
|
||||||
var invocation = Assert.IsType<InvocationMessage>(message);
|
|
||||||
Assert.Same(_testMessage.Target, invocation.Target);
|
|
||||||
Assert.Same(_testMessage.Arguments, invocation.Arguments);
|
|
||||||
|
|
||||||
return new byte[] { 0x2A };
|
|
||||||
}
|
|
||||||
|
|
||||||
public bool IsVersionSupported(int version)
|
|
||||||
{
|
|
||||||
throw new NotSupportedException();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
using System;
|
||||||
|
using System.Buffers;
|
||||||
|
using System.Collections.Generic;
|
||||||
|
using Microsoft.AspNetCore.Connections;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public class DummyHubProtocol : IHubProtocol
|
||||||
|
{
|
||||||
|
private readonly Action _onWrite;
|
||||||
|
private readonly object _lock = new object();
|
||||||
|
private readonly List<HubMessage> _writtenMessages = new List<HubMessage>();
|
||||||
|
|
||||||
|
public static readonly byte[] DummySerialization = new byte[] { 0x2A };
|
||||||
|
|
||||||
|
public string Name { get; }
|
||||||
|
public int Version => 1;
|
||||||
|
public TransferFormat TransferFormat => TransferFormat.Text;
|
||||||
|
|
||||||
|
public DummyHubProtocol(string name, Action onWrite = null)
|
||||||
|
{
|
||||||
|
_onWrite = onWrite ?? (() => { });
|
||||||
|
Name = name;
|
||||||
|
}
|
||||||
|
|
||||||
|
public IReadOnlyList<HubMessage> GetWrittenMessages()
|
||||||
|
{
|
||||||
|
lock (_lock)
|
||||||
|
{
|
||||||
|
return _writtenMessages.ToArray();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
|
||||||
|
{
|
||||||
|
throw new NotSupportedException();
|
||||||
|
}
|
||||||
|
|
||||||
|
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
|
||||||
|
{
|
||||||
|
output.Write(GetMessageBytes(message).Span);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
|
||||||
|
{
|
||||||
|
_onWrite();
|
||||||
|
lock (_lock)
|
||||||
|
{
|
||||||
|
_writtenMessages.Add(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
return DummySerialization;
|
||||||
|
}
|
||||||
|
|
||||||
|
public bool IsVersionSupported(int version)
|
||||||
|
{
|
||||||
|
throw new NotSupportedException();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -4,13 +4,12 @@
|
||||||
using System;
|
using System;
|
||||||
using System.Threading.Tasks;
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
{
|
{
|
||||||
// Possibly useful as a general-purpose async testing helper?
|
|
||||||
public class SyncPoint
|
public class SyncPoint
|
||||||
{
|
{
|
||||||
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>();
|
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>();
|
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
/// Waits for the code-under-test to reach <see cref="WaitToContinue"/>.
|
/// Waits for the code-under-test to reach <see cref="WaitToContinue"/>.
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
using System.Linq;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||||
|
using Xunit;
|
||||||
|
|
||||||
|
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||||
|
{
|
||||||
|
public class SerializedHubMessageTests
|
||||||
|
{
|
||||||
|
[Fact]
|
||||||
|
public void GetSerializedMessageSerializesUsingHubProtocolIfNoCacheAvailable()
|
||||||
|
{
|
||||||
|
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||||
|
var message = new SerializedHubMessage(invocation);
|
||||||
|
var protocol = new DummyHubProtocol("p1");
|
||||||
|
|
||||||
|
var serialized = message.GetSerializedMessage(protocol);
|
||||||
|
|
||||||
|
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
|
||||||
|
Assert.Collection(protocol.GetWrittenMessages(),
|
||||||
|
actualMessage => Assert.Same(invocation, actualMessage));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public void GetSerializedMessageReturnsCachedSerializationIfAvailable()
|
||||||
|
{
|
||||||
|
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||||
|
var message = new SerializedHubMessage(invocation);
|
||||||
|
var protocol = new DummyHubProtocol("p1");
|
||||||
|
|
||||||
|
// This should cache it
|
||||||
|
_ = message.GetSerializedMessage(protocol);
|
||||||
|
|
||||||
|
// Get it again
|
||||||
|
var serialized = message.GetSerializedMessage(protocol);
|
||||||
|
|
||||||
|
|
||||||
|
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
|
||||||
|
|
||||||
|
// We should still only have written one message
|
||||||
|
Assert.Collection(protocol.GetWrittenMessages(),
|
||||||
|
actualMessage => Assert.Same(invocation, actualMessage));
|
||||||
|
}
|
||||||
|
|
||||||
|
[Theory]
|
||||||
|
[InlineData(0)]
|
||||||
|
[InlineData(1)]
|
||||||
|
[InlineData(2)]
|
||||||
|
[InlineData(5)]
|
||||||
|
public async Task SerializingTwoMessagesFromTheSameProtocolSimultaneouslyResultsInOneCachedItemAsync(int numberOfSerializationsToPreCache)
|
||||||
|
{
|
||||||
|
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||||
|
var message = new SerializedHubMessage(invocation);
|
||||||
|
|
||||||
|
// "Pre-cache" the requested number of serializations (so we can test scenarios involving each of the fields and the fallback list)
|
||||||
|
for (var i = 0; i < numberOfSerializationsToPreCache; i++)
|
||||||
|
{
|
||||||
|
_ = message.GetSerializedMessage(new DummyHubProtocol($"p{i}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
var onWrite = SyncPoint.Create(2, out var syncPoints);
|
||||||
|
var protocol = new DummyHubProtocol("test", () => onWrite().Wait());
|
||||||
|
|
||||||
|
// Serialize once, but hold at the Hub Protocol
|
||||||
|
var firstSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
|
||||||
|
await syncPoints[0].WaitForSyncPoint();
|
||||||
|
|
||||||
|
// Serialize again, which should hit the lock
|
||||||
|
var secondSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
|
||||||
|
Assert.False(secondSerialization.IsCompleted);
|
||||||
|
|
||||||
|
// Release both instances of the syncpoint
|
||||||
|
syncPoints[0].Continue();
|
||||||
|
syncPoints[1].Continue();
|
||||||
|
|
||||||
|
// Everything should finish and only one serialization should be written
|
||||||
|
await firstSerialization.OrTimeout();
|
||||||
|
await secondSerialization.OrTimeout();
|
||||||
|
|
||||||
|
Assert.Collection(message.GetAllSerializations().Skip(numberOfSerializationsToPreCache).ToArray(),
|
||||||
|
serializedMessage =>
|
||||||
|
{
|
||||||
|
Assert.Equal("test", serializedMessage.ProtocolName);
|
||||||
|
Assert.Equal(DummyHubProtocol.DummySerialization, serializedMessage.Serialized.ToArray());
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue