From a0183b1facb6534e8be17793b1287b59ad9de2a1 Mon Sep 17 00:00:00 2001 From: Brennan Conroy Date: Mon, 25 Nov 2019 19:37:46 +0000 Subject: [PATCH] Merged PR 4506: [SignalR] Wait to complete pipe and cancel long sends --- .azure/pipelines/ci.yml | 2 +- eng/Baseline.Designer.props | 18 +- eng/Baseline.xml | 17 +- .../src/Internal/HttpConnectionContext.cs | 66 +++- .../src/Internal/HttpConnectionDispatcher.cs | 27 +- .../src/Internal/HttpConnectionManager.cs | 16 +- .../src/Internal/TaskExtensions.cs | 24 ++ .../Transports/LongPollingServerTransport.cs | 27 +- .../ServerSentEventsServerTransport.cs | 15 +- .../Transports/WebSocketsServerTransport.cs | 4 +- .../src/ServerSentEventsMessageFormatter.cs | 13 +- .../test/HttpConnectionDispatcherTests.cs | 282 +++++++++++++++++- .../test/HttpConnectionManagerTests.cs | 9 +- .../ServerSentEventsMessageFormatterTests.cs | 4 +- .../test/TestWebSocketConnectionFeature.cs | 28 +- src/SignalR/common/Shared/PipeWriterStream.cs | 11 +- .../testassets/Tests.Utils/TestClient.cs | 32 +- .../ServerSentEventsBenchmark.cs | 4 +- .../Core/src/DefaultHubLifetimeManager.cs | 27 +- .../server/Core/src/HubConnectionContext.cs | 126 ++++++-- .../server/Core/src/HubConnectionHandler.cs | 22 +- .../server/Core/src/Internal/Proxies.cs | 2 +- .../test/DefaultHubLifetimeManagerTests.cs | 243 ++++++++++++++- .../SignalR/test/HubConnectionHandlerTests.cs | 41 +++ 24 files changed, 940 insertions(+), 120 deletions(-) create mode 100644 src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs diff --git a/.azure/pipelines/ci.yml b/.azure/pipelines/ci.yml index 0418caefef..61497948bb 100644 --- a/.azure/pipelines/ci.yml +++ b/.azure/pipelines/ci.yml @@ -62,7 +62,7 @@ variables: - name: _BuildArgs value: '' - name: _SignType - valule: test + value: test - name: _PublishArgs value: '' # used for post-build phases, internal builds only diff --git a/eng/Baseline.Designer.props b/eng/Baseline.Designer.props index 1e28129d00..aeb0b64b50 100644 --- a/eng/Baseline.Designer.props +++ b/eng/Baseline.Designer.props @@ -2,7 +2,7 @@ $(MSBuildAllProjects);$(MSBuildThisFileFullPath) - 3.0.0 + 3.0.1 @@ -35,7 +35,7 @@ - 3.0.0 + 3.0.1 @@ -407,14 +407,14 @@ - 3.0.0 + 3.0.1 - + - + @@ -635,19 +635,19 @@ - 3.0.0 + 3.0.1 - 3.0.0 + 3.0.1 - 3.0.0 + 3.0.1 - 3.0.0 + 3.0.1 diff --git a/eng/Baseline.xml b/eng/Baseline.xml index 52cbb358f4..638d42dd71 100644 --- a/eng/Baseline.xml +++ b/eng/Baseline.xml @@ -1,16 +1,15 @@  - - + - + @@ -53,7 +52,7 @@ Update this list when preparing for a new patch. - + @@ -78,10 +77,10 @@ Update this list when preparing for a new patch. - - - - + + + + diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index dac620efa8..abf6b69524 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -31,6 +31,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal IHttpTransportFeature, IConnectionInherentKeepAliveFeature { + private static long _tenSeconds = TimeSpan.FromSeconds(10).Ticks; + private readonly object _stateLock = new object(); private readonly object _itemsLock = new object(); private readonly object _heartbeatLock = new object(); @@ -40,6 +42,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private IDuplexPipe _application; private IDictionary _items; + private CancellationTokenSource _sendCts; + private bool _activeSend; + private long _startedSendTime; + private readonly object _sendingLock = new object(); + internal CancellationToken SendingToken { get; private set; } + // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task private readonly TaskCompletionSource _disposeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); @@ -258,8 +266,26 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } else { - // The other transports don't close their own output, so we can do it here safely - Application?.Output.Complete(); + // Normally it isn't safe to try and acquire this lock because the Send can hold onto it for a long time if there is backpressure + // It is safe to wait for this lock now because the Send will be in one of 4 states + // 1. In the middle of a write which is in the middle of being canceled by the CancelPendingFlush above, when it throws + // an OperationCanceledException it will complete the PipeWriter which will make any other Send waiting on the lock + // throw an InvalidOperationException if they call Write + // 2. About to write and see that there is a pending cancel from the CancelPendingFlush, go to 1 to see what happens + // 3. Enters the Send and sees the Dispose state from DisposeAndRemoveAsync and releases the lock + // 4. No Send in progress + await WriteLock.WaitAsync(); + try + { + // Complete the applications read loop + Application?.Output.Complete(); + } + finally + { + WriteLock.Release(); + } + + Application?.Input.CancelPendingRead(); } } @@ -401,7 +427,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal nonClonedContext.Response.RegisterForDispose(timeoutSource); nonClonedContext.Response.RegisterForDispose(tokenSource); - var longPolling = new LongPollingServerTransport(timeoutSource.Token, Application.Input, loggerFactory); + var longPolling = new LongPollingServerTransport(timeoutSource.Token, Application.Input, loggerFactory, this); // Start the transport TransportTask = longPolling.ProcessRequestAsync(nonClonedContext, tokenSource.Token); @@ -507,6 +533,40 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal await connectionDelegate(this); } + internal void StartSendCancellation() + { + lock (_sendingLock) + { + if (_sendCts == null || _sendCts.IsCancellationRequested) + { + _sendCts = new CancellationTokenSource(); + SendingToken = _sendCts.Token; + } + _startedSendTime = DateTime.UtcNow.Ticks; + _activeSend = true; + } + } + internal void TryCancelSend(long currentTicks) + { + lock (_sendingLock) + { + if (_activeSend) + { + if (currentTicks - _startedSendTime > _tenSeconds) + { + _sendCts.Cancel(); + } + } + } + } + internal void StopSendCancellation() + { + lock (_sendingLock) + { + _activeSend = false; + } + } + private static class Log { private static readonly Action _disposingConnection = diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 9da1ea0c18..c40a80b9ce 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -142,7 +142,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal connection.SupportedFormats = TransferFormat.Text; // We only need to provide the Input channel since writing to the application is handled through /send. - var sse = new ServerSentEventsServerTransport(connection.Application.Input, connection.ConnectionId, _loggerFactory); + var sse = new ServerSentEventsServerTransport(connection.Application.Input, connection.ConnectionId, connection, _loggerFactory); await DoPersistentConnection(connectionDelegate, sse, context, connection); } @@ -216,7 +216,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal connection.Transport.Output.Complete(connection.ApplicationTask.Exception); // Wait for the transport to run - await connection.TransportTask; + // Ignore exceptions, it has been logged if there is one and the application has finished + // So there is no one to give the exception to + await connection.TransportTask.NoThrow(); // If the status code is a 204 it means the connection is done if (context.Response.StatusCode == StatusCodes.Status204NoContent) @@ -234,12 +236,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal connection.MarkInactive(); } } - else if (resultTask.IsFaulted) + else if (resultTask.IsFaulted || resultTask.IsCanceled) { // Cancel current request to release any waiting poll and let dispose acquire the lock currentRequestTcs.TrySetCanceled(); - - // transport task was faulted, we should remove the connection + // We should be able to safely dispose because there's no more data being written + // We don't need to wait for close here since we've already waited for both sides await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); } else @@ -434,6 +436,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal context.Response.StatusCode = StatusCodes.Status404NotFound; context.Response.ContentType = "text/plain"; + + // There are no writes anymore (since this is the write "loop") + // So it is safe to complete the writer + // We complete the writer here because we already have the WriteLock acquired + // and it's unsafe to complete outside of the lock + // Other code isn't guaranteed to be able to acquire the lock before another write + // even if CancelPendingFlush is called, and the other write could hang if there is backpressure + connection.Application.Output.Complete(); return; } catch (IOException ex) @@ -481,11 +491,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Log.TerminatingConection(_logger); - // Complete the receiving end of the pipe - connection.Application.Output.Complete(); - - // Dispose the connection gracefully, but don't wait for it. We assign it here so we can wait in tests - connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + // Dispose the connection, but don't wait for it. We assign it here so we can wait in tests + connection.DisposeAndRemoveTask = _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); context.Response.StatusCode = StatusCodes.Status202Accepted; context.Response.ContentType = "text/plain"; diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs index 4a97681fc0..b0f4b079fb 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionManager.cs @@ -31,6 +31,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private readonly TimerAwaitable _nextHeartbeat; private readonly ILogger _logger; private readonly ILogger _connectionLogger; + private readonly bool _useSendTimeout = true; private readonly TimeSpan _disconnectTimeout; public HttpConnectionManager(ILoggerFactory loggerFactory, IHostApplicationLifetime appLifetime) @@ -44,6 +45,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal _connectionLogger = loggerFactory.CreateLogger(); _nextHeartbeat = new TimerAwaitable(_heartbeatTickRate, _heartbeatTickRate); _disconnectTimeout = connectionOptions.Value.DisconnectTimeout ?? ConnectionOptionsSetup.DefaultDisconectTimeout; + if (AppContext.TryGetSwitch("Microsoft.AspNetCore.Http.Connections.DoNotUseSendTimeout", out var timeoutDisabled)) + { + _useSendTimeout = !timeoutDisabled; + } + // Register these last as the callbacks could run immediately appLifetime.ApplicationStarted.Register(() => Start()); appLifetime.ApplicationStopping.Register(() => CloseConnections()); @@ -155,20 +161,26 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // Capture the connection state var lastSeenUtc = connection.LastSeenUtcIfInactive; + var utcNow = DateTimeOffset.UtcNow; // Once the decision has been made to dispose we don't check the status again // But don't clean up connections while the debugger is attached. - if (!Debugger.IsAttached && lastSeenUtc.HasValue && (DateTimeOffset.UtcNow - lastSeenUtc.Value).TotalSeconds > _disconnectTimeout.TotalSeconds) + if (!Debugger.IsAttached && lastSeenUtc.HasValue && (utcNow - lastSeenUtc.Value).TotalSeconds > _disconnectTimeout.TotalSeconds) { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); // This is most likely a long polling connection. The transport here ends because - // a poll completed and has been inactive for > 5 seconds so we wait for the + // a poll completed and has been inactive for > 5 seconds so we wait for the // application to finish gracefully _ = DisposeAndRemoveAsync(connection, closeGracefully: true); } else { + if (!Debugger.IsAttached && _useSendTimeout) + { + connection.TryCancelSend(utcNow.Ticks); + } + // Tick the heartbeat, if the connection is still active connection.TickHeartbeat(); } diff --git a/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs new file mode 100644 index 0000000000..9608a67272 --- /dev/null +++ b/src/SignalR/common/Http.Connections/src/Internal/TaskExtensions.cs @@ -0,0 +1,24 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Runtime.CompilerServices; +namespace System.Threading.Tasks +{ + internal static class TaskExtensions + { + public static async Task NoThrow(this Task task) + { + await new NoThrowAwaiter(task); + } + } + internal readonly struct NoThrowAwaiter : ICriticalNotifyCompletion + { + private readonly Task _task; + public NoThrowAwaiter(Task task) { _task = task; } + public NoThrowAwaiter GetAwaiter() => this; + public bool IsCompleted => _task.IsCompleted; + // Observe exception + public void GetResult() { _ = _task.Exception; } + public void OnCompleted(Action continuation) => _task.GetAwaiter().OnCompleted(continuation); + public void UnsafeOnCompleted(Action continuation) => OnCompleted(continuation); + } +} \ No newline at end of file diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs index 02ff32ab8f..3432e37039 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/LongPollingServerTransport.cs @@ -16,12 +16,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports private readonly PipeReader _application; private readonly ILogger _logger; private readonly CancellationToken _timeoutToken; + private readonly HttpConnectionContext _connection; public LongPollingServerTransport(CancellationToken timeoutToken, PipeReader application, ILoggerFactory loggerFactory) + : this(timeoutToken, application, loggerFactory, connection: null) + { } + + public LongPollingServerTransport(CancellationToken timeoutToken, PipeReader application, ILoggerFactory loggerFactory, HttpConnectionContext connection) { _timeoutToken = timeoutToken; _application = application; + _connection = connection; + // We create the logger with a string to preserve the logging namespace after the server side transport renames. _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.LongPollingTransport"); } @@ -33,7 +40,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports var result = await _application.ReadAsync(token); var buffer = result.Buffer; - if (buffer.IsEmpty && result.IsCompleted) + if (buffer.IsEmpty && (result.IsCompleted || result.IsCanceled)) { Log.LongPolling204(_logger); context.Response.ContentType = "text/plain"; @@ -51,19 +58,22 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports try { - await context.Response.Body.WriteAsync(buffer); + _connection?.StartSendCancellation(); + await context.Response.Body.WriteAsync(buffer, _connection?.SendingToken ?? default); } finally { + _connection?.StopSendCancellation(); _application.AdvanceTo(buffer.End); } } catch (OperationCanceledException) { - // 3 cases: + // 4 cases: // 1 - Request aborted, the client disconnected (no response) // 2 - The poll timeout is hit (200) - // 3 - A new request comes in and cancels this request (204) + // 3 - SendingToken was canceled, abort the connection + // 4 - A new request comes in and cancels this request (204) // Case 1 if (context.RequestAborted.IsCancellationRequested) @@ -81,9 +91,16 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status200OK; } - else + else if (_connection?.SendingToken.IsCancellationRequested == true) { // Case 3 + context.Response.ContentType = "text/plain"; + context.Response.StatusCode = StatusCodes.Status204NoContent; + throw; + } + else + { + // Case 4 Log.LongPolling204(_logger); context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status204NoContent; diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs index 54f2ed8f38..3d5e1f6f4b 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/ServerSentEventsServerTransport.cs @@ -16,11 +16,17 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports private readonly PipeReader _application; private readonly string _connectionId; private readonly ILogger _logger; + private readonly HttpConnectionContext _connection; public ServerSentEventsServerTransport(PipeReader application, string connectionId, ILoggerFactory loggerFactory) + : this(application, connectionId, connection: null, loggerFactory) + { } + + public ServerSentEventsServerTransport(PipeReader application, string connectionId, HttpConnectionContext connection, ILoggerFactory loggerFactory) { _application = application; _connectionId = connectionId; + _connection = connection; // We create the logger with a string to preserve the logging namespace after the server side transport renames. _logger = loggerFactory.CreateLogger("Microsoft.AspNetCore.Http.Connections.Internal.Transports.ServerSentEventsTransport"); @@ -51,11 +57,17 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports try { + if (result.IsCanceled) + { + break; + } + if (!buffer.IsEmpty) { Log.SSEWritingMessage(_logger, buffer.Length); - await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, context.Response.Body); + _connection?.StartSendCancellation(); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, context.Response.Body, _connection?.SendingToken ?? default); } else if (result.IsCompleted) { @@ -64,6 +76,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports } finally { + _connection?.StopSendCancellation(); _application.AdvanceTo(buffer.End); } } diff --git a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs index d5c2c1fefb..a95041c48a 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/Transports/WebSocketsServerTransport.cs @@ -231,7 +231,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports if (WebSocketCanSend(socket)) { - await socket.SendAsync(buffer, webSocketMessageType); + _connection.StartSendCancellation(); + await socket.SendAsync(buffer, webSocketMessageType, _connection.SendingToken); } else { @@ -254,6 +255,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports } finally { + _connection.StopSendCancellation(); _application.Input.AdvanceTo(buffer.End); } } diff --git a/src/SignalR/common/Http.Connections/src/ServerSentEventsMessageFormatter.cs b/src/SignalR/common/Http.Connections/src/ServerSentEventsMessageFormatter.cs index efd2e24f0f..6e723c5168 100644 --- a/src/SignalR/common/Http.Connections/src/ServerSentEventsMessageFormatter.cs +++ b/src/SignalR/common/Http.Connections/src/ServerSentEventsMessageFormatter.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; using System.IO; +using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNetCore.Http.Connections @@ -15,19 +16,19 @@ namespace Microsoft.AspNetCore.Http.Connections private const byte LineFeed = (byte)'\n'; - public static async Task WriteMessageAsync(ReadOnlySequence payload, Stream output) + public static async Task WriteMessageAsync(ReadOnlySequence payload, Stream output, CancellationToken token) { // Payload does not contain a line feed so write it directly to output if (payload.PositionOf(LineFeed) == null) { if (payload.Length > 0) { - await output.WriteAsync(DataPrefix, 0, DataPrefix.Length); - await output.WriteAsync(payload); - await output.WriteAsync(Newline, 0, Newline.Length); + await output.WriteAsync(DataPrefix, 0, DataPrefix.Length, token); + await output.WriteAsync(payload, token); + await output.WriteAsync(Newline, 0, Newline.Length, token); } - await output.WriteAsync(Newline, 0, Newline.Length); + await output.WriteAsync(Newline, 0, Newline.Length, token); return; } @@ -37,7 +38,7 @@ namespace Microsoft.AspNetCore.Http.Connections await WriteMessageToMemory(ms, payload); ms.Position = 0; - await ms.CopyToAsync(output); + await ms.CopyToAsync(output, token); } /// diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index d543cd4f9a..2dabe1adbc 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -1050,6 +1050,178 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + private class BlockingStream : Stream + { + private readonly SyncPoint _sync; + private bool _isSSE; + public BlockingStream(SyncPoint sync, bool isSSE = false) + { + _sync = sync; + _isSSE = isSSE; + } + public override bool CanRead => throw new NotImplementedException(); + public override bool CanSeek => throw new NotImplementedException(); + public override bool CanWrite => throw new NotImplementedException(); + public override long Length => throw new NotImplementedException(); + public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); } + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + throw new NotImplementedException(); + } + public override void Flush() + { + } + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + if (_isSSE) + { + // SSE does an initial write of :\r\n that we want to ignore in testing + _isSSE = false; + return; + } + await _sync.WaitToContinue(); + cancellationToken.ThrowIfCancellationRequested(); + } +#if NETCOREAPP2_1 + public override async ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default) + { + if (_isSSE) + { + // SSE does an initial write of :\r\n that we want to ignore in testing + _isSSE = false; + return; + } + await _sync.WaitToContinue(); + cancellationToken.ThrowIfCancellationRequested(); + } +#endif + } + + [Fact] + [LogLevel(LogLevel.Debug)] + public async Task LongPollingConnectionClosesWhenSendTimeoutReached() + { + bool ExpectedErrors(WriteContext writeContext) + { + return (writeContext.LoggerName == typeof(Internal.Transports.LongPollingServerTransport).FullName && + writeContext.EventId.Name == "LongPollingTerminated") || + (writeContext.LoggerName == typeof(HttpConnectionManager).FullName && writeContext.EventId.Name == "FailedDispose"); + } + + using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors)) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = MakeRequest("/foo", connection); + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + // First poll completes immediately + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + var sync = new SyncPoint(); + context.Response.Body = new BlockingStream(sync); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + await dispatcherTask.OrTimeout(); + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + + [Fact] + [LogLevel(LogLevel.Debug)] + public async Task SSEConnectionClosesWhenSendTimeoutReached() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var context = MakeRequest("/foo", connection); + SetTransport(context, connection.TransportType); + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var sync = new SyncPoint(); + context.Response.Body = new BlockingStream(sync, isSSE: true); + var options = new HttpConnectionDispatcherOptions(); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + await dispatcherTask.OrTimeout(); + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + + [Fact] + [LogLevel(LogLevel.Debug)] + public async Task WebSocketConnectionClosesWhenSendTimeoutReached() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(Internal.Transports.WebSocketsServerTransport).FullName && + writeContext.EventId.Name == "ErrorWritingFrame"; + } + using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors)) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.WebSockets; + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var sync = new SyncPoint(); + var context = MakeRequest("/foo", connection); + SetTransport(context, connection.TransportType, sync); + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(0); + var dispatcherTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await sync.WaitForSyncPoint().OrTimeout(); + // Cancel write to response body + connection.TryCancelSend(long.MaxValue); + sync.Continue(); + await dispatcherTask.OrTimeout(); + // Connection should be removed on canceled write + Assert.False(manager.TryGetConnection(connection.ConnectionId, out var _)); + } + } + [Fact] [LogLevel(LogLevel.Trace)] public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived() @@ -1622,6 +1794,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); Assert.Equal("text/plain", deleteContext.Response.ContentType); + await connection.DisposeAndRemoveTask.OrTimeout(); + // Verify the connection was removed from the manager Assert.False(manager.TryGetConnection(connection.ConnectionToken, out _)); } @@ -1675,6 +1849,110 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task DeleteEndpointTerminatesLongPollingWithHangingApplication() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 2, resumeWriterThreshold: 1); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var pollTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(pollTask.IsCompleted); + + // Now send the second poll + pollTask = dispatcher.ExecuteAsync(context, options, app); + + // Issue the delete request and make sure the poll completes + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + + Assert.False(pollTask.IsCompleted); + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + await pollTask.OrTimeout(); + + // Verify that transport shuts down + await connection.TransportTask.OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + Assert.Equal(HttpConnectionStatus.Disposed, connection.Status); + + // Verify the connection not removed because application is hanging + Assert.True(manager.TryGetConnection(connection.ConnectionId, out _)); + } + } + + [Fact] + public async Task PollCanReceiveFinalMessageAfterAppCompletes() + { + using (StartVerifiableLog()) + { + var transportType = HttpTransportType.LongPolling; + var manager = CreateConnectionManager(LoggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = transportType; + + var waitForMessageTcs1 = new TaskCompletionSource(); + var messageTcs1 = new TaskCompletionSource(); + var waitForMessageTcs2 = new TaskCompletionSource(); + var messageTcs2 = new TaskCompletionSource(); + ConnectionDelegate connectionDelegate = async c => + { + await waitForMessageTcs1.Task.OrTimeout(); + await c.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Message1")).OrTimeout(); + messageTcs1.TrySetResult(null); + await waitForMessageTcs2.Task.OrTimeout(); + await c.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Message2")).OrTimeout(); + messageTcs2.TrySetResult(null); + }; + { + var options = new HttpConnectionDispatcherOptions(); + var context = MakeRequest("/foo", connection); + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + + // second poll should have data + waitForMessageTcs1.SetResult(null); + await messageTcs1.Task.OrTimeout(); + + var ms = new MemoryStream(); + context.Response.Body = ms; + // Now send the second poll + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + Assert.Equal("Message1", Encoding.UTF8.GetString(ms.ToArray())); + + waitForMessageTcs2.SetResult(null); + await messageTcs2.Task.OrTimeout(); + + context = MakeRequest("/foo", connection); + ms.Seek(0, SeekOrigin.Begin); + context.Response.Body = ms; + // This is the third poll which gets the final message after the app is complete + await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); + Assert.Equal("Message2", Encoding.UTF8.GetString(ms.ToArray())); + } + } + } + [Fact] public async Task NegotiateDoesNotReturnWebSocketsWhenNotAvailable() { @@ -1987,12 +2265,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests return context; } - private static void SetTransport(HttpContext context, HttpTransportType transportType) + private static void SetTransport(HttpContext context, HttpTransportType transportType, SyncPoint sync = null) { switch (transportType) { case HttpTransportType.WebSockets: - context.Features.Set(new TestWebSocketConnectionFeature()); + context.Features.Set(new TestWebSocketConnectionFeature(sync)); break; case HttpTransportType.ServerSentEvents: context.Request.Headers["Accept"] = "text/event-stream"; diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs index ade605b08a..05a29f0e73 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionManagerTests.cs @@ -235,9 +235,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests try { Assert.True(result.IsCompleted); - - // We should be able to write - await connection.Transport.Output.WriteAsync(new byte[] { 1 }); } finally { @@ -248,13 +245,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests connection.TransportTask = Task.Run(async () => { var result = await connection.Application.Input.ReadAsync(); - Assert.Equal(new byte[] { 1 }, result.Buffer.ToArray()); - connection.Application.Input.AdvanceTo(result.Buffer.End); - - result = await connection.Application.Input.ReadAsync(); try { - Assert.True(result.IsCompleted); + Assert.True(result.IsCanceled); } finally { diff --git a/src/SignalR/common/Http.Connections/test/ServerSentEventsMessageFormatterTests.cs b/src/SignalR/common/Http.Connections/test/ServerSentEventsMessageFormatterTests.cs index 2a58e8d4dd..1640752056 100644 --- a/src/SignalR/common/Http.Connections/test/ServerSentEventsMessageFormatterTests.cs +++ b/src/SignalR/common/Http.Connections/test/ServerSentEventsMessageFormatterTests.cs @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var buffer = new ReadOnlySequence(Encoding.UTF8.GetBytes(payload)); var output = new MemoryStream(); - await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output, default); Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); } @@ -32,7 +32,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var buffer = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent(Encoding.UTF8.GetBytes(payload)); var output = new MemoryStream(); - await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output, default); Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); } diff --git a/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs b/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs index f67dd94003..9bbb6894db 100644 --- a/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs +++ b/src/SignalR/common/Http.Connections/test/TestWebSocketConnectionFeature.cs @@ -5,11 +5,21 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Internal; +using Microsoft.AspNetCore.SignalR.Tests; namespace Microsoft.AspNetCore.Http.Connections.Tests { internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable { + public TestWebSocketConnectionFeature() + { } + public TestWebSocketConnectionFeature(SyncPoint sync) + { + _sync = sync; + } + + private readonly SyncPoint _sync; private readonly TaskCompletionSource _accepted = new TaskCompletionSource(); public bool IsWebSocketRequest => true; @@ -27,8 +37,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var clientToServer = Channel.CreateUnbounded(); var serverToClient = Channel.CreateUnbounded(); - var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer); - var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer); + var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer, _sync); + var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer, _sync); Client = clientSocket; SubProtocol = context.SubProtocol; @@ -45,16 +55,18 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { private readonly ChannelReader _input; private readonly ChannelWriter _output; + private readonly SyncPoint _sync; private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private WebSocketState _state; private WebSocketMessage _internalBuffer = new WebSocketMessage(); - public WebSocketChannel(ChannelReader input, ChannelWriter output) + public WebSocketChannel(ChannelReader input, ChannelWriter output, SyncPoint sync = null) { _input = input; _output = output; + _sync = sync; } public override WebSocketCloseStatus? CloseStatus => _closeStatus; @@ -173,11 +185,17 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests throw new InvalidOperationException("Unexpected close"); } - public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { + if (_sync != null) + { + await _sync.WaitToContinue(); + } + cancellationToken.ThrowIfCancellationRequested(); + var copy = new byte[buffer.Count]; Buffer.BlockCopy(buffer.Array, buffer.Offset, copy, 0, buffer.Count); - return SendMessageAsync(new WebSocketMessage + await SendMessageAsync(new WebSocketMessage { Buffer = copy, MessageType = messageType, diff --git a/src/SignalR/common/Shared/PipeWriterStream.cs b/src/SignalR/common/Shared/PipeWriterStream.cs index ddb7960b63..ee271aaf05 100644 --- a/src/SignalR/common/Shared/PipeWriterStream.cs +++ b/src/SignalR/common/Shared/PipeWriterStream.cs @@ -77,7 +77,16 @@ namespace System.IO.Pipelines _length += source.Length; var task = _pipeWriter.WriteAsync(source); - if (!task.IsCompletedSuccessfully) + + if (task.IsCompletedSuccessfully) + { + // Cancellation can be triggered by PipeWriter.CancelPendingFlush + if (task.Result.IsCanceled) + { + throw new OperationCanceledException(); + } + } + else { return WriteSlowAsync(task); } diff --git a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs index 6183691ad4..ddb7dee201 100644 --- a/src/SignalR/common/testassets/Tests.Utils/TestClient.cs +++ b/src/SignalR/common/testassets/Tests.Utils/TestClient.cs @@ -37,9 +37,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests public TransferFormat ActiveFormat { get; set; } - public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null) + public TestClient(IHubProtocol protocol = null, IInvocationBinder invocationBinder = null, string userIdentifier = null, long pauseWriterThreshold = 32768) { - var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var options = new PipeOptions(readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false, + pauseWriterThreshold: pauseWriterThreshold, resumeWriterThreshold: pauseWriterThreshold / 2); var pair = DuplexPipe.CreateConnectionPair(options, options); Connection = new DefaultConnectionContext(Guid.NewGuid().ToString(), pair.Transport, pair.Application); @@ -70,16 +71,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { if (sendHandshakeRequestMessage) { - var memoryBufferWriter = MemoryBufferWriter.Get(); - try - { - HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); - await Connection.Application.Output.WriteAsync(memoryBufferWriter.ToArray()); - } - finally - { - MemoryBufferWriter.Return(memoryBufferWriter); - } + await Connection.Application.Output.WriteAsync(GetHandshakeRequestMessage()); } var connection = handler.OnConnectedAsync(Connection); @@ -257,7 +249,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests } else { - // read first message out of the incoming data + // read first message out of the incoming data if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var responseMessage)) { return responseMessage; @@ -312,6 +304,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + public byte[] GetHandshakeRequestMessage() + { + var memoryBufferWriter = MemoryBufferWriter.Get(); + try + { + HandshakeProtocol.WriteRequestMessage(new HandshakeRequestMessage(_protocol.Name, _protocol.Version), memoryBufferWriter); + return memoryBufferWriter.ToArray(); + } + finally + { + MemoryBufferWriter.Return(memoryBufferWriter); + } + } + private class DefaultInvocationBinder : IInvocationBinder { public IReadOnlyList GetParameterTypes(string methodName) diff --git a/src/SignalR/perf/Microbenchmarks/ServerSentEventsBenchmark.cs b/src/SignalR/perf/Microbenchmarks/ServerSentEventsBenchmark.cs index fd4357c952..5b20e7209d 100644 --- a/src/SignalR/perf/Microbenchmarks/ServerSentEventsBenchmark.cs +++ b/src/SignalR/perf/Microbenchmarks/ServerSentEventsBenchmark.cs @@ -61,7 +61,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks _parser = new ServerSentEventsMessageParser(); _rawData = new ReadOnlySequence(protocol.GetMessageBytes(hubMessage)); var ms = new MemoryStream(); - ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, ms).GetAwaiter().GetResult(); + ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, ms, default).GetAwaiter().GetResult(); _sseFormattedData = ms.ToArray(); } @@ -81,7 +81,7 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks [Benchmark] public Task WriteSingleMessage() { - return ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, Stream.Null); + return ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, Stream.Null, default); } public enum Message diff --git a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs index 3c835ab933..02609bce1e 100644 --- a/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs +++ b/src/SignalR/server/Core/src/DefaultHubLifetimeManager.cs @@ -82,10 +82,10 @@ namespace Microsoft.AspNetCore.SignalR /// public override Task SendAllAsync(string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, null); + return SendToAllConnections(methodName, args, include: null, cancellationToken); } - private Task SendToAllConnections(string methodName, object[] args, Func include) + private Task SendToAllConnections(string methodName, object[] args, Func include, CancellationToken cancellationToken) { List tasks = null; SerializedHubMessage message = null; @@ -103,7 +103,7 @@ namespace Microsoft.AspNetCore.SignalR message = CreateSerializedInvocationMessage(methodName, args); } - var task = connection.WriteAsync(message); + var task = connection.WriteAsync(message, cancellationToken); if (!task.IsCompletedSuccessfully) { @@ -127,7 +127,8 @@ namespace Microsoft.AspNetCore.SignalR // Tasks and message are passed by ref so they can be lazily created inside the method post-filtering, // while still being re-usable when sending to multiple groups - private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, ref List tasks, ref SerializedHubMessage message) + private void SendToGroupConnections(string methodName, object[] args, ConcurrentDictionary connections, Func include, + ref List tasks, ref SerializedHubMessage message, CancellationToken cancellationToken) { // foreach over ConcurrentDictionary avoids allocating an enumerator foreach (var connection in connections) @@ -142,7 +143,7 @@ namespace Microsoft.AspNetCore.SignalR message = CreateSerializedInvocationMessage(methodName, args); } - var task = connection.Value.WriteAsync(message); + var task = connection.Value.WriteAsync(message, cancellationToken); if (!task.IsCompletedSuccessfully) { @@ -175,7 +176,7 @@ namespace Microsoft.AspNetCore.SignalR // Write message directly to connection without caching it in memory var message = CreateInvocationMessage(methodName, args); - return connection.WriteAsync(message).AsTask(); + return connection.WriteAsync(message, cancellationToken).AsTask(); } /// @@ -193,7 +194,7 @@ namespace Microsoft.AspNetCore.SignalR // group might be modified inbetween checking and sending List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, null, ref tasks, ref message, cancellationToken); if (tasks != null) { @@ -221,7 +222,7 @@ namespace Microsoft.AspNetCore.SignalR var group = _groups[groupName]; if (group != null) { - SendToGroupConnections(methodName, args, group, null, ref tasks, ref message); + SendToGroupConnections(methodName, args, group, null, ref tasks, ref message, cancellationToken); } } @@ -247,7 +248,7 @@ namespace Microsoft.AspNetCore.SignalR List tasks = null; SerializedHubMessage message = null; - SendToGroupConnections(methodName, args, group, connection => !excludedConnectionIds.Contains(connection.ConnectionId), ref tasks, ref message); + SendToGroupConnections(methodName, args, group, connection => !excludedConnectionIds.Contains(connection.ConnectionId), ref tasks, ref message, cancellationToken); if (tasks != null) { @@ -271,7 +272,7 @@ namespace Microsoft.AspNetCore.SignalR /// public override Task SendUserAsync(string userId, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal)); + return SendToAllConnections(methodName, args, connection => string.Equals(connection.UserIdentifier, userId, StringComparison.Ordinal), cancellationToken); } /// @@ -292,19 +293,19 @@ namespace Microsoft.AspNetCore.SignalR /// public override Task SendAllExceptAsync(string methodName, object[] args, IReadOnlyList excludedConnectionIds, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => !excludedConnectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, connection => !excludedConnectionIds.Contains(connection.ConnectionId), cancellationToken); } /// public override Task SendConnectionsAsync(IReadOnlyList connectionIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId)); + return SendToAllConnections(methodName, args, connection => connectionIds.Contains(connection.ConnectionId), cancellationToken); } /// public override Task SendUsersAsync(IReadOnlyList userIds, string methodName, object[] args, CancellationToken cancellationToken = default) { - return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier)); + return SendToAllConnections(methodName, args, connection => userIds.Contains(connection.UserIdentifier), cancellationToken); } } } diff --git a/src/SignalR/server/Core/src/HubConnectionContext.cs b/src/SignalR/server/Core/src/HubConnectionContext.cs index 8e9216d35d..11e05c177c 100644 --- a/src/SignalR/server/Core/src/HubConnectionContext.cs +++ b/src/SignalR/server/Core/src/HubConnectionContext.cs @@ -34,6 +34,8 @@ namespace Microsoft.AspNetCore.SignalR private readonly long _keepAliveInterval; private readonly long _clientTimeoutInterval; private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1); + private readonly bool _useAbsoluteClientTimeout; + private readonly object _receiveMessageTimeoutLock = new object(); private StreamTracker _streamTracker; private long _lastSendTimeStamp = DateTime.UtcNow.Ticks; @@ -41,10 +43,13 @@ namespace Microsoft.AspNetCore.SignalR private bool _receivedMessageThisInterval = false; private ReadOnlyMemory _cachedPingMessage; private bool _clientTimeoutActive; - private bool _connectionAborted; + private volatile bool _connectionAborted; private volatile bool _allowReconnect = true; private int _streamBufferCapacity; private long? _maxMessageSize; + private bool _receivedMessageTimeoutEnabled = false; + private long _receivedMessageElapsedTicks = 0; + private long _receivedMessageTimestamp; /// /// Initializes a new instance of the class. @@ -64,6 +69,11 @@ namespace Microsoft.AspNetCore.SignalR ConnectionAborted = _connectionAbortedTokenSource.Token; HubCallerContext = new DefaultHubCallerContext(this); + + if (AppContext.TryGetSwitch("Microsoft.AspNetCore.SignalR.UseAbsoluteClientTimeout", out var useAbsoluteClientTimeout)) + { + _useAbsoluteClientTimeout = useAbsoluteClientTimeout; + } } internal StreamTracker StreamTracker @@ -131,7 +141,7 @@ namespace Microsoft.AspNetCore.SignalR // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { - return new ValueTask(WriteSlowAsync(message)); + return new ValueTask(WriteSlowAsync(message, cancellationToken)); } if (_connectionAborted) @@ -141,7 +151,7 @@ namespace Microsoft.AspNetCore.SignalR } // This method should never throw synchronously - var task = WriteCore(message); + var task = WriteCore(message, cancellationToken); // The write didn't complete synchronously so await completion if (!task.IsCompletedSuccessfully) @@ -167,7 +177,7 @@ namespace Microsoft.AspNetCore.SignalR // Try to grab the lock synchronously, if we fail, go to the slower path if (!_writeLock.Wait(0)) { - return new ValueTask(WriteSlowAsync(message)); + return new ValueTask(WriteSlowAsync(message, cancellationToken)); } if (_connectionAborted) @@ -177,7 +187,7 @@ namespace Microsoft.AspNetCore.SignalR } // This method should never throw synchronously - var task = WriteCore(message); + var task = WriteCore(message, cancellationToken); // The write didn't complete synchronously so await completion if (!task.IsCompletedSuccessfully) @@ -191,7 +201,7 @@ namespace Microsoft.AspNetCore.SignalR return default; } - private ValueTask WriteCore(HubMessage message) + private ValueTask WriteCore(HubMessage message, CancellationToken cancellationToken) { try { @@ -199,7 +209,7 @@ namespace Microsoft.AspNetCore.SignalR // write it without caching. Protocol.WriteMessage(message, _connectionContext.Transport.Output); - return _connectionContext.Transport.Output.FlushAsync(); + return _connectionContext.Transport.Output.FlushAsync(cancellationToken); } catch (Exception ex) { @@ -211,14 +221,14 @@ namespace Microsoft.AspNetCore.SignalR } } - private ValueTask WriteCore(SerializedHubMessage message) + private ValueTask WriteCore(SerializedHubMessage message, CancellationToken cancellationToken) { try { // Grab a preserialized buffer for this protocol. var buffer = message.GetSerializedMessage(Protocol); - return _connectionContext.Transport.Output.WriteAsync(buffer); + return _connectionContext.Transport.Output.WriteAsync(buffer, cancellationToken); } catch (Exception ex) { @@ -249,10 +259,10 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task WriteSlowAsync(HubMessage message) + private async Task WriteSlowAsync(HubMessage message, CancellationToken cancellationToken) { // Failed to get the lock immediately when entering WriteAsync so await until it is available - await _writeLock.WaitAsync(); + await _writeLock.WaitAsync(cancellationToken); try { @@ -261,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR return; } - await WriteCore(message); + await WriteCore(message, cancellationToken); } catch (Exception ex) { @@ -274,7 +284,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task WriteSlowAsync(SerializedHubMessage message) + private async Task WriteSlowAsync(SerializedHubMessage message, CancellationToken cancellationToken) { // Failed to get the lock immediately when entering WriteAsync so await until it is available await _writeLock.WaitAsync(); @@ -286,7 +296,7 @@ namespace Microsoft.AspNetCore.SignalR return; } - await WriteCore(message); + await WriteCore(message, cancellationToken); } catch (Exception ex) { @@ -370,6 +380,9 @@ namespace Microsoft.AspNetCore.SignalR private void AbortAllowReconnect() { _connectionAborted = true; + // Cancel any current writes or writes that are about to happen and have already gone past the _connectionAborted bool + // We have to do this outside of the lock otherwise it could hang if the write is observing backpressure + _connectionContext.Transport.Output.CancelPendingFlush(); // If we already triggered the token then noop, this isn't thread safe but it's good enough // to avoid spawning a new task in the most common cases @@ -525,9 +538,23 @@ namespace Microsoft.AspNetCore.SignalR internal Task AbortAsync() { AbortAllowReconnect(); + + // Acquire lock to make sure all writes are completed + if (!_writeLock.Wait(0)) + { + return AbortAsyncSlow(); + } + _writeLock.Release(); return _abortCompletedTcs.Task; } + private async Task AbortAsyncSlow() + { + await _writeLock.WaitAsync(); + _writeLock.Release(); + await _abortCompletedTcs.Task; + } + private void KeepAliveTick() { var currentTime = DateTime.UtcNow.Ticks; @@ -564,17 +591,41 @@ namespace Microsoft.AspNetCore.SignalR private void CheckClientTimeout() { - // If it's been too long since we've heard from the client, then close this - if (DateTime.UtcNow.Ticks - Volatile.Read(ref _lastReceivedTimeStamp) > _clientTimeoutInterval) + if (Debugger.IsAttached) { - if (!_receivedMessageThisInterval) - { - Log.ClientTimeout(_logger, TimeSpan.FromTicks(_clientTimeoutInterval)); - AbortAllowReconnect(); - } + return; + } - _receivedMessageThisInterval = false; - Volatile.Write(ref _lastReceivedTimeStamp, DateTime.UtcNow.Ticks); + if (_useAbsoluteClientTimeout) + { + // If it's been too long since we've heard from the client, then close this + if (DateTime.UtcNow.Ticks - Volatile.Read(ref _lastReceivedTimeStamp) > _clientTimeoutInterval) + { + if (!_receivedMessageThisInterval) + { + Log.ClientTimeout(_logger, TimeSpan.FromTicks(_clientTimeoutInterval)); + AbortAllowReconnect(); + } + + _receivedMessageThisInterval = false; + Volatile.Write(ref _lastReceivedTimeStamp, DateTime.UtcNow.Ticks); + } + } + else + { + lock (_receiveMessageTimeoutLock) + { + if (_receivedMessageTimeoutEnabled) + { + _receivedMessageElapsedTicks = DateTime.UtcNow.Ticks - _receivedMessageTimestamp; + + if (_receivedMessageElapsedTicks >= _clientTimeoutInterval) + { + Log.ClientTimeout(_logger, TimeSpan.FromTicks(_clientTimeoutInterval)); + AbortAllowReconnect(); + } + } + } } } @@ -623,6 +674,35 @@ namespace Microsoft.AspNetCore.SignalR _receivedMessageThisInterval = true; } + internal void BeginClientTimeout() + { + // check if new timeout behavior is in use + if (!_useAbsoluteClientTimeout) + { + lock (_receiveMessageTimeoutLock) + { + _receivedMessageTimeoutEnabled = true; + _receivedMessageTimestamp = DateTime.UtcNow.Ticks; + } + } + } + + internal void StopClientTimeout() + { + // check if new timeout behavior is in use + if (!_useAbsoluteClientTimeout) + { + lock (_receiveMessageTimeoutLock) + { + // we received a message so stop the timer and reset it + // it will resume after the message has been processed + _receivedMessageElapsedTicks = 0; + _receivedMessageTimestamp = 0; + _receivedMessageTimeoutEnabled = false; + } + } + } + private static class Log { // Category: HubConnectionContext diff --git a/src/SignalR/server/Core/src/HubConnectionHandler.cs b/src/SignalR/server/Core/src/HubConnectionHandler.cs index 663864cbb9..0a8f3380f9 100644 --- a/src/SignalR/server/Core/src/HubConnectionHandler.cs +++ b/src/SignalR/server/Core/src/HubConnectionHandler.cs @@ -213,6 +213,8 @@ namespace Microsoft.AspNetCore.SignalR { var input = connection.Input; var protocol = connection.Protocol; + connection.BeginClientTimeout(); + var binder = new HubConnectionBinder(_dispatcher, connection); @@ -221,6 +223,8 @@ namespace Microsoft.AspNetCore.SignalR var result = await input.ReadAsync(); var buffer = result.Buffer; + connection.ResetClientTimeout(); + try { if (result.IsCanceled) @@ -230,15 +234,21 @@ namespace Microsoft.AspNetCore.SignalR if (!buffer.IsEmpty) { - connection.ResetClientTimeout(); - + bool messageReceived = false; // No message limit, just parse and dispatch if (_maximumMessageSize == null) { while (protocol.TryParseMessage(ref buffer, binder, out var message)) { + messageReceived = true; + connection.StopClientTimeout(); await _dispatcher.DispatchMessageAsync(connection, message); } + + if (messageReceived) + { + connection.BeginClientTimeout(); + } } else { @@ -258,6 +268,9 @@ namespace Microsoft.AspNetCore.SignalR if (protocol.TryParseMessage(ref segment, binder, out var message)) { + messageReceived = true; + connection.StopClientTimeout(); + await _dispatcher.DispatchMessageAsync(connection, message); } else if (overLength) @@ -273,6 +286,11 @@ namespace Microsoft.AspNetCore.SignalR // Update the buffer to the remaining segment buffer = buffer.Slice(segment.Start); } + + if (messageReceived) + { + connection.BeginClientTimeout(); + } } } diff --git a/src/SignalR/server/Core/src/Internal/Proxies.cs b/src/SignalR/server/Core/src/Internal/Proxies.cs index 9a3edd56bd..8a2beb26de 100644 --- a/src/SignalR/server/Core/src/Internal/Proxies.cs +++ b/src/SignalR/server/Core/src/Internal/Proxies.cs @@ -105,7 +105,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal public Task SendCoreAsync(string method, object[] args, CancellationToken cancellationToken = default) { - return _lifetimeManager.SendAllAsync(method, args); + return _lifetimeManager.SendAllAsync(method, args, cancellationToken); } } diff --git a/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs b/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs index 0e00c9a9ab..ee312dbf3e 100644 --- a/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs +++ b/src/SignalR/server/SignalR/test/DefaultHubLifetimeManagerTests.cs @@ -1,9 +1,14 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; +using System.Threading.Tasks; +using System.Threading; +using Microsoft.AspNetCore.SignalR.Protocol; +using Microsoft.AspNetCore.SignalR.Specification.Tests; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; -using Microsoft.AspNetCore.SignalR.Specification.Tests; +using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests { @@ -13,5 +18,241 @@ namespace Microsoft.AspNetCore.SignalR.Tests { return new DefaultHubLifetimeManager(new Logger>(NullLoggerFactory.Instance)); } + + [Fact] + public async Task SendAllAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendAllAsync("Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } + + [Fact] + public async Task SendAllExceptAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendAllExceptAsync("Hello", new object[] { "World" }, new List { connection1.ConnectionId }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + Assert.Null(client1.TryRead()); + } + } + + [Fact] + public async Task SendConnectionAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendConnectionAsync(connection1.ConnectionId, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendConnectionsAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendConnectionsAsync(new List { connection1.ConnectionId }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendGroupAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupAsync("group", "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendGroupExceptAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + var connection2 = HubConnectionContextUtils.Create(client2.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + await manager.AddToGroupAsync(connection2.ConnectionId, "group").OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupExceptAsync("group", "Hello", new object[] { "World" }, new List { connection1.ConnectionId }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + Assert.Null(client1.TryRead()); + } + } + + [Fact] + public async Task SendGroupsAsyncWillCancelWithToken() + { + using (var client1 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.AddToGroupAsync(connection1.ConnectionId, "group").OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendGroupsAsync(new List { "group" }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection1.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + } + } + + [Fact] + public async Task SendUserAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "user"); + var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "user"); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendUserAsync("user", "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } + + [Fact] + public async Task SendUsersAsyncWillCancelWithToken() + { + using (var client1 = new TestClient()) + using (var client2 = new TestClient(pauseWriterThreshold: 2)) + { + var manager = CreateNewHubLifetimeManager(); + var connection1 = HubConnectionContextUtils.Create(client1.Connection, userIdentifier: "user1"); + var connection2 = HubConnectionContextUtils.Create(client2.Connection, userIdentifier: "user2"); + await manager.OnConnectedAsync(connection1).OrTimeout(); + await manager.OnConnectedAsync(connection2).OrTimeout(); + var cts = new CancellationTokenSource(); + var sendTask = manager.SendUsersAsync(new List { "user1", "user2" }, "Hello", new object[] { "World" }, cts.Token).OrTimeout(); + Assert.False(sendTask.IsCompleted); + cts.Cancel(); + await sendTask.OrTimeout(); + var message = Assert.IsType(client1.TryRead()); + Assert.Equal("Hello", message.Target); + Assert.Single(message.Arguments); + Assert.Equal("World", (string)message.Arguments[0]); + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + connection2.ConnectionAborted.Register(t => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + await tcs.Task.OrTimeout(); + Assert.False(connection1.ConnectionAborted.IsCancellationRequested); + } + } } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 9f727d6523..79f2122d3c 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2798,6 +2798,47 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task HubMethodInvokeDoesNotCountTowardsClientTimeout() + { + using (StartVerifiableLog()) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.Configure(options => + options.ClientTimeoutInterval = TimeSpan.FromMilliseconds(0)); + services.AddSingleton(tcsService); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient(new JsonHubProtocol())) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler); + // This starts the timeout logic + await client.SendHubMessageAsync(PingMessage.Instance); + + // Call long running hub method + var hubMethodTask = client.InvokeAsync(nameof(LongRunningHub.LongRunningMethod)); + await tcsService.StartedMethod.Task.OrTimeout(); + + // Tick heartbeat while hub method is running to show that close isn't triggered + client.TickHeartbeat(); + + // Unblock long running hub method + tcsService.EndMethod.SetResult(null); + + await hubMethodTask.OrTimeout(); + + // Tick heartbeat again now that we're outside of the hub method + client.TickHeartbeat(); + + // Connection is closed + await connectionHandlerTask.OrTimeout(); + } + } + } + [Fact] public async Task EndingConnectionSendsCloseMessageWithNoError() {