diff --git a/src/Microsoft.AspNetCore.Routing/Matching/DefaultEndpointSelector.cs b/src/Microsoft.AspNetCore.Routing/Matching/DefaultEndpointSelector.cs index 56779c75e0..e503e2f769 100644 --- a/src/Microsoft.AspNetCore.Routing/Matching/DefaultEndpointSelector.cs +++ b/src/Microsoft.AspNetCore.Routing/Matching/DefaultEndpointSelector.cs @@ -24,16 +24,30 @@ namespace Microsoft.AspNetCore.Routing.Matching _selectorPolicies = matcherPolicies.OrderBy(p => p.Order).OfType().ToArray(); } - public override Task SelectAsync( + public override async Task SelectAsync( HttpContext httpContext, EndpointFeature feature, CandidateSet candidateSet) { + var selectorPolicies = _selectorPolicies; for (var i = 0; i < _selectorPolicies.Length; i++) { - _selectorPolicies[i].Apply(httpContext, candidateSet); + await selectorPolicies[i].ApplyAsync(httpContext, feature, candidateSet); + if (feature.Endpoint != null) + { + // This is a short circuit, the selector chose an endpoint. + return; + } } + ProcessFinalCandidates(httpContext, feature, candidateSet); + } + + private static void ProcessFinalCandidates( + HttpContext httpContext, + EndpointFeature feature, + CandidateSet candidateSet) + { RouteEndpoint endpoint = null; RouteValueDictionary values = null; int? foundScore = null; @@ -76,8 +90,6 @@ namespace Microsoft.AspNetCore.Routing.Matching feature.Endpoint = endpoint; feature.RouteValues = values; } - - return Task.CompletedTask; } private static void ReportAmbiguity(CandidateSet candidates) diff --git a/src/Microsoft.AspNetCore.Routing/Matching/IEndpointSelectorPolicy.cs b/src/Microsoft.AspNetCore.Routing/Matching/IEndpointSelectorPolicy.cs index 7dfd04b7bd..1617822f03 100644 --- a/src/Microsoft.AspNetCore.Routing/Matching/IEndpointSelectorPolicy.cs +++ b/src/Microsoft.AspNetCore.Routing/Matching/IEndpointSelectorPolicy.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Http; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Routing.Matching { @@ -19,12 +20,21 @@ namespace Microsoft.AspNetCore.Routing.Matching /// /// The associated with the current request. /// + /// + /// The associated with the current request. + /// /// The . /// + /// /// Implementations of should implement this method /// and filter the set of candidates in the by setting /// to false where desired. + /// + /// + /// To signal an error condition, set to an + /// value that will produce the desired error when executed. + /// /// - void Apply(HttpContext httpContext, CandidateSet candidates); + Task ApplyAsync(HttpContext httpContext, EndpointFeature feature, CandidateSet candidates); } } diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Matching/DefaultEndpointSelectorTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Matching/DefaultEndpointSelectorTest.cs index c6ef7cd4f3..769ca3200c 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/Matching/DefaultEndpointSelectorTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/Matching/DefaultEndpointSelectorTest.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; @@ -184,10 +185,11 @@ test: /test3", ex.Message); var policy = new Mock(); policy .As() - .Setup(p => p.Apply(It.IsAny(), It.IsAny())) - .Callback((c, cs) => + .Setup(p => p.ApplyAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((c, f, cs) => { cs[1].IsValidCandidate = false; + return Task.CompletedTask; }); candidateSet[0].IsValidCandidate = false; @@ -204,6 +206,48 @@ test: /test3", ex.Message); Assert.Same(endpoints[2], feature.Endpoint); } + [Fact] + public async Task SelectAsync_RunsEndpointSelectorPolicies_CanShortCircuit() + { + // Arrange + var endpoints = new RouteEndpoint[] { CreateEndpoint("/test1"), CreateEndpoint("/test2"), CreateEndpoint("/test3"), }; + var scores = new int[] { 0, 0, 1 }; + var candidateSet = CreateCandidateSet(endpoints, scores); + + var policy1 = new Mock(); + policy1 + .As() + .Setup(p => p.ApplyAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns((c, f, cs) => + { + f.Endpoint = cs[0].Endpoint; + return Task.CompletedTask; + }); + + // This should never run, it's after policy1 which short circuits + var policy2 = new Mock(); + policy2 + .SetupGet(p => p.Order) + .Returns(1000); + policy2 + .As() + .Setup(p => p.ApplyAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Throws(new InvalidOperationException()); + + candidateSet[0].IsValidCandidate = false; + candidateSet[1].IsValidCandidate = true; + candidateSet[2].IsValidCandidate = true; + + var (httpContext, feature) = CreateContext(); + var selector = CreateSelector(policy1.Object, policy2.Object); + + // Act + await selector.SelectAsync(httpContext, feature, candidateSet); + + // Assert + Assert.Same(endpoints[0], feature.Endpoint); + } + private static (HttpContext httpContext, EndpointFeature feature) CreateContext() { var feature = new EndpointFeature();