diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs index ff47b5ba48..524e29856f 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMessageHandler.cs @@ -53,11 +53,31 @@ namespace Microsoft.AspNetCore.HeaderPropagation { var outputName = string.IsNullOrEmpty(entry?.OutboundHeaderName) ? headerName : entry.OutboundHeaderName; - if (!request.Headers.Contains(outputName) && - _values.Headers.TryGetValue(headerName, out var values) && - !StringValues.IsNullOrEmpty(values)) + var hasContent = request.Content != null; + + if (!request.Headers.TryGetValues(outputName, out var _) && + !(hasContent && request.Content.Headers.TryGetValues(outputName, out var _))) { - request.Headers.TryAddWithoutValidation(outputName, (string[])values); + if (_values.Headers.TryGetValue(headerName, out var stringValues) && + !StringValues.IsNullOrEmpty(stringValues)) + { + if (stringValues.Count == 1) + { + var value = (string)stringValues; + if (!request.Headers.TryAddWithoutValidation(outputName, value) && hasContent) + { + request.Content.Headers.TryAddWithoutValidation(outputName, value); + } + } + else + { + var values = (string[])stringValues; + if (!request.Headers.TryAddWithoutValidation(outputName, values) && hasContent) + { + request.Content.Headers.TryAddWithoutValidation(outputName, values); + } + } + } } } diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs index 320cc108f1..a824fb13d1 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMessageHandlerTest.cs @@ -52,6 +52,63 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests Assert.Equal(new[] { "test" }, Handler.Headers.GetValues("out")); } + [Fact] + public async Task HeaderInState_WithMultipleValues_AddAllValues() + { + // Arrange + Configuration.Headers.Add("in", new HeaderPropagationEntry { OutboundHeaderName = "out" }); + State.Headers.Add("in", new[] { "one", "two" }); + + // Act + await Client.SendAsync(new HttpRequestMessage()); + + // Assert + Assert.True(Handler.Headers.Contains("out")); + Assert.Equal(new[] { "one", "two" }, Handler.Headers.GetValues("out")); + } + + [Fact] + public async Task HeaderInState_RequestWithContent_ContentHeaderPresent_DoesNotAddIt() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Type" }); + State.Headers.Add("in", "test"); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Type")); + Assert.Equal(new[] { "text/plain; charset=utf-8" }, Handler.Content.Headers.GetValues("Content-Type")); + } + + [Fact] + public async Task HeaderInState_RequestWithContent_ContentHeaderNotPresent_AddValue() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Language" }); + State.Headers.Add("in", "test"); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Language")); + Assert.Equal(new[] { "test" }, Handler.Content.Headers.GetValues("Content-Language")); + } + + [Fact] + public async Task HeaderInState_WithMultipleValues_RequestWithContent_ContentHeaderNotPresent_AddAllValues() + { + Configuration.Headers.Add("in", new HeaderPropagationEntry() { OutboundHeaderName = "Content-Language" }); + State.Headers.Add("in", new[] { "one", "two" }); + + // Act + await Client.SendAsync(new HttpRequestMessage() { Content = new StringContent("test") }); + + // Assert + Assert.True(Handler.Content.Headers.Contains("Content-Language")); + Assert.Equal(new[] { "one", "two" }, Handler.Content.Headers.GetValues("Content-Language")); + } + [Fact] public async Task HeaderInState_NoOutputName_UseInputName() { @@ -168,11 +225,13 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests private class SimpleHandler : DelegatingHandler { public HttpHeaders Headers { get; private set; } + public HttpContent Content { get; private set; } protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { Headers = request.Headers; + Content = request.Content; return Task.FromResult(new HttpResponseMessage()); } }