diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 4f6d82cf9c..e44367c27f 100644 --- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -532,6 +532,33 @@ describe("hubConnection", () => { } }); + it("can connect to hub with authorization using async token factory", async (done) => { + const message = "你好,世界!"; + + try { + const hubConnection = new HubConnection("/authorizedhub", { + accessTokenFactory: () => getJwtToken("http://" + document.location.host + "/generateJwtToken"), + ...commonOptions, + transport: transportType, + }); + hubConnection.onclose((error) => { + expect(error).toBe(undefined); + done(); + }); + await hubConnection.start(); + const response = await hubConnection.invoke("Echo", message); + + expect(response).toEqual(message); + + await hubConnection.stop(); + + done(); + } catch (err) { + fail(err); + done(); + } + }); + if (transportType !== TransportType.LongPolling) { it("terminates if no messages received within timeout interval", (done) => { const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, { diff --git a/clients/ts/signalr/src/HttpConnection.ts b/clients/ts/signalr/src/HttpConnection.ts index e8c489828f..e0fdd754fe 100644 --- a/clients/ts/signalr/src/HttpConnection.ts +++ b/clients/ts/signalr/src/HttpConnection.ts @@ -13,7 +13,7 @@ export interface IHttpConnectionOptions { httpClient?: HttpClient; transport?: TransportType | ITransport; logger?: ILogger | LogLevel; - accessTokenFactory?: () => string; + accessTokenFactory?: () => string | Promise; logMessageContent?: boolean; } @@ -87,7 +87,7 @@ export class HttpConnection implements IConnection { // No fallback or negotiate in this case. await this.transport.connect(this.url, transferFormat, this); } else { - const token = this.options.accessTokenFactory(); + const token = await this.options.accessTokenFactory(); let headers; if (token) { headers = { diff --git a/clients/ts/signalr/src/Transports.ts b/clients/ts/signalr/src/Transports.ts index 0d8a79d8da..a62414390a 100644 --- a/clients/ts/signalr/src/Transports.ts +++ b/clients/ts/signalr/src/Transports.ts @@ -30,17 +30,17 @@ export interface ITransport { export class WebSocketTransport implements ITransport { private readonly logger: ILogger; - private readonly accessTokenFactory: () => string; + private readonly accessTokenFactory: () => string | Promise; private readonly logMessageContent: boolean; private webSocket: WebSocket; - constructor(accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) { + constructor(accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { this.logger = logger; this.accessTokenFactory = accessTokenFactory || (() => null); this.logMessageContent = logMessageContent; } - public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise { + public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise { Arg.isRequired(url, "url"); Arg.isRequired(transferFormat, "transferFormat"); Arg.isIn(transferFormat, TransferFormat, "transferFormat"); @@ -52,9 +52,9 @@ export class WebSocketTransport implements ITransport { this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting"); + const token = await this.accessTokenFactory(); return new Promise((resolve, reject) => { url = url.replace(/^http/, "ws"); - const token = this.accessTokenFactory(); if (token) { url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; } @@ -118,20 +118,20 @@ export class WebSocketTransport implements ITransport { export class ServerSentEventsTransport implements ITransport { private readonly httpClient: HttpClient; - private readonly accessTokenFactory: () => string; + private readonly accessTokenFactory: () => string | Promise; private readonly logger: ILogger; private readonly logMessageContent: boolean; private eventSource: EventSource; private url: string; - constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) { + constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { this.httpClient = httpClient; this.accessTokenFactory = accessTokenFactory || (() => null); this.logger = logger; this.logMessageContent = logMessageContent; } - public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise { + public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise { Arg.isRequired(url, "url"); Arg.isRequired(transferFormat, "transferFormat"); Arg.isIn(transferFormat, TransferFormat, "transferFormat"); @@ -144,12 +144,12 @@ export class ServerSentEventsTransport implements ITransport { this.logger.log(LogLevel.Trace, "(SSE transport) Connecting"); this.url = url; + const token = await this.accessTokenFactory(); return new Promise((resolve, reject) => { if (transferFormat !== TransferFormat.Text) { reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format")); } - const token = this.accessTokenFactory(); if (token) { url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; } @@ -210,7 +210,7 @@ export class ServerSentEventsTransport implements ITransport { export class LongPollingTransport implements ITransport { private readonly httpClient: HttpClient; - private readonly accessTokenFactory: () => string; + private readonly accessTokenFactory: () => string | Promise; private readonly logger: ILogger; private readonly logMessageContent: boolean; @@ -218,7 +218,7 @@ export class LongPollingTransport implements ITransport { private pollXhr: XMLHttpRequest; private pollAbort: AbortController; - constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) { + constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { this.httpClient = httpClient; this.accessTokenFactory = accessTokenFactory || (() => null); this.logger = logger; @@ -259,7 +259,7 @@ export class LongPollingTransport implements ITransport { pollOptions.responseType = "arraybuffer"; } - const token = this.accessTokenFactory(); + const token = await this.accessTokenFactory(); if (token) { // tslint:disable-next-line:no-string-literal pollOptions.headers["Authorization"] = `Bearer ${token}`; @@ -356,12 +356,12 @@ function formatArrayBuffer(data: ArrayBuffer): string { return str.substr(0, str.length - 1); } -async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string, content: string | ArrayBuffer, logMessageContent: boolean): Promise { +async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise, content: string | ArrayBuffer, logMessageContent: boolean): Promise { let headers; - const token = accessTokenFactory(); + const token = await accessTokenFactory(); if (token) { headers = { - ["Authorization"]: `Bearer ${accessTokenFactory()}`, + ["Authorization"]: `Bearer ${token}`, }; } diff --git a/samples/JwtClientSample/Program.cs b/samples/JwtClientSample/Program.cs index 324d7b9a8f..42fe1623df 100644 --- a/samples/JwtClientSample/Program.cs +++ b/samples/JwtClientSample/Program.cs @@ -23,13 +23,13 @@ namespace JwtClientSample private const string ServerUrl = "http://localhost:54543"; - private readonly ConcurrentDictionary _tokens = new ConcurrentDictionary(StringComparer.Ordinal); + private readonly ConcurrentDictionary> _tokens = new ConcurrentDictionary>(StringComparer.Ordinal); private readonly Random _random = new Random(); private async Task RunConnection(HttpTransportType transportType) { var userId = "C#" + transportType; - _tokens[userId] = await GetJwtToken(userId); + _tokens[userId] = GetJwtToken(userId); var hubConnection = new HubConnectionBuilder() .WithUrl(ServerUrl + "/broadcast", options => @@ -60,7 +60,7 @@ namespace JwtClientSample // no need to refresh the token for websockets if (transportType != HttpTransportType.WebSockets) { - _tokens[userId] = await GetJwtToken(userId); + _tokens[userId] = GetJwtToken(userId); Console.WriteLine($"[{userId}] Token refreshed"); } } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpOptions.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpOptions.cs index fabff436b1..38cea6822c 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpOptions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpOptions.cs @@ -7,6 +7,7 @@ using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Http.Connections.Client { @@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client public Func HttpMessageHandlerFactory { get; set; } public IReadOnlyCollection> Headers { get; set; } - public Func AccessTokenFactory { get; set; } + public Func> AccessTokenFactory { get; set; } public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); public ICredentials Credentials { get; set; } public X509CertificateCollection ClientCertificates { get; set; } = new X509CertificateCollection(); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs index cc2050b7aa..9965370492 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs @@ -11,18 +11,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { public class AccessTokenHttpMessageHandler : DelegatingHandler { - private readonly Func _accessTokenFactory; + private readonly Func> _accessTokenFactory; - public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func accessTokenFactory) : base(inner) + public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func> accessTokenFactory) : base(inner) { _accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory)); } - protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory()); + var accessToken = await _accessTokenFactory(); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); - return base.SendAsync(request, cancellationToken); + return await base.SendAsync(request, cancellationToken); } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs index c951f0fbb7..4a3930180a 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs @@ -17,6 +17,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal public partial class WebSocketsTransport : ITransport { private readonly ClientWebSocket _webSocket; + private readonly Func> _accessTokenFactory; private IDuplexPipe _application; private WebSocketMessageType _webSocketMessageType; private readonly ILogger _logger; @@ -80,7 +81,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal if (httpOptions.AccessTokenFactory != null) { - _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.AccessTokenFactory()}"); + _accessTokenFactory = httpOptions.AccessTokenFactory; } httpOptions.WebSocketOptions?.Invoke(_webSocket.Options); @@ -115,6 +116,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.StartTransport(_logger, transferFormat, resolvedUrl); + if (_accessTokenFactory != null) + { + var accessToken = await _accessTokenFactory(); + _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}"); + } + await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None); // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HttpConnectionOptions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HttpConnectionOptions.cs index 745691667d..b1a938c401 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HttpConnectionOptions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HttpConnectionOptions.cs @@ -7,6 +7,7 @@ using System.Net; using System.Net.Http; using System.Net.WebSockets; using System.Security.Cryptography.X509Certificates; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Internal; @@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public bool? UseDefaultCredentials { get; set; } public ICredentials Credentials { get; set; } public IWebProxy Proxy { get; set; } - public Func AccessTokenFactory { get; set; } + public Func> AccessTokenFactory { get; set; } public Action WebSocketOptions { get; set; } public X509CertificateCollection ClientCertificates diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 74ecb3a98b..901e3bc1f8 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -706,15 +706,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { using (StartLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthentication)}_{transportType}")) { - var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken"); - httpResponse.EnsureSuccessStatusCode(); - var token = await httpResponse.Content.ReadAsStringAsync(); + async Task AccessTokenFactory() + { + var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken"); + httpResponse.EnsureSuccessStatusCode(); + return await httpResponse.Content.ReadAsStringAsync(); + }; var hubConnection = new HubConnectionBuilder() .WithLoggerFactory(loggerFactory) .WithUrl(_serverFixture.Url + "/authorizedhub", transportType, options => { - options.AccessTokenFactory = () => token; + options.AccessTokenFactory = AccessTokenFactory; }) .Build(); try diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs index 85c243162d..aba4998f42 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests ITransport transport = null, ITransportFactory transportFactory = null, HttpTransportType transportType = HttpTransportType.LongPolling, - Func accessTokenFactory = null) + Func> accessTokenFactory = null) { var httpOptions = new HttpOptions { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs index 97e16b7394..f388328bd4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs @@ -50,10 +50,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); - string AccessTokenFactory() + Task AccessTokenFactory() { callCount++; - return callCount.ToString(); + return Task.FromResult(callCount.ToString()); } await WithConnectionAsync(