From 80f87e7730e04e6979ee934126be7d608fc43724 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Wed, 4 Apr 2018 15:54:42 -0700 Subject: [PATCH] Add Handshake timeout to C# Client (#1840) --- .../HubConnection.cs | 80 ++++++++++--------- .../HubConnectionTests.ConnectionLifecycle.cs | 66 +++++++++++++++ 2 files changed, 110 insertions(+), 36 deletions(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 5e6e0154c9..601e7376d7 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; -using System.IO.Pipelines; using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Channels; @@ -23,6 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public partial class HubConnection { public static readonly TimeSpan DefaultServerTimeout = TimeSpan.FromSeconds(30); // Server ping rate is 15 sec, this is 2 times that. + public static readonly TimeSpan DefaultHandshakeTimeout = TimeSpan.FromSeconds(15); // This lock protects the connection state. private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1); @@ -46,6 +46,7 @@ namespace Microsoft.AspNetCore.SignalR.Client /// will not be applied until the Keep Alive timer is next reset. /// public TimeSpan ServerTimeout { get; set; } = DefaultServerTimeout; + public TimeSpan HandshakeTimeout { get; set; } = DefaultHandshakeTimeout; public HubConnection(Func connectionFactory, IHubProtocol protocol, IServiceProvider serviceProvider, ILoggerFactory loggerFactory) { @@ -57,10 +58,10 @@ namespace Microsoft.AspNetCore.SignalR.Client _logger = _loggerFactory.CreateLogger(); } - public async Task StartAsync() + public async Task StartAsync(CancellationToken cancellationToken = default) { CheckDisposed(); - await StartAsyncCore().ForceAsync(); + await StartAsyncCore(cancellationToken).ForceAsync(); } public async Task StopAsync() @@ -109,7 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public async Task SendAsync(string methodName, object[] args, CancellationToken cancellationToken = default) => await SendAsyncCore(methodName, args, cancellationToken).ForceAsync(); - private async Task StartAsyncCore() + private async Task StartAsyncCore(CancellationToken cancellationToken) { await WaitConnectionLockAsync(); try @@ -120,6 +121,8 @@ namespace Microsoft.AspNetCore.SignalR.Client return; } + cancellationToken.ThrowIfCancellationRequested(); + CheckDisposed(); Log.Starting(_logger); @@ -134,7 +137,7 @@ namespace Microsoft.AspNetCore.SignalR.Client try { Log.HubProtocol(_logger, _protocol.Name, _protocol.Version); - await HandshakeAsync(startingConnectionState); + await HandshakeAsync(startingConnectionState, cancellationToken); } catch (Exception ex) { @@ -492,7 +495,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private async Task HandshakeAsync(ConnectionState startingConnectionState) + private async Task HandshakeAsync(ConnectionState startingConnectionState, CancellationToken cancellationToken) { // Send the Handshake request Log.SendingHubHandshake(_logger); @@ -510,47 +513,52 @@ namespace Microsoft.AspNetCore.SignalR.Client try { - while (true) + using (var handshakeCts = new CancellationTokenSource(HandshakeTimeout)) + using (var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, handshakeCts.Token)) { - var result = await startingConnectionState.Connection.Transport.Input.ReadAsync(); - var buffer = result.Buffer; - var consumed = buffer.Start; - var examined = buffer.End; - - try + while (true) { - // Read first message out of the incoming data - if (!buffer.IsEmpty) + var result = await startingConnectionState.Connection.Transport.Input.ReadAsync(cts.Token); + + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; + + try { - if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) + // Read first message out of the incoming data + if (!buffer.IsEmpty) { - // Adjust consumed and examined to point to the end of the handshake - // response, this handles the case where invocations are sent in the same payload - // as the the negotiate response. - consumed = buffer.Start; - examined = consumed; - - if (message.Error != null) + if (HandshakeProtocol.TryParseResponseMessage(ref buffer, out var message)) { - Log.HandshakeServerError(_logger, message.Error); - throw new HubException( - $"Unable to complete handshake with the server due to an error: {message.Error}"); - } + // Adjust consumed and examined to point to the end of the handshake + // response, this handles the case where invocations are sent in the same payload + // as the the negotiate response. + consumed = buffer.Start; + examined = consumed; - break; + if (message.Error != null) + { + Log.HandshakeServerError(_logger, message.Error); + throw new HubException( + $"Unable to complete handshake with the server due to an error: {message.Error}"); + } + + break; + } + } + else if (result.IsCompleted) + { + // Not enough data, and we won't be getting any more data. + throw new InvalidOperationException( + "The server disconnected before sending a handshake response"); } } - else if (result.IsCompleted) + finally { - // Not enough data, and we won't be getting any more data. - throw new InvalidOperationException( - "The server disconnected before sending a handshake response"); + startingConnectionState.Connection.Transport.Input.AdvanceTo(consumed, examined); } } - finally - { - startingConnectionState.Connection.Transport.Input.AdvanceTo(consumed, examined); - } } } // Ignore HubException because we throw it when we receive a handshake response with an error diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs index 13882a283b..28b2e6ad06 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.ConnectionLifecycle.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.SignalR.Internal.Protocol; @@ -347,6 +348,71 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } + [Fact] + public async Task ClientTimesoutWhenHandshakeResponseTakesTooLong() + { + var connection = new TestConnection(autoHandshake: false); + var hubConnection = CreateHubConnection(() => connection); + try + { + hubConnection.HandshakeTimeout = TimeSpan.FromMilliseconds(1); + + await Assert.ThrowsAsync(() => hubConnection.StartAsync().OrTimeout()); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StartAsyncWithTriggeredCancellationTokenIsCanceled() + { + var onStartCalled = false; + var connection = new TestConnection(onStart: () => + { + onStartCalled = true; + return Task.CompletedTask; + }); + var hubConnection = CreateHubConnection(() => connection); + try + { + await Assert.ThrowsAsync(() => hubConnection.StartAsync(new CancellationToken(canceled: true)).OrTimeout()); + Assert.False(onStartCalled); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task StartAsyncCanTriggerCancellationTokenToCancelHandshake() + { + var cts = new CancellationTokenSource(); + var connection = new TestConnection(onStart: () => + { + cts.Cancel(); + return Task.CompletedTask; + }, autoHandshake: false); + var hubConnection = CreateHubConnection(() => connection); + // We want to make sure the cancellation is because of the token passed to StartAsync + hubConnection.HandshakeTimeout = Timeout.InfiniteTimeSpan; + try + { + var startTask = hubConnection.StartAsync(cts.Token); + var exception = await Assert.ThrowsAnyAsync(() => startTask.OrTimeout()); + Assert.Equal("The operation was canceled.", exception.Message); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + private static async Task ForceLastInvocationToComplete(TestConnection testConnection) { // We need to "complete" the invocation