Fix AccessTokenFactory not being called with each request (#1880)

This commit is contained in:
James Newton-King 2018-04-06 15:46:36 +12:00 committed by GitHub
parent b0c4e9d0f7
commit 921986d561
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 16 deletions

View File

@ -399,9 +399,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
throw new InvalidOperationException("Configured HttpMessageHandlerFactory did not return a value.");
}
}
// Apply the authorization header in a handler instead of a default header because it can change with each request
if (_httpOptions.AccessTokenFactory != null)
{
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpOptions.AccessTokenFactory);
}
}
// Wrap message handler in a logging handler last to ensure it is always present
// Wrap message handler after HttpMessageHandlerFactory to ensure not overriden
httpMessageHandler = new LoggingHttpMessageHandler(httpMessageHandler, _loggerFactory);
var httpClient = new HttpClient(httpMessageHandler);
@ -410,21 +416,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
// Start with the user agent header
httpClient.DefaultRequestHeaders.UserAgent.Add(Constants.UserAgentHeader);
if (_httpOptions != null)
// Apply any headers configured on the HttpOptions
if (_httpOptions?.Headers != null)
{
// Apply any headers configured on the HttpOptions
if (_httpOptions.Headers != null)
foreach (var header in _httpOptions.Headers)
{
foreach (var header in _httpOptions.Headers)
{
httpClient.DefaultRequestHeaders.Add(header.Key, header.Value);
}
}
// Apply the authorization header
if (_httpOptions.AccessTokenFactory != null)
{
httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {_httpOptions.AccessTokenFactory()}");
httpClient.DefaultRequestHeaders.Add(header.Key, header.Value);
}
}

View File

@ -0,0 +1,28 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
public class AccessTokenHttpMessageHandler : DelegatingHandler
{
private readonly Func<string> _accessTokenFactory;
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<string> accessTokenFactory) : base(inner)
{
_accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory));
}
protected override Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory());
return base.SendAsync(request, cancellationToken);
}
}
}

View File

@ -13,11 +13,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
public partial class HttpConnectionTests
{
private static HttpConnection CreateConnection(HttpMessageHandler httpHandler = null, ILoggerFactory loggerFactory = null, string url = null, ITransport transport = null, ITransportFactory transportFactory = null, HttpTransportType transportType = HttpTransportType.LongPolling)
private static HttpConnection CreateConnection(
HttpMessageHandler httpHandler = null,
ILoggerFactory loggerFactory = null,
string url = null,
ITransport transport = null,
ITransportFactory transportFactory = null,
HttpTransportType transportType = HttpTransportType.LongPolling,
Func<string> accessTokenFactory = null)
{
var httpOptions = new HttpOptions()
var httpOptions = new HttpOptions
{
HttpMessageHandlerFactory = (httpMessageHandler) => httpHandler ?? TestHttpMessageHandler.CreateDefault(),
AccessTokenFactory = accessTokenFactory,
};
return CreateConnection(httpOptions, loggerFactory, url, transport, transportFactory, transportType);

View File

@ -19,6 +19,55 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
public class Transport
{
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task HttpConnectionSetsAccessTokenOnAllRequests(HttpTransportType transportType)
{
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
var requestsExecuted = false;
var callCount = 0;
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
testHttpHandler.OnNegotiate((_, cancellationToken) =>
{
return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent());
});
testHttpHandler.OnRequest(async (request, next, token) =>
{
Assert.Equal("Bearer", request.Headers.Authorization.Scheme);
// Call count increments with each call and is used as the access token
Assert.Equal(callCount.ToString(), request.Headers.Authorization.Parameter);
requestsExecuted = true;
return await next();
});
string AccessTokenFactory()
{
callCount++;
return callCount.ToString();
}
await WithConnectionAsync(
CreateConnection(testHttpHandler, transportType: transportType, accessTokenFactory: AccessTokenFactory),
async (connection) =>
{
await connection.StartAsync(TransferFormat.Text).OrTimeout();
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 1"));
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 2"));
});
// Fail safe in case the code is modified and some requests don't execute as a result
Assert.True(requestsExecuted);
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]