diff --git a/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs b/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs index 50c4584c37..887719c624 100644 --- a/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs +++ b/src/Middleware/HeaderPropagation/ref/Microsoft.AspNetCore.HeaderPropagation.netcoreapp3.0.cs @@ -21,9 +21,9 @@ namespace Microsoft.AspNetCore.HeaderPropagation } public partial class HeaderPropagationEntry { - public HeaderPropagationEntry(string inboundHeaderName, string outboundHeaderName, System.Func valueFilter) { } + public HeaderPropagationEntry(string inboundHeaderName, string capturedHeaderName, System.Func valueFilter) { } + public string CapturedHeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } public string InboundHeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } - public string OutboundHeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } public System.Func ValueFilter { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } } public sealed partial class HeaderPropagationEntryCollection : System.Collections.ObjectModel.Collection @@ -36,9 +36,26 @@ namespace Microsoft.AspNetCore.HeaderPropagation } public partial class HeaderPropagationMessageHandler : System.Net.Http.DelegatingHandler { - public HeaderPropagationMessageHandler(Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationValues values) { } + public HeaderPropagationMessageHandler(Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationMessageHandlerOptions options, Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationValues values) { } protected override System.Threading.Tasks.Task SendAsync(System.Net.Http.HttpRequestMessage request, System.Threading.CancellationToken cancellationToken) { throw null; } } + public partial class HeaderPropagationMessageHandlerEntry + { + public HeaderPropagationMessageHandlerEntry(string capturedHeaderName, string outboundHeaderName) { } + public string CapturedHeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + public string OutboundHeaderName { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } } + } + public sealed partial class HeaderPropagationMessageHandlerEntryCollection : System.Collections.ObjectModel.Collection + { + public HeaderPropagationMessageHandlerEntryCollection() { } + public void Add(string headerName) { } + public void Add(string capturedHeaderName, string outboundHeaderName) { } + } + public partial class HeaderPropagationMessageHandlerOptions + { + public HeaderPropagationMessageHandlerOptions() { } + public Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationMessageHandlerEntryCollection Headers { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } + } public partial class HeaderPropagationMiddleware { public HeaderPropagationMiddleware(Microsoft.AspNetCore.Http.RequestDelegate next, Microsoft.Extensions.Options.IOptions options, Microsoft.AspNetCore.HeaderPropagation.HeaderPropagationValues values) { } @@ -60,6 +77,7 @@ namespace Microsoft.Extensions.DependencyInjection public static partial class HeaderPropagationHttpClientBuilderExtensions { public static Microsoft.Extensions.DependencyInjection.IHttpClientBuilder AddHeaderPropagation(this Microsoft.Extensions.DependencyInjection.IHttpClientBuilder builder) { throw null; } + public static Microsoft.Extensions.DependencyInjection.IHttpClientBuilder AddHeaderPropagation(this Microsoft.Extensions.DependencyInjection.IHttpClientBuilder builder, System.Action configure) { throw null; } } public static partial class HeaderPropagationServiceCollectionExtensions { diff --git a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/HeaderPropagationSample.csproj b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/HeaderPropagationSample.csproj index 34c4dfc8b6..8410f6d290 100644 --- a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/HeaderPropagationSample.csproj +++ b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/HeaderPropagationSample.csproj @@ -8,6 +8,7 @@ + diff --git a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Program.cs b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Program.cs index bbfb777399..648d98e459 100644 --- a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Program.cs +++ b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Program.cs @@ -15,6 +15,7 @@ namespace HeaderPropagationSample Host.CreateDefaultBuilder(args) .ConfigureWebHost(webBuilder => { + webBuilder.UseKestrel(); webBuilder.UseStartup(); }); } diff --git a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs index e404fc15e4..6787910f19 100644 --- a/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs +++ b/src/Middleware/HeaderPropagation/samples/HeaderPropagationSample/Startup.cs @@ -49,6 +49,9 @@ namespace HeaderPropagationSample .AddHttpClient("test") .AddHeaderPropagation(); + services + .AddHttpClient("another") + .AddHeaderPropagation(options => options.Headers.Add("X-BetaFeatures", "X-Experiments")); } public void Configure(IApplicationBuilder app, IWebHostEnvironment env, IHttpClientFactory clientFactory) @@ -71,19 +74,23 @@ namespace HeaderPropagationSample await context.Response.WriteAsync($"'/' Got Header '{header.Key}': {string.Join(", ", header.Value)}\r\n"); } - await context.Response.WriteAsync("Sending request to /forwarded\r\n"); - - var uri = UriHelper.BuildAbsolute(context.Request.Scheme, context.Request.Host, context.Request.PathBase, "/forwarded"); - var client = clientFactory.CreateClient("test"); - var response = await client.GetAsync(uri); - - foreach (var header in response.RequestMessage.Headers) + var clientNames = new[] { "test", "another" }; + foreach (var clientName in clientNames) { - await context.Response.WriteAsync($"Sent Header '{header.Key}': {string.Join(", ", header.Value)}\r\n"); - } + await context.Response.WriteAsync("Sending request to /forwarded\r\n"); - await context.Response.WriteAsync("Got response\r\n"); - await context.Response.WriteAsync(await response.Content.ReadAsStringAsync()); + var uri = UriHelper.BuildAbsolute(context.Request.Scheme, context.Request.Host, context.Request.PathBase, "/forwarded"); + var client = clientFactory.CreateClient(clientName); + var response = await client.GetAsync(uri); + + foreach (var header in response.RequestMessage.Headers) + { + await context.Response.WriteAsync($"Sent Header '{header.Key}': {string.Join(", ", header.Value)}\r\n"); + } + + await context.Response.WriteAsync("Got response\r\n"); + await context.Response.WriteAsync(await response.Content.ReadAsStringAsync()); + } }); endpoints.MapGet("/forwarded", async context => diff --git a/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationHttpClientBuilderExtensions.cs b/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationHttpClientBuilderExtensions.cs index f71614392f..cf28277a62 100644 --- a/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationHttpClientBuilderExtensions.cs +++ b/src/Middleware/HeaderPropagation/src/DependencyInjection/HeaderPropagationHttpClientBuilderExtensions.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNetCore.HeaderPropagation; +using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection { @@ -11,6 +12,9 @@ namespace Microsoft.Extensions.DependencyInjection /// /// Adds a message handler for propagating headers collected by the to a outgoing request. /// + /// + /// When using this method, all the configured headers will be applied to the outgoing HTTP requests. + /// /// The to add the message handler to. /// The so that additional calls can be chained. public static IHttpClientBuilder AddHeaderPropagation(this IHttpClientBuilder builder) @@ -22,7 +26,49 @@ namespace Microsoft.Extensions.DependencyInjection builder.Services.AddHeaderPropagation(); - builder.AddHttpMessageHandler(); + builder.AddHttpMessageHandler(services => + { + var options = new HeaderPropagationMessageHandlerOptions(); + var middlewareOptions = services.GetRequiredService>(); + for (var i = 0; i < middlewareOptions.Value.Headers.Count; i++) + { + var header = middlewareOptions.Value.Headers[i]; + options.Headers.Add(header.CapturedHeaderName, header.CapturedHeaderName); + } + return new HeaderPropagationMessageHandler(options, services.GetRequiredService()); + }); + + return builder; + } + + /// + /// Adds a message handler for propagating headers collected by the to a outgoing request, + /// explicitly specifying which headers to propagate. + /// + /// This also allows to redefine the name to use for a header in the outgoing request. + /// The to add the message handler to. + /// A delegate used to configure the . + /// The so that additional calls can be chained. + public static IHttpClientBuilder AddHeaderPropagation(this IHttpClientBuilder builder, Action configure) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (configure == null) + { + throw new ArgumentNullException(nameof(configure)); + } + + builder.Services.AddHeaderPropagation(); + + builder.AddHttpMessageHandler(services => + { + var options = new HeaderPropagationMessageHandlerOptions(); + configure(options); + return new HeaderPropagationMessageHandler(options, services.GetRequiredService()); + }); return builder; } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationEntry.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationEntry.cs index 197c7d71ee..2955a2fe1f 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationEntry.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationEntry.cs @@ -13,12 +13,12 @@ namespace Microsoft.AspNetCore.HeaderPropagation { /// /// Creates a new with the provided , - /// , and + /// and . /// /// /// The name of the header to be captured by . /// - /// + /// /// The name of the header to be added by . /// /// @@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// public HeaderPropagationEntry( string inboundHeaderName, - string outboundHeaderName, + string capturedHeaderName, Func valueFilter) { if (inboundHeaderName == null) @@ -34,13 +34,13 @@ namespace Microsoft.AspNetCore.HeaderPropagation throw new ArgumentNullException(nameof(inboundHeaderName)); } - if (outboundHeaderName == null) + if (capturedHeaderName == null) { - throw new ArgumentNullException(nameof(outboundHeaderName)); + throw new ArgumentNullException(nameof(capturedHeaderName)); } InboundHeaderName = inboundHeaderName; - OutboundHeaderName = outboundHeaderName; + CapturedHeaderName = capturedHeaderName; ValueFilter = valueFilter; // May be null } @@ -50,10 +50,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation public string InboundHeaderName { get; } /// - /// Gets the name of the header to be used by the for the + /// Gets the name of the header to be used by default by the for the /// outbound http requests. /// - public string OutboundHeaderName { get; } + public string CapturedHeaderName { get; } /// /// Gets or sets a filter delegate that can be used to transform the header value. diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationEntryCollection.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationEntryCollection.cs index 86eec5d9df..c0d2617674 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationEntryCollection.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationEntryCollection.cs @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// /// Adds an that will use as /// the value of and - /// . + /// . /// /// The header name to be propagated. public void Add(string headerName) @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// /// Adds an that will use as /// the value of and - /// . + /// . /// /// The header name to be propagated. /// diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs index c9d0df9d87..7d07dc3572 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs @@ -6,7 +6,6 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; -using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.HeaderPropagation @@ -17,7 +16,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation public class HeaderPropagationMessageHandler : DelegatingHandler { private readonly HeaderPropagationValues _values; - private readonly HeaderPropagationOptions _options; + private readonly HeaderPropagationMessageHandlerOptions _options; /// /// Creates a new instance of the . @@ -25,15 +24,9 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// The options that define which headers are propagated. /// The values of the headers to be propagated populated by the /// . - public HeaderPropagationMessageHandler(IOptions options, HeaderPropagationValues values) + public HeaderPropagationMessageHandler(HeaderPropagationMessageHandlerOptions options, HeaderPropagationValues values) { - if (options == null) - { - throw new ArgumentNullException(nameof(options)); - } - - _options = options.Value; - + _options = options ?? throw new ArgumentNullException(nameof(options)); _values = values ?? throw new ArgumentNullException(nameof(values)); } @@ -71,7 +64,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation if (!request.Headers.TryGetValues(entry.OutboundHeaderName, out var _) && !(hasContent && request.Content.Headers.TryGetValues(entry.OutboundHeaderName, out var _))) { - if (captured.TryGetValue(entry.OutboundHeaderName, out var stringValues) && + if (captured.TryGetValue(entry.CapturedHeaderName, out var stringValues) && !StringValues.IsNullOrEmpty(stringValues)) { if (stringValues.Count == 1) diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntry.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntry.cs new file mode 100644 index 0000000000..8c982696a4 --- /dev/null +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntry.cs @@ -0,0 +1,51 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.HeaderPropagation +{ + /// + /// Define the configuration of an header for the . + /// + public class HeaderPropagationMessageHandlerEntry + { + /// + /// Creates a new with the provided + /// and . + /// + /// + /// The name of the header to be used to lookup the headers captured by the . + /// + /// + /// The name of the header to be added to the outgoing http requests by the . + /// + public HeaderPropagationMessageHandlerEntry( + string capturedHeaderName, + string outboundHeaderName) + { + if (capturedHeaderName == null) + { + throw new ArgumentNullException(nameof(capturedHeaderName)); + } + + if (outboundHeaderName == null) + { + throw new ArgumentNullException(nameof(outboundHeaderName)); + } + + CapturedHeaderName = capturedHeaderName; + OutboundHeaderName = outboundHeaderName; + } + + /// + /// Gets the name of the header to be used to lookup the headers captured by the . + /// + public string CapturedHeaderName { get; } + + /// + /// Gets the name of the header to be added to the outgoing http requests by the . + /// + public string OutboundHeaderName { get; } + } +} diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntryCollection.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntryCollection.cs new file mode 100644 index 0000000000..002f3a3368 --- /dev/null +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerEntryCollection.cs @@ -0,0 +1,57 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.ObjectModel; + +namespace Microsoft.AspNetCore.HeaderPropagation +{ + /// + /// A collection of items. + /// + public sealed class HeaderPropagationMessageHandlerEntryCollection : Collection + { + /// + /// Adds an that will use as + /// the value of and + /// . + /// + /// + /// The name of the header to be added by the . + /// + public void Add(string headerName) + { + if (headerName == null) + { + throw new ArgumentNullException(nameof(headerName)); + } + + Add(new HeaderPropagationMessageHandlerEntry(headerName, headerName)); + } + + /// + /// Adds an that will use the provided + /// and . + /// + /// + /// The name of the header captured by the . + /// + /// + /// The name of the header to be added by the . + /// + public void Add(string capturedHeaderName, string outboundHeaderName) + { + if (capturedHeaderName == null) + { + throw new ArgumentNullException(nameof(capturedHeaderName)); + } + + if (outboundHeaderName == null) + { + throw new ArgumentNullException(nameof(outboundHeaderName)); + } + + Add(new HeaderPropagationMessageHandlerEntry(capturedHeaderName, outboundHeaderName)); + } + } +} diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerOptions.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerOptions.cs new file mode 100644 index 0000000000..27dbb74b2d --- /dev/null +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandlerOptions.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.HeaderPropagation +{ + /// + /// Provides configuration for the . + /// + public class HeaderPropagationMessageHandlerOptions + { + /// + /// Gets or sets the headers to be propagated by the . + /// + /// + /// If is empty, all the headers captured by the are propagated. + /// Entries in are processed in order while adding headers inside + /// . This can cause an earlier entry to take precedence + /// over a later entry if they have the same . + /// + public HeaderPropagationMessageHandlerEntryCollection Headers { get; set; } = new HeaderPropagationMessageHandlerEntryCollection(); + } +} diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs index ed18acd3ac..f62c9e4a72 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs @@ -46,12 +46,12 @@ namespace Microsoft.AspNetCore.HeaderPropagation // We intentionally process entries in order, and allow earlier entries to // take precedence over later entries when they have the same output name. - if (!headers.ContainsKey(entry.OutboundHeaderName)) + if (!headers.ContainsKey(entry.CapturedHeaderName)) { var value = GetValue(context, entry); if (!StringValues.IsNullOrEmpty(value)) { - headers.Add(entry.OutboundHeaderName, value); + headers.Add(entry.CapturedHeaderName, value); } } } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationOptions.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationOptions.cs index e5f26e831b..731c2ca286 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationOptions.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationOptions.cs @@ -13,9 +13,9 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// and to be propagated by the . /// /// - /// Entries in are processes in order while capturing headers inside + /// Entries in are processed in order while capturing headers inside /// . This can cause an earlier entry to take precedence - /// over a later entry if they have the same . + /// over a later entry if they have the same . /// public HeaderPropagationEntryCollection Headers { get; set; } = new HeaderPropagationEntryCollection(); } diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationValues.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationValues.cs index fdf0067a58..741b02e7d6 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationValues.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationValues.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation /// that can be propagated. /// /// - /// The keys of correspond to . + /// The keys of correspond to . /// public IDictionary Headers { diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs index 60a417e883..585bb91472 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationIntegrationTest.cs @@ -91,6 +91,35 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests Assert.Equal(new[] { "test" }, handler.Headers.GetValues("out")); } + [Fact] + public async Task MultipleHeaders_HeadersInRequest_AddAllHeaders() + { + // Arrange + var handler = new SimpleHandler(); + var builder = CreateBuilder(c => + { + c.Headers.Add("first"); + c.Headers.Add("second"); + }, + handler); + var server = new TestServer(builder); + var client = server.CreateClient(); + + var request = new HttpRequestMessage(); + request.Headers.Add("first", "value"); + request.Headers.Add("second", "other"); + + // Act + var response = await client.SendAsync(request); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.True(handler.Headers.Contains("first")); + Assert.Equal(new[] { "value" }, handler.Headers.GetValues("first")); + Assert.True(handler.Headers.Contains("second")); + Assert.Equal(new[] { "other" }, handler.Headers.GetValues("second")); + } + [Fact] public void Builder_UseHeaderPropagation_Without_AddHeaderPropagation_Throws() { @@ -106,7 +135,31 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests exception.Message); } - private IWebHostBuilder CreateBuilder(Action configure, HttpMessageHandler primaryHandler) + [Fact] + public async Task HeaderInRequest_OverrideHeaderPerClient_AddCorrectValue() + { + // Arrange + var handler = new SimpleHandler(); + var builder = CreateBuilder( + c => c.Headers.Add("in", "out"), + handler, + c => c.Headers.Add("out", "different")); + var server = new TestServer(builder); + var client = server.CreateClient(); + + var request = new HttpRequestMessage(); + request.Headers.Add("in", "test"); + + // Act + var response = await client.SendAsync(request); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.True(handler.Headers.Contains("different")); + Assert.Equal(new[] { "test" }, handler.Headers.GetValues("different")); + } + + private IWebHostBuilder CreateBuilder(Action configure, HttpMessageHandler primaryHandler, Action configureClient = null) { return new WebHostBuilder() .Configure(app => @@ -116,13 +169,21 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests }) .ConfigureServices(services => { - services.AddHttpClient("example.com", c => c.BaseAddress = new Uri("http://example.com")) + services.AddHeaderPropagation(configure); + var client = services.AddHttpClient("example.com", c => c.BaseAddress = new Uri("http://example.com")) .ConfigureHttpMessageHandlerBuilder(b => { b.PrimaryHandler = primaryHandler; - }) - .AddHeaderPropagation(); - services.AddHeaderPropagation(configure); + }); + + if (configureClient != null) + { + client.AddHeaderPropagation(configureClient); + } + else + { + client.AddHeaderPropagation(); + } }); } diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerEntryCollectionTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerEntryCollectionTest.cs new file mode 100644 index 0000000000..483585c6d2 --- /dev/null +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerEntryCollectionTest.cs @@ -0,0 +1,34 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Xunit; + +namespace Microsoft.AspNetCore.HeaderPropagation.Tests +{ + public class HeaderPropagationMessageHandlerEntryCollectionTest + { + [Fact] + public void Add_SingleValue_UseValueForBothProperties() + { + var collection = new HeaderPropagationMessageHandlerEntryCollection(); + collection.Add("foo"); + + Assert.Single(collection); + var entry = collection[0]; + Assert.Equal("foo", entry.CapturedHeaderName); + Assert.Equal("foo", entry.OutboundHeaderName); + } + + [Fact] + public void Add_BothValues_UseCorrectValues() + { + var collection = new HeaderPropagationMessageHandlerEntryCollection(); + collection.Add("foo", "bar"); + + Assert.Single(collection); + var entry = collection[0]; + Assert.Equal("foo", entry.CapturedHeaderName); + Assert.Equal("bar", entry.OutboundHeaderName); + } + } +} diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs index d19348c3d0..fe51ea0e1f 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs @@ -7,7 +7,6 @@ using System.Net.Http; using System.Net.Http.Headers; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Options; using Microsoft.Extensions.Primitives; using Xunit; @@ -22,10 +21,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests State = new HeaderPropagationValues(); State.Headers = new Dictionary(StringComparer.OrdinalIgnoreCase); - Configuration = new HeaderPropagationOptions(); + Configuration = new HeaderPropagationMessageHandlerOptions(); var headerPropagationMessageHandler = - new HeaderPropagationMessageHandler(Options.Create(Configuration), State) + new HeaderPropagationMessageHandler(Configuration, State) { InnerHandler = Handler }; @@ -38,14 +37,14 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests private SimpleHandler Handler { get; } public HeaderPropagationValues State { get; set; } - public HeaderPropagationOptions Configuration { get; set; } + public HeaderPropagationMessageHandlerOptions Configuration { get; set; } public HttpClient Client { get; set; } [Fact] public async Task HeaderInState_AddCorrectValue() { // Arrange - Configuration.Headers.Add("in", "out"); + Configuration.Headers.Add("out"); State.Headers.Add("out", "test"); // Act @@ -60,7 +59,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests public async Task HeaderInState_WithMultipleValues_AddAllValues() { // Arrange - Configuration.Headers.Add("in", "out"); + Configuration.Headers.Add("out"); State.Headers.Add("out", new[] { "one", "two" }); // Act @@ -74,8 +73,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests [Fact] public async Task HeaderInState_RequestWithContent_ContentHeaderPresent_DoesNotAddIt() { - Configuration.Headers.Add("in", "Content-Type"); - State.Headers.Add("in", "test"); + Configuration.Headers.Add("Content-Type"); + State.Headers.Add("Content-Type", "test"); // Act await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); @@ -88,7 +87,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests [Fact] public async Task HeaderInState_RequestWithContent_ContentHeaderNotPresent_AddValue() { - Configuration.Headers.Add("in", "Content-Language"); + Configuration.Headers.Add("Content-Language"); State.Headers.Add("Content-Language", "test"); // Act @@ -102,7 +101,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests [Fact] public async Task HeaderInState_WithMultipleValues_RequestWithContent_ContentHeaderNotPresent_AddAllValues() { - Configuration.Headers.Add("in", "Content-Language"); + Configuration.Headers.Add("Content-Language"); State.Headers.Add("Content-Language", new[] { "one", "two" }); // Act @@ -114,18 +113,18 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests } [Fact] - public async Task HeaderInState_NoOutputName_UseInputName() + public async Task HeaderInState_WithOutboundName_UseOutboundName() { // Arrange - Configuration.Headers.Add("in"); - State.Headers.Add("in", "test"); + Configuration.Headers.Add("state", "out"); + State.Headers.Add("state", "test"); // Act await Client.SendAsync(new HttpRequestMessage()); // Assert - Assert.True(Handler.Headers.Contains("in")); - Assert.Equal(new[] { "test" }, Handler.Headers.GetValues("in")); + Assert.True(Handler.Headers.Contains("out")); + Assert.Equal(new[] { "test" }, Handler.Headers.GetValues("out")); } [Fact] @@ -193,7 +192,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests [InlineData("", new[] { "" })] [InlineData(null, new[] { "" })] [InlineData("42", new[] { "42" })] - public async Task HeaderInState_HeaderAlreadyInOutgoingRequest(string outgoingValue, + public async Task HeaderInState_HeaderAlreadyInOutgoingRequest_DoesNotOverrideIt(string outgoingValue, string[] expectedValues) { // Arrange @@ -211,6 +210,57 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests Assert.Equal(expectedValues, Handler.Headers.GetValues("inout")); } + [Fact] + public async Task HeaderInState_HeaderTwiceInOptions_DoesNotAddItTwice() + { + // Arrange + State.Headers.Add("name", "value"); + Configuration.Headers.Add("name"); + Configuration.Headers.Add("name"); + + // Act + await Client.SendAsync(new HttpRequestMessage()); + + // Assert + Assert.True(Handler.Headers.Contains("name")); + Assert.Equal(new[] { "value" }, Handler.Headers.GetValues("name")); + } + + [Fact] + public async Task HeaderInState_HeaderTwiceInOptionsWithDifferentNames_AddsBoth() + { + // Arrange + State.Headers.Add("name", "value"); + Configuration.Headers.Add("name"); + Configuration.Headers.Add("name", "other"); + + // Act + await Client.SendAsync(new HttpRequestMessage()); + + // Assert + Assert.True(Handler.Headers.Contains("name")); + Assert.Equal(new[] { "value" }, Handler.Headers.GetValues("name")); + Assert.True(Handler.Headers.Contains("other")); + Assert.Equal(new[] { "value" }, Handler.Headers.GetValues("name")); + } + + [Fact] + public async Task TwoHeadersInState_BothHeadersInOptionsWithSameName_AddsFirst() + { + // Arrange + State.Headers.Add("name", "value"); + State.Headers.Add("other", "override"); + Configuration.Headers.Add("name"); + Configuration.Headers.Add("other", "name"); + + // Act + await Client.SendAsync(new HttpRequestMessage()); + + // Assert + Assert.True(Handler.Headers.Contains("name")); + Assert.Equal(new[] { "value" }, Handler.Headers.GetValues("name")); + } + private class SimpleHandler : DelegatingHandler { public HttpHeaders Headers { get; private set; }