Improved allocations and throughput for broadcast scenarios (#1660)

- Don't allocate when enumerating connections
- Don't allocate tasks unless we truly go async
- Don't get the timestamp, just write the pings always (if there's no ongoing write)
- Track the time since last keep alive write instead of the last write
- ValueTask all the things!
- Renamed HubConnectionList to HubConnectionStore
This commit is contained in:
David Fowler 2018-03-21 09:03:36 -07:00 committed by GitHub
parent a2764109b0
commit 6583e5fb47
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 189 additions and 112 deletions

View File

@ -40,13 +40,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
protocol = new MessagePackHubProtocol();
}
var options = new PipeOptions();
for (var i = 0; i < Connections; ++i)
{
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default);
var pair = DuplexPipe.CreateConnectionPair(options, options);
var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport);
var hubConnection = new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance);
hubConnection.Protocol = protocol;
_hubLifetimeManager.OnConnectedAsync(hubConnection).Wait();
_hubLifetimeManager.OnConnectedAsync(hubConnection).GetAwaiter().GetResult();
_ = ConsumeAsync(connection.Application);
}
_hubContext = new HubContext<Hub>(_hubLifetimeManager);
@ -57,5 +60,26 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
return _hubContext.Clients.All.SendAsync("Method");
}
// Consume the data written to the transport
private static async Task ConsumeAsync(IDuplexPipe application)
{
while (true)
{
var result = await application.Input.ReadAsync();
var buffer = result.Buffer;
if (!buffer.IsEmpty)
{
application.Input.AdvanceTo(buffer.End);
}
else if (result.IsCompleted)
{
break;
}
}
application.Input.Complete();
}
}
}

View File

@ -68,7 +68,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
}
public override Task WriteAsync(HubMessage message)
public override ValueTask WriteAsync(HubMessage message)
{
if (message is CompletionMessage completionMessage)
{
@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
}
}
return Task.CompletedTask;
return default;
}
}

View File

@ -36,7 +36,7 @@ namespace ChatSample
where THubLifetimeManager : HubLifetimeManager<THub>
where THub : HubWithPresence
{
private readonly HubConnectionList _connections = new HubConnectionList();
private readonly HubConnectionStore _connections = new HubConnectionStore();
private readonly IUserTracker<THub> _userTracker;
private readonly IServiceScopeFactory _serviceScopeFactory;
private readonly ILogger _logger;

View File

@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub> where THub : Hub
{
private readonly HubConnectionList _connections = new HubConnectionList();
private readonly HubConnectionStore _connections = new HubConnectionStore();
private readonly HubGroupList _groups = new HubGroupList();
private readonly ILogger _logger;
@ -69,18 +69,37 @@ namespace Microsoft.AspNetCore.SignalR
public override Task SendAllAsync(string methodName, object[] args)
{
return SendAllWhere(methodName, args, c => true);
}
List<Task> tasks = null;
var message = CreateInvocationMessage(methodName, args);
private Task SendAllWhere(string methodName, object[] args, Func<HubConnectionContext, bool> include)
{
var count = _connections.Count;
if (count == 0)
foreach (var connection in _connections)
{
var task = connection.WriteAsync(message);
if (!task.IsCompletedSuccessfully)
{
if (tasks == null)
{
tasks = new List<Task>();
}
tasks.Add(task.AsTask());
}
}
// No async
if (tasks == null)
{
return Task.CompletedTask;
}
var tasks = new List<Task>(count);
// Some connections are slow
return Task.WhenAll(tasks);
}
private Task SendAllWhere(string methodName, object[] args, Func<HubConnectionContext, bool> include)
{
List<Task> tasks = null;
var message = CreateInvocationMessage(methodName, args);
foreach (var connection in _connections)
@ -90,9 +109,25 @@ namespace Microsoft.AspNetCore.SignalR
continue;
}
tasks.Add(SafeWriteAsync(connection, message));
var task = connection.WriteAsync(message);
if (!task.IsCompletedSuccessfully)
{
if (tasks == null)
{
tasks = new List<Task>();
}
tasks.Add(task.AsTask());
}
}
if (tasks == null)
{
return Task.CompletedTask;
}
// Some connections are slow
return Task.WhenAll(tasks);
}
@ -112,7 +147,7 @@ namespace Microsoft.AspNetCore.SignalR
var message = CreateInvocationMessage(methodName, args);
return SafeWriteAsync(connection, message);
return connection.WriteAsync(message).AsTask();
}
public override Task SendGroupAsync(string groupName, string methodName, object[] args)
@ -126,7 +161,7 @@ namespace Microsoft.AspNetCore.SignalR
if (group != null)
{
var message = CreateInvocationMessage(methodName, args);
var tasks = group.Values.Select(c => SafeWriteAsync(c, message));
var tasks = group.Values.Select(c => c.WriteAsync(message).AsTask());
return Task.WhenAll(tasks);
}
@ -149,7 +184,7 @@ namespace Microsoft.AspNetCore.SignalR
var group = _groups[groupName];
if (group != null)
{
tasks.Add(Task.WhenAll(group.Values.Select(c => SafeWriteAsync(c, message))));
tasks.Add(Task.WhenAll(group.Values.Select(c => c.WriteAsync(message).AsTask())));
}
}
@ -168,7 +203,7 @@ namespace Microsoft.AspNetCore.SignalR
{
var message = CreateInvocationMessage(methodName, args);
var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId))
.Select(c => SafeWriteAsync(c, message));
.Select(c => c.WriteAsync(message).AsTask());
return Task.WhenAll(tasks);
}
@ -222,30 +257,5 @@ namespace Microsoft.AspNetCore.SignalR
return userIds.Contains(connection.UserIdentifier);
});
}
// This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to
private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message)
{
try
{
await connection.WriteAsync(message);
}
// This exception isn't interesting to users
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
}
}
private static class Log
{
private static readonly Action<ILogger, Exception> _failedWritingMessage =
LoggerMessage.Define(LogLevel.Warning, new EventId(1, "FailedWritingMessage"), "Failed writing message.");
public static void FailedWritingMessage(ILogger logger, Exception exception)
{
_failedWritingMessage(logger, exception);
}
}
}
}

View File

@ -74,20 +74,78 @@ namespace Microsoft.AspNetCore.SignalR
public int? LocalPort => Features.Get<IHttpConnectionFeature>()?.LocalPort;
public virtual async Task WriteAsync(HubMessage message)
public virtual ValueTask WriteAsync(HubMessage message)
{
await _writeLock.WaitAsync();
// We were unable to get the lock so take the slow async path of waiting for the semaphore
if (!_writeLock.Wait(0))
{
return new ValueTask(WriteSlowAsync(message));
}
// This method should never throw synchronously
var task = WriteCore(message);
// The write didn't complete synchronously so await completion
if (!task.IsCompletedSuccessfully)
{
return new ValueTask(CompleteWriteAsync(task));
}
// Otherwise, release the lock acquired when entering WriteAsync
_writeLock.Release();
return default;
}
private ValueTask<FlushResult> WriteCore(HubMessage message)
{
try
{
// This will internally cache the buffer for each unique HubProtocol/DataEncoder combination
// This will internally cache the buffer for each unique HubProtocol
// So that we don't serialize the HubMessage for every single connection
var buffer = message.WriteMessage(Protocol);
_connectionContext.Transport.Output.Write(buffer);
Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp());
return _connectionContext.Transport.Output.FlushAsync();
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
await _connectionContext.Transport.Output.FlushAsync();
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, isCompleted: true));
}
}
private async Task CompleteWriteAsync(ValueTask<FlushResult> task)
{
try
{
await task;
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
}
finally
{
// Release the lock acquired when entering WriteAsync
_writeLock.Release();
}
}
private async Task WriteSlowAsync(HubMessage message)
{
try
{
// Failed to get the lock immediately when entering WriteAsync so await until it is available
await _writeLock.WaitAsync();
await WriteCore(message);
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
}
finally
{
@ -95,23 +153,33 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task TryWritePingAsync()
private ValueTask TryWritePingAsync()
{
// Don't wait for the lock, if it returns false that means someone wrote to the connection
// and we don't need to send a ping anymore
if (!await _writeLock.WaitAsync(0))
if (!_writeLock.Wait(0))
{
return;
return default;
}
return new ValueTask(TryWritePingSlowAsync());
}
private async Task TryWritePingSlowAsync()
{
try
{
Debug.Assert(_cachedPingMessage != null);
_connectionContext.Transport.Output.Write(_cachedPingMessage);
Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp());
await _connectionContext.Transport.Output.FlushAsync();
Log.SentPing(_logger);
}
catch (Exception ex)
{
Log.FailedWritingMessage(_logger, ex);
}
finally
{
@ -254,21 +322,21 @@ namespace Microsoft.AspNetCore.SignalR
private void KeepAliveTick()
{
var timestamp = Stopwatch.GetTimestamp();
// Implements the keep-alive tick behavior
// Each tick, we check if the time since the last send is larger than the keep alive duration (in ticks).
// If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the
// true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval),
// because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick.
if (Stopwatch.GetTimestamp() - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration)
if (timestamp - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration)
{
// Haven't sent a message for the entire keep-alive duration, so send a ping.
// If the transport channel is full, this will fail, but that's OK because
// adding a Ping message when the transport is full is unnecessary since the
// transport is still in the process of sending frames.
_ = TryWritePingAsync();
Log.SentPing(_logger);
Interlocked.Exchange(ref _lastSendTimestamp, timestamp);
}
}
@ -308,6 +376,9 @@ namespace Microsoft.AspNetCore.SignalR
private static readonly Action<ILogger, Exception> _handshakeFailed =
LoggerMessage.Define(LogLevel.Error, new EventId(5, "HandshakeFailed"), "Failed connection handshake.");
private static readonly Action<ILogger, Exception> _failedWritingMessage =
LoggerMessage.Define(LogLevel.Debug, new EventId(6, "FailedWritingMessage"), "Failed writing message.");
public static void HandshakeComplete(ILogger logger, string hubProtocol)
{
_handshakeComplete(logger, hubProtocol, null);
@ -332,6 +403,11 @@ namespace Microsoft.AspNetCore.SignalR
{
_handshakeFailed(logger, exception);
}
public static void FailedWritingMessage(ILogger logger, Exception exception)
{
_failedWritingMessage(logger, exception);
}
}
}

View File

@ -7,7 +7,7 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionList : IReadOnlyCollection<HubConnectionContext>
public class HubConnectionStore
{
private readonly ConcurrentDictionary<string, HubConnectionContext> _connections = new ConcurrentDictionary<string, HubConnectionContext>();
@ -32,14 +32,29 @@ namespace Microsoft.AspNetCore.SignalR
_connections.TryRemove(connection.ConnectionId, out _);
}
public IEnumerator<HubConnectionContext> GetEnumerator()
public Enumerator GetEnumerator()
{
return _connections.Values.GetEnumerator();
return new Enumerator(this);
}
IEnumerator IEnumerable.GetEnumerator()
public readonly struct Enumerator : IEnumerator<HubConnectionContext>
{
return GetEnumerator();
private readonly IEnumerator<KeyValuePair<string, HubConnectionContext>> _enumerator;
public Enumerator(HubConnectionStore hubConnectionList)
{
_enumerator = hubConnectionList._connections.GetEnumerator();
}
public HubConnectionContext Current => _enumerator.Current.Value;
object IEnumerator.Current => Current;
public void Dispose() => _enumerator.Dispose();
public bool MoveNext() => _enumerator.MoveNext();
public void Reset() => _enumerator.Reset();
}
}
}

View File

@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable where THub : Hub
{
private readonly HubConnectionList _connections = new HubConnectionList();
private readonly HubConnectionStore _connections = new HubConnectionStore();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly IConnectionMultiplexer _redisServerConnection;
@ -665,7 +665,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public HubConnectionList Connections = new HubConnectionList();
public HubConnectionStore Connections = new HubConnectionStore();
}
private interface IRedisFeature

View File

@ -110,54 +110,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task SendConnectionAsyncDoesNotThrowIfConnectionFailsToWrite()
{
using (var client = new TestClient())
{
var manager = new DefaultHubLifetimeManager<MyHub>(new Logger<DefaultHubLifetimeManager<MyHub>>(NullLoggerFactory.Instance));
var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection);
// Force an exception when writing to connection
connectionMock.Setup(m => m.WriteAsync(It.IsAny<HubMessage>())).Throws(new Exception("Message"));
var connection = connectionMock.Object;
await manager.OnConnectedAsync(connection).OrTimeout();
await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
}
}
[Fact]
public async Task SendAllAsyncSendsToAllConnectionsEvenWhenSomeFailToSend()
{
using (var client = new TestClient())
using (var client2 = new TestClient())
{
var manager = new DefaultHubLifetimeManager<MyHub>(new Logger<DefaultHubLifetimeManager<MyHub>>(NullLoggerFactory.Instance));
var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection);
var connectionMock2 = HubConnectionContextUtils.CreateMock(client2.Connection);
var tcs = new TaskCompletionSource<object>();
var tcs2 = new TaskCompletionSource<object>();
// Force an exception when writing to connection
connectionMock.Setup(m => m.WriteAsync(It.IsAny<HubMessage>())).Callback(() => tcs.TrySetResult(null)).Throws(new Exception("Message"));
connectionMock2.Setup(m => m.WriteAsync(It.IsAny<HubMessage>())).Callback(() => tcs2.TrySetResult(null)).Throws(new Exception("Message"));
var connection = connectionMock.Object;
var connection2 = connectionMock2.Object;
await manager.OnConnectedAsync(connection).OrTimeout();
await manager.OnConnectedAsync(connection2).OrTimeout();
await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout();
// Check that all connections were "written" to
await tcs.Task.OrTimeout();
await tcs2.Task.OrTimeout();
}
}
[Fact]
public async Task SendConnectionAsyncOnNonExistentConnectionNoops()
{