Merge pull request #2104 from aspnet/release/2.1
fix #2078 by adding locking (#2079)
This commit is contained in:
commit
bea09f5f94
|
|
@ -16,11 +16,13 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
private SerializedMessage _cachedItem1;
|
||||
private SerializedMessage _cachedItem2;
|
||||
private IList<SerializedMessage> _cachedItems;
|
||||
private readonly object _lock = new object();
|
||||
|
||||
public HubMessage Message { get; }
|
||||
|
||||
public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
|
||||
{
|
||||
// A lock isn't needed here because nobody has access to this type until the constructor finishes.
|
||||
for (var i = 0; i < messages.Count; i++)
|
||||
{
|
||||
var message = messages[i];
|
||||
|
|
@ -35,23 +37,58 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
|
||||
{
|
||||
if (!TryGetCached(protocol.Name, out var serialized))
|
||||
lock (_lock)
|
||||
{
|
||||
if (Message == null)
|
||||
if (!TryGetCached(protocol.Name, out var serialized))
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"This message was received from another server that did not have the requested protocol available.");
|
||||
if (Message == null)
|
||||
{
|
||||
throw new InvalidOperationException(
|
||||
"This message was received from another server that did not have the requested protocol available.");
|
||||
}
|
||||
|
||||
serialized = protocol.GetMessageBytes(Message);
|
||||
SetCache(protocol.Name, serialized);
|
||||
}
|
||||
|
||||
serialized = protocol.GetMessageBytes(Message);
|
||||
SetCache(protocol.Name, serialized);
|
||||
return serialized;
|
||||
}
|
||||
}
|
||||
|
||||
return serialized;
|
||||
// Used for unit testing.
|
||||
internal IReadOnlyList<SerializedMessage> GetAllSerializations()
|
||||
{
|
||||
// Even if this is only used in tests, let's do it right.
|
||||
lock (_lock)
|
||||
{
|
||||
if (_cachedItem1.ProtocolName == null)
|
||||
{
|
||||
return Array.Empty<SerializedMessage>();
|
||||
}
|
||||
|
||||
var list = new List<SerializedMessage>(2);
|
||||
list.Add(_cachedItem1);
|
||||
|
||||
if (_cachedItem2.ProtocolName != null)
|
||||
{
|
||||
list.Add(_cachedItem2);
|
||||
|
||||
if (_cachedItems != null)
|
||||
{
|
||||
list.AddRange(_cachedItems);
|
||||
}
|
||||
}
|
||||
|
||||
return list;
|
||||
}
|
||||
}
|
||||
|
||||
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
|
||||
{
|
||||
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
|
||||
// We have to read/write these fields under the lock because the structs might tear and another
|
||||
// thread might observe them half-assigned
|
||||
|
||||
if (_cachedItem1.ProtocolName == null)
|
||||
{
|
||||
_cachedItem1 = new SerializedMessage(protocolName, serialized);
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Connections;
|
|||
using Microsoft.AspNetCore.SignalR.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||
using Microsoft.AspNetCore.SignalR.Redis.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Tests;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
||||
|
|
@ -150,7 +151,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
Assert.Equal(
|
||||
expected.Message.GetSerializedMessage(hubProtocol).ToArray(),
|
||||
decoded.Message.GetSerializedMessage(hubProtocol).ToArray());
|
||||
Assert.Equal(1, hubProtocol.SerializationCount);
|
||||
|
||||
var writtenMessages = hubProtocol.GetWrittenMessages();
|
||||
Assert.Collection(writtenMessages,
|
||||
actualMessage =>
|
||||
{
|
||||
var invocation = Assert.IsType<InvocationMessage>(actualMessage);
|
||||
Assert.Same(_testMessage.Target, invocation.Target);
|
||||
Assert.Same(_testMessage.Arguments, invocation.Arguments);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -186,46 +195,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
|
|||
Encoded = encoded;
|
||||
}
|
||||
}
|
||||
|
||||
public class DummyHubProtocol : IHubProtocol
|
||||
{
|
||||
public int SerializationCount { get; private set; }
|
||||
|
||||
public string Name { get; }
|
||||
public int Version => 1;
|
||||
public TransferFormat TransferFormat => TransferFormat.Text;
|
||||
|
||||
public DummyHubProtocol(string name)
|
||||
{
|
||||
Name = name;
|
||||
}
|
||||
|
||||
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
|
||||
{
|
||||
output.Write(GetMessageBytes(message).Span);
|
||||
}
|
||||
|
||||
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
|
||||
{
|
||||
SerializationCount += 1;
|
||||
|
||||
// Assert that we got the test message
|
||||
var invocation = Assert.IsType<InvocationMessage>(message);
|
||||
Assert.Same(_testMessage.Target, invocation.Target);
|
||||
Assert.Same(_testMessage.Arguments, invocation.Arguments);
|
||||
|
||||
return new byte[] { 0x2A };
|
||||
}
|
||||
|
||||
public bool IsVersionSupported(int version)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
using System;
|
||||
using System.Buffers;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public class DummyHubProtocol : IHubProtocol
|
||||
{
|
||||
private readonly Action _onWrite;
|
||||
private readonly object _lock = new object();
|
||||
private readonly List<HubMessage> _writtenMessages = new List<HubMessage>();
|
||||
|
||||
public static readonly byte[] DummySerialization = new byte[] { 0x2A };
|
||||
|
||||
public string Name { get; }
|
||||
public int Version => 1;
|
||||
public TransferFormat TransferFormat => TransferFormat.Text;
|
||||
|
||||
public DummyHubProtocol(string name, Action onWrite = null)
|
||||
{
|
||||
_onWrite = onWrite ?? (() => { });
|
||||
Name = name;
|
||||
}
|
||||
|
||||
public IReadOnlyList<HubMessage> GetWrittenMessages()
|
||||
{
|
||||
lock (_lock)
|
||||
{
|
||||
return _writtenMessages.ToArray();
|
||||
}
|
||||
}
|
||||
|
||||
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
|
||||
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
|
||||
{
|
||||
output.Write(GetMessageBytes(message).Span);
|
||||
}
|
||||
|
||||
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
|
||||
{
|
||||
_onWrite();
|
||||
lock (_lock)
|
||||
{
|
||||
_writtenMessages.Add(message);
|
||||
}
|
||||
|
||||
return DummySerialization;
|
||||
}
|
||||
|
||||
public bool IsVersionSupported(int version)
|
||||
{
|
||||
throw new NotSupportedException();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -4,13 +4,12 @@
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
// Possibly useful as a general-purpose async testing helper?
|
||||
public class SyncPoint
|
||||
{
|
||||
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>();
|
||||
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
/// <summary>
|
||||
/// Waits for the code-under-test to reach <see cref="WaitToContinue"/>.
|
||||
|
|
@ -0,0 +1,88 @@
|
|||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.SignalR.Protocol;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
public class SerializedHubMessageTests
|
||||
{
|
||||
[Fact]
|
||||
public void GetSerializedMessageSerializesUsingHubProtocolIfNoCacheAvailable()
|
||||
{
|
||||
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||
var message = new SerializedHubMessage(invocation);
|
||||
var protocol = new DummyHubProtocol("p1");
|
||||
|
||||
var serialized = message.GetSerializedMessage(protocol);
|
||||
|
||||
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
|
||||
Assert.Collection(protocol.GetWrittenMessages(),
|
||||
actualMessage => Assert.Same(invocation, actualMessage));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetSerializedMessageReturnsCachedSerializationIfAvailable()
|
||||
{
|
||||
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||
var message = new SerializedHubMessage(invocation);
|
||||
var protocol = new DummyHubProtocol("p1");
|
||||
|
||||
// This should cache it
|
||||
_ = message.GetSerializedMessage(protocol);
|
||||
|
||||
// Get it again
|
||||
var serialized = message.GetSerializedMessage(protocol);
|
||||
|
||||
|
||||
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
|
||||
|
||||
// We should still only have written one message
|
||||
Assert.Collection(protocol.GetWrittenMessages(),
|
||||
actualMessage => Assert.Same(invocation, actualMessage));
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(0)]
|
||||
[InlineData(1)]
|
||||
[InlineData(2)]
|
||||
[InlineData(5)]
|
||||
public async Task SerializingTwoMessagesFromTheSameProtocolSimultaneouslyResultsInOneCachedItemAsync(int numberOfSerializationsToPreCache)
|
||||
{
|
||||
var invocation = new InvocationMessage("Foo", new object[0]);
|
||||
var message = new SerializedHubMessage(invocation);
|
||||
|
||||
// "Pre-cache" the requested number of serializations (so we can test scenarios involving each of the fields and the fallback list)
|
||||
for (var i = 0; i < numberOfSerializationsToPreCache; i++)
|
||||
{
|
||||
_ = message.GetSerializedMessage(new DummyHubProtocol($"p{i}"));
|
||||
}
|
||||
|
||||
var onWrite = SyncPoint.Create(2, out var syncPoints);
|
||||
var protocol = new DummyHubProtocol("test", () => onWrite().Wait());
|
||||
|
||||
// Serialize once, but hold at the Hub Protocol
|
||||
var firstSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
|
||||
await syncPoints[0].WaitForSyncPoint();
|
||||
|
||||
// Serialize again, which should hit the lock
|
||||
var secondSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
|
||||
Assert.False(secondSerialization.IsCompleted);
|
||||
|
||||
// Release both instances of the syncpoint
|
||||
syncPoints[0].Continue();
|
||||
syncPoints[1].Continue();
|
||||
|
||||
// Everything should finish and only one serialization should be written
|
||||
await firstSerialization.OrTimeout();
|
||||
await secondSerialization.OrTimeout();
|
||||
|
||||
Assert.Collection(message.GetAllSerializations().Skip(numberOfSerializationsToPreCache).ToArray(),
|
||||
serializedMessage =>
|
||||
{
|
||||
Assert.Equal("test", serializedMessage.ProtocolName);
|
||||
Assert.Equal(DummyHubProtocol.DummySerialization, serializedMessage.Serialized.ToArray());
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue