diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs index 8133b10e55..3f5697c7b7 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/BroadcastBenchmark.cs @@ -40,13 +40,16 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks protocol = new MessagePackHubProtocol(); } + var options = new PipeOptions(); for (var i = 0; i < Connections; ++i) { - var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, PipeOptions.Default); + var pair = DuplexPipe.CreateConnectionPair(options, options); var connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Application, pair.Transport); var hubConnection = new HubConnectionContext(connection, Timeout.InfiniteTimeSpan, NullLoggerFactory.Instance); hubConnection.Protocol = protocol; - _hubLifetimeManager.OnConnectedAsync(hubConnection).Wait(); + _hubLifetimeManager.OnConnectedAsync(hubConnection).GetAwaiter().GetResult(); + + _ = ConsumeAsync(connection.Application); } _hubContext = new HubContext(_hubLifetimeManager); @@ -57,5 +60,26 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { return _hubContext.Clients.All.SendAsync("Method"); } + + // Consume the data written to the transport + private static async Task ConsumeAsync(IDuplexPipe application) + { + while (true) + { + var result = await application.Input.ReadAsync(); + var buffer = result.Buffer; + + if (!buffer.IsEmpty) + { + application.Input.AdvanceTo(buffer.End); + } + else if (result.IsCompleted) + { + break; + } + } + + application.Input.Complete(); + } } } diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 341a1959de..9137c79a5f 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -68,7 +68,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { } - public override Task WriteAsync(HubMessage message) + public override ValueTask WriteAsync(HubMessage message) { if (message is CompletionMessage completionMessage) { @@ -78,7 +78,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } } - return Task.CompletedTask; + return default; } } diff --git a/samples/ChatSample/PresenceHubLifetimeManager.cs b/samples/ChatSample/PresenceHubLifetimeManager.cs index 587b97de3e..485bac061c 100644 --- a/samples/ChatSample/PresenceHubLifetimeManager.cs +++ b/samples/ChatSample/PresenceHubLifetimeManager.cs @@ -36,7 +36,7 @@ namespace ChatSample where THubLifetimeManager : HubLifetimeManager where THub : HubWithPresence { - private readonly HubConnectionList _connections = new HubConnectionList(); + private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly IUserTracker _userTracker; private readonly IServiceScopeFactory _serviceScopeFactory; private readonly ILogger _logger; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs index 03f07a5b6d..2385bc76db 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/DefaultHubLifetimeManager.cs @@ -12,7 +12,7 @@ namespace Microsoft.AspNetCore.SignalR { public class DefaultHubLifetimeManager : HubLifetimeManager where THub : Hub { - private readonly HubConnectionList _connections = new HubConnectionList(); + private readonly HubConnectionStore _connections = new HubConnectionStore(); private readonly HubGroupList _groups = new HubGroupList(); private readonly ILogger _logger; @@ -69,18 +69,37 @@ namespace Microsoft.AspNetCore.SignalR public override Task SendAllAsync(string methodName, object[] args) { - return SendAllWhere(methodName, args, c => true); - } + List tasks = null; + var message = CreateInvocationMessage(methodName, args); - private Task SendAllWhere(string methodName, object[] args, Func include) - { - var count = _connections.Count; - if (count == 0) + foreach (var connection in _connections) + { + var task = connection.WriteAsync(message); + + if (!task.IsCompletedSuccessfully) + { + if (tasks == null) + { + tasks = new List(); + } + + tasks.Add(task.AsTask()); + } + } + + // No async + if (tasks == null) { return Task.CompletedTask; } - var tasks = new List(count); + // Some connections are slow + return Task.WhenAll(tasks); + } + + private Task SendAllWhere(string methodName, object[] args, Func include) + { + List tasks = null; var message = CreateInvocationMessage(methodName, args); foreach (var connection in _connections) @@ -90,9 +109,25 @@ namespace Microsoft.AspNetCore.SignalR continue; } - tasks.Add(SafeWriteAsync(connection, message)); + var task = connection.WriteAsync(message); + + if (!task.IsCompletedSuccessfully) + { + if (tasks == null) + { + tasks = new List(); + } + + tasks.Add(task.AsTask()); + } } + if (tasks == null) + { + return Task.CompletedTask; + } + + // Some connections are slow return Task.WhenAll(tasks); } @@ -112,7 +147,7 @@ namespace Microsoft.AspNetCore.SignalR var message = CreateInvocationMessage(methodName, args); - return SafeWriteAsync(connection, message); + return connection.WriteAsync(message).AsTask(); } public override Task SendGroupAsync(string groupName, string methodName, object[] args) @@ -126,7 +161,7 @@ namespace Microsoft.AspNetCore.SignalR if (group != null) { var message = CreateInvocationMessage(methodName, args); - var tasks = group.Values.Select(c => SafeWriteAsync(c, message)); + var tasks = group.Values.Select(c => c.WriteAsync(message).AsTask()); return Task.WhenAll(tasks); } @@ -149,7 +184,7 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - tasks.Add(Task.WhenAll(group.Values.Select(c => SafeWriteAsync(c, message)))); + tasks.Add(Task.WhenAll(group.Values.Select(c => c.WriteAsync(message).AsTask()))); } } @@ -168,7 +203,7 @@ namespace Microsoft.AspNetCore.SignalR { var message = CreateInvocationMessage(methodName, args); var tasks = group.Values.Where(connection => !excludedIds.Contains(connection.ConnectionId)) - .Select(c => SafeWriteAsync(c, message)); + .Select(c => c.WriteAsync(message).AsTask()); return Task.WhenAll(tasks); } @@ -222,30 +257,5 @@ namespace Microsoft.AspNetCore.SignalR return userIds.Contains(connection.UserIdentifier); }); } - - // This method is to protect against connections throwing synchronously when writing to them and preventing other connections from being written to - private async Task SafeWriteAsync(HubConnectionContext connection, InvocationMessage message) - { - try - { - await connection.WriteAsync(message); - } - // This exception isn't interesting to users - catch (Exception ex) - { - Log.FailedWritingMessage(_logger, ex); - } - } - - private static class Log - { - private static readonly Action _failedWritingMessage = - LoggerMessage.Define(LogLevel.Warning, new EventId(1, "FailedWritingMessage"), "Failed writing message."); - - public static void FailedWritingMessage(ILogger logger, Exception exception) - { - _failedWritingMessage(logger, exception); - } - } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs index 741d91d8b9..8026a1ad00 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionContext.cs @@ -74,20 +74,78 @@ namespace Microsoft.AspNetCore.SignalR public int? LocalPort => Features.Get()?.LocalPort; - public virtual async Task WriteAsync(HubMessage message) + public virtual ValueTask WriteAsync(HubMessage message) { - await _writeLock.WaitAsync(); + // We were unable to get the lock so take the slow async path of waiting for the semaphore + if (!_writeLock.Wait(0)) + { + return new ValueTask(WriteSlowAsync(message)); + } + // This method should never throw synchronously + var task = WriteCore(message); + + // The write didn't complete synchronously so await completion + if (!task.IsCompletedSuccessfully) + { + return new ValueTask(CompleteWriteAsync(task)); + } + + // Otherwise, release the lock acquired when entering WriteAsync + _writeLock.Release(); + + return default; + } + + private ValueTask WriteCore(HubMessage message) + { try { - // This will internally cache the buffer for each unique HubProtocol/DataEncoder combination + // This will internally cache the buffer for each unique HubProtocol // So that we don't serialize the HubMessage for every single connection var buffer = message.WriteMessage(Protocol); + _connectionContext.Transport.Output.Write(buffer); - Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); + return _connectionContext.Transport.Output.FlushAsync(); + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); - await _connectionContext.Transport.Output.FlushAsync(); + return new ValueTask(new FlushResult(isCanceled: false, isCompleted: true)); + } + } + + private async Task CompleteWriteAsync(ValueTask task) + { + try + { + await task; + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); + } + finally + { + // Release the lock acquired when entering WriteAsync + _writeLock.Release(); + } + } + + private async Task WriteSlowAsync(HubMessage message) + { + try + { + // Failed to get the lock immediately when entering WriteAsync so await until it is available + await _writeLock.WaitAsync(); + + await WriteCore(message); + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); } finally { @@ -95,23 +153,33 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task TryWritePingAsync() + private ValueTask TryWritePingAsync() { // Don't wait for the lock, if it returns false that means someone wrote to the connection // and we don't need to send a ping anymore - if (!await _writeLock.WaitAsync(0)) + if (!_writeLock.Wait(0)) { - return; + return default; } + return new ValueTask(TryWritePingSlowAsync()); + } + + private async Task TryWritePingSlowAsync() + { try { Debug.Assert(_cachedPingMessage != null); + _connectionContext.Transport.Output.Write(_cachedPingMessage); - Interlocked.Exchange(ref _lastSendTimestamp, Stopwatch.GetTimestamp()); - await _connectionContext.Transport.Output.FlushAsync(); + + Log.SentPing(_logger); + } + catch (Exception ex) + { + Log.FailedWritingMessage(_logger, ex); } finally { @@ -254,21 +322,21 @@ namespace Microsoft.AspNetCore.SignalR private void KeepAliveTick() { + var timestamp = Stopwatch.GetTimestamp(); // Implements the keep-alive tick behavior // Each tick, we check if the time since the last send is larger than the keep alive duration (in ticks). // If it is, we send a ping frame, if not, we no-op on this tick. This means that in the worst case, the // true "ping rate" of the server could be (_hubOptions.KeepAliveInterval + HubEndPoint.KeepAliveTimerInterval), // because if the interval elapses right after the last tick of this timer, it won't be detected until the next tick. - - if (Stopwatch.GetTimestamp() - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration) + if (timestamp - Interlocked.Read(ref _lastSendTimestamp) > _keepAliveDuration) { // Haven't sent a message for the entire keep-alive duration, so send a ping. // If the transport channel is full, this will fail, but that's OK because // adding a Ping message when the transport is full is unnecessary since the // transport is still in the process of sending frames. - _ = TryWritePingAsync(); - Log.SentPing(_logger); + + Interlocked.Exchange(ref _lastSendTimestamp, timestamp); } } @@ -308,6 +376,9 @@ namespace Microsoft.AspNetCore.SignalR private static readonly Action _handshakeFailed = LoggerMessage.Define(LogLevel.Error, new EventId(5, "HandshakeFailed"), "Failed connection handshake."); + private static readonly Action _failedWritingMessage = + LoggerMessage.Define(LogLevel.Debug, new EventId(6, "FailedWritingMessage"), "Failed writing message."); + public static void HandshakeComplete(ILogger logger, string hubProtocol) { _handshakeComplete(logger, hubProtocol, null); @@ -332,6 +403,11 @@ namespace Microsoft.AspNetCore.SignalR { _handshakeFailed(logger, exception); } + + public static void FailedWritingMessage(ILogger logger, Exception exception) + { + _failedWritingMessage(logger, exception); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionList.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionStore.cs similarity index 58% rename from src/Microsoft.AspNetCore.SignalR.Core/HubConnectionList.cs rename to src/Microsoft.AspNetCore.SignalR.Core/HubConnectionStore.cs index f398f87f2b..2a2842950c 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionList.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubConnectionStore.cs @@ -7,7 +7,7 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.SignalR { - public class HubConnectionList : IReadOnlyCollection + public class HubConnectionStore { private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); @@ -32,14 +32,29 @@ namespace Microsoft.AspNetCore.SignalR _connections.TryRemove(connection.ConnectionId, out _); } - public IEnumerator GetEnumerator() + public Enumerator GetEnumerator() { - return _connections.Values.GetEnumerator(); + return new Enumerator(this); } - IEnumerator IEnumerable.GetEnumerator() + public readonly struct Enumerator : IEnumerator { - return GetEnumerator(); + private readonly IEnumerator> _enumerator; + + public Enumerator(HubConnectionStore hubConnectionList) + { + _enumerator = hubConnectionList._connections.GetEnumerator(); + } + + public HubConnectionContext Current => _enumerator.Current.Value; + + object IEnumerator.Current => Current; + + public void Dispose() => _enumerator.Dispose(); + + public bool MoveNext() => _enumerator.MoveNext(); + + public void Reset() => _enumerator.Reset(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 125924eeb0..9fd3d2ca12 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager : HubLifetimeManager, IDisposable where THub : Hub { - private readonly HubConnectionList _connections = new HubConnectionList(); + private readonly HubConnectionStore _connections = new HubConnectionStore(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary _groups = new ConcurrentDictionary(); private readonly IConnectionMultiplexer _redisServerConnection; @@ -665,7 +665,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis private class GroupData { public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); - public HubConnectionList Connections = new HubConnectionList(); + public HubConnectionStore Connections = new HubConnectionStore(); } private interface IRedisFeature diff --git a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs index cbad435b55..f48f30f3ba 100644 --- a/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs +++ b/test/Microsoft.AspNetCore.SignalR.Redis.Tests/Docker.cs @@ -63,8 +63,8 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests // create and run docker container, remove automatically when stopped, map 6379 from the container to 6379 localhost // use static name 'redisTestContainer' so if the container doesn't get removed we don't keep adding more // use redis base docker image - // 10 second timeout to allow redis image to be downloaded - RunProcessAndThrowIfFailed(_path, $"run --rm -p 6379:6379 --name {_dockerContainerName} -d redis", logger, TimeSpan.FromSeconds(10)); + // 20 second timeout to allow redis image to be downloaded, should be a rare occurance, only happening when a new version is released + RunProcessAndThrowIfFailed(_path, $"run --rm -p 6379:6379 --name {_dockerContainerName} -d redis", logger, TimeSpan.FromSeconds(20)); } public void Stop(ILogger logger) @@ -125,7 +125,11 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests process.BeginErrorReadLine(); process.BeginOutputReadLine(); - process.WaitForExit((int)timeout.TotalMilliseconds); + if (!process.WaitForExit((int)timeout.TotalMilliseconds)) + { + process.Close(); + logger.LogError("Closing process '{processName}' because it is running longer than the configured timeout.", fileName); + } output = string.Join(Environment.NewLine, lines); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs index dd87b3527f..02ccd24d62 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultHubLifetimeManagerTests.cs @@ -110,54 +110,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - [Fact] - public async Task SendConnectionAsyncDoesNotThrowIfConnectionFailsToWrite() - { - using (var client = new TestClient()) - { - var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); - - var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); - // Force an exception when writing to connection - connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Throws(new Exception("Message")); - var connection = connectionMock.Object; - - await manager.OnConnectedAsync(connection).OrTimeout(); - - await manager.SendConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout(); - } - } - - [Fact] - public async Task SendAllAsyncSendsToAllConnectionsEvenWhenSomeFailToSend() - { - using (var client = new TestClient()) - using (var client2 = new TestClient()) - { - var manager = new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); - - var connectionMock = HubConnectionContextUtils.CreateMock(client.Connection); - var connectionMock2 = HubConnectionContextUtils.CreateMock(client2.Connection); - - var tcs = new TaskCompletionSource(); - var tcs2 = new TaskCompletionSource(); - // Force an exception when writing to connection - connectionMock.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs.TrySetResult(null)).Throws(new Exception("Message")); - connectionMock2.Setup(m => m.WriteAsync(It.IsAny())).Callback(() => tcs2.TrySetResult(null)).Throws(new Exception("Message")); - var connection = connectionMock.Object; - var connection2 = connectionMock2.Object; - - await manager.OnConnectedAsync(connection).OrTimeout(); - await manager.OnConnectedAsync(connection2).OrTimeout(); - - await manager.SendAllAsync("Hello", new object[] { "World" }).OrTimeout(); - - // Check that all connections were "written" to - await tcs.Task.OrTimeout(); - await tcs2.Task.OrTimeout(); - } - } - [Fact] public async Task SendConnectionAsyncOnNonExistentConnectionNoops() {