Support async access token factory (#1911)
This commit is contained in:
parent
6bc2ebb4c5
commit
31dfe91962
|
|
@ -532,6 +532,33 @@ describe("hubConnection", () => {
|
|||
}
|
||||
});
|
||||
|
||||
it("can connect to hub with authorization using async token factory", async (done) => {
|
||||
const message = "你好,世界!";
|
||||
|
||||
try {
|
||||
const hubConnection = new HubConnection("/authorizedhub", {
|
||||
accessTokenFactory: () => getJwtToken("http://" + document.location.host + "/generateJwtToken"),
|
||||
...commonOptions,
|
||||
transport: transportType,
|
||||
});
|
||||
hubConnection.onclose((error) => {
|
||||
expect(error).toBe(undefined);
|
||||
done();
|
||||
});
|
||||
await hubConnection.start();
|
||||
const response = await hubConnection.invoke("Echo", message);
|
||||
|
||||
expect(response).toEqual(message);
|
||||
|
||||
await hubConnection.stop();
|
||||
|
||||
done();
|
||||
} catch (err) {
|
||||
fail(err);
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
if (transportType !== TransportType.LongPolling) {
|
||||
it("terminates if no messages received within timeout interval", (done) => {
|
||||
const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, {
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ export interface IHttpConnectionOptions {
|
|||
httpClient?: HttpClient;
|
||||
transport?: TransportType | ITransport;
|
||||
logger?: ILogger | LogLevel;
|
||||
accessTokenFactory?: () => string;
|
||||
accessTokenFactory?: () => string | Promise<string>;
|
||||
logMessageContent?: boolean;
|
||||
}
|
||||
|
||||
|
|
@ -87,7 +87,7 @@ export class HttpConnection implements IConnection {
|
|||
// No fallback or negotiate in this case.
|
||||
await this.transport.connect(this.url, transferFormat, this);
|
||||
} else {
|
||||
const token = this.options.accessTokenFactory();
|
||||
const token = await this.options.accessTokenFactory();
|
||||
let headers;
|
||||
if (token) {
|
||||
headers = {
|
||||
|
|
|
|||
|
|
@ -30,17 +30,17 @@ export interface ITransport {
|
|||
|
||||
export class WebSocketTransport implements ITransport {
|
||||
private readonly logger: ILogger;
|
||||
private readonly accessTokenFactory: () => string;
|
||||
private readonly accessTokenFactory: () => string | Promise<string>;
|
||||
private readonly logMessageContent: boolean;
|
||||
private webSocket: WebSocket;
|
||||
|
||||
constructor(accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
|
||||
constructor(accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
|
||||
this.logger = logger;
|
||||
this.accessTokenFactory = accessTokenFactory || (() => null);
|
||||
this.logMessageContent = logMessageContent;
|
||||
}
|
||||
|
||||
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
|
||||
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
|
||||
Arg.isRequired(url, "url");
|
||||
Arg.isRequired(transferFormat, "transferFormat");
|
||||
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
|
||||
|
|
@ -52,9 +52,9 @@ export class WebSocketTransport implements ITransport {
|
|||
|
||||
this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting");
|
||||
|
||||
const token = await this.accessTokenFactory();
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
url = url.replace(/^http/, "ws");
|
||||
const token = this.accessTokenFactory();
|
||||
if (token) {
|
||||
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
|
@ -118,20 +118,20 @@ export class WebSocketTransport implements ITransport {
|
|||
|
||||
export class ServerSentEventsTransport implements ITransport {
|
||||
private readonly httpClient: HttpClient;
|
||||
private readonly accessTokenFactory: () => string;
|
||||
private readonly accessTokenFactory: () => string | Promise<string>;
|
||||
private readonly logger: ILogger;
|
||||
private readonly logMessageContent: boolean;
|
||||
private eventSource: EventSource;
|
||||
private url: string;
|
||||
|
||||
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
|
||||
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
|
||||
this.httpClient = httpClient;
|
||||
this.accessTokenFactory = accessTokenFactory || (() => null);
|
||||
this.logger = logger;
|
||||
this.logMessageContent = logMessageContent;
|
||||
}
|
||||
|
||||
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
|
||||
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
|
||||
Arg.isRequired(url, "url");
|
||||
Arg.isRequired(transferFormat, "transferFormat");
|
||||
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
|
||||
|
|
@ -144,12 +144,12 @@ export class ServerSentEventsTransport implements ITransport {
|
|||
this.logger.log(LogLevel.Trace, "(SSE transport) Connecting");
|
||||
|
||||
this.url = url;
|
||||
const token = await this.accessTokenFactory();
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
if (transferFormat !== TransferFormat.Text) {
|
||||
reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format"));
|
||||
}
|
||||
|
||||
const token = this.accessTokenFactory();
|
||||
if (token) {
|
||||
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
|
||||
}
|
||||
|
|
@ -210,7 +210,7 @@ export class ServerSentEventsTransport implements ITransport {
|
|||
|
||||
export class LongPollingTransport implements ITransport {
|
||||
private readonly httpClient: HttpClient;
|
||||
private readonly accessTokenFactory: () => string;
|
||||
private readonly accessTokenFactory: () => string | Promise<string>;
|
||||
private readonly logger: ILogger;
|
||||
private readonly logMessageContent: boolean;
|
||||
|
||||
|
|
@ -218,7 +218,7 @@ export class LongPollingTransport implements ITransport {
|
|||
private pollXhr: XMLHttpRequest;
|
||||
private pollAbort: AbortController;
|
||||
|
||||
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
|
||||
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
|
||||
this.httpClient = httpClient;
|
||||
this.accessTokenFactory = accessTokenFactory || (() => null);
|
||||
this.logger = logger;
|
||||
|
|
@ -259,7 +259,7 @@ export class LongPollingTransport implements ITransport {
|
|||
pollOptions.responseType = "arraybuffer";
|
||||
}
|
||||
|
||||
const token = this.accessTokenFactory();
|
||||
const token = await this.accessTokenFactory();
|
||||
if (token) {
|
||||
// tslint:disable-next-line:no-string-literal
|
||||
pollOptions.headers["Authorization"] = `Bearer ${token}`;
|
||||
|
|
@ -356,12 +356,12 @@ function formatArrayBuffer(data: ArrayBuffer): string {
|
|||
return str.substr(0, str.length - 1);
|
||||
}
|
||||
|
||||
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
|
||||
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise<string>, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
|
||||
let headers;
|
||||
const token = accessTokenFactory();
|
||||
const token = await accessTokenFactory();
|
||||
if (token) {
|
||||
headers = {
|
||||
["Authorization"]: `Bearer ${accessTokenFactory()}`,
|
||||
["Authorization"]: `Bearer ${token}`,
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@ namespace JwtClientSample
|
|||
|
||||
private const string ServerUrl = "http://localhost:54543";
|
||||
|
||||
private readonly ConcurrentDictionary<string, string> _tokens = new ConcurrentDictionary<string, string>(StringComparer.Ordinal);
|
||||
private readonly ConcurrentDictionary<string, Task<string>> _tokens = new ConcurrentDictionary<string, Task<string>>(StringComparer.Ordinal);
|
||||
private readonly Random _random = new Random();
|
||||
|
||||
private async Task RunConnection(HttpTransportType transportType)
|
||||
{
|
||||
var userId = "C#" + transportType;
|
||||
_tokens[userId] = await GetJwtToken(userId);
|
||||
_tokens[userId] = GetJwtToken(userId);
|
||||
|
||||
var hubConnection = new HubConnectionBuilder()
|
||||
.WithUrl(ServerUrl + "/broadcast", options =>
|
||||
|
|
@ -60,7 +60,7 @@ namespace JwtClientSample
|
|||
// no need to refresh the token for websockets
|
||||
if (transportType != HttpTransportType.WebSockets)
|
||||
{
|
||||
_tokens[userId] = await GetJwtToken(userId);
|
||||
_tokens[userId] = GetJwtToken(userId);
|
||||
Console.WriteLine($"[{userId}] Token refreshed");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ using System.Net;
|
|||
using System.Net.Http;
|
||||
using System.Net.WebSockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Http.Connections.Client
|
||||
{
|
||||
|
|
@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
public Func<HttpMessageHandler, HttpMessageHandler> HttpMessageHandlerFactory { get; set; }
|
||||
|
||||
public IReadOnlyCollection<KeyValuePair<string, string>> Headers { get; set; }
|
||||
public Func<string> AccessTokenFactory { get; set; }
|
||||
public Func<Task<string>> AccessTokenFactory { get; set; }
|
||||
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||
public ICredentials Credentials { get; set; }
|
||||
public X509CertificateCollection ClientCertificates { get; set; } = new X509CertificateCollection();
|
||||
|
|
|
|||
|
|
@ -11,18 +11,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
|||
{
|
||||
public class AccessTokenHttpMessageHandler : DelegatingHandler
|
||||
{
|
||||
private readonly Func<string> _accessTokenFactory;
|
||||
private readonly Func<Task<string>> _accessTokenFactory;
|
||||
|
||||
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<string> accessTokenFactory) : base(inner)
|
||||
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<Task<string>> accessTokenFactory) : base(inner)
|
||||
{
|
||||
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
|
||||
}
|
||||
|
||||
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
|
||||
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
|
||||
{
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory());
|
||||
var accessToken = await _accessTokenFactory();
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
|
||||
|
||||
return base.SendAsync(request, cancellationToken);
|
||||
return await base.SendAsync(request, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
|||
public partial class WebSocketsTransport : ITransport
|
||||
{
|
||||
private readonly ClientWebSocket _webSocket;
|
||||
private readonly Func<Task<string>> _accessTokenFactory;
|
||||
private IDuplexPipe _application;
|
||||
private WebSocketMessageType _webSocketMessageType;
|
||||
private readonly ILogger _logger;
|
||||
|
|
@ -80,7 +81,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
|||
|
||||
if (httpOptions.AccessTokenFactory != null)
|
||||
{
|
||||
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.AccessTokenFactory()}");
|
||||
_accessTokenFactory = httpOptions.AccessTokenFactory;
|
||||
}
|
||||
|
||||
httpOptions.WebSocketOptions?.Invoke(_webSocket.Options);
|
||||
|
|
@ -115,6 +116,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
|||
|
||||
Log.StartTransport(_logger, transferFormat, resolvedUrl);
|
||||
|
||||
if (_accessTokenFactory != null)
|
||||
{
|
||||
var accessToken = await _accessTokenFactory();
|
||||
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
|
||||
}
|
||||
|
||||
await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None);
|
||||
|
||||
// Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ using System.Net;
|
|||
using System.Net.Http;
|
||||
using System.Net.WebSockets;
|
||||
using System.Security.Cryptography.X509Certificates;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http.Connections;
|
||||
using Microsoft.AspNetCore.Http.Connections.Internal;
|
||||
|
||||
|
|
@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
public bool? UseDefaultCredentials { get; set; }
|
||||
public ICredentials Credentials { get; set; }
|
||||
public IWebProxy Proxy { get; set; }
|
||||
public Func<string> AccessTokenFactory { get; set; }
|
||||
public Func<Task<string>> AccessTokenFactory { get; set; }
|
||||
public Action<ClientWebSocketOptions> WebSocketOptions { get; set; }
|
||||
|
||||
public X509CertificateCollection ClientCertificates
|
||||
|
|
|
|||
|
|
@ -706,15 +706,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
{
|
||||
using (StartLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthentication)}_{transportType}"))
|
||||
{
|
||||
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
|
||||
httpResponse.EnsureSuccessStatusCode();
|
||||
var token = await httpResponse.Content.ReadAsStringAsync();
|
||||
async Task<string> AccessTokenFactory()
|
||||
{
|
||||
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
|
||||
httpResponse.EnsureSuccessStatusCode();
|
||||
return await httpResponse.Content.ReadAsStringAsync();
|
||||
};
|
||||
|
||||
var hubConnection = new HubConnectionBuilder()
|
||||
.WithLoggerFactory(loggerFactory)
|
||||
.WithUrl(_serverFixture.Url + "/authorizedhub", transportType, options =>
|
||||
{
|
||||
options.AccessTokenFactory = () => token;
|
||||
options.AccessTokenFactory = AccessTokenFactory;
|
||||
})
|
||||
.Build();
|
||||
try
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
ITransport transport = null,
|
||||
ITransportFactory transportFactory = null,
|
||||
HttpTransportType transportType = HttpTransportType.LongPolling,
|
||||
Func<string> accessTokenFactory = null)
|
||||
Func<Task<string>> accessTokenFactory = null)
|
||||
{
|
||||
var httpOptions = new HttpOptions
|
||||
{
|
||||
|
|
|
|||
|
|
@ -50,10 +50,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
return await next();
|
||||
});
|
||||
|
||||
string AccessTokenFactory()
|
||||
Task<string> AccessTokenFactory()
|
||||
{
|
||||
callCount++;
|
||||
return callCount.ToString();
|
||||
return Task.FromResult(callCount.ToString());
|
||||
}
|
||||
|
||||
await WithConnectionAsync(
|
||||
|
|
|
|||
Loading…
Reference in New Issue