Merge pull request #2104 from aspnet/release/2.1

fix #2078 by adding locking (#2079)
This commit is contained in:
Andrew Stanton-Nurse 2018-04-19 15:32:27 -07:00 committed by GitHub
commit bea09f5f94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 206 additions and 53 deletions

View File

@ -16,11 +16,13 @@ namespace Microsoft.AspNetCore.SignalR
private SerializedMessage _cachedItem1;
private SerializedMessage _cachedItem2;
private IList<SerializedMessage> _cachedItems;
private readonly object _lock = new object();
public HubMessage Message { get; }
public SerializedHubMessage(IReadOnlyList<SerializedMessage> messages)
{
// A lock isn't needed here because nobody has access to this type until the constructor finishes.
for (var i = 0; i < messages.Count; i++)
{
var message = messages[i];
@ -35,23 +37,58 @@ namespace Microsoft.AspNetCore.SignalR
public ReadOnlyMemory<byte> GetSerializedMessage(IHubProtocol protocol)
{
if (!TryGetCached(protocol.Name, out var serialized))
lock (_lock)
{
if (Message == null)
if (!TryGetCached(protocol.Name, out var serialized))
{
throw new InvalidOperationException(
"This message was received from another server that did not have the requested protocol available.");
if (Message == null)
{
throw new InvalidOperationException(
"This message was received from another server that did not have the requested protocol available.");
}
serialized = protocol.GetMessageBytes(Message);
SetCache(protocol.Name, serialized);
}
serialized = protocol.GetMessageBytes(Message);
SetCache(protocol.Name, serialized);
return serialized;
}
}
return serialized;
// Used for unit testing.
internal IReadOnlyList<SerializedMessage> GetAllSerializations()
{
// Even if this is only used in tests, let's do it right.
lock (_lock)
{
if (_cachedItem1.ProtocolName == null)
{
return Array.Empty<SerializedMessage>();
}
var list = new List<SerializedMessage>(2);
list.Add(_cachedItem1);
if (_cachedItem2.ProtocolName != null)
{
list.Add(_cachedItem2);
if (_cachedItems != null)
{
list.AddRange(_cachedItems);
}
}
return list;
}
}
private void SetCache(string protocolName, ReadOnlyMemory<byte> serialized)
{
// We set the fields before moving on to the list, if we need it to hold more than 2 items.
// We have to read/write these fields under the lock because the structs might tear and another
// thread might observe them half-assigned
if (_cachedItem1.ProtocolName == null)
{
_cachedItem1 = new SerializedMessage(protocolName, serialized);

View File

@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.AspNetCore.SignalR.Redis.Internal;
using Microsoft.AspNetCore.SignalR.Tests;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Redis.Tests
@ -150,7 +151,15 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
Assert.Equal(
expected.Message.GetSerializedMessage(hubProtocol).ToArray(),
decoded.Message.GetSerializedMessage(hubProtocol).ToArray());
Assert.Equal(1, hubProtocol.SerializationCount);
var writtenMessages = hubProtocol.GetWrittenMessages();
Assert.Collection(writtenMessages,
actualMessage =>
{
var invocation = Assert.IsType<InvocationMessage>(actualMessage);
Assert.Same(_testMessage.Target, invocation.Target);
Assert.Same(_testMessage.Arguments, invocation.Arguments);
});
}
}
@ -186,46 +195,5 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
Encoded = encoded;
}
}
public class DummyHubProtocol : IHubProtocol
{
public int SerializationCount { get; private set; }
public string Name { get; }
public int Version => 1;
public TransferFormat TransferFormat => TransferFormat.Text;
public DummyHubProtocol(string name)
{
Name = name;
}
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
throw new NotSupportedException();
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
output.Write(GetMessageBytes(message).Span);
}
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
SerializationCount += 1;
// Assert that we got the test message
var invocation = Assert.IsType<InvocationMessage>(message);
Assert.Same(_testMessage.Target, invocation.Target);
Assert.Same(_testMessage.Arguments, invocation.Arguments);
return new byte[] { 0x2A };
}
public bool IsVersionSupported(int version)
{
throw new NotSupportedException();
}
}
}
}

View File

@ -0,0 +1,61 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class DummyHubProtocol : IHubProtocol
{
private readonly Action _onWrite;
private readonly object _lock = new object();
private readonly List<HubMessage> _writtenMessages = new List<HubMessage>();
public static readonly byte[] DummySerialization = new byte[] { 0x2A };
public string Name { get; }
public int Version => 1;
public TransferFormat TransferFormat => TransferFormat.Text;
public DummyHubProtocol(string name, Action onWrite = null)
{
_onWrite = onWrite ?? (() => { });
Name = name;
}
public IReadOnlyList<HubMessage> GetWrittenMessages()
{
lock (_lock)
{
return _writtenMessages.ToArray();
}
}
public bool TryParseMessage(ref ReadOnlySequence<byte> input, IInvocationBinder binder, out HubMessage message)
{
throw new NotSupportedException();
}
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
output.Write(GetMessageBytes(message).Span);
}
public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message)
{
_onWrite();
lock (_lock)
{
_writtenMessages.Add(message);
}
return DummySerialization;
}
public bool IsVersionSupported(int version)
{
throw new NotSupportedException();
}
}
}

View File

@ -4,13 +4,12 @@
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
namespace Microsoft.AspNetCore.SignalR.Tests
{
// Possibly useful as a general-purpose async testing helper?
public class SyncPoint
{
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>();
private readonly TaskCompletionSource<object> _atSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
private readonly TaskCompletionSource<object> _continueFromSyncPoint = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
/// <summary>
/// Waits for the code-under-test to reach <see cref="WaitToContinue"/>.

View File

@ -0,0 +1,88 @@
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Protocol;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class SerializedHubMessageTests
{
[Fact]
public void GetSerializedMessageSerializesUsingHubProtocolIfNoCacheAvailable()
{
var invocation = new InvocationMessage("Foo", new object[0]);
var message = new SerializedHubMessage(invocation);
var protocol = new DummyHubProtocol("p1");
var serialized = message.GetSerializedMessage(protocol);
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
Assert.Collection(protocol.GetWrittenMessages(),
actualMessage => Assert.Same(invocation, actualMessage));
}
[Fact]
public void GetSerializedMessageReturnsCachedSerializationIfAvailable()
{
var invocation = new InvocationMessage("Foo", new object[0]);
var message = new SerializedHubMessage(invocation);
var protocol = new DummyHubProtocol("p1");
// This should cache it
_ = message.GetSerializedMessage(protocol);
// Get it again
var serialized = message.GetSerializedMessage(protocol);
Assert.Equal(DummyHubProtocol.DummySerialization, serialized.ToArray());
// We should still only have written one message
Assert.Collection(protocol.GetWrittenMessages(),
actualMessage => Assert.Same(invocation, actualMessage));
}
[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(5)]
public async Task SerializingTwoMessagesFromTheSameProtocolSimultaneouslyResultsInOneCachedItemAsync(int numberOfSerializationsToPreCache)
{
var invocation = new InvocationMessage("Foo", new object[0]);
var message = new SerializedHubMessage(invocation);
// "Pre-cache" the requested number of serializations (so we can test scenarios involving each of the fields and the fallback list)
for (var i = 0; i < numberOfSerializationsToPreCache; i++)
{
_ = message.GetSerializedMessage(new DummyHubProtocol($"p{i}"));
}
var onWrite = SyncPoint.Create(2, out var syncPoints);
var protocol = new DummyHubProtocol("test", () => onWrite().Wait());
// Serialize once, but hold at the Hub Protocol
var firstSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
await syncPoints[0].WaitForSyncPoint();
// Serialize again, which should hit the lock
var secondSerialization = Task.Run(() => message.GetSerializedMessage(protocol));
Assert.False(secondSerialization.IsCompleted);
// Release both instances of the syncpoint
syncPoints[0].Continue();
syncPoints[1].Continue();
// Everything should finish and only one serialization should be written
await firstSerialization.OrTimeout();
await secondSerialization.OrTimeout();
Assert.Collection(message.GetAllSerializations().Skip(numberOfSerializationsToPreCache).ToArray(),
serializedMessage =>
{
Assert.Equal("test", serializedMessage.ProtocolName);
Assert.Equal(DummyHubProtocol.DummySerialization, serializedMessage.Serialized.ToArray());
});
}
}
}