diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 8afb75b247..e43015ca4e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -22,6 +22,8 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class HttpConnection : IConnection { + private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120); + private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; @@ -77,6 +79,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler); + _httpClient.Timeout = HttpClientTimeout; _transportFactory = new DefaultTransportFactory(transportType, _loggerFactory, _httpClient); } @@ -86,6 +89,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); _httpClient = httpMessageHandler == null ? new HttpClient() : new HttpClient(httpMessageHandler); + _httpClient.Timeout = HttpClientTimeout; _transportFactory = transportFactory ?? throw new ArgumentNullException(nameof(transportFactory)); } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index cd0b970db4..9229219a52 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -92,7 +92,20 @@ namespace Microsoft.AspNetCore.Sockets.Client var request = new HttpRequestMessage(HttpMethod.Get, pollUrl); request.Headers.UserAgent.Add(Constants.UserAgentHeader); - var response = await _httpClient.SendAsync(request, cancellationToken); + HttpResponseMessage response; + + try + { + response = await _httpClient.SendAsync(request, cancellationToken); + } + catch (OperationCanceledException) + { + // SendAsync will throw the OperationCanceledException if the passed cancellationToken is canceled + // or if the http request times out due to HttpClient.Timeout expiring. In the latter case we + // just want to start a new poll. + continue; + } + response.EnsureSuccessStatusCode(); if (response.StatusCode == HttpStatusCode.NoContent || cancellationToken.IsCancellationRequested) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 9c6f9fbb6c..1892508f13 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -437,5 +437,48 @@ namespace Microsoft.AspNetCore.Client.Tests Assert.Equal("requestedTransferMode", exception.ParamName); } } + + [Fact] + public async Task LongPollingTransportRePollsIfRequestCancelled() + { + var numPolls = 0; + var completionTcs = new TaskCompletionSource(); + + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + + if (Interlocked.Increment(ref numPolls) < 3) + { + throw new OperationCanceledException(); + } + + completionTcs.SetResult(null); + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + + try + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); + + var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout(); + Assert.Equal(completionTcs.Task, completedTask); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } } }