Fix AccessTokenFactory not being called with each request (#1880)
This commit is contained in:
parent
b0c4e9d0f7
commit
921986d561
|
|
@ -399,9 +399,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
throw new InvalidOperationException("Configured HttpMessageHandlerFactory did not return a value.");
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the authorization header in a handler instead of a default header because it can change with each request
|
||||
if (_httpOptions.AccessTokenFactory != null)
|
||||
{
|
||||
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpOptions.AccessTokenFactory);
|
||||
}
|
||||
}
|
||||
|
||||
// Wrap message handler in a logging handler last to ensure it is always present
|
||||
// Wrap message handler after HttpMessageHandlerFactory to ensure not overriden
|
||||
httpMessageHandler = new LoggingHttpMessageHandler(httpMessageHandler, _loggerFactory);
|
||||
|
||||
var httpClient = new HttpClient(httpMessageHandler);
|
||||
|
|
@ -410,21 +416,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
// Start with the user agent header
|
||||
httpClient.DefaultRequestHeaders.UserAgent.Add(Constants.UserAgentHeader);
|
||||
|
||||
if (_httpOptions != null)
|
||||
// Apply any headers configured on the HttpOptions
|
||||
if (_httpOptions?.Headers != null)
|
||||
{
|
||||
// Apply any headers configured on the HttpOptions
|
||||
if (_httpOptions.Headers != null)
|
||||
foreach (var header in _httpOptions.Headers)
|
||||
{
|
||||
foreach (var header in _httpOptions.Headers)
|
||||
{
|
||||
httpClient.DefaultRequestHeaders.Add(header.Key, header.Value);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply the authorization header
|
||||
if (_httpOptions.AccessTokenFactory != null)
|
||||
{
|
||||
httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {_httpOptions.AccessTokenFactory()}");
|
||||
httpClient.DefaultRequestHeaders.Add(header.Key, header.Value);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Headers;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
||||
{
|
||||
public class AccessTokenHttpMessageHandler : DelegatingHandler
|
||||
{
|
||||
private readonly Func<string> _accessTokenFactory;
|
||||
|
||||
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<string> accessTokenFactory) : base(inner)
|
||||
{
|
||||
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
|
||||
}
|
||||
|
||||
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
|
||||
{
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory());
|
||||
|
||||
return base.SendAsync(request, cancellationToken);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -13,11 +13,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
{
|
||||
public partial class HttpConnectionTests
|
||||
{
|
||||
private static HttpConnection CreateConnection(HttpMessageHandler httpHandler = null, ILoggerFactory loggerFactory = null, string url = null, ITransport transport = null, ITransportFactory transportFactory = null, HttpTransportType transportType = HttpTransportType.LongPolling)
|
||||
private static HttpConnection CreateConnection(
|
||||
HttpMessageHandler httpHandler = null,
|
||||
ILoggerFactory loggerFactory = null,
|
||||
string url = null,
|
||||
ITransport transport = null,
|
||||
ITransportFactory transportFactory = null,
|
||||
HttpTransportType transportType = HttpTransportType.LongPolling,
|
||||
Func<string> accessTokenFactory = null)
|
||||
{
|
||||
var httpOptions = new HttpOptions()
|
||||
var httpOptions = new HttpOptions
|
||||
{
|
||||
HttpMessageHandlerFactory = (httpMessageHandler) => httpHandler ?? TestHttpMessageHandler.CreateDefault(),
|
||||
AccessTokenFactory = accessTokenFactory,
|
||||
};
|
||||
|
||||
return CreateConnection(httpOptions, loggerFactory, url, transport, transportFactory, transportType);
|
||||
|
|
|
|||
|
|
@ -19,6 +19,55 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
{
|
||||
public class Transport
|
||||
{
|
||||
[Theory]
|
||||
[InlineData(HttpTransportType.LongPolling)]
|
||||
[InlineData(HttpTransportType.ServerSentEvents)]
|
||||
public async Task HttpConnectionSetsAccessTokenOnAllRequests(HttpTransportType transportType)
|
||||
{
|
||||
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
|
||||
var requestsExecuted = false;
|
||||
var callCount = 0;
|
||||
|
||||
testHttpHandler.OnRequest((request, next, token) =>
|
||||
{
|
||||
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
|
||||
});
|
||||
|
||||
testHttpHandler.OnNegotiate((_, cancellationToken) =>
|
||||
{
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent());
|
||||
});
|
||||
|
||||
testHttpHandler.OnRequest(async (request, next, token) =>
|
||||
{
|
||||
Assert.Equal("Bearer", request.Headers.Authorization.Scheme);
|
||||
|
||||
// Call count increments with each call and is used as the access token
|
||||
Assert.Equal(callCount.ToString(), request.Headers.Authorization.Parameter);
|
||||
|
||||
requestsExecuted = true;
|
||||
|
||||
return await next();
|
||||
});
|
||||
|
||||
string AccessTokenFactory()
|
||||
{
|
||||
callCount++;
|
||||
return callCount.ToString();
|
||||
}
|
||||
|
||||
await WithConnectionAsync(
|
||||
CreateConnection(testHttpHandler, transportType: transportType, accessTokenFactory: AccessTokenFactory),
|
||||
async (connection) =>
|
||||
{
|
||||
await connection.StartAsync(TransferFormat.Text).OrTimeout();
|
||||
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 1"));
|
||||
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 2"));
|
||||
});
|
||||
// Fail safe in case the code is modified and some requests don't execute as a result
|
||||
Assert.True(requestsExecuted);
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(HttpTransportType.LongPolling)]
|
||||
[InlineData(HttpTransportType.ServerSentEvents)]
|
||||
|
|
|
|||
Loading…
Reference in New Issue