Support async access token factory (#1911)

This commit is contained in:
Vegard Løkken 2018-04-10 14:34:10 +02:00 committed by David Fowler
parent 6bc2ebb4c5
commit 31dfe91962
11 changed files with 74 additions and 34 deletions

View File

@ -532,6 +532,33 @@ describe("hubConnection", () => {
}
});
it("can connect to hub with authorization using async token factory", async (done) => {
const message = "你好,世界!";
try {
const hubConnection = new HubConnection("/authorizedhub", {
accessTokenFactory: () => getJwtToken("http://" + document.location.host + "/generateJwtToken"),
...commonOptions,
transport: transportType,
});
hubConnection.onclose((error) => {
expect(error).toBe(undefined);
done();
});
await hubConnection.start();
const response = await hubConnection.invoke("Echo", message);
expect(response).toEqual(message);
await hubConnection.stop();
done();
} catch (err) {
fail(err);
done();
}
});
if (transportType !== TransportType.LongPolling) {
it("terminates if no messages received within timeout interval", (done) => {
const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, {

View File

@ -13,7 +13,7 @@ export interface IHttpConnectionOptions {
httpClient?: HttpClient;
transport?: TransportType | ITransport;
logger?: ILogger | LogLevel;
accessTokenFactory?: () => string;
accessTokenFactory?: () => string | Promise<string>;
logMessageContent?: boolean;
}
@ -87,7 +87,7 @@ export class HttpConnection implements IConnection {
// No fallback or negotiate in this case.
await this.transport.connect(this.url, transferFormat, this);
} else {
const token = this.options.accessTokenFactory();
const token = await this.options.accessTokenFactory();
let headers;
if (token) {
headers = {

View File

@ -30,17 +30,17 @@ export interface ITransport {
export class WebSocketTransport implements ITransport {
private readonly logger: ILogger;
private readonly accessTokenFactory: () => string;
private readonly accessTokenFactory: () => string | Promise<string>;
private readonly logMessageContent: boolean;
private webSocket: WebSocket;
constructor(accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
constructor(accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
this.logger = logger;
this.accessTokenFactory = accessTokenFactory || (() => null);
this.logMessageContent = logMessageContent;
}
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
Arg.isRequired(url, "url");
Arg.isRequired(transferFormat, "transferFormat");
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
@ -52,9 +52,9 @@ export class WebSocketTransport implements ITransport {
this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting");
const token = await this.accessTokenFactory();
return new Promise<void>((resolve, reject) => {
url = url.replace(/^http/, "ws");
const token = this.accessTokenFactory();
if (token) {
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
}
@ -118,20 +118,20 @@ export class WebSocketTransport implements ITransport {
export class ServerSentEventsTransport implements ITransport {
private readonly httpClient: HttpClient;
private readonly accessTokenFactory: () => string;
private readonly accessTokenFactory: () => string | Promise<string>;
private readonly logger: ILogger;
private readonly logMessageContent: boolean;
private eventSource: EventSource;
private url: string;
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
this.httpClient = httpClient;
this.accessTokenFactory = accessTokenFactory || (() => null);
this.logger = logger;
this.logMessageContent = logMessageContent;
}
public connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
public async connect(url: string, transferFormat: TransferFormat, connection: IConnection): Promise<void> {
Arg.isRequired(url, "url");
Arg.isRequired(transferFormat, "transferFormat");
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
@ -144,12 +144,12 @@ export class ServerSentEventsTransport implements ITransport {
this.logger.log(LogLevel.Trace, "(SSE transport) Connecting");
this.url = url;
const token = await this.accessTokenFactory();
return new Promise<void>((resolve, reject) => {
if (transferFormat !== TransferFormat.Text) {
reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format"));
}
const token = this.accessTokenFactory();
if (token) {
url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`;
}
@ -210,7 +210,7 @@ export class ServerSentEventsTransport implements ITransport {
export class LongPollingTransport implements ITransport {
private readonly httpClient: HttpClient;
private readonly accessTokenFactory: () => string;
private readonly accessTokenFactory: () => string | Promise<string>;
private readonly logger: ILogger;
private readonly logMessageContent: boolean;
@ -218,7 +218,7 @@ export class LongPollingTransport implements ITransport {
private pollXhr: XMLHttpRequest;
private pollAbort: AbortController;
constructor(httpClient: HttpClient, accessTokenFactory: () => string, logger: ILogger, logMessageContent: boolean) {
constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise<string>, logger: ILogger, logMessageContent: boolean) {
this.httpClient = httpClient;
this.accessTokenFactory = accessTokenFactory || (() => null);
this.logger = logger;
@ -259,7 +259,7 @@ export class LongPollingTransport implements ITransport {
pollOptions.responseType = "arraybuffer";
}
const token = this.accessTokenFactory();
const token = await this.accessTokenFactory();
if (token) {
// tslint:disable-next-line:no-string-literal
pollOptions.headers["Authorization"] = `Bearer ${token}`;
@ -356,12 +356,12 @@ function formatArrayBuffer(data: ArrayBuffer): string {
return str.substr(0, str.length - 1);
}
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise<string>, content: string | ArrayBuffer, logMessageContent: boolean): Promise<void> {
let headers;
const token = accessTokenFactory();
const token = await accessTokenFactory();
if (token) {
headers = {
["Authorization"]: `Bearer ${accessTokenFactory()}`,
["Authorization"]: `Bearer ${token}`,
};
}

View File

@ -23,13 +23,13 @@ namespace JwtClientSample
private const string ServerUrl = "http://localhost:54543";
private readonly ConcurrentDictionary<string, string> _tokens = new ConcurrentDictionary<string, string>(StringComparer.Ordinal);
private readonly ConcurrentDictionary<string, Task<string>> _tokens = new ConcurrentDictionary<string, Task<string>>(StringComparer.Ordinal);
private readonly Random _random = new Random();
private async Task RunConnection(HttpTransportType transportType)
{
var userId = "C#" + transportType;
_tokens[userId] = await GetJwtToken(userId);
_tokens[userId] = GetJwtToken(userId);
var hubConnection = new HubConnectionBuilder()
.WithUrl(ServerUrl + "/broadcast", options =>
@ -60,7 +60,7 @@ namespace JwtClientSample
// no need to refresh the token for websockets
if (transportType != HttpTransportType.WebSockets)
{
_tokens[userId] = await GetJwtToken(userId);
_tokens[userId] = GetJwtToken(userId);
Console.WriteLine($"[{userId}] Token refreshed");
}
}

View File

@ -7,6 +7,7 @@ using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Http.Connections.Client
{
@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
public Func<HttpMessageHandler, HttpMessageHandler> HttpMessageHandlerFactory { get; set; }
public IReadOnlyCollection<KeyValuePair<string, string>> Headers { get; set; }
public Func<string> AccessTokenFactory { get; set; }
public Func<Task<string>> AccessTokenFactory { get; set; }
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
public ICredentials Credentials { get; set; }
public X509CertificateCollection ClientCertificates { get; set; } = new X509CertificateCollection();

View File

@ -11,18 +11,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
public class AccessTokenHttpMessageHandler : DelegatingHandler
{
private readonly Func<string> _accessTokenFactory;
private readonly Func<Task<string>> _accessTokenFactory;
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<string> accessTokenFactory) : base(inner)
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<Task<string>> accessTokenFactory) : base(inner)
{
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory());
var accessToken = await _accessTokenFactory();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
return base.SendAsync(request, cancellationToken);
return await base.SendAsync(request, cancellationToken);
}
}
}

View File

@ -17,6 +17,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
public partial class WebSocketsTransport : ITransport
{
private readonly ClientWebSocket _webSocket;
private readonly Func<Task<string>> _accessTokenFactory;
private IDuplexPipe _application;
private WebSocketMessageType _webSocketMessageType;
private readonly ILogger _logger;
@ -80,7 +81,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
if (httpOptions.AccessTokenFactory != null)
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.AccessTokenFactory()}");
_accessTokenFactory = httpOptions.AccessTokenFactory;
}
httpOptions.WebSocketOptions?.Invoke(_webSocket.Options);
@ -115,6 +116,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
Log.StartTransport(_logger, transferFormat, resolvedUrl);
if (_accessTokenFactory != null)
{
var accessToken = await _accessTokenFactory();
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
}
await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None);
// Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa)

View File

@ -7,6 +7,7 @@ using System.Net;
using System.Net.Http;
using System.Net.WebSockets;
using System.Security.Cryptography.X509Certificates;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
@ -29,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public bool? UseDefaultCredentials { get; set; }
public ICredentials Credentials { get; set; }
public IWebProxy Proxy { get; set; }
public Func<string> AccessTokenFactory { get; set; }
public Func<Task<string>> AccessTokenFactory { get; set; }
public Action<ClientWebSocketOptions> WebSocketOptions { get; set; }
public X509CertificateCollection ClientCertificates

View File

@ -706,15 +706,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
using (StartLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthentication)}_{transportType}"))
{
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
httpResponse.EnsureSuccessStatusCode();
var token = await httpResponse.Content.ReadAsStringAsync();
async Task<string> AccessTokenFactory()
{
var httpResponse = await new HttpClient().GetAsync(_serverFixture.Url + "/generateJwtToken");
httpResponse.EnsureSuccessStatusCode();
return await httpResponse.Content.ReadAsStringAsync();
};
var hubConnection = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
.WithUrl(_serverFixture.Url + "/authorizedhub", transportType, options =>
{
options.AccessTokenFactory = () => token;
options.AccessTokenFactory = AccessTokenFactory;
})
.Build();
try

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
ITransport transport = null,
ITransportFactory transportFactory = null,
HttpTransportType transportType = HttpTransportType.LongPolling,
Func<string> accessTokenFactory = null)
Func<Task<string>> accessTokenFactory = null)
{
var httpOptions = new HttpOptions
{

View File

@ -50,10 +50,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return await next();
});
string AccessTokenFactory()
Task<string> AccessTokenFactory()
{
callCount++;
return callCount.ToString();
return Task.FromResult(callCount.ToString());
}
await WithConnectionAsync(