From 921986d561d142b9eda940d872e8ca21711e6f21 Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Fri, 6 Apr 2018 15:46:36 +1200 Subject: [PATCH] Fix AccessTokenFactory not being called with each request (#1880) --- .../HttpConnection.cs | 25 +++++----- .../Internal/AccessTokenHttpMessageHandler.cs | 28 +++++++++++ .../HttpConnectionTests.Helpers.cs | 12 ++++- .../HttpConnectionTests.Transport.cs | 49 +++++++++++++++++++ 4 files changed, 98 insertions(+), 16 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs index a26bc27e43..f9d20d6d58 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs @@ -399,9 +399,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client throw new InvalidOperationException("Configured HttpMessageHandlerFactory did not return a value."); } } + + // Apply the authorization header in a handler instead of a default header because it can change with each request + if (_httpOptions.AccessTokenFactory != null) + { + httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpOptions.AccessTokenFactory); + } } - // Wrap message handler in a logging handler last to ensure it is always present + // Wrap message handler after HttpMessageHandlerFactory to ensure not overriden httpMessageHandler = new LoggingHttpMessageHandler(httpMessageHandler, _loggerFactory); var httpClient = new HttpClient(httpMessageHandler); @@ -410,21 +416,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client // Start with the user agent header httpClient.DefaultRequestHeaders.UserAgent.Add(Constants.UserAgentHeader); - if (_httpOptions != null) + // Apply any headers configured on the HttpOptions + if (_httpOptions?.Headers != null) { - // Apply any headers configured on the HttpOptions - if (_httpOptions.Headers != null) + foreach (var header in _httpOptions.Headers) { - foreach (var header in _httpOptions.Headers) - { - httpClient.DefaultRequestHeaders.Add(header.Key, header.Value); - } - } - - // Apply the authorization header - if (_httpOptions.AccessTokenFactory != null) - { - httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {_httpOptions.AccessTokenFactory()}"); + httpClient.DefaultRequestHeaders.Add(header.Key, header.Value); } } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs new file mode 100644 index 0000000000..cc2050b7aa --- /dev/null +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs @@ -0,0 +1,28 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Net.Http; +using System.Net.Http.Headers; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Http.Connections.Client.Internal +{ + public class AccessTokenHttpMessageHandler : DelegatingHandler + { + private readonly Func _accessTokenFactory; + + public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func accessTokenFactory) : base(inner) + { + _accessTokenFactory = accessTokenFactory ?? throw new ArgumentNullException(nameof(accessTokenFactory)); + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessTokenFactory()); + + return base.SendAsync(request, cancellationToken); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs index d7459e7de9..4f0434f1da 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Helpers.cs @@ -13,11 +13,19 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HttpConnectionTests { - private static HttpConnection CreateConnection(HttpMessageHandler httpHandler = null, ILoggerFactory loggerFactory = null, string url = null, ITransport transport = null, ITransportFactory transportFactory = null, HttpTransportType transportType = HttpTransportType.LongPolling) + private static HttpConnection CreateConnection( + HttpMessageHandler httpHandler = null, + ILoggerFactory loggerFactory = null, + string url = null, + ITransport transport = null, + ITransportFactory transportFactory = null, + HttpTransportType transportType = HttpTransportType.LongPolling, + Func accessTokenFactory = null) { - var httpOptions = new HttpOptions() + var httpOptions = new HttpOptions { HttpMessageHandlerFactory = (httpMessageHandler) => httpHandler ?? TestHttpMessageHandler.CreateDefault(), + AccessTokenFactory = accessTokenFactory, }; return CreateConnection(httpOptions, loggerFactory, url, transport, transportFactory, transportType); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs index a09fd682b3..97e16b7394 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs @@ -19,6 +19,55 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public class Transport { + [Theory] + [InlineData(HttpTransportType.LongPolling)] + [InlineData(HttpTransportType.ServerSentEvents)] + public async Task HttpConnectionSetsAccessTokenOnAllRequests(HttpTransportType transportType) + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + var requestsExecuted = false; + var callCount = 0; + + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + + testHttpHandler.OnNegotiate((_, cancellationToken) => + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); + }); + + testHttpHandler.OnRequest(async (request, next, token) => + { + Assert.Equal("Bearer", request.Headers.Authorization.Scheme); + + // Call count increments with each call and is used as the access token + Assert.Equal(callCount.ToString(), request.Headers.Authorization.Parameter); + + requestsExecuted = true; + + return await next(); + }); + + string AccessTokenFactory() + { + callCount++; + return callCount.ToString(); + } + + await WithConnectionAsync( + CreateConnection(testHttpHandler, transportType: transportType, accessTokenFactory: AccessTokenFactory), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 1")); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello world 2")); + }); + // Fail safe in case the code is modified and some requests don't execute as a result + Assert.True(requestsExecuted); + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)]