From d711916ad66adbc250ad7d45486c7b153017c46e Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Tue, 1 May 2018 16:14:24 -0700 Subject: [PATCH 1/2] fix #2140 by ensuring the access token flows to WebSocketTransport (#2173) --- .../HttpConnection.cs | 2 +- .../Internal/AccessTokenHttpMessageHandler.cs | 1 - .../Internal/DefaultTransportFactory.cs | 11 +++++--- .../Internal/WebSocketsTransport.cs | 21 ++++++-------- .../HubConnectionTests.cs | 28 +++++++++++++++++++ .../Startup.cs | 9 ++++++ .../HttpConnectionTests.Negotiate.cs | 1 - .../DefaultTransportFactoryTests.cs | 16 +++++------ .../WebSocketsTransportTests.cs | 16 +++++------ 9 files changed, 70 insertions(+), 35 deletions(-) diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs index 24709e4860..74e8577c94 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs @@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client _httpClient = CreateHttpClient(); } - _transportFactory = new DefaultTransportFactory(httpConnectionOptions.Transports, _loggerFactory, _httpClient, httpConnectionOptions); + _transportFactory = new DefaultTransportFactory(httpConnectionOptions.Transports, _loggerFactory, _httpClient, httpConnectionOptions, GetAccessTokenAsync); _logScope = new ConnectionLogScope(); _scopeDisposable = _logger.BeginScope(_logScope); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs index 327a81a5d7..eb9a18d96e 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs @@ -1,7 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Net.Http; using System.Net.Http.Headers; using System.Threading; diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/DefaultTransportFactory.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/DefaultTransportFactory.cs index 6332d276f5..6e66673ad9 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/DefaultTransportFactory.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/DefaultTransportFactory.cs @@ -3,8 +3,7 @@ using System; using System.Net.Http; -using Microsoft.AspNetCore.Http.Connections.Client; -using Microsoft.AspNetCore.Http.Connections.Client.Internal; +using System.Threading.Tasks; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Http.Connections.Client.Internal @@ -13,11 +12,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { private readonly HttpClient _httpClient; private readonly HttpConnectionOptions _httpConnectionOptions; + private readonly Func> _accessTokenProvider; private readonly HttpTransportType _requestedTransportType; private readonly ILoggerFactory _loggerFactory; private static volatile bool _websocketsSupported = true; - public DefaultTransportFactory(HttpTransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpConnectionOptions httpConnectionOptions) + public DefaultTransportFactory(HttpTransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpConnectionOptions httpConnectionOptions, Func> accessTokenProvider) { if (httpClient == null && requestedTransportType != HttpTransportType.WebSockets) { @@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _loggerFactory = loggerFactory; _httpClient = httpClient; _httpConnectionOptions = httpConnectionOptions; + _accessTokenProvider = accessTokenProvider; } public ITransport CreateTransport(HttpTransportType availableServerTransports) @@ -36,7 +37,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { try { - return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory); + return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider); } catch (PlatformNotSupportedException) { @@ -46,11 +47,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal if ((availableServerTransports & HttpTransportType.ServerSentEvents & _requestedTransportType) == HttpTransportType.ServerSentEvents) { + // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. return new ServerSentEventsTransport(_httpClient, _loggerFactory); } if ((availableServerTransports & HttpTransportType.LongPolling & _requestedTransportType) == HttpTransportType.LongPolling) { + // We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us. return new LongPollingTransport(_httpClient, _loggerFactory); } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs index cfa225b46d..d471216b7d 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs @@ -32,12 +32,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal public PipeWriter Output => _transport.Output; - public WebSocketsTransport() - : this(null, null) - { - } - - public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory) + public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func> accessTokenProvider) { _webSocket = new ClientWebSocket(); @@ -79,11 +74,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _webSocket.Options.UseDefaultCredentials = httpConnectionOptions.UseDefaultCredentials.Value; } - if (httpConnectionOptions.AccessTokenProvider != null) - { - _accessTokenProvider = httpConnectionOptions.AccessTokenProvider; - } - httpConnectionOptions.WebSocketConfiguration?.Invoke(_webSocket.Options); _closeTimeout = httpConnectionOptions.CloseTimeout; @@ -94,6 +84,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _webSocket.Options.SetRequestHeader("X-Requested-With", "XMLHttpRequest"); _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); + + // Ignore the HttpConnectionOptions access token provider. We were given an updated delegate from the HttpConnection. + _accessTokenProvider = accessTokenProvider; } public async Task StartAsync(Uri url, TransferFormat transferFormat) @@ -116,10 +109,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.StartTransport(_logger, transferFormat, resolvedUrl); + // We don't need to capture to a local because we never change this delegate. if (_accessTokenProvider != null) { var accessToken = await _accessTokenProvider(); - _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}"); + if (!string.IsNullOrEmpty(accessToken)) + { + _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}"); + } } await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index ff6b2cbf04..4f7af0fc85 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -797,6 +797,34 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(TransportTypes))] + public async Task ClientCanUseJwtBearerTokenForAuthenticationWhenRedirected(HttpTransportType transportType) + { + using (StartVerifableLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthenticationWhenRedirected)}_{transportType}")) + { + var hubConnection = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(ServerFixture.Url + "/redirect", transportType) + .Build(); + try + { + await hubConnection.StartAsync().OrTimeout(); + var message = await hubConnection.InvokeAsync(nameof(TestHub.Echo), "Hello, World!").OrTimeout(); + Assert.Equal("Hello, World!", message); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(TransportTypes))] public async Task ClientCanSendHeaders(HttpTransportType transportType) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs index 81a67d13a4..c9a98272a5 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Connections; using Microsoft.Extensions.DependencyInjection; using Microsoft.IdentityModel.Tokens; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { @@ -69,6 +70,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await context.Response.WriteAsync(GenerateJwtToken()); return; } + else if (context.Request.Path.StartsWithSegments("/redirect")) + { + await context.Response.WriteAsync(JsonConvert.SerializeObject(new + { + url = $"{context.Request.Scheme}://{context.Request.Host}/authorizedHub", + accessToken = GenerateJwtToken() + })); + } }); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs index 635aca510f..1dd265aa57 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs @@ -8,7 +8,6 @@ using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections; -using Microsoft.AspNetCore.Http.Connections.Client; using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.SignalR.Tests; using Moq; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs index 2cb7b0a835..11fdab9749 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/DefaultTransportFactoryTests.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [InlineData((HttpTransportType)int.MaxValue)] public void DefaultTransportFactoryCanBeCreatedWithNoOrUnknownTransportTypeFlags(HttpTransportType transportType) { - Assert.NotNull(new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient(), httpConnectionOptions: null)); + Assert.NotNull(new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null)); } [Theory] @@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void DefaultTransportFactoryCannotBeCreatedWithoutHttpClient(HttpTransportType transportType) { var exception = Assert.Throws( - () => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null, httpConnectionOptions: null)); + () => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null, httpConnectionOptions: null, accessTokenProvider: null)); Assert.Equal("httpClient", exception.ParamName); } @@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public void DefaultTransportFactoryCanBeCreatedWithoutHttpClientIfWebSocketsTransportRequestedExplicitly() { - new DefaultTransportFactory(HttpTransportType.WebSockets, new LoggerFactory(), httpClient: null, httpConnectionOptions: null); + new DefaultTransportFactory(HttpTransportType.WebSockets, new LoggerFactory(), httpClient: null, httpConnectionOptions: null, accessTokenProvider: null); } [ConditionalTheory] @@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [WebSocketsSupportedCondition] public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(HttpTransportType requestedTransport, Type expectedTransportType) { - var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null); + var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, transportFactory.CreateTransport(AllTransportTypes)); } @@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(HttpTransportType requestedTransport) { var transportFactory = - new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null); + new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( () => transportFactory.CreateTransport(~requestedTransport)); @@ -76,7 +76,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable() { Assert.IsType( - new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null) + new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null) .CreateTransport(AllTransportTypes)); } @@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { if (!TestHelpers.IsWebSocketsSupported()) { - var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null); + var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); Assert.IsType(expectedTransportType, transportFactory.CreateTransport(AllTransportTypes)); } @@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests if (!TestHelpers.IsWebSocketsSupported()) { var transportFactory = - new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null); + new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null); var ex = Assert.Throws( () => transportFactory.CreateTransport(AllTransportTypes)); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 90dc73b020..b43368a0ff 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests httpOptions.Proxy = Mock.Of(); httpOptions.WebSocketConfiguration = options => webSocketsOptions = options; - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: httpOptions, loggerFactory: null); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: httpOptions, loggerFactory: null, accessTokenProvider: null); Assert.NotNull(webSocketsTransport); Assert.NotNull(webSocketsOptions); @@ -59,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"), TransferFormat.Binary).OrTimeout(); await webSocketsTransport.StopAsync().OrTimeout(); @@ -73,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/httpheader"), TransferFormat.Binary).OrTimeout(); @@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/httpheader"), TransferFormat.Binary).OrTimeout(); @@ -124,7 +124,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"), TransferFormat.Binary); webSocketsTransport.Output.Complete(); @@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echoAndClose"), transferFormat); await webSocketsTransport.Output.WriteAsync(new byte[] { 0x42 }); @@ -162,7 +162,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"), transferFormat).OrTimeout(); @@ -180,7 +180,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { using (StartVerifableLog(out var loggerFactory)) { - var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory); + var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null); var exception = await Assert.ThrowsAsync(() => webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), transferFormat)); From ae329edd2af6ff1edb11d8ba7f8c54a4904be41d Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Tue, 1 May 2018 16:37:22 -0700 Subject: [PATCH 2/2] Fix #2169 by correcting shutdown timeout (#2170) --- .../FunctionalTests/ts/HubConnectionTests.ts | 5 -- .../signalr/spec/LongPollingTransport.spec.ts | 67 +++++++++++++++++++ .../ts/signalr/src/LongPollingTransport.ts | 12 ++-- 3 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 clients/ts/signalr/spec/LongPollingTransport.spec.ts diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index bd539851d6..0567382a29 100644 --- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -577,13 +577,8 @@ describe("hubConnection", () => { const hubConnection = getConnectionBuilder(transportType).build(); hubConnection.serverTimeoutInMilliseconds = 100; - const timeout = setTimeout(200, () => { - fail("Server timeout did not fire within expected interval"); - }); - hubConnection.start().then(() => { hubConnection.onclose((error) => { - clearTimeout(timeout); expect(error).toEqual(new Error("Server timeout elapsed without receiving a message from the server.")); done(); }); diff --git a/clients/ts/signalr/spec/LongPollingTransport.spec.ts b/clients/ts/signalr/spec/LongPollingTransport.spec.ts new file mode 100644 index 0000000000..8023ab27d4 --- /dev/null +++ b/clients/ts/signalr/spec/LongPollingTransport.spec.ts @@ -0,0 +1,67 @@ +import { HttpResponse } from "../src/HttpClient"; +import { LogLevel } from "../src/ILogger"; +import { TransferFormat } from "../src/ITransport"; +import { NullLogger } from "../src/Loggers"; +import { LongPollingTransport } from "../src/LongPollingTransport"; +import { ConsoleLogger } from "../src/Utils"; + +import { TestHttpClient } from "./TestHttpClient"; +import { asyncit as it, PromiseSource } from "./Utils"; + +describe("LongPollingTransport", () => { + it("shuts down poll after timeout even if server doesn't shut it down on receiving the DELETE", async () => { + let firstPoll = true; + const pollCompleted = new PromiseSource(); + const client = new TestHttpClient() + .on("GET", async (r) => { + if (firstPoll) { + firstPoll = false; + return new HttpResponse(200); + } else { + // Turn 'onabort' into a promise. + const abort = new Promise((resolve, reject) => r.abortSignal.onabort = resolve); + await abort; + + // Signal that the poll has completed. + pollCompleted.resolve(); + return new HttpResponse(204); + } + }) + .on("DELETE", (r) => new HttpResponse(202)); + const transport = new LongPollingTransport(client, null, NullLogger.instance, false, 100); + + await transport.connect("http://example.com", TransferFormat.Text); + await transport.stop(); + + // This should complete within the shutdown timeout + await pollCompleted.promise; + }); + + it("sends DELETE request on stop", async () => { + let firstPoll = true; + const deleteReceived = new PromiseSource(); + const pollCompleted = new PromiseSource(); + const client = new TestHttpClient() + .on("GET", async (r) => { + if (firstPoll) { + firstPoll = false; + return new HttpResponse(200); + } else { + await deleteReceived.promise; + pollCompleted.resolve(); + return new HttpResponse(204); + } + }) + .on("DELETE", (r) => { + deleteReceived.resolve(); + return new HttpResponse(202); + }); + const transport = new LongPollingTransport(client, null, NullLogger.instance, false); + + await transport.connect("http://example.com", TransferFormat.Text); + await transport.stop(); + + // This should complete, because the DELETE request triggers it to stop. + await pollCompleted.promise; + }); +}); \ No newline at end of file diff --git a/clients/ts/signalr/src/LongPollingTransport.ts b/clients/ts/signalr/src/LongPollingTransport.ts index db0c2edf5c..b2b456f8a2 100644 --- a/clients/ts/signalr/src/LongPollingTransport.ts +++ b/clients/ts/signalr/src/LongPollingTransport.ts @@ -19,15 +19,17 @@ export class LongPollingTransport implements ITransport { private url: string; private pollXhr: XMLHttpRequest; private pollAbort: AbortController; + private shutdownTimer: any; // We use 'any' because this is an object in NodeJS. But it still gets passed to clearTimeout, so it doesn't really matter private shutdownTimeout: number; private running: boolean; - constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { + constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean, shutdownTimeout?: number) { this.httpClient = httpClient; this.accessTokenFactory = accessTokenFactory || (() => null); this.logger = logger; this.pollAbort = new AbortController(); this.logMessageContent = logMessageContent; + this.shutdownTimeout = shutdownTimeout || SHUTDOWN_TIMEOUT; } public async connect(url: string, transferFormat: TransferFormat): Promise { @@ -107,7 +109,7 @@ export class LongPollingTransport implements ITransport { this.logger.log(LogLevel.Information, "(LongPolling transport) Poll terminated by server"); // If we were on a timeout waiting for shutdown, unregister it. - clearTimeout(this.shutdownTimeout); + clearTimeout(this.shutdownTimer); this.running = false; } else if (response.statusCode !== 200) { @@ -179,10 +181,10 @@ export class LongPollingTransport implements ITransport { } finally { // Abort the poll after 5 seconds if the server doesn't stop it. if (!this.pollAbort.aborted) { - this.shutdownTimeout = setTimeout(SHUTDOWN_TIMEOUT, () => { - this.logger.log(LogLevel.Warning, "(LongPolling transport) server did not terminate within 5 seconds after DELETE request, cancelling poll."); + this.shutdownTimer = setTimeout(() => { + this.logger.log(LogLevel.Warning, "(LongPolling transport) server did not terminate after DELETE request, canceling poll."); this.pollAbort.abort(); - }); + }, this.shutdownTimeout); } } }