diff --git a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherAzureBenchmark.cs b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherAzureBenchmark.cs index 8a97e2a9b2..23e3948aad 100644 --- a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherAzureBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherAzureBenchmark.cs @@ -13,7 +13,6 @@ namespace Microsoft.AspNetCore.Routing.Matchers private BarebonesMatcher _baseline; private Matcher _dfa; - private Matcher _tree; private int[] _samples; private EndpointFeature _feature; @@ -31,7 +30,6 @@ namespace Microsoft.AspNetCore.Routing.Matchers _baseline = (BarebonesMatcher)SetupMatcher(new BarebonesMatcherBuilder()); _dfa = SetupMatcher(CreateDfaMatcherBuilder()); - _tree = SetupMatcher(new TreeRouterMatcherBuilder()); _feature = new EndpointFeature(); } @@ -61,22 +59,5 @@ namespace Microsoft.AspNetCore.Routing.Matchers Validate(httpContext, Endpoints[sample], feature.Endpoint); } } - - [Benchmark(OperationsPerInvoke = SampleCount)] - public async Task LegacyTreeRouter() - { - var feature = _feature; - for (var i = 0; i < SampleCount; i++) - { - var sample = _samples[i]; - var httpContext = Requests[sample]; - - // This is required to make the legacy router implementation work with global routing. - httpContext.Features.Set(feature); - - await _tree.MatchAsync(httpContext, feature); - Validate(httpContext, Endpoints[sample], feature.Endpoint); - } - } } } \ No newline at end of file diff --git a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherBenchmarkBase.cs b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherBenchmarkBase.cs index bf6a18297b..612e4db9f3 100644 --- a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherBenchmarkBase.cs +++ b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherBenchmarkBase.cs @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.Routing.Matchers var metadata = new List(); if (httpMethod != null) { - metadata.Add(new HttpMethodEndpointConstraint(new string[] { httpMethod, })); + metadata.Add(new HttpMethodMetadata(new string[] { httpMethod, })); } return new MatcherEndpoint( diff --git a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherGithubBenchmark.cs b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherGithubBenchmark.cs index b47aa4a592..2cdb76f11f 100644 --- a/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherGithubBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.Routing.Performance/Matchers/MatcherGithubBenchmark.cs @@ -12,7 +12,6 @@ namespace Microsoft.AspNetCore.Routing.Matchers { private BarebonesMatcher _baseline; private Matcher _dfa; - private Matcher _tree; private EndpointFeature _feature; @@ -25,7 +24,6 @@ namespace Microsoft.AspNetCore.Routing.Matchers _baseline = (BarebonesMatcher)SetupMatcher(new BarebonesMatcherBuilder()); _dfa = SetupMatcher(CreateDfaMatcherBuilder()); - _tree = SetupMatcher(new TreeRouterMatcherBuilder()); _feature = new EndpointFeature(); } @@ -53,21 +51,5 @@ namespace Microsoft.AspNetCore.Routing.Matchers Validate(httpContext, Endpoints[i], feature.Endpoint); } } - - [Benchmark(OperationsPerInvoke = EndpointCount)] - public async Task LegacyTreeRouter() - { - var feature = _feature; - for (var i = 0; i < EndpointCount; i++) - { - var httpContext = Requests[i]; - - // This is required to make the legacy router implementation work with global routing. - httpContext.Features.Set(feature); - - await _tree.MatchAsync(httpContext, feature); - Validate(httpContext, Endpoints[i], feature.Endpoint); - } - } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/CompositeEndpointDataSource.cs b/src/Microsoft.AspNetCore.Routing/CompositeEndpointDataSource.cs index a806808d20..27629fd1a3 100644 --- a/src/Microsoft.AspNetCore.Routing/CompositeEndpointDataSource.cs +++ b/src/Microsoft.AspNetCore.Routing/CompositeEndpointDataSource.cs @@ -9,6 +9,7 @@ using System.Text; using System.Threading; using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.AspNetCore.Routing.Matchers; +using Microsoft.AspNetCore.Routing.Metadata; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Routing @@ -134,15 +135,16 @@ namespace Microsoft.AspNetCore.Routing sb.Append(", Order:"); sb.Append(matcherEndpoint.Order); - var httpEndpointConstraints = matcherEndpoint.Metadata.GetOrderedMetadata() - .OfType(); - foreach (var constraint in httpEndpointConstraints) + var httpMethodMetadata = matcherEndpoint.Metadata.GetMetadata(); + if (httpMethodMetadata != null) { - sb.Append(", Http Methods: "); - sb.Append(string.Join(", ", constraint.HttpMethods)); - sb.Append(", Constraint Order:"); - sb.Append(constraint.Order); + foreach (var httpMethod in httpMethodMetadata.HttpMethods) + { + sb.Append(", Http Methods: "); + sb.Append(string.Join(", ", httpMethod)); + } } + sb.AppendLine(); } else diff --git a/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs index fac20603a5..0e335d7d46 100644 --- a/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs @@ -79,7 +79,7 @@ namespace Microsoft.Extensions.DependencyInjection // services.TryAddSingleton(); services.TryAddSingleton(); - services.TryAddEnumerable(ServiceDescriptor.Singleton()); + services.TryAddEnumerable(ServiceDescriptor.Singleton()); // Will be cached by the EndpointSelector services.TryAddEnumerable( diff --git a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs index 7631e10b22..296986d523 100644 --- a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs +++ b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs @@ -43,6 +43,8 @@ namespace Microsoft.AspNetCore.Routing.EndpointConstraints IReadOnlyList IHttpMethodMetadata.HttpMethods => _httpMethods; + bool IHttpMethodMetadata.AcceptCorsPreflight => false; + public virtual bool Accept(EndpointConstraintContext context) { if (context == null) diff --git a/src/Microsoft.AspNetCore.Routing/Matchers/HttpMethodMatcherPolicy.cs b/src/Microsoft.AspNetCore.Routing/Matchers/HttpMethodMatcherPolicy.cs index 17802e61c0..e593713bb6 100644 --- a/src/Microsoft.AspNetCore.Routing/Matchers/HttpMethodMatcherPolicy.cs +++ b/src/Microsoft.AspNetCore.Routing/Matchers/HttpMethodMatcherPolicy.cs @@ -8,11 +8,18 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing.Metadata; using Microsoft.AspNetCore.Routing.Patterns; +using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Routing.Matchers { - public sealed class HttpMethodEndpointSelectorPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy + public sealed class HttpMethodMatcherPolicy : MatcherPolicy, IEndpointComparerPolicy, INodeBuilderPolicy { + // Used in tests + internal static readonly string OriginHeader = "Origin"; + internal static readonly string AccessControlRequestMethod = "Access-Control-Request-Method"; + internal static readonly string PreflightHttpMethod = "OPTIONS"; + // Used in tests internal const string Http405EndpointDisplayName = "405 HTTP Method Not Supported"; @@ -47,38 +54,94 @@ namespace Microsoft.AspNetCore.Routing.Matchers { // The algorithm here is designed to be preserve the order of the endpoints // while also being relatively simple. Preserving order is important. - var allHttpMethods = endpoints - .SelectMany(e => GetHttpMethods(e)) - .Distinct() - .OrderBy(m => m); // Sort for testability - - var dictionary = new Dictionary>(); - foreach (var httpMethod in allHttpMethods) - { - dictionary.Add(httpMethod, new List()); - } - - dictionary.Add(AnyMethod, new List()); + // First, build a dictionary of all possible HTTP method/CORS combinations + // that exist in this list of endpoints. + // + // For now we're just building up the set of keys. We don't add any endpoints + // to lists now because we don't want ordering problems. + var allHttpMethods = new HashSet(StringComparer.OrdinalIgnoreCase); + var edges = new Dictionary>(); for (var i = 0; i < endpoints.Count; i++) { var endpoint = endpoints[i]; + var (httpMethods, acceptCorsPreFlight) = GetHttpMethods(endpoint); - var httpMethods = GetHttpMethods(endpoint); + // If the action doesn't list HTTP methods then it supports all methods. + // In this phase we use a sentinel value to represent the *other* HTTP method + // a state that represents any HTTP method that doesn't have a match. if (httpMethods.Count == 0) { - // This endpoint suports all HTTP methods. - foreach (var kvp in dictionary) - { - kvp.Value.Add(endpoint); - } - - continue; + httpMethods = new[] { AnyMethod, }; } for (var j = 0; j < httpMethods.Count; j++) { - dictionary[httpMethods[j]].Add(endpoint); + // An endpoint that allows CORS reqests will match both CORS and non-CORS + // so we model it as both. + var httpMethod = httpMethods[j]; + var key = new EdgeKey(httpMethod, acceptCorsPreFlight); + if (!edges.ContainsKey(key)) + { + edges.Add(key, new List()); + } + + // An endpoint that allows CORS reqests will match both CORS and non-CORS + // so we model it as both. + if (acceptCorsPreFlight) + { + key = new EdgeKey(httpMethod, false); + if (!edges.ContainsKey(key)) + { + edges.Add(key, new List()); + } + } + + // Also if it's not the *any* method key, then track it. + if (!string.Equals(AnyMethod, httpMethod, StringComparison.OrdinalIgnoreCase)) + { + allHttpMethods.Add(httpMethod); + } + } + } + + // Now in a second loop, add endpoints to these lists. We've enumerated all of + // the states, so we want to see which states this endpoint matches. + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + var (httpMethods, acceptCorsPreFlight) = GetHttpMethods(endpoint); + + if (httpMethods.Count == 0) + { + // OK this means that this endpoint matches *all* HTTP methods. + // So, loop and add it to all states. + foreach (var kvp in edges) + { + if (acceptCorsPreFlight || !kvp.Key.IsCorsPreflightRequest) + { + kvp.Value.Add(endpoint); + } + } + } + else + { + // OK this endpoint matches specific methods. + for (var j = 0; j < httpMethods.Count; j++) + { + var httpMethod = httpMethods[j]; + var key = new EdgeKey(httpMethod, acceptCorsPreFlight); + + edges[key].Add(endpoint); + + // An endpoint that allows CORS reqests will match both CORS and non-CORS + // so we model it as both. + if (acceptCorsPreFlight) + { + key = new EdgeKey(httpMethod, false); + edges[key].Add(endpoint); + } + } } } @@ -95,42 +158,67 @@ namespace Microsoft.AspNetCore.Routing.Matchers // // This will make 405 much more likely in API-focused applications, and somewhat // unlikely in a traditional MVC application. That's good. - if (dictionary[AnyMethod].Count == 0) + // + // We don't bother returning a 405 when the CORS preflight method doesn't exist. + // The developer calling the API will see it as a CORS error, which is fine because + // there isn't an endpoint to check for a CORS policy. + if (!edges.TryGetValue(new EdgeKey(AnyMethod, false), out var matches)) { - dictionary[AnyMethod].Add(CreateRejectionEndpoint(allHttpMethods)); + // Methods sorted for testability. + var endpoint = CreateRejectionEndpoint(allHttpMethods.OrderBy(m => m)); + matches = new List() { endpoint, }; + edges[new EdgeKey(AnyMethod, false)] = matches; } - var edges = new List(); - foreach (var kvp in dictionary) - { - edges.Add(new PolicyNodeEdge(kvp.Key, kvp.Value)); - } + return edges + .Select(kvp => new PolicyNodeEdge(kvp.Key, kvp.Value)) + .ToArray(); - return edges; - - IReadOnlyList GetHttpMethods(Endpoint e) + (IReadOnlyList httpMethods, bool acceptCorsPreflight) GetHttpMethods(Endpoint e) { - return e.Metadata.GetMetadata()?.HttpMethods ?? Array.Empty(); + var metadata = e.Metadata.GetMetadata(); + return metadata == null ? (Array.Empty(), false) : (metadata.HttpMethods, metadata.AcceptCorsPreflight); } } public PolicyJumpTable BuildJumpTable(int exitDestination, IReadOnlyList edges) { - var dictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); + var destinations = new Dictionary(StringComparer.OrdinalIgnoreCase); + var corsPreflightDestinations = new Dictionary(StringComparer.OrdinalIgnoreCase); for (var i = 0; i < edges.Count; i++) { - // We create this data, so it's safe to cast it to a string. - dictionary.Add((string)edges[i].State, edges[i].Destination); + // We create this data, so it's safe to cast it. + var key = (EdgeKey)edges[i].State; + if (key.IsCorsPreflightRequest) + { + corsPreflightDestinations.Add(key.HttpMethod, edges[i].Destination); + } + else + { + destinations.Add(key.HttpMethod, edges[i].Destination); + } } - if (dictionary.TryGetValue(AnyMethod, out var matchesAnyVerb)) + int corsPreflightExitDestination = exitDestination; + if (corsPreflightDestinations.TryGetValue(AnyMethod, out var matchesAnyVerb)) + { + // If we have endpoints that match any HTTP method, use that as the exit. + corsPreflightExitDestination = matchesAnyVerb; + corsPreflightDestinations.Remove(AnyMethod); + } + + if (destinations.TryGetValue(AnyMethod, out matchesAnyVerb)) { // If we have endpoints that match any HTTP method, use that as the exit. exitDestination = matchesAnyVerb; - dictionary.Remove(AnyMethod); + destinations.Remove(AnyMethod); } - return new DictionaryPolicyJumpTable(exitDestination, dictionary); + return new HttpMethodPolicyJumpTable( + exitDestination, + destinations, + corsPreflightExitDestination, + corsPreflightDestinations); } private Endpoint CreateRejectionEndpoint(IEnumerable httpMethods) @@ -150,21 +238,46 @@ namespace Microsoft.AspNetCore.Routing.Matchers Http405EndpointDisplayName); } - private class DictionaryPolicyJumpTable : PolicyJumpTable + private class HttpMethodPolicyJumpTable : PolicyJumpTable { private readonly int _exitDestination; private readonly Dictionary _destinations; + private readonly int _corsPreflightExitDestination; + private readonly Dictionary _corsPreflightDestinations; - public DictionaryPolicyJumpTable(int exitDestination, Dictionary destinations) + private readonly bool _supportsCorsPreflight; + + public HttpMethodPolicyJumpTable( + int exitDestination, + Dictionary destinations, + int corsPreflightExitDestination, + Dictionary corsPreflightDestinations) { _exitDestination = exitDestination; _destinations = destinations; + _corsPreflightExitDestination = corsPreflightExitDestination; + _corsPreflightDestinations = corsPreflightDestinations; + + _supportsCorsPreflight = _corsPreflightDestinations.Count > 0; } public override int GetDestination(HttpContext httpContext) { + int destination; + var httpMethod = httpContext.Request.Method; - return _destinations.TryGetValue(httpMethod, out var destination) ? destination : _exitDestination; + if (_supportsCorsPreflight && + string.Equals(httpMethod, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) && + httpContext.Request.Headers.ContainsKey(OriginHeader) && + httpContext.Request.Headers.TryGetValue(AccessControlRequestMethod, out var accessControlRequestMethod) && + !StringValues.IsNullOrEmpty(accessControlRequestMethod)) + { + return _corsPreflightDestinations.TryGetValue(accessControlRequestMethod, out destination) + ? destination + : _corsPreflightExitDestination; + } + + return _destinations.TryGetValue(httpMethod, out destination) ? destination : _exitDestination; } } @@ -178,5 +291,58 @@ namespace Microsoft.AspNetCore.Routing.Matchers y?.HttpMethods.Count > 0 ? y : null); } } + + internal readonly struct EdgeKey : IEquatable, IComparable, IComparable + { + // Note that in contrast with the metadata, the edge represents a possible state change + // rather than a list of what's allowed. We represent CORS and non-CORS requests as separate + // states. + public readonly bool IsCorsPreflightRequest; + public readonly string HttpMethod; + + public EdgeKey(string httpMethod, bool isCorsPreflightRequest) + { + HttpMethod = httpMethod; + IsCorsPreflightRequest = isCorsPreflightRequest; + } + + // These are comparable so they can be sorted in tests. + public int CompareTo(EdgeKey other) + { + var compare = HttpMethod.CompareTo(other.HttpMethod); + if (compare != 0) + { + return compare; + } + + return IsCorsPreflightRequest.CompareTo(other.IsCorsPreflightRequest); + } + + public int CompareTo(object obj) + { + return CompareTo((EdgeKey)obj); + } + + public bool Equals(EdgeKey other) + { + return + IsCorsPreflightRequest == other.IsCorsPreflightRequest && + string.Equals(HttpMethod, other.HttpMethod, StringComparison.OrdinalIgnoreCase); + } + + public override bool Equals(object obj) + { + var other = obj as EdgeKey?; + return other == null ? false : Equals(other.Value); + } + + public override int GetHashCode() + { + var hash = new HashCodeCombiner(); + hash.Add(IsCorsPreflightRequest); + hash.Add(HttpMethod, StringComparer.Ordinal); + return hash; + } + } } } diff --git a/src/Microsoft.AspNetCore.Routing/Metadata/HttpMethodMetadata.cs b/src/Microsoft.AspNetCore.Routing/Metadata/HttpMethodMetadata.cs index 90cfd057a3..2dc98a5a12 100644 --- a/src/Microsoft.AspNetCore.Routing/Metadata/HttpMethodMetadata.cs +++ b/src/Microsoft.AspNetCore.Routing/Metadata/HttpMethodMetadata.cs @@ -3,13 +3,20 @@ using System; using System.Collections.Generic; +using System.Diagnostics; using System.Linq; namespace Microsoft.AspNetCore.Routing.Metadata { + [DebuggerDisplay("{DebuggerToString,nq}")] public sealed class HttpMethodMetadata : IHttpMethodMetadata { public HttpMethodMetadata(IEnumerable httpMethods) + : this(httpMethods, acceptCorsPreflight: false) + { + } + + public HttpMethodMetadata(IEnumerable httpMethods, bool acceptCorsPreflight) { if (httpMethods == null) { @@ -17,8 +24,16 @@ namespace Microsoft.AspNetCore.Routing.Metadata } HttpMethods = httpMethods.ToArray(); + AcceptCorsPreflight = acceptCorsPreflight; } + public bool AcceptCorsPreflight { get; } + public IReadOnlyList HttpMethods { get; } + + private string DebuggerToString() + { + return $"HttpMethods: {string.Join(",", HttpMethods)} - Cors: {AcceptCorsPreflight}"; + } } } diff --git a/src/Microsoft.AspNetCore.Routing/Metadata/IHttpMethodMetadata.cs b/src/Microsoft.AspNetCore.Routing/Metadata/IHttpMethodMetadata.cs index a77a617f13..9d6447d1e0 100644 --- a/src/Microsoft.AspNetCore.Routing/Metadata/IHttpMethodMetadata.cs +++ b/src/Microsoft.AspNetCore.Routing/Metadata/IHttpMethodMetadata.cs @@ -7,6 +7,8 @@ namespace Microsoft.AspNetCore.Routing.Metadata { public interface IHttpMethodMetadata { + bool AcceptCorsPreflight { get; } + IReadOnlyList HttpMethods { get; } } } diff --git a/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointConstraintEndpointSelectorTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointConstraintEndpointSelectorTest.cs index fe5d9bb3b2..72bf124785 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointConstraintEndpointSelectorTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointConstraintEndpointSelectorTest.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Collections.ObjectModel; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs deleted file mode 100644 index 53b38fea1a..0000000000 --- a/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs +++ /dev/null @@ -1,91 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Routing.EndpointConstraints; -using Microsoft.AspNetCore.Routing.TestObjects; -using Microsoft.Extensions.Primitives; -using System; -using System.Collections.Generic; -using System.Text; -using Xunit; - -namespace Microsoft.AspNetCore.Routing.Internal -{ - public class HttpMethodEndpointConstraintTest - { - public static TheoryData AcceptCaseInsensitiveData = - new TheoryData, string> - { - { new string[] { "get", "Get", "GET", "GEt"}, "gEt" }, - { new string[] { "POST", "PoSt", "GEt"}, "GET" }, - { new string[] { "get" }, "get" }, - { new string[] { "post" }, "POST" }, - { new string[] { "gEt" }, "get" }, - { new string[] { "get", "PoST" }, "pOSt" } - }; - - [Theory] - [MemberData(nameof(AcceptCaseInsensitiveData))] - public void HttpMethodEndpointConstraint_IgnoresPreflightRequests(IEnumerable httpMethods, string accessControlMethod) - { - // Arrange - var constraint = new HttpMethodEndpointConstraint(httpMethods); - var context = CreateEndpointConstraintContext(constraint); - context.HttpContext = CreateHttpContext("oPtIoNs", accessControlMethod); - - // Act - var result = constraint.Accept(context); - - // Assert - Assert.False(result, "Request should have been rejected."); - } - - [Theory] - [MemberData(nameof(AcceptCaseInsensitiveData))] - public void HttpMethodEndpointConstraint_Accept_CaseInsensitive(IEnumerable httpMethods, string expectedMethod) - { - // Arrange - var constraint = new HttpMethodEndpointConstraint(httpMethods); - var context = CreateEndpointConstraintContext(constraint); - context.HttpContext = CreateHttpContext(expectedMethod); - - // Act - var result = constraint.Accept(context); - - // Assert - Assert.True(result, "Request should have been accepted."); - } - - private static EndpointConstraintContext CreateEndpointConstraintContext(HttpMethodEndpointConstraint constraint) - { - var context = new EndpointConstraintContext(); - - var endpointSelectorCandidate = new EndpointSelectorCandidate( - new TestEndpoint(EndpointMetadataCollection.Empty, string.Empty), - 0, - new RouteValueDictionary(), - new List { constraint }); - - context.Candidates = new List { endpointSelectorCandidate }; - context.CurrentCandidate = context.Candidates[0]; - - return context; - } - - private static HttpContext CreateHttpContext(string requestedMethod, string accessControlMethod = null) - { - var httpContext = new DefaultHttpContext(); - - httpContext.Request.Method = requestedMethod; - - if (accessControlMethod != null) - { - httpContext.Request.Headers.Add("Origin", StringValues.Empty); - httpContext.Request.Headers.Add("Access-Control-Request-Method", accessControlMethod); - } - - return httpContext; - } - } -} diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/DfaMatcherTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/DfaMatcherTest.cs index ee6d64cfc1..4b9f4a721d 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/DfaMatcherTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/DfaMatcherTest.cs @@ -4,9 +4,9 @@ using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.Extensions.DependencyInjection; +using Moq; using Xunit; namespace Microsoft.AspNetCore.Routing.Matchers @@ -26,13 +26,19 @@ namespace Microsoft.AspNetCore.Routing.Matchers template); } - private Matcher CreateDfaMatcher(EndpointDataSource dataSource) + private Matcher CreateDfaMatcher(EndpointDataSource dataSource, EndpointSelector endpointSelector = null) { - var services = new ServiceCollection() + var serviceCollection = new ServiceCollection() .AddLogging() .AddOptions() - .AddRouting() - .BuildServiceProvider(); + .AddRouting(); + + if (endpointSelector != null) + { + serviceCollection.AddSingleton(endpointSelector); + } + + var services = serviceCollection.BuildServiceProvider(); var factory = services.GetRequiredService(); return Assert.IsType(factory.CreateMatcher(dataSource)); @@ -115,22 +121,39 @@ namespace Microsoft.AspNetCore.Routing.Matchers public async Task MatchAsync_MultipleMatches_EndpointSelectorCalled() { // Arrange - var endpointWithoutConstraint = CreateEndpoint("/Teams", 0); - var endpointWithConstraint = CreateEndpoint( - "/Teams", - 0, - metadata: new EndpointMetadataCollection(new object[] { new HttpMethodEndpointConstraint(new[] { "POST" }) })); + var endpoint1 = CreateEndpoint("/Teams", 0); + var endpoint2 = CreateEndpoint("/Teams", 1); + + var endpointSelector = new Mock(); + endpointSelector + .Setup(s => s.SelectAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Callback((c, f, cs) => + { + Assert.Equal(2, cs.Count); + + Assert.Same(endpoint1, cs[0].Endpoint); + Assert.True(cs[0].IsValidCandidate); + Assert.Equal(0, cs[0].Score); + Assert.Empty(cs[0].Values); + + Assert.Same(endpoint2, cs[1].Endpoint); + Assert.True(cs[1].IsValidCandidate); + Assert.Equal(1, cs[1].Score); + Assert.Empty(cs[1].Values); + + f.Endpoint = endpoint2; + }) + .Returns(Task.CompletedTask); var endpointDataSource = new DefaultEndpointDataSource(new List { - endpointWithoutConstraint, - endpointWithConstraint + endpoint1, + endpoint2 }); - var matcher = CreateDfaMatcher(endpointDataSource); + var matcher = CreateDfaMatcher(endpointDataSource, endpointSelector.Object); var httpContext = new DefaultHttpContext(); - httpContext.Request.Method = "POST"; httpContext.Request.Path = "/Teams"; var endpointFeature = new EndpointFeature(); @@ -139,7 +162,7 @@ namespace Microsoft.AspNetCore.Routing.Matchers await matcher.MatchAsync(httpContext, endpointFeature); // Assert - Assert.Equal(endpointWithConstraint, endpointFeature.Endpoint); + Assert.Equal(endpoint2, endpointFeature.Endpoint); } } } diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyIntegrationTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyIntegrationTest.cs index 26bf2173de..2f48332732 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyIntegrationTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyIntegrationTest.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Routing.Metadata; using Microsoft.AspNetCore.Routing.Patterns; using Microsoft.Extensions.DependencyInjection; using Xunit; +using static Microsoft.AspNetCore.Routing.Matchers.HttpMethodMatcherPolicy; namespace Microsoft.AspNetCore.Routing.Matchers { @@ -31,6 +32,56 @@ namespace Microsoft.AspNetCore.Routing.Matchers MatcherAssert.AssertMatch(feature, endpoint); } + [Fact] + public async Task Match_HttpMethod_CORS() + { + // Arrange + var endpoint = CreateEndpoint("/hello", httpMethods: new string[] { "GET", }, acceptCorsPreflight: true); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET"); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertMatch(feature, endpoint); + } + + [Fact] + public async Task Match_HttpMethod_CORS_Preflight() + { + // Arrange + var endpoint = CreateEndpoint("/hello", httpMethods: new string[] { "GET", }, acceptCorsPreflight: true); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertMatch(feature, endpoint); + } + + + [Fact] // Nothing here supports OPTIONS, so it goes to a 405. + public async Task NotMatch_HttpMethod_CORS_Preflight() + { + // Arrange + var endpoint = CreateEndpoint("/hello", httpMethods: new string[] { "GET", }, acceptCorsPreflight: false); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + Assert.NotSame(endpoint, feature.Endpoint); + Assert.Same(HttpMethodMatcherPolicy.Http405EndpointDisplayName, feature.Endpoint.DisplayName); + } + [Fact] public async Task Match_HttpMethod_CaseInsensitive() { @@ -47,6 +98,22 @@ namespace Microsoft.AspNetCore.Routing.Matchers MatcherAssert.AssertMatch(feature, endpoint); } + [Fact] + public async Task Match_HttpMethod_CaseInsensitive_CORS_Preflight() + { + // Arrange + var endpoint = CreateEndpoint("/hello", httpMethods: new string[] { "GeT", }, acceptCorsPreflight: true); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertMatch(feature, endpoint); + } + [Fact] public async Task Match_NoMetadata_MatchesAnyHttpMethod() { @@ -63,6 +130,38 @@ namespace Microsoft.AspNetCore.Routing.Matchers MatcherAssert.AssertMatch(feature, endpoint); } + [Fact] + public async Task Match_NoMetadata_MatchesAnyHttpMethod_CORS_Preflight() + { + // Arrange + var endpoint = CreateEndpoint("/hello", acceptCorsPreflight: true); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertMatch(feature, endpoint); + } + + [Fact] // This matches because the endpoint accepts OPTIONS + public async Task Match_NoMetadata_MatchesAnyHttpMethod_CORS_Preflight_DoesNotSupportPreflight() + { + // Arrange + var endpoint = CreateEndpoint("/hello", acceptCorsPreflight: false); + + var matcher = CreateMatcher(endpoint); + var (httpContext, feature) = CreateContext("/hello", "GET", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertMatch(feature, endpoint); + } + [Fact] public async Task Match_EmptyMethodList_MatchesAnyHttpMethod() { @@ -96,7 +195,7 @@ namespace Microsoft.AspNetCore.Routing.Matchers Assert.NotSame(endpoint1, feature.Endpoint); Assert.NotSame(endpoint2, feature.Endpoint); - Assert.Same(HttpMethodEndpointSelectorPolicy.Http405EndpointDisplayName, feature.Endpoint.DisplayName); + Assert.Same(HttpMethodMatcherPolicy.Http405EndpointDisplayName, feature.Endpoint.DisplayName); // Invoke the endpoint await feature.Invoker((c) => Task.CompletedTask)(httpContext); @@ -104,6 +203,23 @@ namespace Microsoft.AspNetCore.Routing.Matchers Assert.Equal("DELETE, GET, PUT", httpContext.Response.Headers["Allow"]); } + [Fact] // When all of the candidates handles specific verbs, use a 405 endpoint + public async Task NotMatch_HttpMethod_CORS_DoesNotReturn405() + { + // Arrange + var endpoint1 = CreateEndpoint("/hello", httpMethods: new string[] { "GET", "PUT" }, acceptCorsPreflight: true); + var endpoint2 = CreateEndpoint("/hello", httpMethods: new string[] { "DELETE" }); + + var matcher = CreateMatcher(endpoint1, endpoint2); + var (httpContext, feature) = CreateContext("/hello", "POST", corsPreflight: true); + + // Act + await matcher.MatchAsync(httpContext, feature); + + // Assert + MatcherAssert.AssertNotMatch(feature); + } + [Fact] // When one of the candidates handles all verbs, dont use a 405 endpoint public async Task NotMatch_HttpMethod_WithAllMethodEndpoint_DoesNotReturn405() { @@ -189,12 +305,21 @@ namespace Microsoft.AspNetCore.Routing.Matchers return builder.Build(); } - internal static (HttpContext httpContext, IEndpointFeature feature) CreateContext(string path, string httpMethod) + internal static (HttpContext httpContext, IEndpointFeature feature) CreateContext( + string path, + string httpMethod, + bool corsPreflight = false) { var httpContext = new DefaultHttpContext(); - httpContext.Request.Method = httpMethod; + httpContext.Request.Method = corsPreflight ? PreflightHttpMethod : httpMethod; httpContext.Request.Path = path; + if (corsPreflight) + { + httpContext.Request.Headers[OriginHeader] = "example.com"; + httpContext.Request.Headers[AccessControlRequestMethod] = httpMethod; + } + var feature = new EndpointFeature(); httpContext.Features.Set(feature); @@ -205,12 +330,13 @@ namespace Microsoft.AspNetCore.Routing.Matchers object defaults = null, object constraints = null, int order = 0, - string[] httpMethods = null) + string[] httpMethods = null, + bool acceptCorsPreflight = false) { var metadata = new List(); if (httpMethods != null) { - metadata.Add(new HttpMethodMetadata(httpMethods)); + metadata.Add(new HttpMethodMetadata(httpMethods ?? Array.Empty(), acceptCorsPreflight)); } var displayName = "endpoint: " + template + " " + string.Join(", ", httpMethods ?? new[] { "(any)" }); diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyTest.cs index d4e3149852..9c977341a3 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/HttpMethodMatcherPolicyTest.cs @@ -4,10 +4,10 @@ using System; using System.Collections.Generic; using System.Linq; -using System.Text; using Microsoft.AspNetCore.Routing.Metadata; using Microsoft.AspNetCore.Routing.Patterns; using Xunit; +using static Microsoft.AspNetCore.Routing.Matchers.HttpMethodMatcherPolicy; namespace Microsoft.AspNetCore.Routing.Matchers { @@ -32,7 +32,10 @@ namespace Microsoft.AspNetCore.Routing.Matchers public void AppliesToNode_EndpointWithoutHttpMethods_ReturnsFalse() { // Arrange - var endpoints = new[] { CreateEndpoint("/", Array.Empty()), }; + var endpoints = new[] + { + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty())), + }; var policy = CreatePolicy(); @@ -47,7 +50,11 @@ namespace Microsoft.AspNetCore.Routing.Matchers public void AppliesToNode_EndpointHasHttpMethods_ReturnsTrue() { // Arrange - var endpoints = new[] { CreateEndpoint("/", Array.Empty()), CreateEndpoint("/", new[] { "GET", })}; + var endpoints = new[] + { + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty())), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", })), + }; var policy = CreatePolicy(); @@ -66,11 +73,11 @@ namespace Microsoft.AspNetCore.Routing.Matchers { // These are arrange in an order that we won't actually see in a product scenario. It's done // this way so we can verify that ordering is preserved by GetEdges. - CreateEndpoint("/", new[] { "GET", }), - CreateEndpoint("/", Array.Empty()), - CreateEndpoint("/", new[] { "GET", "PUT", "POST" }), - CreateEndpoint("/", new[] { "PUT", "POST" }), - CreateEndpoint("/", Array.Empty()), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", })), + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty())), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", "PUT", "POST" })), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "PUT", "POST" })), + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty())), }; var policy = CreatePolicy(); @@ -83,26 +90,91 @@ namespace Microsoft.AspNetCore.Routing.Matchers edges.OrderBy(e => e.State), e => { - Assert.Equal(HttpMethodEndpointSelectorPolicy.AnyMethod, e.State); + Assert.Equal(new EdgeKey(AnyMethod, isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[1], endpoints[4], }, e.Endpoints.ToArray()); }, e => { - Assert.Equal("GET", e.State); + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); }, e => { - Assert.Equal("POST", e.State); + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray()); }, e => { - Assert.Equal("PUT", e.State); + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray()); }); } + [Fact] + public void GetEdges_GroupsByHttpMethod_Cors() + { + // Arrange + var endpoints = new[] + { + // These are arrange in an order that we won't actually see in a product scenario. It's done + // this way so we can verify that ordering is preserved by GetEdges. + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", })), + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty())), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", "PUT", "POST" }, acceptCorsPreflight: true)), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "PUT", "POST" })), + CreateEndpoint("/", new HttpMethodMetadata(Array.Empty(), acceptCorsPreflight: true)), + }; + + var policy = CreatePolicy(); + + // Act + var edges = policy.GetEdges(endpoints); + + // Assert + Assert.Collection( + edges.OrderBy(e => e.State), + e => + { + Assert.Equal(new EdgeKey(AnyMethod, isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[1], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey(AnyMethod, isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[0], endpoints[1], endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[1], endpoints[2], endpoints[3], endpoints[4], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[2], endpoints[4], }, e.Endpoints.ToArray()); + }); + } + [Fact] // See explanation in GetEdges for how this case is different public void GetEdges_GroupsByHttpMethod_CreatesHttp405Endpoint() { @@ -111,9 +183,9 @@ namespace Microsoft.AspNetCore.Routing.Matchers { // These are arrange in an order that we won't actually see in a product scenario. It's done // this way so we can verify that ordering is preserved by GetEdges. - CreateEndpoint("/", new[] { "GET", }), - CreateEndpoint("/", new[] { "GET", "PUT", "POST" }), - CreateEndpoint("/", new[] { "PUT", "POST" }), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", })), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", "PUT", "POST" })), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "PUT", "POST" })), }; var policy = CreatePolicy(); @@ -126,32 +198,91 @@ namespace Microsoft.AspNetCore.Routing.Matchers edges.OrderBy(e => e.State), e => { - Assert.Equal(HttpMethodEndpointSelectorPolicy.AnyMethod, e.State); - Assert.Equal(HttpMethodEndpointSelectorPolicy.Http405EndpointDisplayName, e.Endpoints.Single().DisplayName); + Assert.Equal(new EdgeKey(AnyMethod, isCorsPreflightRequest: false), e.State); + Assert.Equal(Http405EndpointDisplayName, e.Endpoints.Single().DisplayName); }, e => { - Assert.Equal("GET", e.State); + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[0], endpoints[1], }, e.Endpoints.ToArray()); }, e => { - Assert.Equal("POST", e.State); + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); }, e => { - Assert.Equal("PUT", e.State); + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: false), e.State); Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); }); + + } + + [Fact] // See explanation in GetEdges for how this case is different + public void GetEdges_GroupsByHttpMethod_CreatesHttp405Endpoint_CORS() + { + // Arrange + var endpoints = new[] + { + // These are arrange in an order that we won't actually see in a product scenario. It's done + // this way so we can verify that ordering is preserved by GetEdges. + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", })), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "GET", "PUT", "POST" }, acceptCorsPreflight: true)), + CreateEndpoint("/", new HttpMethodMetadata(new[] { "PUT", "POST" })), + }; + + var policy = CreatePolicy(); + + // Act + var edges = policy.GetEdges(endpoints); + + // Assert + Assert.Collection( + edges.OrderBy(e => e.State), + e => + { + Assert.Equal(new EdgeKey(AnyMethod, isCorsPreflightRequest: false), e.State); + Assert.Equal(Http405EndpointDisplayName, e.Endpoints.Single().DisplayName); + }, + e => + { + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[0], endpoints[1], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("GET", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[1], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("POST", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[1], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: false), e.State); + Assert.Equal(new[] { endpoints[1], endpoints[2], }, e.Endpoints.ToArray()); + }, + e => + { + Assert.Equal(new EdgeKey("PUT", isCorsPreflightRequest: true), e.State); + Assert.Equal(new[] { endpoints[1], }, e.Endpoints.ToArray()); + }); } - private static MatcherEndpoint CreateEndpoint(string template, string[] httpMethods) + private static MatcherEndpoint CreateEndpoint(string template, HttpMethodMetadata httpMethodMetadata) { var metadata = new List(); - if (httpMethods != null) + if (httpMethodMetadata != null) { - metadata.Add(new HttpMethodMetadata(httpMethods)); + metadata.Add(httpMethodMetadata); } return new MatcherEndpoint( @@ -163,9 +294,9 @@ namespace Microsoft.AspNetCore.Routing.Matchers $"test: {template}"); } - private static HttpMethodEndpointSelectorPolicy CreatePolicy() + private static HttpMethodMatcherPolicy CreatePolicy() { - return new HttpMethodEndpointSelectorPolicy(); + return new HttpMethodMatcherPolicy(); } } }