From 41f26dc69de11eaca1ae216cb526880b182d9b0b Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Thu, 21 Sep 2017 15:35:06 -0700 Subject: [PATCH] Add endpoint disambiguation - Better sample of metadata - Sample shows how conventional routing would work - Added endpoint disambiguation --- .../DispatcherSample/DispatcherSample.csproj | 1 + .../HttpMethodEndpointSelector.cs | 86 ++++++++++ .../DispatcherSample/HttpMethodMetadata.cs | 23 +++ .../IAuthorizationPolicyMetadata.cs | 20 +++ .../DispatcherSample/ICorsPolicyMetadata.cs | 20 +++ .../DispatcherSample/IHttpMethodMetadata.cs | 12 ++ samples/DispatcherSample/Program.cs | 3 + samples/DispatcherSample/Startup.cs | 136 +++++++++------ .../IDispatcherFeature.cs | 2 + .../DispatcherBase.cs | 70 +++++++- .../DispatcherFeature.cs | 29 +--- .../DispatcherMiddleware.cs | 13 +- .../DispatcherValueEndpointSelector.cs | 55 ++++++ .../EndpointMiddleware.cs | 24 ++- .../EndpointSelector.cs | 12 ++ .../EndpointSelectorContext.cs | 88 ++++++++++ .../IDispatcherValueSelectableEndpoint.cs | 10 ++ .../SimpleEndpoint.cs | 33 ++-- .../DispatcherValueCollectionExtensions.cs | 21 +++ .../Dispatcher/RouteTemplateDispatcher.cs | 157 ++++++++++++++++++ .../Dispatcher/RouteValuesEndpoint.cs | 59 ------- .../Dispatcher/RouterDispatcher.cs | 36 +++- .../Dispatcher/RouterEndpointSelector.cs | 110 ------------ 23 files changed, 753 insertions(+), 267 deletions(-) create mode 100644 samples/DispatcherSample/HttpMethodEndpointSelector.cs create mode 100644 samples/DispatcherSample/HttpMethodMetadata.cs create mode 100644 samples/DispatcherSample/IAuthorizationPolicyMetadata.cs create mode 100644 samples/DispatcherSample/ICorsPolicyMetadata.cs create mode 100644 samples/DispatcherSample/IHttpMethodMetadata.cs create mode 100644 src/Microsoft.AspNetCore.Dispatcher/DispatcherValueEndpointSelector.cs create mode 100644 src/Microsoft.AspNetCore.Dispatcher/EndpointSelector.cs create mode 100644 src/Microsoft.AspNetCore.Dispatcher/EndpointSelectorContext.cs create mode 100644 src/Microsoft.AspNetCore.Dispatcher/IDispatcherValueSelectableEndpoint.cs create mode 100644 src/Microsoft.AspNetCore.Routing/Dispatcher/DispatcherValueCollectionExtensions.cs create mode 100644 src/Microsoft.AspNetCore.Routing/Dispatcher/RouteTemplateDispatcher.cs delete mode 100644 src/Microsoft.AspNetCore.Routing/Dispatcher/RouteValuesEndpoint.cs delete mode 100644 src/Microsoft.AspNetCore.Routing/Dispatcher/RouterEndpointSelector.cs diff --git a/samples/DispatcherSample/DispatcherSample.csproj b/samples/DispatcherSample/DispatcherSample.csproj index 29bcf6c954..add200e6c6 100644 --- a/samples/DispatcherSample/DispatcherSample.csproj +++ b/samples/DispatcherSample/DispatcherSample.csproj @@ -11,6 +11,7 @@ + diff --git a/samples/DispatcherSample/HttpMethodEndpointSelector.cs b/samples/DispatcherSample/HttpMethodEndpointSelector.cs new file mode 100644 index 0000000000..16b43234df --- /dev/null +++ b/samples/DispatcherSample/HttpMethodEndpointSelector.cs @@ -0,0 +1,86 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Dispatcher; + +namespace DispatcherSample +{ + public class HttpMethodEndpointSelector : EndpointSelector + { + public override async Task SelectAsync(EndpointSelectorContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var snapshot = context.CreateSnapshot(); + + var fallback = new List(); + for (var i = context.Endpoints.Count - 1; i >= 0; i--) + { + var endpoint = context.Endpoints[i]; + IHttpMethodMetadata metadata = null; + + for (var j = endpoint.Metadata.Count - 1; j >= 0; j--) + { + metadata = endpoint.Metadata[j] as IHttpMethodMetadata; + if (metadata != null) + { + break; + } + } + + if (metadata == null) + { + // No metadata. + fallback.Add(endpoint); + context.Endpoints.RemoveAt(i); + } + else if (Matches(metadata, context.HttpContext.Request.Method)) + { + // Do thing, this one matches + } + else + { + // Not a match. + context.Endpoints.RemoveAt(i); + } + } + + // Now the list of endpoints only contains those that have an HTTP method preference AND match the current + // request. + await context.InvokeNextAsync(); + + if (context.Endpoints.Count == 0) + { + // Nothing matched, do the fallback. + context.RestoreSnapshot(snapshot); + context.Endpoints.Clear(); + + for (var i = 0; i < fallback.Count; i++) + { + context.Endpoints.Add(fallback[i]); + } + + await context.InvokeNextAsync(); + } + } + + private bool Matches(IHttpMethodMetadata metadata, string httpMethod) + { + for (var i = 0; i < metadata.AllowedMethods.Count; i++) + { + if (string.Equals(metadata.AllowedMethods[i], httpMethod, StringComparison.OrdinalIgnoreCase)) + { + return true; + } + } + + return false; + } + } +} diff --git a/samples/DispatcherSample/HttpMethodMetadata.cs b/samples/DispatcherSample/HttpMethodMetadata.cs new file mode 100644 index 0000000000..c2266ead8c --- /dev/null +++ b/samples/DispatcherSample/HttpMethodMetadata.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace DispatcherSample +{ + public class HttpMethodMetadata : IHttpMethodMetadata + { + public HttpMethodMetadata(string httpMethod) + { + if (httpMethod == null) + { + throw new ArgumentNullException(nameof(httpMethod)); + } + + AllowedMethods = new[] { httpMethod, }; + } + + public IReadOnlyList AllowedMethods { get; } + } +} diff --git a/samples/DispatcherSample/IAuthorizationPolicyMetadata.cs b/samples/DispatcherSample/IAuthorizationPolicyMetadata.cs new file mode 100644 index 0000000000..a758f06f10 --- /dev/null +++ b/samples/DispatcherSample/IAuthorizationPolicyMetadata.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace DispatcherSample +{ + public interface IAuthorizationPolicyMetadata + { + string Name { get; } + } + + public class AuthorizationPolicyMetadata : IAuthorizationPolicyMetadata + { + public AuthorizationPolicyMetadata(string name) + { + Name = name; + } + + public string Name { get; } + } +} diff --git a/samples/DispatcherSample/ICorsPolicyMetadata.cs b/samples/DispatcherSample/ICorsPolicyMetadata.cs new file mode 100644 index 0000000000..d2d54dc4f5 --- /dev/null +++ b/samples/DispatcherSample/ICorsPolicyMetadata.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace DispatcherSample +{ + public interface ICorsPolicyMetadata + { + string Name { get; } + } + + public class CorsPolicyMetadata : ICorsPolicyMetadata + { + public CorsPolicyMetadata(string name) + { + Name = name; + } + + public string Name { get; } + } +} diff --git a/samples/DispatcherSample/IHttpMethodMetadata.cs b/samples/DispatcherSample/IHttpMethodMetadata.cs new file mode 100644 index 0000000000..b12731d8df --- /dev/null +++ b/samples/DispatcherSample/IHttpMethodMetadata.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; + +namespace DispatcherSample +{ + public interface IHttpMethodMetadata + { + IReadOnlyList AllowedMethods { get; } + } +} diff --git a/samples/DispatcherSample/Program.cs b/samples/DispatcherSample/Program.cs index 6ff193bc8c..70b401a66b 100644 --- a/samples/DispatcherSample/Program.cs +++ b/samples/DispatcherSample/Program.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Console; namespace DispatcherSample { @@ -13,6 +15,7 @@ namespace DispatcherSample .UseIISIntegration() .UseKestrel() .UseStartup() + .ConfigureLogging((c, b) => b.AddProvider(new ConsoleLoggerProvider((category, level) => true, includeScopes: false))) .Build(); host.Run(); diff --git a/samples/DispatcherSample/Startup.cs b/samples/DispatcherSample/Startup.cs index a07f62da66..878133d1d8 100644 --- a/samples/DispatcherSample/Startup.cs +++ b/samples/DispatcherSample/Startup.cs @@ -3,6 +3,7 @@ using System; using System.Linq; +using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Dispatcher; using Microsoft.AspNetCore.Hosting; @@ -10,6 +11,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.Dispatcher; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace DispatcherSample @@ -26,45 +28,24 @@ namespace DispatcherSample { services.Configure(options => { - options.Dispatchers.Add(CreateDispatcher( - "{Endpoint=example}", - new RouteValuesEndpoint( - new RouteValueDictionary(new { Endpoint = "First" }), - async (context) => - { - await context.Response.WriteAsync("Hello from the example!"); - }, - Array.Empty(), - "example"), - new RouteValuesEndpoint( - new RouteValueDictionary(new { Endpoint = "Second" }), - async (context) => - { - await context.Response.WriteAsync("Hello from the second example!"); - }, - Array.Empty(), - "example2"))); + options.Dispatchers.Add(new RouteTemplateDispatcher("{controller=Home}/{action=Index}/{id?}", ConstraintResolver) + { + Endpoints = + { + new SimpleEndpoint(Home_Index, Array.Empty(), new { controller = "Home", action = "Index", }, "Home:Index()"), + new SimpleEndpoint(Home_About, Array.Empty(), new { controller = "Home", action = "About", }, "Home:About()"), + new SimpleEndpoint(Admin_Index, Array.Empty(), new { controller = "Admin", action = "Index", }, "Admin:Index()"), + new SimpleEndpoint(Admin_GetUsers, new object[] { new HttpMethodMetadata("GET"), new AuthorizationPolicyMetadata("Admin"), }, new { controller = "Admin", action = "Users", }, "Admin:GetUsers()"), + new SimpleEndpoint(Admin_EditUsers, new object[] { new HttpMethodMetadata("POST"), new AuthorizationPolicyMetadata("Admin"), }, new { controller = "Admin", action = "Users", }, "Admin:EditUsers()"), + }, + Selectors = + { + new DispatcherValueEndpointSelector(), + new HttpMethodEndpointSelector(), + } + }.InvokeAsync); - options.Dispatchers.Add(CreateDispatcher( - "{Endpoint=example}/{Parameter=foo}", - new RouteValuesEndpoint( - new RouteValueDictionary(new { Endpoint = "First", Parameter = "param1" }), - async (context) => - { - await context.Response.WriteAsync("Hello from the example for foo!"); - }, - Array.Empty(), - "example"), - new RouteValuesEndpoint( - new RouteValueDictionary(new { Endpoint = "Second", Parameter = "param2" }), - async (context) => - { - await context.Response.WriteAsync("Hello from the second example for foo!"); - }, - Array.Empty(), - "example2"))); - - options.HandlerFactories.Add((endpoint) => (endpoint as RouteValuesEndpoint)?.HandlerFactory); + options.HandlerFactories.Add((endpoint) => (endpoint as SimpleEndpoint)?.HandlerFactory); }); services.AddSingleton(); @@ -72,35 +53,86 @@ namespace DispatcherSample services.AddDispatcher(); } - public void Configure(IApplicationBuilder app, IHostingEnvironment env) + public void Configure(IApplicationBuilder app, IHostingEnvironment env, ILogger logger) { - app.Use(async (context, next) => - { - await context.Response.WriteAsync("

Middleware 1

"); - await next.Invoke(); - }); - app.UseDispatcher(); app.Use(async (context, next) => { - await context.Response.WriteAsync("

Middleware 2

"); + logger.LogInformation("Executing fake CORS middleware"); + + var feature = context.Features.Get(); + var policy = feature.Endpoint?.Metadata.OfType().LastOrDefault(); + logger.LogInformation("using CORS policy {PolicyName}", policy?.Name ?? "default"); + await next.Invoke(); }); app.Use(async (context, next) => { - var urlGenerator = app.ApplicationServices.GetService(); - var url = urlGenerator.GenerateURL(new RouteValueDictionary(new { Movie = "The Lion King", Character = "Mufasa" }), context); - await context.Response.WriteAsync($"

Generated url: {url}

"); + logger.LogInformation("Executing fake AuthZ middleware"); + + var feature = context.Features.Get(); + var policy = feature.Endpoint?.Metadata.OfType().LastOrDefault(); + if (policy != null) + { + logger.LogInformation("using Auth policy {PolicyName}", policy.Name); + } + await next.Invoke(); }); } - private static RequestDelegate CreateDispatcher(string routeTemplate, RouteValuesEndpoint endpoint, params RouteValuesEndpoint[] endpoints) + public static Task Home_Index(HttpContext httpContext) { - var dispatcher = new RouterDispatcher(new Route(new RouterEndpointSelector(new[] { endpoint }.Concat(endpoints)), routeTemplate, ConstraintResolver)); - return dispatcher.InvokeAsync; + var urlGenerator = httpContext.RequestServices.GetService(); + var url = urlGenerator.GenerateURL(new RouteValueDictionary(new { Movie = "The Lion King", Character = "Mufasa" }), httpContext); + return httpContext.Response.WriteAsync( + $"" + + $"" + + $"

Generated url: {url}

" + + $"" + + $""); + } + + public static Task Home_About(HttpContext httpContext) + { + return httpContext.Response.WriteAsync( + $"" + + $"" + + $"

This is a dispatcher sample.

" + + $"" + + $""); + } + + public static Task Admin_Index(HttpContext httpContext) + { + return httpContext.Response.WriteAsync( + $"" + + $"" + + $"

This is the admin page.

" + + $"" + + $""); + } + + public static Task Admin_GetUsers(HttpContext httpContext) + { + return httpContext.Response.WriteAsync( + $"" + + $"" + + $"

Users: rynowak, jbagga

" + + $"" + + $""); + } + + public static Task Admin_EditUsers(HttpContext httpContext) + { + return httpContext.Response.WriteAsync( + $"" + + $"" + + $"

blerp

" + + $"" + + $""); } } } diff --git a/src/Microsoft.AspNetCore.Dispatcher.Abstractions/IDispatcherFeature.cs b/src/Microsoft.AspNetCore.Dispatcher.Abstractions/IDispatcherFeature.cs index 86f7e07737..bbfb0975ef 100644 --- a/src/Microsoft.AspNetCore.Dispatcher.Abstractions/IDispatcherFeature.cs +++ b/src/Microsoft.AspNetCore.Dispatcher.Abstractions/IDispatcherFeature.cs @@ -10,5 +10,7 @@ namespace Microsoft.AspNetCore.Dispatcher Endpoint Endpoint { get; set; } RequestDelegate RequestDelegate { get; set; } + + DispatcherValueCollection Values { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Dispatcher/DispatcherBase.cs b/src/Microsoft.AspNetCore.Dispatcher/DispatcherBase.cs index 85423f0aef..61ecfb087f 100644 --- a/src/Microsoft.AspNetCore.Dispatcher/DispatcherBase.cs +++ b/src/Microsoft.AspNetCore.Dispatcher/DispatcherBase.cs @@ -1,6 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; +using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -8,6 +11,71 @@ namespace Microsoft.AspNetCore.Dispatcher { public abstract class DispatcherBase { - public abstract Task InvokeAsync(HttpContext httpContext); + private IList _endpoints; + private IList _endpointSelectors; + + public virtual IList Endpoints + { + get + { + if (_endpoints == null) + { + _endpoints = new List(); + } + + return _endpoints; + } + } + + public virtual IList Selectors + { + get + { + if (_endpointSelectors == null) + { + _endpointSelectors = new List(); + } + + return _endpointSelectors; + } + } + + public virtual async Task InvokeAsync(HttpContext httpContext) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + var feature = httpContext.Features.Get(); + if (await TryMatchAsync(httpContext)) + { + if (feature.RequestDelegate != null) + { + // Short circuit, no need to select an endpoint. + return; + } + + var selectorContext = new EndpointSelectorContext(httpContext, Endpoints.ToList(), Selectors); + await selectorContext.InvokeNextAsync(); + + switch (selectorContext.Endpoints.Count) + { + case 0: + break; + + case 1: + + feature.Endpoint = selectorContext.Endpoints[0]; + break; + + default: + throw new InvalidOperationException("Ambiguous bro!"); + + } + } + } + + protected abstract Task TryMatchAsync(HttpContext httpContext); } } diff --git a/src/Microsoft.AspNetCore.Dispatcher/DispatcherFeature.cs b/src/Microsoft.AspNetCore.Dispatcher/DispatcherFeature.cs index a33f2f41f1..bfece9b91d 100644 --- a/src/Microsoft.AspNetCore.Dispatcher/DispatcherFeature.cs +++ b/src/Microsoft.AspNetCore.Dispatcher/DispatcherFeature.cs @@ -7,33 +7,10 @@ namespace Microsoft.AspNetCore.Dispatcher { public class DispatcherFeature : IDispatcherFeature { - private Endpoint _endpoint; - private RequestDelegate _next; + public Endpoint Endpoint { get; set; } - public Endpoint Endpoint - { - get - { - return _endpoint; - } + public RequestDelegate RequestDelegate { get; set; } - set - { - _endpoint = value; - } - } - - public RequestDelegate RequestDelegate - { - get - { - return _next; - } - - set - { - _next = value; - } - } + public DispatcherValueCollection Values { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Dispatcher/DispatcherMiddleware.cs b/src/Microsoft.AspNetCore.Dispatcher/DispatcherMiddleware.cs index b0eb305ff1..4261b36eaf 100644 --- a/src/Microsoft.AspNetCore.Dispatcher/DispatcherMiddleware.cs +++ b/src/Microsoft.AspNetCore.Dispatcher/DispatcherMiddleware.cs @@ -4,28 +4,36 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Dispatcher { public class DispatcherMiddleware { + private readonly ILogger _logger; private readonly DispatcherOptions _options; private readonly RequestDelegate _next; - public DispatcherMiddleware(IOptions options, RequestDelegate next) + public DispatcherMiddleware(IOptions options, ILogger logger, RequestDelegate next) { if (options == null) { throw new ArgumentNullException(nameof(options)); } + if (logger == null) + { + throw new ArgumentNullException(nameof(logger)); + } + if (next == null) { throw new ArgumentNullException(nameof(next)); } - + _options = options.Value; + _logger = logger; _next = next; } @@ -39,6 +47,7 @@ namespace Microsoft.AspNetCore.Dispatcher await entry(httpContext); if (feature.Endpoint != null || feature.RequestDelegate != null) { + _logger.LogInformation("Matched endpoint {Endpoint}", feature.Endpoint.DisplayName); break; } } diff --git a/src/Microsoft.AspNetCore.Dispatcher/DispatcherValueEndpointSelector.cs b/src/Microsoft.AspNetCore.Dispatcher/DispatcherValueEndpointSelector.cs new file mode 100644 index 0000000000..a8377321f1 --- /dev/null +++ b/src/Microsoft.AspNetCore.Dispatcher/DispatcherValueEndpointSelector.cs @@ -0,0 +1,55 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Dispatcher +{ + public class DispatcherValueEndpointSelector : EndpointSelector + { + public override Task SelectAsync(EndpointSelectorContext context) + { + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + var dispatcherFeature = context.HttpContext.Features.Get(); + + for (var i = context.Endpoints.Count - 1; i >= 0; i--) + { + var endpoint = context.Endpoints[i] as IDispatcherValueSelectableEndpoint; + if (!CompareRouteValues(dispatcherFeature.Values, endpoint.Values)) + { + context.Endpoints.RemoveAt(i); + } + } + + return context.InvokeNextAsync(); + } + + private bool CompareRouteValues(DispatcherValueCollection values, DispatcherValueCollection requiredValues) + { + foreach (var kvp in requiredValues) + { + if (string.IsNullOrEmpty(kvp.Value.ToString())) + { + if (values.TryGetValue(kvp.Key, out var routeValue) && !string.IsNullOrEmpty(routeValue.ToString())) + { + return false; + } + } + else + { + if (!values.TryGetValue(kvp.Key, out var routeValue) || !string.Equals(kvp.Value.ToString(), routeValue.ToString(), StringComparison.OrdinalIgnoreCase)) + { + return false; + } + } + } + + return true; + } + } +} diff --git a/src/Microsoft.AspNetCore.Dispatcher/EndpointMiddleware.cs b/src/Microsoft.AspNetCore.Dispatcher/EndpointMiddleware.cs index 2b3b9b92f1..2c41fa0d7a 100644 --- a/src/Microsoft.AspNetCore.Dispatcher/EndpointMiddleware.cs +++ b/src/Microsoft.AspNetCore.Dispatcher/EndpointMiddleware.cs @@ -4,28 +4,36 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Dispatcher { public class EndpointMiddleware { + private readonly ILogger _logger; private readonly DispatcherOptions _options; - private RequestDelegate _next; + private readonly RequestDelegate _next; - public EndpointMiddleware(IOptions options, RequestDelegate next) + public EndpointMiddleware(IOptions options, ILogger logger, RequestDelegate next) { if (options == null) { throw new ArgumentNullException(nameof(options)); } + if (logger == null) + { + throw new ArgumentNullException(nameof(logger)); + } + if (next == null) { throw new ArgumentNullException(nameof(next)); } _options = options.Value; + _logger = logger; _next = next; } @@ -47,7 +55,17 @@ namespace Microsoft.AspNetCore.Dispatcher if (feature.RequestDelegate != null) { - await feature.RequestDelegate(context); + _logger.LogInformation("Executing endpoint {Endpoint}", feature.Endpoint.DisplayName); + try + { + await feature.RequestDelegate(context); + } + finally + { + _logger.LogInformation("Executed endpoint {Endpoint}", feature.Endpoint.DisplayName); + } + + return; } await _next(context); diff --git a/src/Microsoft.AspNetCore.Dispatcher/EndpointSelector.cs b/src/Microsoft.AspNetCore.Dispatcher/EndpointSelector.cs new file mode 100644 index 0000000000..d98bd9bf78 --- /dev/null +++ b/src/Microsoft.AspNetCore.Dispatcher/EndpointSelector.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Dispatcher +{ + public abstract class EndpointSelector + { + public abstract Task SelectAsync(EndpointSelectorContext context); + } +} diff --git a/src/Microsoft.AspNetCore.Dispatcher/EndpointSelectorContext.cs b/src/Microsoft.AspNetCore.Dispatcher/EndpointSelectorContext.cs new file mode 100644 index 0000000000..f6a5ae602c --- /dev/null +++ b/src/Microsoft.AspNetCore.Dispatcher/EndpointSelectorContext.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Dispatcher +{ + public sealed class EndpointSelectorContext + { + private int _index; + + public EndpointSelectorContext(HttpContext httpContext, IList endpoints, IList selectors) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + if (endpoints == null) + { + throw new ArgumentNullException(nameof(endpoints)); + } + + if (selectors == null) + { + throw new ArgumentNullException(nameof(selectors)); + } + + HttpContext = httpContext; + Endpoints = endpoints; + Selectors = selectors; + } + + public IList Endpoints { get; } + + public HttpContext HttpContext { get; } + + public IList Selectors { get; } + + public Task InvokeNextAsync() + { + if (_index >= Selectors.Count) + { + return Task.CompletedTask; + } + + var selector = Selectors[_index++]; + return selector.SelectAsync(this); + } + + public Snapshot CreateSnapshot() + { + return new Snapshot(_index, Endpoints); + } + + public void RestoreSnapshot(Snapshot snapshot) + { + snapshot.Apply(this); + } + + public struct Snapshot + { + private readonly int _index; + private readonly Endpoint[] _endpoints; + + internal Snapshot(int index, IList endpoints) + { + _index = index; + _endpoints = endpoints.ToArray(); + } + + internal void Apply(EndpointSelectorContext context) + { + context._index = _index; + + context.Endpoints.Clear(); + for (var i = 0; i < _endpoints.Length; i++) + { + context.Endpoints.Add(_endpoints[i]); + } + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Dispatcher/IDispatcherValueSelectableEndpoint.cs b/src/Microsoft.AspNetCore.Dispatcher/IDispatcherValueSelectableEndpoint.cs new file mode 100644 index 0000000000..61688390ae --- /dev/null +++ b/src/Microsoft.AspNetCore.Dispatcher/IDispatcherValueSelectableEndpoint.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Dispatcher +{ + public interface IDispatcherValueSelectableEndpoint + { + DispatcherValueCollection Values { get; } + } +} diff --git a/src/Microsoft.AspNetCore.Dispatcher/SimpleEndpoint.cs b/src/Microsoft.AspNetCore.Dispatcher/SimpleEndpoint.cs index 667667456b..5b81776fc0 100644 --- a/src/Microsoft.AspNetCore.Dispatcher/SimpleEndpoint.cs +++ b/src/Microsoft.AspNetCore.Dispatcher/SimpleEndpoint.cs @@ -8,29 +8,34 @@ using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Dispatcher { - public class SimpleEndpoint : Endpoint + public class SimpleEndpoint : Endpoint, IDispatcherValueSelectableEndpoint { public SimpleEndpoint(RequestDelegate requestDelegate) - : this(requestDelegate, Array.Empty(), null) + : this(requestDelegate, Array.Empty(), null, null) { } public SimpleEndpoint(Func delegateFactory) - : this(delegateFactory, Array.Empty(), null) + : this(delegateFactory, Array.Empty(), null, null) { } public SimpleEndpoint(RequestDelegate requestDelegate, IEnumerable metadata) - : this(requestDelegate, metadata, null) + : this(requestDelegate, metadata, null, null) { } public SimpleEndpoint(Func delegateFactory, IEnumerable metadata) - : this(delegateFactory, metadata, null) + : this(delegateFactory, metadata, null, null) { } - public SimpleEndpoint(RequestDelegate requestDelegate, IEnumerable metadata, string displayName) + public SimpleEndpoint(Func delegateFactory, IEnumerable metadata, object values) + : this(delegateFactory, metadata, null, null) + { + } + + public SimpleEndpoint(RequestDelegate requestDelegate, IEnumerable metadata, object values, string displayName) { if (metadata == null) { @@ -42,12 +47,13 @@ namespace Microsoft.AspNetCore.Dispatcher throw new ArgumentNullException(nameof(requestDelegate)); } - DisplayName = displayName; + HandlerFactory = (next) => requestDelegate; Metadata = metadata.ToArray(); - DelegateFactory = (next) => requestDelegate; + Values = new DispatcherValueCollection(values); + DisplayName = displayName; } - public SimpleEndpoint(Func delegateFactory, IEnumerable metadata, string displayName) + public SimpleEndpoint(Func delegateFactory, IEnumerable metadata, object values, string displayName) { if (metadata == null) { @@ -59,15 +65,18 @@ namespace Microsoft.AspNetCore.Dispatcher throw new ArgumentNullException(nameof(delegateFactory)); } - DisplayName = displayName; + HandlerFactory = delegateFactory; Metadata = metadata.ToArray(); - DelegateFactory = delegateFactory; + Values = new DispatcherValueCollection(values); + DisplayName = displayName; } public override string DisplayName { get; } public override IReadOnlyList Metadata { get; } - public Func DelegateFactory { get; } + public Func HandlerFactory { get; } + + public DispatcherValueCollection Values { get; } } } diff --git a/src/Microsoft.AspNetCore.Routing/Dispatcher/DispatcherValueCollectionExtensions.cs b/src/Microsoft.AspNetCore.Routing/Dispatcher/DispatcherValueCollectionExtensions.cs new file mode 100644 index 0000000000..337b6bcd7c --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/Dispatcher/DispatcherValueCollectionExtensions.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Routing; + +namespace Microsoft.AspNetCore.Dispatcher +{ + public static class DispatcherValueCollectionExtensions + { + public static RouteValueDictionary AsRouteValueDictionary(this DispatcherValueCollection values) + { + if (values == null) + { + throw new ArgumentNullException(nameof(values)); + } + + return values as RouteValueDictionary ?? new RouteValueDictionary(values); + } + } +} diff --git a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteTemplateDispatcher.cs b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteTemplateDispatcher.cs new file mode 100644 index 0000000000..551540883e --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteTemplateDispatcher.cs @@ -0,0 +1,157 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Dispatcher; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.Template; + +namespace Microsoft.AspNetCore.Routing.Dispatcher +{ + public class RouteTemplateDispatcher : DispatcherBase + { + private readonly IDictionary _constraints; + private readonly RouteValueDictionary _defaults; + private readonly TemplateMatcher _matcher; + private readonly RouteTemplate _parsedTemplate; + + public RouteTemplateDispatcher( + string routeTemplate, + IInlineConstraintResolver constraintResolver) + : this(routeTemplate, constraintResolver, null, null) + { + } + + public RouteTemplateDispatcher( + string routeTemplate, + IInlineConstraintResolver constraintResolver, + RouteValueDictionary defaults) + : this(routeTemplate, constraintResolver, defaults, null) + { + } + + public RouteTemplateDispatcher( + string routeTemplate, + IInlineConstraintResolver constraintResolver, + RouteValueDictionary defaults, + IDictionary constraints) + { + if (routeTemplate == null) + { + throw new ArgumentNullException(nameof(routeTemplate)); + } + + if (constraintResolver == null) + { + throw new ArgumentNullException(nameof(constraintResolver)); + } + + RouteTemplate = routeTemplate; + + try + { + // Data we parse from the template will be used to fill in the rest of the constraints or + // defaults. The parser will throw for invalid routes. + _parsedTemplate = TemplateParser.Parse(routeTemplate); + + _constraints = GetConstraints(constraintResolver, _parsedTemplate, constraints); + _defaults = GetDefaults(_parsedTemplate, defaults); + } + catch (Exception exception) + { + throw new RouteCreationException(Resources.FormatTemplateRoute_Exception(string.Empty, routeTemplate), exception); + } + + _matcher = new TemplateMatcher(_parsedTemplate, _defaults); + } + + public string RouteTemplate { get; } + + protected override Task TryMatchAsync(HttpContext httpContext) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + var feature = httpContext.Features.Get(); + feature.Values = feature.Values ?? new RouteValueDictionary(); + + if (!_matcher.TryMatch(httpContext.Request.Path, (RouteValueDictionary)feature.Values)) + { + // If we got back a null value set, that means the URI did not match + return Task.FromResult(false); + } + + foreach (var kvp in _constraints) + { + var constraint = kvp.Value; + if (!constraint.Match(httpContext, null, kvp.Key, (RouteValueDictionary)feature.Values, RouteDirection.IncomingRequest)) + { + return Task.FromResult(false); + } + } + + return Task.FromResult(true); + } + + private static IDictionary GetConstraints( + IInlineConstraintResolver inlineConstraintResolver, + RouteTemplate parsedTemplate, + IDictionary constraints) + { + var constraintBuilder = new RouteConstraintBuilder(inlineConstraintResolver, parsedTemplate.TemplateText); + + if (constraints != null) + { + foreach (var kvp in constraints) + { + constraintBuilder.AddConstraint(kvp.Key, kvp.Value); + } + } + + foreach (var parameter in parsedTemplate.Parameters) + { + if (parameter.IsOptional) + { + constraintBuilder.SetOptional(parameter.Name); + } + + foreach (var inlineConstraint in parameter.InlineConstraints) + { + constraintBuilder.AddResolvedConstraint(parameter.Name, inlineConstraint.Constraint); + } + } + + return constraintBuilder.Build(); + } + + private static RouteValueDictionary GetDefaults( + RouteTemplate parsedTemplate, + RouteValueDictionary defaults) + { + var result = defaults == null ? new RouteValueDictionary() : new RouteValueDictionary(defaults); + + foreach (var parameter in parsedTemplate.Parameters) + { + if (parameter.DefaultValue != null) + { + if (result.ContainsKey(parameter.Name)) + { + throw new InvalidOperationException( + Resources.FormatTemplateRoute_CannotHaveDefaultValueSpecifiedInlineAndExplicitly( + parameter.Name)); + } + else + { + result.Add(parameter.Name, parameter.DefaultValue); + } + } + } + + return result; + } + } +} diff --git a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteValuesEndpoint.cs b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteValuesEndpoint.cs deleted file mode 100644 index 71b0c746e4..0000000000 --- a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouteValuesEndpoint.cs +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using Microsoft.AspNetCore.Dispatcher; -using Microsoft.AspNetCore.Http; - -namespace Microsoft.AspNetCore.Routing.Dispatcher -{ - public class RouteValuesEndpoint : Endpoint - { - public RouteValuesEndpoint(RouteValueDictionary requiredValues, RequestDelegate requestDelegate) - : this(requiredValues, requestDelegate, Array.Empty(), null) - { - } - - public RouteValuesEndpoint(RouteValueDictionary requiredValues, RequestDelegate requestDelegate, IEnumerable metadata) - : this(requiredValues, requestDelegate, metadata, null) - { - } - - public RouteValuesEndpoint( - RouteValueDictionary requiredValues, - RequestDelegate requestDelegate, - IEnumerable metadata, - string displayName) - { - if (requiredValues == null) - { - throw new ArgumentNullException(nameof(requiredValues)); - } - - if (requestDelegate == null) - { - throw new ArgumentNullException(nameof(requestDelegate)); - } - - if (metadata == null) - { - throw new ArgumentNullException(nameof(metadata)); - } - - RequiredValues = requiredValues; - HandlerFactory = (next) => requestDelegate; - Metadata = metadata.ToArray(); - DisplayName = displayName; - } - - public override string DisplayName { get; } - - public override IReadOnlyList Metadata { get; } - - public Func HandlerFactory { get; set; } - - public RouteValueDictionary RequiredValues { get; set; } - } -} diff --git a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterDispatcher.cs b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterDispatcher.cs index 54b3e97a06..a96d9ac9fb 100644 --- a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterDispatcher.cs +++ b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterDispatcher.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Dispatcher; using Microsoft.AspNetCore.Http; @@ -13,6 +14,7 @@ namespace Microsoft.AspNetCore.Routing.Dispatcher /// public class RouterDispatcher : DispatcherBase { + private readonly Endpoint _fallbackEndpoint; private readonly IRouter _router; public RouterDispatcher(IRouter router) @@ -23,17 +25,47 @@ namespace Microsoft.AspNetCore.Routing.Dispatcher } _router = router; + _fallbackEndpoint = new UnknownEndpoint(_router); } - public async override Task InvokeAsync(HttpContext httpContext) + protected override async Task TryMatchAsync(HttpContext httpContext) { if (httpContext == null) { throw new ArgumentNullException(nameof(httpContext)); } - + var routeContext = new RouteContext(httpContext); await _router.RouteAsync(routeContext); + + var feature = httpContext.Features.Get(); + if (routeContext.Handler == null) + { + // The route did not match, clear everything as it may have been set by the route. + feature.Endpoint = null; + feature.RequestDelegate = null; + feature.Values = null; + return false; + } + else + { + feature.Endpoint = feature.Endpoint ?? _fallbackEndpoint; + feature.RequestDelegate = routeContext.Handler; + feature.Values = routeContext.RouteData.Values; + return true; + } + } + + private class UnknownEndpoint : Endpoint + { + public UnknownEndpoint(IRouter router) + { + DisplayName = $"Endpoint for '{router}"; + } + + public override string DisplayName { get; } + + public override IReadOnlyList Metadata => Array.Empty(); } } } diff --git a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterEndpointSelector.cs b/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterEndpointSelector.cs deleted file mode 100644 index 67c1e5e92e..0000000000 --- a/src/Microsoft.AspNetCore.Routing/Dispatcher/RouterEndpointSelector.cs +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Dispatcher; -using Microsoft.AspNetCore.Http; - -namespace Microsoft.AspNetCore.Routing.Dispatcher -{ - public class RouterEndpointSelector : IRouter, IRouteHandler - { - private readonly RouteValuesEndpoint[] _endpoints; - - public RouterEndpointSelector(IEnumerable endpoints) - { - if (endpoints == null) - { - throw new ArgumentNullException(nameof(endpoints)); - } - - _endpoints = endpoints.ToArray(); - } - - public RequestDelegate GetRequestHandler(HttpContext httpContext, RouteData routeData) - { - if (httpContext == null) - { - throw new ArgumentNullException(nameof(httpContext)); - } - - if (routeData == null) - { - throw new ArgumentNullException(nameof(routeData)); - } - - var dispatcherFeature = httpContext.Features.Get(); - if (dispatcherFeature == null) - { - throw new InvalidOperationException(Resources.FormatDispatcherFeatureIsRequired( - nameof(HttpContext), - nameof(IDispatcherFeature), - nameof(RouterEndpointSelector))); - } - - for (var i = 0; i < _endpoints.Length; i++) - { - var endpoint = _endpoints[i]; - if (CompareRouteValues(routeData.Values, endpoint.RequiredValues)) - { - dispatcherFeature.Endpoint = endpoint; - return null; - } - } - - return null; - } - - public VirtualPathData GetVirtualPath(VirtualPathContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - return null; - } - - public Task RouteAsync(RouteContext context) - { - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - var handler = GetRequestHandler(context.HttpContext, context.RouteData); - if (handler != null) - { - context.Handler = handler; - } - - return Task.CompletedTask; - } - - private bool CompareRouteValues(RouteValueDictionary values, RouteValueDictionary requiredValues) - { - foreach (var kvp in requiredValues) - { - if (string.IsNullOrEmpty(kvp.Value.ToString())) - { - if (values.TryGetValue(kvp.Key, out var routeValue) && !string.IsNullOrEmpty(routeValue.ToString())) - { - return false; - } - } - else - { - if (!values.TryGetValue(kvp.Key, out var routeValue) || !string.Equals(kvp.Value.ToString(), routeValue.ToString(), StringComparison.OrdinalIgnoreCase)) - { - return false; - } - } - } - - return true; - } - } -}