Add HostPolicyMatcher (#6214)

This commit is contained in:
James Newton-King 2019-01-11 10:46:09 +13:00 committed by GitHub
parent 55ec35bb80
commit 90511e6039
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 1221 additions and 2 deletions

View File

@ -0,0 +1,43 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Builder
{
/// <summary>
/// Extension methods for adding routing metadata to endpoint instances using <see cref="IEndpointConventionBuilder"/>.
/// </summary>
public static class RoutingEndpointConventionBuilderExtensions
{
/// <summary>
/// Requires that endpoints match one of the specified hosts during routing.
/// </summary>
/// <param name="builder">The <see cref="IEndpointConventionBuilder"/> to add the metadata to.</param>
/// <param name="hosts">
/// The hosts used during routing.
/// Hosts should be Unicode rather than punycode, and may have a port.
/// An empty collection means any host will be accepted.
/// </param>
/// <returns>A reference to this instance after the operation has completed.</returns>
public static IEndpointConventionBuilder RequireHost(this IEndpointConventionBuilder builder, params string[] hosts)
{
if (builder == null)
{
throw new ArgumentNullException(nameof(builder));
}
if (hosts == null)
{
throw new ArgumentNullException(nameof(hosts));
}
builder.Add(endpointBuilder =>
{
endpointBuilder.Metadata.Add(new HostAttribute(hosts));
});
return builder;
}
}
}

View File

@ -87,6 +87,7 @@ namespace Microsoft.Extensions.DependencyInjection
//
services.TryAddSingleton<EndpointSelector, DefaultEndpointSelector>();
services.TryAddEnumerable(ServiceDescriptor.Singleton<MatcherPolicy, HttpMethodMatcherPolicy>());
services.TryAddEnumerable(ServiceDescriptor.Singleton<MatcherPolicy, HostMatcherPolicy>());
//
// Misc infrastructure

View File

@ -0,0 +1,67 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
namespace Microsoft.AspNetCore.Routing
{
/// <summary>
/// Attribute for providing host metdata that is used during routing.
/// </summary>
[DebuggerDisplay("{DebuggerToString(),nq}")]
[AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = false, Inherited = false)]
public sealed class HostAttribute : Attribute, IHostMetadata
{
/// <summary>
/// Initializes a new instance of the <see cref="HostAttribute" /> class.
/// </summary>
/// <param name="host">
/// The host used during routing.
/// Host should be Unicode rather than punycode, and may have a port.
/// </param>
public HostAttribute(string host) : this(new[] { host })
{
if (host == null)
{
throw new ArgumentNullException(nameof(host));
}
}
/// <summary>
/// Initializes a new instance of the <see cref="HostAttribute" /> class.
/// </summary>
/// <param name="hosts">
/// The hosts used during routing.
/// Hosts should be Unicode rather than punycode, and may have a port.
/// An empty collection means any host will be accepted.
/// </param>
public HostAttribute(params string[] hosts)
{
if (hosts == null)
{
throw new ArgumentNullException(nameof(hosts));
}
Hosts = hosts.ToArray();
}
/// <summary>
/// Returns a read-only collection of hosts used during routing.
/// Hosts will be Unicode rather than punycode, and may have a port.
/// An empty collection means any host will be accepted.
/// </summary>
public IReadOnlyList<string> Hosts { get; }
private string DebuggerToString()
{
var hostsDisplay = (Hosts.Count == 0)
? "*:*"
: string.Join(",", Hosts.Select(h => h.Contains(':') ? h : h + ":*"));
return $"Hosts: {hostsDisplay}";
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Routing
{
/// <summary>
/// Represents host metadata used during routing.
/// </summary>
public interface IHostMetadata
{
/// <summary>
/// Returns a read-only collection of hosts used during routing.
/// Hosts will be Unicode rather than punycode, and may have a port.
/// An empty collection means any host will be accepted.
/// </summary>
IReadOnlyList<string> Hosts { get; }
}
}

View File

@ -0,0 +1,366 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Routing.Matching
{
/// <summary>
/// An <see cref="MatcherPolicy"/> that implements filtering and selection by
/// the host header of a request.
/// </summary>
public sealed class HostMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy
{
// Run after HTTP methods, but before 'default'.
public override int Order { get; } = -100;
public IComparer<Endpoint> Comparer { get; } = new HostMetadataEndpointComparer();
public bool AppliesToEndpoints(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
return endpoints.Any(e =>
{
var hosts = e.Metadata.GetMetadata<IHostMetadata>()?.Hosts;
if (hosts == null || hosts.Count == 0)
{
return false;
}
foreach (var host in hosts)
{
// Don't run policy on endpoints that match everything
var key = CreateEdgeKey(host);
if (!key.MatchesAll)
{
return true;
}
}
return false;
});
}
private static EdgeKey CreateEdgeKey(string host)
{
if (host == null)
{
return EdgeKey.WildcardEdgeKey;
}
var hostParts = host.Split(':');
if (hostParts.Length == 1)
{
if (!string.IsNullOrEmpty(hostParts[0]))
{
return new EdgeKey(hostParts[0], null);
}
}
if (hostParts.Length == 2)
{
if (!string.IsNullOrEmpty(hostParts[0]))
{
if (int.TryParse(hostParts[1], out var port))
{
return new EdgeKey(hostParts[0], port);
}
else if (string.Equals(hostParts[1], "*", StringComparison.Ordinal))
{
return new EdgeKey(hostParts[0], null);
}
}
}
throw new InvalidOperationException($"Could not parse host: {host}");
}
public IReadOnlyList<PolicyNodeEdge> GetEdges(IReadOnlyList<Endpoint> endpoints)
{
if (endpoints == null)
{
throw new ArgumentNullException(nameof(endpoints));
}
// The algorithm here is designed to be preserve the order of the endpoints
// while also being relatively simple. Preserving order is important.
// First, build a dictionary of all of the hosts that are included
// at this node.
//
// For now we're just building up the set of keys. We don't add any endpoints
// to lists now because we don't want ordering problems.
var edges = new Dictionary<EdgeKey, List<Endpoint>>();
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var hosts = endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts.Select(h => CreateEdgeKey(h)).ToArray();
if (hosts == null || hosts.Length == 0)
{
hosts = new[] { EdgeKey.WildcardEdgeKey };
}
for (var j = 0; j < hosts.Length; j++)
{
var host = hosts[j];
if (!edges.ContainsKey(host))
{
edges.Add(host, new List<Endpoint>());
}
}
}
// Now in a second loop, add endpoints to these lists. We've enumerated all of
// the states, so we want to see which states this endpoint matches.
for (var i = 0; i < endpoints.Count; i++)
{
var endpoint = endpoints[i];
var endpointKeys = endpoint.Metadata.GetMetadata<IHostMetadata>()?.Hosts.Select(h => CreateEdgeKey(h)).ToArray() ?? Array.Empty<EdgeKey>();
if (endpointKeys.Length == 0)
{
// OK this means that this endpoint matches *all* hosts.
// So, loop and add it to all states.
foreach (var kvp in edges)
{
kvp.Value.Add(endpoint);
}
}
else
{
// OK this endpoint matches specific hosts
foreach (var kvp in edges)
{
// The edgeKey maps to a possible request header value
var edgeKey = kvp.Key;
for (var j = 0; j < endpointKeys.Length; j++)
{
var endpointKey = endpointKeys[j];
if (edgeKey.Equals(endpointKey))
{
kvp.Value.Add(endpoint);
break;
}
else if (edgeKey.HasHostWildcard && endpointKey.HasHostWildcard &&
edgeKey.Port == endpointKey.Port && edgeKey.MatchHost(endpointKey.Host))
{
kvp.Value.Add(endpoint);
break;
}
}
}
}
}
return edges
.Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value))
.ToArray();
}
public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList<PolicyJumpTableEdge> edges)
{
if (edges == null)
{
throw new ArgumentNullException(nameof(edges));
}
// Since our 'edges' can have wildcards, we do a sort based on how wildcard-ey they
// are then then execute them in linear order.
var ordered = edges
.Select(e => (host: (EdgeKey)e.State, destination: e.Destination))
.OrderBy(e => GetScore(e.host))
.ToArray();
return new HostPolicyJumpTable(exitDestination, ordered);
}
private int GetScore(in EdgeKey key)
{
// Higher score == lower priority.
if (key.MatchesHost && !key.HasHostWildcard && key.MatchesPort)
{
return 1; // Has host AND port, e.g. www.consoto.com:8080
}
else if (key.MatchesHost && !key.HasHostWildcard)
{
return 2; // Has host, e.g. www.consoto.com
}
else if (key.MatchesHost && key.MatchesPort)
{
return 3; // Has wildcard host AND port, e.g. *.consoto.com:8080
}
else if (key.MatchesHost)
{
return 4; // Has wildcard host, e.g. *.consoto.com
}
else if (key.MatchesPort)
{
return 5; // Has port, e.g. *:8080
}
else
{
return 6; // Has neither, e.g. *:* (or no metadata)
}
}
private class HostMetadataEndpointComparer : EndpointMetadataComparer<IHostMetadata>
{
protected override int CompareMetadata(IHostMetadata x, IHostMetadata y)
{
// Ignore the metadata if it has an empty list of hosts.
return base.CompareMetadata(
x?.Hosts.Count > 0 ? x : null,
y?.Hosts.Count > 0 ? y : null);
}
}
private class HostPolicyJumpTable : PolicyJumpTable
{
private (EdgeKey host, int destination)[] _destinations;
private int _exitDestination;
public HostPolicyJumpTable(int exitDestination, (EdgeKey host, int destination)[] destinations)
{
_exitDestination = exitDestination;
_destinations = destinations;
}
public override int GetDestination(HttpContext httpContext)
{
// HostString can allocate when accessing the host or port
// Store host and port locally and reuse
var requestHost = httpContext.Request.Host;
var host = requestHost.Host;
var port = ResolvePort(httpContext, requestHost);
var destinations = _destinations;
for (var i = 0; i < destinations.Length; i++)
{
var destination = destinations[i];
if ((!destination.host.MatchesPort || destination.host.Port == port) &&
destination.host.MatchHost(host))
{
return destination.destination;
}
}
return _exitDestination;
}
private static int? ResolvePort(HttpContext httpContext, HostString requestHost)
{
if (requestHost.Port != null)
{
return requestHost.Port;
}
else if (string.Equals("https", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
{
return 443;
}
else if (string.Equals("http", httpContext.Request.Scheme, StringComparison.OrdinalIgnoreCase))
{
return 80;
}
else
{
return null;
}
}
}
private readonly struct EdgeKey : IEquatable<EdgeKey>, IComparable<EdgeKey>, IComparable
{
private const string WildcardHost = "*";
internal static readonly EdgeKey WildcardEdgeKey = new EdgeKey(null, null);
public readonly int? Port;
public readonly string Host;
private readonly string _wildcardEndsWith;
public EdgeKey(string host, int? port)
{
Host = host ?? WildcardHost;
Port = port;
HasHostWildcard = Host.StartsWith("*.", StringComparison.Ordinal);
_wildcardEndsWith = HasHostWildcard ? Host.Substring(1) : null;
}
public bool HasHostWildcard { get; }
public bool MatchesHost => !string.Equals(Host, WildcardHost, StringComparison.Ordinal);
public bool MatchesPort => Port != null;
public bool MatchesAll => !MatchesHost && !MatchesPort;
public int CompareTo(EdgeKey other)
{
var result = Comparer<string>.Default.Compare(Host, other.Host);
if (result != 0)
{
return result;
}
return Comparer<int?>.Default.Compare(Port, other.Port);
}
public int CompareTo(object obj)
{
return CompareTo((EdgeKey)obj);
}
public bool Equals(EdgeKey other)
{
return string.Equals(Host, other.Host, StringComparison.Ordinal) && Port == other.Port;
}
public bool MatchHost(string host)
{
if (MatchesHost)
{
if (HasHostWildcard)
{
return host.EndsWith(_wildcardEndsWith, StringComparison.OrdinalIgnoreCase);
}
else
{
return string.Equals(host, Host, StringComparison.OrdinalIgnoreCase);
}
}
return true;
}
public override int GetHashCode()
{
return (Host?.GetHashCode() ?? 0) ^ (Port?.GetHashCode() ?? 0);
}
public override bool Equals(object obj)
{
if (obj is EdgeKey key)
{
return Equals(key);
}
return false;
}
public override string ToString()
{
return $"{Host}:{Port?.ToString() ?? "*"}";
}
}
}
}

View File

@ -0,0 +1,119 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using RoutingWebSite;
using Xunit;
namespace Microsoft.AspNetCore.Routing.FunctionalTests
{
public class HostMatchingTests : IClassFixture<RoutingTestFixture<UseEndpointRoutingStartup>>
{
private readonly RoutingTestFixture<UseEndpointRoutingStartup> _fixture;
public HostMatchingTests(RoutingTestFixture<UseEndpointRoutingStartup> fixture)
{
_fixture = fixture;
}
private HttpClient CreateClient(string baseAddress)
{
var client = _fixture.CreateClient(baseAddress);
return client;
}
[Theory]
[InlineData("http://localhost")]
[InlineData("http://localhost:5001")]
public async Task Get_CatchAll(string baseAddress)
{
// Arrange
var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard");
// Act
var client = CreateClient(baseAddress);
var response = await client.SendAsync(request);
var responseContent = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("*:*", responseContent);
}
[Theory]
[InlineData("http://9000.0.0.1")]
[InlineData("http://9000.0.0.1:8888")]
public async Task Get_MatchWildcardDomain(string baseAddress)
{
// Arrange
var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard");
// Act
var client = CreateClient(baseAddress);
var response = await client.SendAsync(request);
var responseContent = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("*.0.0.1:*", responseContent);
}
[Theory]
[InlineData("http://127.0.0.1")]
[InlineData("http://127.0.0.1:8888")]
public async Task Get_MatchDomain(string baseAddress)
{
// Arrange
var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard");
// Act
var client = CreateClient(baseAddress);
var response = await client.SendAsync(request);
var responseContent = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("127.0.0.1:*", responseContent);
}
[Theory]
[InlineData("http://9000.0.0.1:5000")]
[InlineData("http://9000.0.0.1:5001")]
public async Task Get_MatchWildcardDomainAndPort(string baseAddress)
{
// Arrange
var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard");
// Act
var client = CreateClient(baseAddress);
var response = await client.SendAsync(request);
var responseContent = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("*.0.0.1:5000,*.0.0.1:5001", responseContent);
}
[Theory]
[InlineData("http://www.contoso.com")]
[InlineData("http://contoso.com")]
public async Task Get_MatchWildcardDomainAndSubdomain(string baseAddress)
{
// Arrange
var request = new HttpRequestMessage(HttpMethod.Get, "api/DomainWildcard");
// Act
var client = CreateClient(baseAddress);
var response = await client.SendAsync(request);
var responseContent = await response.Content.ReadAsStringAsync();
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
Assert.Equal("contoso.com:*,*.contoso.com:*", responseContent);
}
}
}

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>netcoreapp3.0</TargetFramework>

View File

@ -25,6 +25,14 @@ namespace Microsoft.AspNetCore.Routing.FunctionalTests
public HttpClient Client { get; }
public HttpClient CreateClient(string baseAddress)
{
var client = _server.CreateClient();
client.BaseAddress = new Uri(baseAddress);
return client;
}
public void Dispose()
{
Client.Dispose();

View File

@ -0,0 +1,339 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Routing.Patterns;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.AspNetCore.Routing.Matching
{
// End-to-end tests for the host matching functionality
public class HostMatcherPolicyIntegrationTest
{
[Fact]
public async Task Match_Host()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithPort()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com:8080");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_Host_Unicode()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "æon.contoso.com", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "æon.contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithPort_IncorrectPort()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com:1111");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertNotMatch(context, httpContext);
}
[Fact]
public async Task Match_HostWithPort_IncorrectHost()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:8080", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "www.contoso.com:8080");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertNotMatch(context, httpContext);
}
[Fact]
public async Task Match_HostWithWildcard()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*.contoso.com:8080", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "æon.contoso.com:8080");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithWildcard_Unicode()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*.contoso.com:8080", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "www.contoso.com:8080");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_Host_CaseInsensitive()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "Contoso.COM", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithPort_InferHttpPort()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:80", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com", "http");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithPort_InferHttpsPort()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:443", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com", "https");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_HostWithPort_NoHostHeader()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "contoso.com:443", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", null, "https");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertNotMatch(context, httpContext);
}
[Fact]
public async Task Match_Port_NoHostHeader_InferHttpsPort()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*:443", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", null, "https");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_NoMetadata_MatchesAnyHost()
{
// Arrange
var endpoint = CreateEndpoint("/hello");
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_EmptyHostList_MatchesAnyHost()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_WildcardHost_MatchesAnyHost()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
[Fact]
public async Task Match_WildcardHostAndWildcardPort_MatchesAnyHost()
{
// Arrange
var endpoint = CreateEndpoint("/hello", hosts: new string[] { "*:*", });
var matcher = CreateMatcher(endpoint);
var (httpContext, context) = CreateContext("/hello", "contoso.com");
// Act
await matcher.MatchAsync(httpContext, context);
// Assert
MatcherAssert.AssertMatch(context, httpContext, endpoint);
}
private static Matcher CreateMatcher(params RouteEndpoint[] endpoints)
{
var services = new ServiceCollection()
.AddOptions()
.AddLogging()
.AddRouting()
.BuildServiceProvider();
var builder = services.GetRequiredService<DfaMatcherBuilder>();
for (var i = 0; i < endpoints.Length; i++)
{
builder.AddEndpoint(endpoints[i]);
}
return builder.Build();
}
internal static (HttpContext httpContext, EndpointSelectorContext context) CreateContext(
string path,
string host,
string scheme = null)
{
var httpContext = new DefaultHttpContext();
if (host != null)
{
httpContext.Request.Host = new HostString(host);
}
httpContext.Request.Path = path;
httpContext.Request.Scheme = scheme;
var context = new EndpointSelectorContext();
httpContext.Features.Set<IEndpointFeature>(context);
httpContext.Features.Set<IRouteValuesFeature>(context);
return (httpContext, context);
}
internal static RouteEndpoint CreateEndpoint(
string template,
object defaults = null,
object constraints = null,
int order = 0,
string[] hosts = null)
{
var metadata = new List<object>();
if (hosts != null)
{
metadata.Add(new HostAttribute(hosts ?? Array.Empty<string>()));
}
var displayName = "endpoint: " + template + " " + string.Join(", ", hosts ?? new[] { "*:*" });
return new RouteEndpoint(
TestConstants.EmptyRequestDelegate,
RoutePatternFactory.Parse(template, defaults, constraints),
order,
new EndpointMetadataCollection(metadata),
displayName);
}
internal (Matcher matcher, RouteEndpoint endpoint) CreateMatcher(string template)
{
var endpoint = CreateEndpoint(template);
return (CreateMatcher(endpoint), endpoint);
}
}
}

View File

@ -0,0 +1,176 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.Routing.Matching;
using Microsoft.AspNetCore.Routing.Patterns;
using Xunit;
namespace Microsoft.AspNetCore.Routing.Matching
{
public class HostMatcherPolicyTest
{
[Fact]
public void AppliesToEndpoints_EndpointWithoutMetadata_ReturnsFalse()
{
// Arrange
var endpoints = new[] { CreateEndpoint("/", null), };
var policy = CreatePolicy();
// Act
var result = policy.AppliesToEndpoints(endpoints);
// Assert
Assert.False(result);
}
[Fact]
public void AppliesToEndpoints_EndpointWithoutHosts_ReturnsFalse()
{
// Arrange
var endpoints = new[]
{
CreateEndpoint("/", new HostAttribute(Array.Empty<string>())),
};
var policy = CreatePolicy();
// Act
var result = policy.AppliesToEndpoints(endpoints);
// Assert
Assert.False(result);
}
[Fact]
public void AppliesToEndpoints_EndpointHasHosts_ReturnsTrue()
{
// Arrange
var endpoints = new[]
{
CreateEndpoint("/", new HostAttribute(Array.Empty<string>())),
CreateEndpoint("/", new HostAttribute(new[] { "localhost", })),
};
var policy = CreatePolicy();
// Act
var result = policy.AppliesToEndpoints(endpoints);
// Assert
Assert.True(result);
}
[Theory]
[InlineData(":")]
[InlineData(":80")]
[InlineData("80:")]
[InlineData("")]
[InlineData("::")]
[InlineData("*:test")]
public void AppliesToEndpoints_InvalidHosts(string host)
{
// Arrange
var endpoints = new[] { CreateEndpoint("/", new HostAttribute(new[] { host })), };
var policy = CreatePolicy();
// Act & Assert
Assert.Throws<InvalidOperationException>(() =>
{
policy.AppliesToEndpoints(endpoints);
});
}
[Fact]
public void GetEdges_GroupsByHost()
{
// Arrange
var endpoints = new[]
{
CreateEndpoint("/", new HostAttribute(new[] { "*:5000", "*:5001", })),
CreateEndpoint("/", new HostAttribute(Array.Empty<string>())),
CreateEndpoint("/", hostMetadata: null),
CreateEndpoint("/", new HostAttribute("*.contoso.com:*")),
CreateEndpoint("/", new HostAttribute("*.sub.contoso.com:*")),
CreateEndpoint("/", new HostAttribute("www.contoso.com:*")),
CreateEndpoint("/", new HostAttribute("www.contoso.com:5000")),
CreateEndpoint("/", new HostAttribute("*:*")),
};
var policy = CreatePolicy();
// Act
var edges = policy.GetEdges(endpoints);
var data = edges.OrderBy(e => e.State).ToList();
// Assert
Assert.Collection(
data,
e =>
{
Assert.Equal("*:*", e.State.ToString());
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[7], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("*:5000", e.State.ToString());
Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("*:5001", e.State.ToString());
Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("*.contoso.com:*", e.State.ToString());
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("*.sub.contoso.com:*", e.State.ToString());
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[4], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("www.contoso.com:*", e.State.ToString());
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[5], }, e.Endpoints.ToArray());
},
e =>
{
Assert.Equal("www.contoso.com:5000", e.State.ToString());
Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[6], }, e.Endpoints.ToArray());
});
}
private static RouteEndpoint CreateEndpoint(string template, IHostMetadata hostMetadata)
{
var metadata = new List<object>();
if (hostMetadata != null)
{
metadata.Add(hostMetadata);
}
return new RouteEndpoint(
(context) => Task.CompletedTask,
RoutePatternFactory.Parse(template),
0,
new EndpointMetadataCollection(metadata),
$"test: {template} - {string.Join(", ", hostMetadata?.Hosts ?? Array.Empty<string>())}");
}
private static HostMatcherPolicy CreatePolicy()
{
return new HostMatcherPolicy();
}
}
}

View File

@ -356,6 +356,7 @@ namespace Microsoft.AspNetCore.Routing.Matching
return (httpContext, context);
}
internal static RouteEndpoint CreateEndpoint(
string template,
object defaults = null,

View File

@ -0,0 +1,47 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Routing.Patterns;
using Xunit;
namespace Microsoft.AspNetCore.Routing
{
public class RoutingEndpointConventionBuilderExtensionsTests
{
[Fact]
public void RequireHost_HostNames()
{
// Arrange
var builder = new TestEndpointConventionBuilder();
// Act
builder.RequireHost("contoso.com:8080");
// Assert
var convention = Assert.Single(builder.Conventions);
var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0);
convention(endpointModel);
var hostMetadata = Assert.IsType<HostAttribute>(Assert.Single(endpointModel.Metadata));
Assert.Equal("contoso.com:8080", hostMetadata.Hosts.Single());
}
private class TestEndpointConventionBuilder : IEndpointConventionBuilder
{
public IList<Action<EndpointBuilder>> Conventions { get; } = new List<Action<EndpointBuilder>>();
public void Add(Action<EndpointBuilder> convention)
{
Conventions.Add(convention);
}
}
}
}

View File

@ -8,6 +8,7 @@ using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Endpoints;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.Routing;
@ -112,6 +113,12 @@ namespace RoutingWebSite
"Link: " + linkGenerator.GetPathByRouteValues(httpContext, "WithDoubleAsteriskCatchAll", new { }));
},
new RouteNameMetadata(routeName: "WithDoubleAsteriskCatchAll"));
MapHostEndpoint(routes);
MapHostEndpoint(routes, "*.0.0.1");
MapHostEndpoint(routes, "127.0.0.1");
MapHostEndpoint(routes, "*.0.0.1:5000", "*.0.0.1:5001");
MapHostEndpoint(routes, "contoso.com:*", "*.contoso.com:*");
});
app.Map("/Branch1", branch => SetupBranch(branch, "Branch1"));
@ -124,6 +131,31 @@ namespace RoutingWebSite
app.UseEndpoint();
}
private IEndpointConventionBuilder MapHostEndpoint(IEndpointRouteBuilder routes, params string[] hosts)
{
var hostsDisplay = (hosts == null || hosts.Length == 0)
? "*:*"
: string.Join(",", hosts.Select(h => h.Contains(':') ? h : h + ":*"));
var conventionBuilder = routes.MapGet(
"api/DomainWildcard",
httpContext =>
{
var response = httpContext.Response;
response.StatusCode = 200;
response.ContentType = "text/plain";
return response.WriteAsync(hostsDisplay);
});
conventionBuilder.Add(endpointBuilder =>
{
endpointBuilder.Metadata.Add(new HostAttribute(hosts));
endpointBuilder.DisplayName += " HOST: " + hostsDisplay;
});
return conventionBuilder;
}
private void SetupBranch(IApplicationBuilder app, string name)
{
app.UseRouting(routes =>