diff --git a/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersDefaults.cs b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersDefaults.cs new file mode 100644 index 0000000000..959dc295d3 --- /dev/null +++ b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersDefaults.cs @@ -0,0 +1,42 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.HttpOverrides +{ + /// + /// Default values related to middleware + /// + /// + public static class ForwardedHeadersDefaults + { + /// + /// X-Forwarded-For + /// + public static string XForwardedForHeaderName { get; } = "X-Forwarded-For"; + + /// + /// X-Forwarded-Host + /// + public static string XForwardedHostHeaderName { get; } = "X-Forwarded-Host"; + + /// + /// X-Forwarded-Proto + /// + public static string XForwardedProtoHeaderName { get; } = "X-Forwarded-Proto"; + + /// + /// X-Original-For + /// + public static string XOriginalForHeaderName { get; } = "X-Original-For"; + + /// + /// X-Original-Host + /// + public static string XOriginalHostHeaderName { get; } = "X-Original-Host"; + + /// + /// X-Original-Proto + /// + public static string XOriginalProtoHeaderName { get; } = "X-Original-Proto"; + } +} diff --git a/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersMiddleware.cs b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersMiddleware.cs index 09cdfce9ca..2882dfdf8d 100644 --- a/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersMiddleware.cs +++ b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersMiddleware.cs @@ -17,13 +17,6 @@ namespace Microsoft.AspNetCore.HttpOverrides { public class ForwardedHeadersMiddleware { - private const string XForwardedForHeaderName = "X-Forwarded-For"; - private const string XForwardedHostHeaderName = "X-Forwarded-Host"; - private const string XForwardedProtoHeaderName = "X-Forwarded-Proto"; - private const string XOriginalForName = "X-Original-For"; - private const string XOriginalHostName = "X-Original-Host"; - private const string XOriginalProtoName = "X-Original-Proto"; - private readonly ForwardedHeadersOptions _options; private readonly RequestDelegate _next; private readonly ILogger _logger; @@ -43,11 +36,27 @@ namespace Microsoft.AspNetCore.HttpOverrides throw new ArgumentNullException(nameof(options)); } + // Make sure required options is not null or whitespace + EnsureOptionNotNullorWhitespace(options.Value.ForwardedForHeaderName, nameof(options.Value.ForwardedForHeaderName)); + EnsureOptionNotNullorWhitespace(options.Value.ForwardedHostHeaderName, nameof(options.Value.ForwardedHostHeaderName)); + EnsureOptionNotNullorWhitespace(options.Value.ForwardedProtoHeaderName, nameof(options.Value.ForwardedProtoHeaderName)); + EnsureOptionNotNullorWhitespace(options.Value.OriginalForHeaderName, nameof(options.Value.OriginalForHeaderName)); + EnsureOptionNotNullorWhitespace(options.Value.OriginalHostHeaderName, nameof(options.Value.OriginalHostHeaderName)); + EnsureOptionNotNullorWhitespace(options.Value.OriginalProtoHeaderName, nameof(options.Value.OriginalProtoHeaderName)); + _options = options.Value; _logger = loggerFactory.CreateLogger(); _next = next; } + private static void EnsureOptionNotNullorWhitespace(string value, string propertyName) + { + if (string.IsNullOrWhiteSpace(value)) + { + throw new ArgumentException($"options.{propertyName} is required", "options"); + } + } + public Task Invoke(HttpContext context) { ApplyForwarders(context); @@ -64,14 +73,14 @@ namespace Microsoft.AspNetCore.HttpOverrides if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedFor) == ForwardedHeaders.XForwardedFor) { checkFor = true; - forwardedFor = context.Request.Headers.GetCommaSeparatedValues(XForwardedForHeaderName); + forwardedFor = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedForHeaderName); entryCount = Math.Max(forwardedFor.Length, entryCount); } if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedProto) == ForwardedHeaders.XForwardedProto) { checkProto = true; - forwardedProto = context.Request.Headers.GetCommaSeparatedValues(XForwardedProtoHeaderName); + forwardedProto = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedProtoHeaderName); if (_options.RequireHeaderSymmetry && checkFor && forwardedFor.Length != forwardedProto.Length) { _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-For and X-Forwarded-Proto."); @@ -83,7 +92,7 @@ namespace Microsoft.AspNetCore.HttpOverrides if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedHost) == ForwardedHeaders.XForwardedHost) { checkHost = true; - forwardedHost = context.Request.Headers.GetCommaSeparatedValues(XForwardedHostHeaderName); + forwardedHost = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedHostHeaderName); if (_options.RequireHeaderSymmetry && ((checkFor && forwardedFor.Length != forwardedHost.Length) || (checkProto && forwardedProto.Length != forwardedHost.Length))) @@ -198,17 +207,17 @@ namespace Microsoft.AspNetCore.HttpOverrides if (connection.RemoteIpAddress != null) { // Save the original - request.Headers[XOriginalForName] = new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort).ToString(); + request.Headers[_options.OriginalForHeaderName] = new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort).ToString(); } if (forwardedFor.Length > entriesConsumed) { // Truncate the consumed header values - request.Headers[XForwardedForHeaderName] = forwardedFor.Take(forwardedFor.Length - entriesConsumed).ToArray(); + request.Headers[_options.ForwardedForHeaderName] = forwardedFor.Take(forwardedFor.Length - entriesConsumed).ToArray(); } else { // All values were consumed - request.Headers.Remove(XForwardedForHeaderName); + request.Headers.Remove(_options.ForwardedForHeaderName); } connection.RemoteIpAddress = currentValues.RemoteIpAndPort.Address; connection.RemotePort = currentValues.RemoteIpAndPort.Port; @@ -217,16 +226,16 @@ namespace Microsoft.AspNetCore.HttpOverrides if (checkProto && currentValues.Scheme != null) { // Save the original - request.Headers[XOriginalProtoName] = request.Scheme; + request.Headers[_options.OriginalProtoHeaderName] = request.Scheme; if (forwardedProto.Length > entriesConsumed) { // Truncate the consumed header values - request.Headers[XForwardedProtoHeaderName] = forwardedProto.Take(forwardedProto.Length - entriesConsumed).ToArray(); + request.Headers[_options.ForwardedProtoHeaderName] = forwardedProto.Take(forwardedProto.Length - entriesConsumed).ToArray(); } else { // All values were consumed - request.Headers.Remove(XForwardedProtoHeaderName); + request.Headers.Remove(_options.ForwardedProtoHeaderName); } request.Scheme = currentValues.Scheme; } @@ -234,16 +243,16 @@ namespace Microsoft.AspNetCore.HttpOverrides if (checkHost && currentValues.Host != null) { // Save the original - request.Headers[XOriginalHostName] = request.Host.ToString(); + request.Headers[_options.OriginalHostHeaderName] = request.Host.ToString(); if (forwardedHost.Length > entriesConsumed) { // Truncate the consumed header values - request.Headers[XForwardedHostHeaderName] = forwardedHost.Take(forwardedHost.Length - entriesConsumed).ToArray(); + request.Headers[_options.ForwardedHostHeaderName] = forwardedHost.Take(forwardedHost.Length - entriesConsumed).ToArray(); } else { // All values were consumed - request.Headers.Remove(XForwardedHostHeaderName); + request.Headers.Remove(_options.ForwardedHostHeaderName); } request.Host = HostString.FromUriComponent(currentValues.Host); } diff --git a/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersOptions.cs b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersOptions.cs index 602dd2ec1d..556c8acc3d 100644 --- a/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersOptions.cs +++ b/src/Microsoft.AspNetCore.HttpOverrides/ForwardedHeadersOptions.cs @@ -9,6 +9,42 @@ namespace Microsoft.AspNetCore.Builder { public class ForwardedHeadersOptions { + /// + /// Use this header instead of + /// + /// + public string ForwardedForHeaderName { get; set; } = ForwardedHeadersDefaults.XForwardedForHeaderName; + + /// + /// Use this header instead of + /// + /// + public string ForwardedHostHeaderName { get; set; } = ForwardedHeadersDefaults.XForwardedHostHeaderName; + + /// + /// Use this header instead of + /// + /// + public string ForwardedProtoHeaderName { get; set; } = ForwardedHeadersDefaults.XForwardedProtoHeaderName; + + /// + /// Use this header instead of + /// + /// + public string OriginalForHeaderName { get; set; } = ForwardedHeadersDefaults.XOriginalForHeaderName; + + /// + /// Use this header instead of + /// + /// + public string OriginalHostHeaderName { get; set; } = ForwardedHeadersDefaults.XOriginalHostHeaderName; + + /// + /// Use this header instead of + /// + /// + public string OriginalProtoHeaderName { get; set; } = ForwardedHeadersDefaults.XOriginalProtoHeaderName; + /// /// Identifies which forwarders should be processed. ///