diff --git a/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs index d1f5d96f5f..2bb82e0933 100644 --- a/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.AspNetCore.Routing.Matchers; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; @@ -33,6 +34,16 @@ namespace Microsoft.Extensions.DependencyInjection // services.TryAddSingleton(); + // + // Endpoint Selection + // + services.TryAddSingleton(); + services.TryAddSingleton(); + + // Will be cached by the EndpointSelector + services.TryAddEnumerable( + ServiceDescriptor.Transient()); + return services; } diff --git a/src/Microsoft.AspNetCore.Routing/DispatcherMiddleware.cs b/src/Microsoft.AspNetCore.Routing/DispatcherMiddleware.cs index 68328c49ec..cccebfe7b1 100644 --- a/src/Microsoft.AspNetCore.Routing/DispatcherMiddleware.cs +++ b/src/Microsoft.AspNetCore.Routing/DispatcherMiddleware.cs @@ -15,14 +15,14 @@ namespace Microsoft.AspNetCore.Routing { private readonly MatcherFactory _matcherFactory; private readonly ILogger _logger; - private readonly IOptions _options; + private readonly CompositeEndpointDataSource _endpointDataSource; private readonly RequestDelegate _next; private Task _initializationTask; public DispatcherMiddleware( MatcherFactory matcherFactory, - IOptions options, + CompositeEndpointDataSource endpointDataSource, ILogger logger, RequestDelegate next) { @@ -31,9 +31,9 @@ namespace Microsoft.AspNetCore.Routing throw new ArgumentNullException(nameof(matcherFactory)); } - if (options == null) + if (endpointDataSource == null) { - throw new ArgumentNullException(nameof(options)); + throw new ArgumentNullException(nameof(endpointDataSource)); } if (logger == null) @@ -47,7 +47,7 @@ namespace Microsoft.AspNetCore.Routing } _matcherFactory = matcherFactory; - _options = options; + _endpointDataSource = endpointDataSource; _logger = logger; _next = next; } @@ -94,8 +94,7 @@ namespace Microsoft.AspNetCore.Routing null) == null) { // This thread won the race, do the initialization. - var dataSource = new CompositeEndpointDataSource(_options.Value.DataSources); - var matcher = _matcherFactory.CreateMatcher(dataSource); + var matcher = _matcherFactory.CreateMatcher(_endpointDataSource); initializationTask.SetResult(matcher); } diff --git a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointConstraintCache.cs b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointConstraintCache.cs new file mode 100644 index 0000000000..0728dcc9b6 --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointConstraintCache.cs @@ -0,0 +1,193 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; + +namespace Microsoft.AspNetCore.Routing.EndpointConstraints +{ + internal class EndpointConstraintCache + { + private readonly CompositeEndpointDataSource _dataSource; + private readonly IEndpointConstraintProvider[] _endpointConstraintProviders; + + private volatile InnerCache _currentCache; + + public EndpointConstraintCache( + CompositeEndpointDataSource dataSource, + IEnumerable endpointConstraintProviders) + { + _dataSource = dataSource; + _endpointConstraintProviders = endpointConstraintProviders.OrderBy(item => item.Order).ToArray(); + } + + private InnerCache CurrentCache + { + get + { + var current = _currentCache; + var endpointDescriptors = _dataSource.Endpoints; + + if (current == null) + { + current = new InnerCache(); + _currentCache = current; + } + + return current; + } + } + + public IReadOnlyList GetEndpointConstraints(HttpContext httpContext, Endpoint endpoint) + { + var cache = CurrentCache; + + if (cache.Entries.TryGetValue(endpoint, out var entry)) + { + return GetEndpointConstraintsFromEntry(entry, httpContext, endpoint); + } + + if (endpoint.Metadata == null || endpoint.Metadata.Count == 0) + { + return null; + } + + var items = endpoint.Metadata + .OfType() + .Select(m => new EndpointConstraintItem(m)) + .ToList(); + + ExecuteProviders(httpContext, endpoint, items); + + var endpointConstraints = ExtractEndpointConstraints(items); + + var allEndpointConstraintsCached = true; + for (var i = 0; i < items.Count; i++) + { + var item = items[i]; + if (!item.IsReusable) + { + item.Constraint = null; + allEndpointConstraintsCached = false; + } + } + + if (allEndpointConstraintsCached) + { + entry = new CacheEntry(endpointConstraints); + } + else + { + entry = new CacheEntry(items); + } + + cache.Entries.TryAdd(endpoint, entry); + return endpointConstraints; + } + + private IReadOnlyList GetEndpointConstraintsFromEntry(CacheEntry entry, HttpContext httpContext, Endpoint endpoint) + { + Debug.Assert(entry.EndpointConstraints != null || entry.Items != null); + + if (entry.EndpointConstraints != null) + { + return entry.EndpointConstraints; + } + + var items = new List(entry.Items.Count); + for (var i = 0; i < entry.Items.Count; i++) + { + var item = entry.Items[i]; + if (item.IsReusable) + { + items.Add(item); + } + else + { + items.Add(new EndpointConstraintItem(item.Metadata)); + } + } + + ExecuteProviders(httpContext, endpoint, items); + + return ExtractEndpointConstraints(items); + } + + private void ExecuteProviders(HttpContext httpContext, Endpoint endpoint, List items) + { + var context = new EndpointConstraintProviderContext(httpContext, endpoint, items); + + for (var i = 0; i < _endpointConstraintProviders.Length; i++) + { + _endpointConstraintProviders[i].OnProvidersExecuting(context); + } + + for (var i = _endpointConstraintProviders.Length - 1; i >= 0; i--) + { + _endpointConstraintProviders[i].OnProvidersExecuted(context); + } + } + + private IReadOnlyList ExtractEndpointConstraints(List items) + { + var count = 0; + for (var i = 0; i < items.Count; i++) + { + if (items[i].Constraint != null) + { + count++; + } + } + + if (count == 0) + { + return null; + } + + var endpointConstraints = new IEndpointConstraint[count]; + var endpointConstraintIndex = 0; + for (int i = 0; i < items.Count; i++) + { + var endpointConstraint = items[i].Constraint; + if (endpointConstraint != null) + { + endpointConstraints[endpointConstraintIndex++] = endpointConstraint; + } + } + + return endpointConstraints; + } + + private class InnerCache + { + public InnerCache() + { + } + + public ConcurrentDictionary Entries { get; } = + new ConcurrentDictionary(); + } + + private struct CacheEntry + { + public CacheEntry(IReadOnlyList endpointConstraints) + { + EndpointConstraints = endpointConstraints; + Items = null; + } + + public CacheEntry(List items) + { + Items = items; + EndpointConstraints = null; + } + + public IReadOnlyList EndpointConstraints { get; } + + public List Items { get; } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointSelector.cs b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointSelector.cs new file mode 100644 index 0000000000..4ade56e671 --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/EndpointSelector.cs @@ -0,0 +1,225 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.Matchers; +using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq; +using System.Threading; + +namespace Microsoft.AspNetCore.Routing.EndpointConstraints +{ + internal class EndpointSelector + { + private static readonly IReadOnlyList EmptyEndpoints = Array.Empty(); + + private readonly CompositeEndpointDataSource _dataSource; + private readonly EndpointConstraintCache _endpointConstraintCache; + private readonly ILogger _logger; + + public EndpointSelector( + CompositeEndpointDataSource dataSource, + EndpointConstraintCache endpointConstraintCache, + ILoggerFactory loggerFactory) + { + _dataSource = dataSource; + _logger = loggerFactory.CreateLogger(); + _endpointConstraintCache = endpointConstraintCache; + } + + public Endpoint SelectBestCandidate(HttpContext context, IReadOnlyList candidates) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + if (candidates == null) + { + throw new ArgumentNullException(nameof(candidates)); + } + + var finalMatches = EvaluateEndpointConstraints(context, candidates); + + if (finalMatches == null || finalMatches.Count == 0) + { + return null; + } + else if (finalMatches.Count == 1) + { + var selectedEndpoint = finalMatches[0]; + + return selectedEndpoint; + } + else + { + var endpointNames = string.Join( + Environment.NewLine, + finalMatches.Select(a => a.DisplayName)); + + Log.MatchAmbiguous(_logger, context, finalMatches); + + var message = Resources.FormatAmbiguousEndpoints( + Environment.NewLine, + string.Join(Environment.NewLine, endpointNames)); + + throw new AmbiguousMatchException(message); + } + } + + private IReadOnlyList EvaluateEndpointConstraints( + HttpContext context, + IReadOnlyList endpoints) + { + var candidates = new List(); + + // Perf: Avoid allocations + for (var i = 0; i < endpoints.Count; i++) + { + var endpoint = endpoints[i]; + var constraints = _endpointConstraintCache.GetEndpointConstraints(context, endpoint); + candidates.Add(new EndpointSelectorCandidate(endpoint, constraints)); + } + + var matches = EvaluateEndpointConstraintsCore(context, candidates, startingOrder: null); + + List results = null; + if (matches != null) + { + results = new List(matches.Count); + // Perf: Avoid allocations + for (var i = 0; i < matches.Count; i++) + { + var candidate = matches[i]; + results.Add(candidate.Endpoint); + } + } + + return results; + } + + private IReadOnlyList EvaluateEndpointConstraintsCore( + HttpContext context, + IReadOnlyList candidates, + int? startingOrder) + { + // Find the next group of constraints to process. This will be the lowest value of + // order that is higher than startingOrder. + int? order = null; + + // Perf: Avoid allocations + for (var i = 0; i < candidates.Count; i++) + { + var candidate = candidates[i]; + if (candidate.Constraints != null) + { + for (var j = 0; j < candidate.Constraints.Count; j++) + { + var constraint = candidate.Constraints[j]; + if ((startingOrder == null || constraint.Order > startingOrder) && + (order == null || constraint.Order < order)) + { + order = constraint.Order; + } + } + } + } + + // If we don't find a next then there's nothing left to do. + if (order == null) + { + return candidates; + } + + // Since we have a constraint to process, bisect the set of endpoints into those with and without a + // constraint for the current order. + var endpointsWithConstraint = new List(); + var endpointsWithoutConstraint = new List(); + + var constraintContext = new EndpointConstraintContext(); + constraintContext.Candidates = candidates; + constraintContext.HttpContext = context; + + // Perf: Avoid allocations + for (var i = 0; i < candidates.Count; i++) + { + var candidate = candidates[i]; + var isMatch = true; + var foundMatchingConstraint = false; + + if (candidate.Constraints != null) + { + constraintContext.CurrentCandidate = candidate; + for (var j = 0; j < candidate.Constraints.Count; j++) + { + var constraint = candidate.Constraints[j]; + if (constraint.Order == order) + { + foundMatchingConstraint = true; + + if (!constraint.Accept(constraintContext)) + { + isMatch = false; + //_logger.ConstraintMismatch( + // candidate.Endpoint.DisplayName, + // candidate.Endpoint.Id, + // constraint); + break; + } + } + } + } + + if (isMatch && foundMatchingConstraint) + { + endpointsWithConstraint.Add(candidate); + } + else if (isMatch) + { + endpointsWithoutConstraint.Add(candidate); + } + } + + // If we have matches with constraints, those are better so try to keep processing those + if (endpointsWithConstraint.Count > 0) + { + var matches = EvaluateEndpointConstraintsCore(context, endpointsWithConstraint, order); + if (matches?.Count > 0) + { + return matches; + } + } + + // If the set of matches with constraints can't work, then process the set without constraints. + if (endpointsWithoutConstraint.Count == 0) + { + return null; + } + else + { + return EvaluateEndpointConstraintsCore(context, endpointsWithoutConstraint, order); + } + } + + private static class Log + { + private static readonly Action, Exception> _matchAmbiguous = LoggerMessage.Define>( + LogLevel.Error, + new EventId(1, "MatchAmbiguous"), + "Request matched multiple endpoints for request path '{Path}'. Matching endpoints: {AmbiguousEndpoints}"); + + public static void MatchAmbiguous(ILogger logger, HttpContext httpContext, IEnumerable endpoints) + { + if (logger.IsEnabled(LogLevel.Error)) + { + _matchAmbiguous(logger, httpContext.Request.Path, endpoints.Select(e => e.DisplayName), null); + } + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs new file mode 100644 index 0000000000..9c890732bc --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/HttpMethodEndpointConstraint.cs @@ -0,0 +1,70 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Collections.ObjectModel; + +namespace Microsoft.AspNetCore.Routing.EndpointConstraints +{ + public class HttpMethodEndpointConstraint : IEndpointConstraint + { + public static readonly int HttpMethodConstraintOrder = 100; + + private readonly IReadOnlyList _httpMethods; + + // Empty collection means any method will be accepted. + public HttpMethodEndpointConstraint(IEnumerable httpMethods) + { + if (httpMethods == null) + { + throw new ArgumentNullException(nameof(httpMethods)); + } + + var methods = new List(); + + foreach (var method in httpMethods) + { + if (string.IsNullOrEmpty(method)) + { + throw new ArgumentException("httpMethod cannot be null or empty"); + } + + methods.Add(method); + } + + _httpMethods = new ReadOnlyCollection(methods); + } + + public IEnumerable HttpMethods => _httpMethods; + + public int Order => HttpMethodConstraintOrder; + + public virtual bool Accept(EndpointConstraintContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + if (_httpMethods.Count == 0) + { + return true; + } + + var request = context.HttpContext.Request; + var method = request.Method; + + for (var i = 0; i < _httpMethods.Count; i++) + { + var supportedMethod = _httpMethods[i]; + if (string.Equals(supportedMethod, method, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/EndpointConstraints/IEndpointConstraint.cs b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/IEndpointConstraint.cs new file mode 100644 index 0000000000..ace843104d --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/EndpointConstraints/IEndpointConstraint.cs @@ -0,0 +1,165 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Routing.EndpointConstraints +{ + public class EndpointConstraintContext + { + public IReadOnlyList Candidates { get; set; } + + public EndpointSelectorCandidate CurrentCandidate { get; set; } + + public HttpContext HttpContext { get; set; } + } + + public interface IEndpointConstraint : IEndpointConstraintMetadata + { + int Order { get; } + + bool Accept(EndpointConstraintContext context); + } + + public interface IEndpointConstraintMetadata + { + } + + public struct EndpointSelectorCandidate + { + public EndpointSelectorCandidate(Endpoint endpoint, IReadOnlyList constraints) + { + if (endpoint == null) + { + throw new ArgumentNullException(nameof(endpoint)); + } + + Endpoint = endpoint; + Constraints = constraints; + } + + public Endpoint Endpoint { get; } + + public IReadOnlyList Constraints { get; } + } + + public class EndpointConstraintItem + { + public EndpointConstraintItem(IEndpointConstraintMetadata metadata) + { + if (metadata == null) + { + throw new ArgumentNullException(nameof(metadata)); + } + + Metadata = metadata; + } + + public IEndpointConstraint Constraint { get; set; } + + public IEndpointConstraintMetadata Metadata { get; } + + public bool IsReusable { get; set; } + } + + public interface IEndpointConstraintProvider + { + int Order { get; } + + void OnProvidersExecuting(EndpointConstraintProviderContext context); + + void OnProvidersExecuted(EndpointConstraintProviderContext context); + } + + public class EndpointConstraintProviderContext + { + public EndpointConstraintProviderContext( + HttpContext context, + Endpoint endpoint, + IList items) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + if (endpoint == null) + { + throw new ArgumentNullException(nameof(endpoint)); + } + + if (items == null) + { + throw new ArgumentNullException(nameof(items)); + } + + HttpContext = context; + Endpoint = endpoint; + Results = items; + } + + public HttpContext HttpContext { get; } + + public Endpoint Endpoint { get; } + + public IList Results { get; } + } + + public class DefaultEndpointConstraintProvider : IEndpointConstraintProvider + { + /// + public int Order => -1000; + + /// + public void OnProvidersExecuting(EndpointConstraintProviderContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + for (var i = 0; i < context.Results.Count; i++) + { + ProvideConstraint(context.Results[i], context.HttpContext.RequestServices); + } + } + + /// + public void OnProvidersExecuted(EndpointConstraintProviderContext context) + { + } + + private void ProvideConstraint(EndpointConstraintItem item, IServiceProvider services) + { + // Don't overwrite anything that was done by a previous provider. + if (item.Constraint != null) + { + return; + } + + if (item.Metadata is IEndpointConstraint constraint) + { + item.Constraint = constraint; + item.IsReusable = true; + return; + } + + if (item.Metadata is IEndpointConstraintFactory factory) + { + item.Constraint = factory.CreateInstance(services); + item.IsReusable = factory.IsReusable; + return; + } + } + } + + public interface IEndpointConstraintFactory : IEndpointConstraintMetadata + { + bool IsReusable { get; } + + IEndpointConstraint CreateInstance(IServiceProvider services); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcher.cs b/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcher.cs index 26923cd1c9..1e93f17d02 100644 --- a/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcher.cs +++ b/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcher.cs @@ -6,6 +6,7 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.AspNetCore.Routing.Internal; using Microsoft.AspNetCore.Routing.Template; using Microsoft.AspNetCore.Routing.Tree; @@ -18,12 +19,14 @@ namespace Microsoft.AspNetCore.Routing.Matchers { private readonly IInlineConstraintResolver _constraintFactory; private readonly ILogger _logger; + private readonly EndpointSelector _endpointSelector; private readonly DataSourceDependantCache _cache; public TreeMatcher( IInlineConstraintResolver constraintFactory, ILogger logger, - EndpointDataSource dataSource) + EndpointDataSource dataSource, + EndpointSelector endpointSelector) { if (constraintFactory == null) { @@ -42,6 +45,7 @@ namespace Microsoft.AspNetCore.Routing.Matchers _constraintFactory = constraintFactory; _logger = logger; + _endpointSelector = endpointSelector; _cache = new DataSourceDependantCache(dataSource, CreateTrees); _cache.EnsureInitialized(); } @@ -137,6 +141,8 @@ namespace Microsoft.AspNetCore.Routing.Matchers private Task SelectEndpointAsync(HttpContext httpContext, IEndpointFeature feature, IReadOnlyList endpoints) { + var bestEndpoint = _endpointSelector.SelectBestCandidate(httpContext, endpoints); + // REVIEW: Note that this code doesn't do anything significant now. This will eventually incorporate something like IActionConstraint switch (endpoints.Count) { diff --git a/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcherFactory.cs b/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcherFactory.cs index 1e90b82a14..2d4ef3d864 100644 --- a/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcherFactory.cs +++ b/src/Microsoft.AspNetCore.Routing/Matchers/TreeMatcherFactory.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Routing.Matchers @@ -10,8 +11,12 @@ namespace Microsoft.AspNetCore.Routing.Matchers { private readonly IInlineConstraintResolver _constraintFactory; private readonly ILogger _logger; + private readonly EndpointSelector _endpointSelector; - public TreeMatcherFactory(IInlineConstraintResolver constraintFactory, ILogger logger) + public TreeMatcherFactory( + IInlineConstraintResolver constraintFactory, + ILogger logger, + EndpointSelector endpointSelector) { if (constraintFactory == null) { @@ -23,8 +28,14 @@ namespace Microsoft.AspNetCore.Routing.Matchers throw new ArgumentNullException(nameof(logger)); } + if (endpointSelector == null) + { + throw new ArgumentNullException(nameof(endpointSelector)); + } + _constraintFactory = constraintFactory; _logger = logger; + _endpointSelector = endpointSelector; } public override Matcher CreateMatcher(EndpointDataSource dataSource) @@ -34,7 +45,7 @@ namespace Microsoft.AspNetCore.Routing.Matchers throw new ArgumentNullException(nameof(dataSource)); } - return new TreeMatcher(_constraintFactory, _logger, dataSource); + return new TreeMatcher(_constraintFactory, _logger, dataSource, _endpointSelector); } } } diff --git a/test/Microsoft.AspNetCore.Routing.Tests/DispatcherMiddlewareTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/DispatcherMiddlewareTest.cs index 7d93f42951..ad0855f4c8 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/DispatcherMiddlewareTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/DispatcherMiddlewareTest.cs @@ -24,12 +24,7 @@ namespace Microsoft.AspNetCore.Routing var httpContext = new DefaultHttpContext(); httpContext.RequestServices = new TestServiceProvider(); - RequestDelegate next = (c) => Task.FromResult(null); - - var logger = new Logger(NullLoggerFactory.Instance); - var options = Options.Create(new DispatcherOptions()); - var matcherFactory = new TestMatcherFactory(false); - var middleware = new DispatcherMiddleware(matcherFactory, options, logger, next); + var middleware = CreateMiddleware(); // Act await middleware.Invoke(httpContext); @@ -53,12 +48,8 @@ namespace Microsoft.AspNetCore.Routing var httpContext = new DefaultHttpContext(); httpContext.RequestServices = new TestServiceProvider(); - RequestDelegate next = (c) => Task.FromResult(null); - var logger = new Logger(loggerFactory); - var options = Options.Create(new DispatcherOptions()); - var matcherFactory = new TestMatcherFactory(true); - var middleware = new DispatcherMiddleware(matcherFactory, options, logger, next); + var middleware = CreateMiddleware(logger); // Act await middleware.Invoke(httpContext); @@ -68,5 +59,22 @@ namespace Microsoft.AspNetCore.Routing var write = Assert.Single(sink.Writes); Assert.Equal(expectedMessage, write.State?.ToString()); } + + private DispatcherMiddleware CreateMiddleware(Logger logger = null) + { + RequestDelegate next = (c) => Task.FromResult(null); + + logger = logger ?? new Logger(NullLoggerFactory.Instance); + + var options = Options.Create(new DispatcherOptions()); + var matcherFactory = new TestMatcherFactory(true); + var middleware = new DispatcherMiddleware( + matcherFactory, + new CompositeEndpointDataSource(Array.Empty()), + logger, + next); + + return middleware; + } } } diff --git a/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointSelectorTests.cs b/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointSelectorTests.cs new file mode 100644 index 0000000000..5fea074c4c --- /dev/null +++ b/test/Microsoft.AspNetCore.Routing.Tests/EndpointConstraints/EndpointSelectorTests.cs @@ -0,0 +1,438 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.Matchers; +using Microsoft.AspNetCore.Routing.TestObjects; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging.Testing; +using Moq; +using System; +using System.Collections.Generic; +using System.Linq; +using Xunit; + +namespace Microsoft.AspNetCore.Routing.EndpointConstraints +{ + public class EndpointSelectorTests + { + [Fact] + public void SelectBestCandidate_MultipleEndpoints_BestMatchSelected() + { + // Arrange + var defaultEndpoint = new TestEndpoint( + EndpointMetadataCollection.Empty, + "No constraint endpoint"); + + var postEndpoint = new TestEndpoint( + new EndpointMetadataCollection(new object[] { new HttpMethodEndpointConstraint(new[] { "POST" }) }), + "POST constraint endpoint"); + + var endpoints = new Endpoint[] + { + defaultEndpoint, + postEndpoint + }; + + var endpointSelector = CreateSelector(endpoints); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.Method = "POST"; + + // Act + var bestCandidateEndpoint = endpointSelector.SelectBestCandidate(httpContext, endpoints); + + // Assert + Assert.NotNull(postEndpoint); + } + + [Fact] + public void SelectBestCandidate_MultipleEndpoints_AmbiguousMatchExceptionThrown() + { + // Arrange + var expectedMessage = + "The request matched multiple endpoints. Matches: " + Environment.NewLine + + Environment.NewLine + + "Ambiguous1" + Environment.NewLine + + "Ambiguous2"; + + var defaultEndpoint1 = new TestEndpoint( + EndpointMetadataCollection.Empty, + "Ambiguous1"); + + var defaultEndpoint2 = new TestEndpoint( + EndpointMetadataCollection.Empty, + "Ambiguous2"); + + var endpoints = new Endpoint[] + { + defaultEndpoint1, + defaultEndpoint2 + }; + + var endpointSelector = CreateSelector(endpoints); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.Method = "POST"; + + // Act + var ex = Assert.ThrowsAny(() => + { + endpointSelector.SelectBestCandidate(httpContext, endpoints); + }); + + // Assert + Assert.Equal(expectedMessage, ex.Message); + } + + [Fact] + public void SelectBestCandidate_AmbiguousEndpoints_LogIsCorrect() + { + // Arrange + var sink = new TestSink(); + var loggerFactory = new TestLoggerFactory(sink, enabled: true); + + var actions = new Endpoint[] + { + new TestEndpoint(EndpointMetadataCollection.Empty, "A1"), + new TestEndpoint(EndpointMetadataCollection.Empty, "A2"), + }; + var selector = CreateSelector(actions, loggerFactory); + + var httpContext = CreateHttpContext("POST"); + var actionNames = string.Join(", ", actions.Select(action => action.DisplayName)); + var expectedMessage = $"Request matched multiple endpoints for request path '/test'. Matching endpoints: {actionNames}"; + + // Act + Assert.Throws(() => { selector.SelectBestCandidate(httpContext, actions); }); + + // Assert + Assert.Empty(sink.Scopes); + var write = Assert.Single(sink.Writes); + Assert.Equal(expectedMessage, write.State?.ToString()); + } + + [Fact] + public void SelectBestCandidate_PrefersEndpointWithConstraints() + { + // Arrange + var actionWithConstraints = new TestEndpoint( + new EndpointMetadataCollection(new[] { new HttpMethodEndpointConstraint(new string[] { "POST" }) }), + "Has constraint"); + + var actionWithoutConstraints = new TestEndpoint(EndpointMetadataCollection.Empty, "No constraint"); + + var actions = new Endpoint[] { actionWithConstraints, actionWithoutConstraints }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, actionWithConstraints); + } + + [Fact] + public void SelectBestCandidate_ConstraintsRejectAll() + { + // Arrange + var action1 = new TestEndpoint(new EndpointMetadataCollection(new[] { new BooleanConstraint() { Pass = false, } }), "action1"); + + var action2 = new TestEndpoint(new EndpointMetadataCollection(new[] { new BooleanConstraint() { Pass = false, } }), "action2"); + + var actions = new Endpoint[] { action1, action2 }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Null(action); + } + + [Fact] + public void SelectBestCandidate_ConstraintsRejectAll_DifferentStages() + { + // Arrange + var action1 = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = false, Order = 0 }, + new BooleanConstraint() { Pass = true, Order = 1 }, + }), "action1"); + + var action2 = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0 }, + new BooleanConstraint() { Pass = false, Order = 1 }, + }), "action2"); + + var actions = new Endpoint[] { action1, action2 }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Null(action); + } + + [Fact] + public void SelectBestCandidate_EndpointConstraintFactory() + { + // Arrange + var actionWithConstraints = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new ConstraintFactory() + { + Constraint = new BooleanConstraint() { Pass = true }, + }, + }), "actionWithConstraints"); + + var actionWithoutConstraints = new TestEndpoint(EndpointMetadataCollection.Empty, "actionWithoutConstraints"); + + var actions = new Endpoint[] { actionWithConstraints, actionWithoutConstraints }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, actionWithConstraints); + } + + [Fact] + public void SelectBestCandidate_EndpointConstraintFactory_ReturnsNull() + { + // Arrange + var nullConstraint = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new ConstraintFactory(), + }), "nullConstraint"); + + var actions = new Endpoint[] { nullConstraint }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, nullConstraint); + } + + // There's a custom constraint provider registered that only understands BooleanConstraintMarker + [Fact] + public void SelectBestCandidate_CustomProvider() + { + // Arrange + var actionWithConstraints = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraintMarker() { Pass = true }, + }), "actionWithConstraints"); + + var actionWithoutConstraints = new TestEndpoint(EndpointMetadataCollection.Empty, "actionWithoutConstraints"); + + var actions = new Endpoint[] { actionWithConstraints, actionWithoutConstraints, }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, actionWithConstraints); + } + + // Due to ordering of stages, the first action will be better. + [Fact] + public void SelectBestCandidate_ConstraintsInOrder() + { + // Arrange + var best = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0, }, + }), "best"); + + var worst = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 1, }, + }), "worst"); + + var actions = new Endpoint[] { best, worst }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, best); + } + + // Due to ordering of stages, the first action will be better. + [Fact] + public void SelectBestCandidate_ConstraintsInOrder_MultipleStages() + { + // Arrange + var best = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0, }, + new BooleanConstraint() { Pass = true, Order = 1, }, + new BooleanConstraint() { Pass = true, Order = 2, }, + }), "best"); + + var worst = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0, }, + new BooleanConstraint() { Pass = true, Order = 1, }, + new BooleanConstraint() { Pass = true, Order = 3, }, + }), "worst"); + + var actions = new Endpoint[] { best, worst }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, best); + } + + [Fact] + public void SelectBestCandidate_Fallback_ToEndpointWithoutConstraints() + { + // Arrange + var nomatch1 = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0, }, + new BooleanConstraint() { Pass = true, Order = 1, }, + new BooleanConstraint() { Pass = false, Order = 2, }, + }), "nomatch1"); + + var nomatch2 = new TestEndpoint(new EndpointMetadataCollection(new[] + { + new BooleanConstraint() { Pass = true, Order = 0, }, + new BooleanConstraint() { Pass = true, Order = 1, }, + new BooleanConstraint() { Pass = false, Order = 3, }, + }), "nomatch2"); + + var best = new TestEndpoint(EndpointMetadataCollection.Empty, "best"); + + var actions = new Endpoint[] { best, nomatch1, nomatch2 }; + + var selector = CreateSelector(actions); + var context = CreateHttpContext("POST"); + + // Act + var action = selector.SelectBestCandidate(context, actions); + + // Assert + Assert.Same(action, best); + } + + private static EndpointSelector CreateSelector(IReadOnlyList actions, ILoggerFactory loggerFactory = null) + { + loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + + var endpointDataSource = new CompositeEndpointDataSource(new[] { new DefaultEndpointDataSource(actions) }); + + var actionConstraintProviders = new IEndpointConstraintProvider[] { + new DefaultEndpointConstraintProvider(), + new BooleanConstraintProvider(), + }; + + return new EndpointSelector( + endpointDataSource, + GetEndpointConstraintCache(actionConstraintProviders), + loggerFactory); + } + + private static HttpContext CreateHttpContext(string httpMethod) + { + var serviceProvider = new ServiceCollection().BuildServiceProvider(); + + var httpContext = new Mock(MockBehavior.Strict); + + var request = new Mock(MockBehavior.Strict); + request.SetupGet(r => r.Method).Returns(httpMethod); + request.SetupGet(r => r.Path).Returns(new PathString("/test")); + request.SetupGet(r => r.Headers).Returns(new HeaderDictionary()); + httpContext.SetupGet(c => c.Request).Returns(request.Object); + httpContext.SetupGet(c => c.RequestServices).Returns(serviceProvider); + + return httpContext.Object; + } + + private static EndpointConstraintCache GetEndpointConstraintCache(IEndpointConstraintProvider[] actionConstraintProviders = null) + { + return new EndpointConstraintCache( + new CompositeEndpointDataSource(Array.Empty()), + actionConstraintProviders.AsEnumerable() ?? new List()); + } + + private class BooleanConstraint : IEndpointConstraint + { + public bool Pass { get; set; } + + public int Order { get; set; } + + public bool Accept(EndpointConstraintContext context) + { + return Pass; + } + } + + private class ConstraintFactory : IEndpointConstraintFactory + { + public IEndpointConstraint Constraint { get; set; } + + public bool IsReusable => true; + + public IEndpointConstraint CreateInstance(IServiceProvider services) + { + return Constraint; + } + } + + private class BooleanConstraintMarker : IEndpointConstraintMetadata + { + public bool Pass { get; set; } + } + + private class BooleanConstraintProvider : IEndpointConstraintProvider + { + public int Order { get; set; } + + public void OnProvidersExecuting(EndpointConstraintProviderContext context) + { + foreach (var item in context.Results) + { + if (item.Metadata is BooleanConstraintMarker marker) + { + Assert.Null(item.Constraint); + item.Constraint = new BooleanConstraint() { Pass = marker.Pass }; + } + } + } + + public void OnProvidersExecuted(EndpointConstraintProviderContext context) + { + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs new file mode 100644 index 0000000000..81e35c4535 --- /dev/null +++ b/test/Microsoft.AspNetCore.Routing.Tests/Internal/HttpMethodEndpointConstraintTest.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.EndpointConstraints; +using Microsoft.AspNetCore.Routing.TestObjects; +using Microsoft.Extensions.Primitives; +using System; +using System.Collections.Generic; +using System.Text; +using Xunit; + +namespace Microsoft.AspNetCore.Routing.Internal +{ + public class HttpMethodEndpointConstraintTest + { + public static TheoryData AcceptCaseInsensitiveData = + new TheoryData, string> + { + { new string[] { "get", "Get", "GET", "GEt"}, "gEt" }, + { new string[] { "POST", "PoSt", "GEt"}, "GET" }, + { new string[] { "get" }, "get" }, + { new string[] { "post" }, "POST" }, + { new string[] { "gEt" }, "get" }, + { new string[] { "get", "PoST" }, "pOSt" } + }; + + [Theory] + [MemberData(nameof(AcceptCaseInsensitiveData))] + public void HttpMethodEndpointConstraint_IgnoresPreflightRequests(IEnumerable httpMethods, string accessControlMethod) + { + // Arrange + var constraint = new HttpMethodEndpointConstraint(httpMethods); + var context = CreateEndpointConstraintContext(constraint); + context.HttpContext = CreateHttpContext("oPtIoNs", accessControlMethod); + + // Act + var result = constraint.Accept(context); + + // Assert + Assert.False(result, "Request should have been rejected."); + } + + [Theory] + [MemberData(nameof(AcceptCaseInsensitiveData))] + public void HttpMethodEndpointConstraint_Accept_CaseInsensitive(IEnumerable httpMethods, string expectedMethod) + { + // Arrange + var constraint = new HttpMethodEndpointConstraint(httpMethods); + var context = CreateEndpointConstraintContext(constraint); + context.HttpContext = CreateHttpContext(expectedMethod); + + // Act + var result = constraint.Accept(context); + + // Assert + Assert.True(result, "Request should have been accepted."); + } + + private static EndpointConstraintContext CreateEndpointConstraintContext(HttpMethodEndpointConstraint constraint) + { + var context = new EndpointConstraintContext(); + + var endpointSelectorCandidate = new EndpointSelectorCandidate(new TestEndpoint(EndpointMetadataCollection.Empty, string.Empty), new List { constraint }); + + context.Candidates = new List { endpointSelectorCandidate }; + context.CurrentCandidate = context.Candidates[0]; + + return context; + } + + private static HttpContext CreateHttpContext(string requestedMethod, string accessControlMethod = null) + { + var httpContext = new DefaultHttpContext(); + + httpContext.Request.Method = requestedMethod; + + if (accessControlMethod != null) + { + httpContext.Request.Headers.Add("Origin", StringValues.Empty); + httpContext.Request.Headers.Add("Access-Control-Request-Method", accessControlMethod); + } + + return httpContext; + } + } +} diff --git a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/TreeMatcherTests.cs b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/TreeMatcherTests.cs index 68072aaaea..8bebe119c8 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/Matchers/TreeMatcherTests.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/Matchers/TreeMatcherTests.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Options; using System; @@ -21,16 +22,20 @@ namespace Microsoft.AspNetCore.Routing.Matchers private TreeMatcher CreateTreeMatcher(EndpointDataSource endpointDataSource) { + var compositeDataSource = new CompositeEndpointDataSource(new[] { endpointDataSource }); var defaultInlineConstraintResolver = new DefaultInlineConstraintResolver(Options.Create(new RouteOptions())); - return new TreeMatcher(defaultInlineConstraintResolver, NullLogger.Instance, endpointDataSource); + var endpointSelector = new EndpointSelector( + compositeDataSource, + new EndpointConstraintCache(compositeDataSource, new IEndpointConstraintProvider[] { new DefaultEndpointConstraintProvider() }), + NullLoggerFactory.Instance); + + return new TreeMatcher(defaultInlineConstraintResolver, NullLogger.Instance, endpointDataSource, endpointSelector); } [Fact] public async Task MatchAsync_DuplicateTemplatesAndDifferentOrder_LowerOrderEndpointMatched() { // Arrange - var defaultInlineConstraintResolver = new DefaultInlineConstraintResolver(Options.Create(new RouteOptions())); - var higherOrderEndpoint = CreateEndpoint("/Teams", 1); var lowerOrderEndpoint = CreateEndpoint("/Teams", 0);