// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Linq; using System.Net; using System.Runtime.CompilerServices; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.HttpOverrides.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.HttpOverrides { public class ForwardedHeadersMiddleware { private static readonly bool[] HostCharValidity = new bool[127]; private static readonly bool[] SchemeCharValidity = new bool[123]; private readonly ForwardedHeadersOptions _options; private readonly RequestDelegate _next; private readonly ILogger _logger; static ForwardedHeadersMiddleware() { // RFC 3986 scheme = ALPHA * (ALPHA / DIGIT / "+" / "-" / ".") SchemeCharValidity['+'] = true; SchemeCharValidity['-'] = true; SchemeCharValidity['.'] = true; // Host Matches Http.Sys and Kestrel // Host Matches RFC 3986 except "*" / "+" / "," / ";" / "=" and "%" HEXDIG HEXDIG which are not allowed by Http.Sys HostCharValidity['!'] = true; HostCharValidity['$'] = true; HostCharValidity['&'] = true; HostCharValidity['\''] = true; HostCharValidity['('] = true; HostCharValidity[')'] = true; HostCharValidity['-'] = true; HostCharValidity['.'] = true; HostCharValidity['_'] = true; HostCharValidity['~'] = true; for (var ch = '0'; ch <= '9'; ch++) { SchemeCharValidity[ch] = true; HostCharValidity[ch] = true; } for (var ch = 'A'; ch <= 'Z'; ch++) { SchemeCharValidity[ch] = true; HostCharValidity[ch] = true; } for (var ch = 'a'; ch <= 'z'; ch++) { SchemeCharValidity[ch] = true; HostCharValidity[ch] = true; } } public ForwardedHeadersMiddleware(RequestDelegate next, ILoggerFactory loggerFactory, IOptions options) { if (next == null) { throw new ArgumentNullException(nameof(next)); } if (loggerFactory == null) { throw new ArgumentNullException(nameof(loggerFactory)); } if (options == null) { throw new ArgumentNullException(nameof(options)); } // Make sure required options is not null or whitespace EnsureOptionNotNullorWhitespace(options.Value.ForwardedForHeaderName, nameof(options.Value.ForwardedForHeaderName)); EnsureOptionNotNullorWhitespace(options.Value.ForwardedHostHeaderName, nameof(options.Value.ForwardedHostHeaderName)); EnsureOptionNotNullorWhitespace(options.Value.ForwardedProtoHeaderName, nameof(options.Value.ForwardedProtoHeaderName)); EnsureOptionNotNullorWhitespace(options.Value.OriginalForHeaderName, nameof(options.Value.OriginalForHeaderName)); EnsureOptionNotNullorWhitespace(options.Value.OriginalHostHeaderName, nameof(options.Value.OriginalHostHeaderName)); EnsureOptionNotNullorWhitespace(options.Value.OriginalProtoHeaderName, nameof(options.Value.OriginalProtoHeaderName)); _options = options.Value; _logger = loggerFactory.CreateLogger(); _next = next; } private static void EnsureOptionNotNullorWhitespace(string value, string propertyName) { if (string.IsNullOrWhiteSpace(value)) { throw new ArgumentException($"options.{propertyName} is required", "options"); } } public Task Invoke(HttpContext context) { ApplyForwarders(context); return _next(context); } public void ApplyForwarders(HttpContext context) { // Gather expected headers. Enabled headers must have the same number of entries. string[] forwardedFor = null, forwardedProto = null, forwardedHost = null; bool checkFor = false, checkProto = false, checkHost = false; int entryCount = 0; if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedFor) == ForwardedHeaders.XForwardedFor) { checkFor = true; forwardedFor = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedForHeaderName); entryCount = Math.Max(forwardedFor.Length, entryCount); } if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedProto) == ForwardedHeaders.XForwardedProto) { checkProto = true; forwardedProto = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedProtoHeaderName); if (_options.RequireHeaderSymmetry && checkFor && forwardedFor.Length != forwardedProto.Length) { _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-For and X-Forwarded-Proto."); return; } entryCount = Math.Max(forwardedProto.Length, entryCount); } if ((_options.ForwardedHeaders & ForwardedHeaders.XForwardedHost) == ForwardedHeaders.XForwardedHost) { checkHost = true; forwardedHost = context.Request.Headers.GetCommaSeparatedValues(_options.ForwardedHostHeaderName); if (_options.RequireHeaderSymmetry && ((checkFor && forwardedFor.Length != forwardedHost.Length) || (checkProto && forwardedProto.Length != forwardedHost.Length))) { _logger.LogWarning(1, "Parameter count mismatch between X-Forwarded-Host and X-Forwarded-For or X-Forwarded-Proto."); return; } entryCount = Math.Max(forwardedHost.Length, entryCount); } // Apply ForwardLimit, if any if (_options.ForwardLimit.HasValue && entryCount > _options.ForwardLimit) { entryCount = _options.ForwardLimit.Value; } // Group the data together. var sets = new SetOfForwarders[entryCount]; for (int i = 0; i < sets.Length; i++) { // They get processed in reverse order, right to left. var set = new SetOfForwarders(); if (checkFor && i < forwardedFor.Length) { set.IpAndPortText = forwardedFor[forwardedFor.Length - i - 1]; } if (checkProto && i < forwardedProto.Length) { set.Scheme = forwardedProto[forwardedProto.Length - i - 1]; } if (checkHost && i < forwardedHost.Length) { set.Host = forwardedHost[forwardedHost.Length - i - 1]; } sets[i] = set; } // Gather initial values var connection = context.Connection; var request = context.Request; var currentValues = new SetOfForwarders() { RemoteIpAndPort = connection.RemoteIpAddress != null ? new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort) : null, // Host and Scheme initial values are never inspected, no need to set them here. }; var checkKnownIps = _options.KnownNetworks.Count > 0 || _options.KnownProxies.Count > 0; bool applyChanges = false; int entriesConsumed = 0; for ( ; entriesConsumed < sets.Length; entriesConsumed++) { var set = sets[entriesConsumed]; if (checkFor) { // For the first instance, allow remoteIp to be null for servers that don't support it natively. if (currentValues.RemoteIpAndPort != null && checkKnownIps && !CheckKnownAddress(currentValues.RemoteIpAndPort.Address)) { // Stop at the first unknown remote IP, but still apply changes processed so far. _logger.LogDebug(1, $"Unknown proxy: {currentValues.RemoteIpAndPort}"); break; } IPEndPoint parsedEndPoint; if (IPEndPointParser.TryParse(set.IpAndPortText, out parsedEndPoint)) { applyChanges = true; set.RemoteIpAndPort = parsedEndPoint; currentValues.IpAndPortText = set.IpAndPortText; currentValues.RemoteIpAndPort = set.RemoteIpAndPort; } else if (!string.IsNullOrEmpty(set.IpAndPortText)) { // Stop at the first unparsable IP, but still apply changes processed so far. _logger.LogDebug(1, $"Unparsable IP: {set.IpAndPortText}"); break; } else if (_options.RequireHeaderSymmetry) { _logger.LogWarning(2, $"Missing forwarded IPAddress."); return; } } if (checkProto) { if (!string.IsNullOrEmpty(set.Scheme) && TryValidateScheme(set.Scheme)) { applyChanges = true; currentValues.Scheme = set.Scheme; } else if (_options.RequireHeaderSymmetry) { _logger.LogWarning(3, $"Forwarded scheme is not present, this is required by {nameof(_options.RequireHeaderSymmetry)}"); return; } } if (checkHost) { if (!string.IsNullOrEmpty(set.Host) && TryValidateHost(set.Host)) { applyChanges = true; currentValues.Host = set.Host; } else if (_options.RequireHeaderSymmetry) { _logger.LogWarning(4, $"Incorrect number of x-forwarded-proto header values, see {nameof(_options.RequireHeaderSymmetry)}."); return; } } } if (applyChanges) { if (checkFor && currentValues.RemoteIpAndPort != null) { if (connection.RemoteIpAddress != null) { // Save the original request.Headers[_options.OriginalForHeaderName] = new IPEndPoint(connection.RemoteIpAddress, connection.RemotePort).ToString(); } if (forwardedFor.Length > entriesConsumed) { // Truncate the consumed header values request.Headers[_options.ForwardedForHeaderName] = forwardedFor.Take(forwardedFor.Length - entriesConsumed).ToArray(); } else { // All values were consumed request.Headers.Remove(_options.ForwardedForHeaderName); } connection.RemoteIpAddress = currentValues.RemoteIpAndPort.Address; connection.RemotePort = currentValues.RemoteIpAndPort.Port; } if (checkProto && currentValues.Scheme != null) { // Save the original request.Headers[_options.OriginalProtoHeaderName] = request.Scheme; if (forwardedProto.Length > entriesConsumed) { // Truncate the consumed header values request.Headers[_options.ForwardedProtoHeaderName] = forwardedProto.Take(forwardedProto.Length - entriesConsumed).ToArray(); } else { // All values were consumed request.Headers.Remove(_options.ForwardedProtoHeaderName); } request.Scheme = currentValues.Scheme; } if (checkHost && currentValues.Host != null) { // Save the original request.Headers[_options.OriginalHostHeaderName] = request.Host.ToString(); if (forwardedHost.Length > entriesConsumed) { // Truncate the consumed header values request.Headers[_options.ForwardedHostHeaderName] = forwardedHost.Take(forwardedHost.Length - entriesConsumed).ToArray(); } else { // All values were consumed request.Headers.Remove(_options.ForwardedHostHeaderName); } request.Host = HostString.FromUriComponent(currentValues.Host); } } } private bool CheckKnownAddress(IPAddress address) { if (_options.KnownProxies.Contains(address)) { return true; } foreach (var network in _options.KnownNetworks) { if (network.Contains(address)) { return true; } } return false; } private struct SetOfForwarders { public string IpAndPortText; public IPEndPoint RemoteIpAndPort; public string Host; public string Scheme; } // Empty was checked for by the caller [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryValidateScheme(string scheme) { for (var i = 0; i < scheme.Length; i++) { if (!IsValidSchemeChar(scheme[i])) { return false; } } return true; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool IsValidSchemeChar(char ch) { return ch < SchemeCharValidity.Length && SchemeCharValidity[ch]; } // Empty was checked for by the caller [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryValidateHost(string host) { if (host[0] == '[') { return TryValidateIPv6Host(host); } if (host[0] == ':') { // Only a port return false; } var i = 0; for (; i < host.Length; i++) { if (!IsValidHostChar(host[i])) { break; } } return TryValidateHostPort(host, i); } [MethodImpl(MethodImplOptions.AggressiveInlining)] private static bool IsValidHostChar(char ch) { return ch < HostCharValidity.Length && HostCharValidity[ch]; } // The lead '[' was already checked [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryValidateIPv6Host(string hostText) { for (var i = 1; i < hostText.Length; i++) { var ch = hostText[i]; if (ch == ']') { // [::1] is the shortest valid IPv6 host if (i < 4) { return false; } return TryValidateHostPort(hostText, i + 1); } if (!IsHex(ch) && ch != ':' && ch != '.') { return false; } } // Must contain a ']' return false; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool TryValidateHostPort(string hostText, int offset) { if (offset == hostText.Length) { // No port return true; } if (hostText[offset] != ':' || hostText.Length == offset + 1) { // Must have at least one number after the colon if present. return false; } for (var i = offset + 1; i < hostText.Length; i++) { if (!IsNumeric(hostText[i])) { return false; } } return true; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool IsNumeric(char ch) { return '0' <= ch && ch <= '9'; } [MethodImpl(MethodImplOptions.AggressiveInlining)] private bool IsHex(char ch) { return IsNumeric(ch) || ('a' <= ch && ch <= 'f') || ('A' <= ch && ch <= 'F'); } } }