From 0b590ff46ff7d06028e1e746d98c5463d66447ef Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Tue, 7 May 2019 15:17:06 -0700 Subject: [PATCH] Fix #6764 EndpointConventionBuilder API review --- ...NetCore.Components.Server.netcoreapp3.0.cs | 2 +- ...nentEndpointConventionBuilderExtensions.cs | 9 ++- ...Microsoft.AspNetCore.Cors.netcoreapp3.0.cs | 4 +- .../CORS/samples/SampleDestination/Startup.cs | 10 ++-- ...CorsEndpointConventionBuilderExtensions.cs | 4 +- ...ndpointConventionBuilderExtensionsTests.cs | 12 ++-- .../test/UnitTests/CorsMiddlewareTests.cs | 2 +- ...tionEndpointConventionBuilderExtensions.cs | 55 ++++++++++++------- ...ndpointConventionBuilderExtensionsTests.cs | 38 +++++++++++++ 9 files changed, 94 insertions(+), 42 deletions(-) diff --git a/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs b/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs index 762408f9b1..7bb0e6fd6f 100644 --- a/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs +++ b/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs @@ -5,7 +5,7 @@ namespace Microsoft.AspNetCore.Builder { public static partial class ComponentEndpointConventionBuilderExtensions { - public static TBuilder AddComponent(this TBuilder builder, System.Type componentType, string selector) where TBuilder : Microsoft.AspNetCore.SignalR.IHubEndpointConventionBuilder { throw null; } + public static Microsoft.AspNetCore.Components.Server.ComponentEndpointConventionBuilder AddComponent(this Microsoft.AspNetCore.Components.Server.ComponentEndpointConventionBuilder builder, System.Type componentType, string selector) { throw null; } } public static partial class ComponentEndpointRouteBuilderExtensions { diff --git a/src/Components/Server/src/Builder/ComponentEndpointConventionBuilderExtensions.cs b/src/Components/Server/src/Builder/ComponentEndpointConventionBuilderExtensions.cs index f44577bb2c..0a29e1790d 100644 --- a/src/Components/Server/src/Builder/ComponentEndpointConventionBuilderExtensions.cs +++ b/src/Components/Server/src/Builder/ComponentEndpointConventionBuilderExtensions.cs @@ -9,19 +9,18 @@ using Microsoft.AspNetCore.SignalR; namespace Microsoft.AspNetCore.Builder { /// - /// Extensions for . + /// Extensions for . /// public static class ComponentEndpointConventionBuilderExtensions { /// - /// Adds to the list of components registered with this instance. - /// The selector will default to the component name in lowercase. + /// Adds to the list of components registered with this hub instance. /// - /// The . + /// The . /// The component type. /// The component selector in the DOM for the . /// The . - public static TBuilder AddComponent(this TBuilder builder, Type componentType, string selector) where TBuilder : IHubEndpointConventionBuilder + public static ComponentEndpointConventionBuilder AddComponent(this ComponentEndpointConventionBuilder builder, Type componentType, string selector) { if (builder == null) { diff --git a/src/Middleware/CORS/ref/Microsoft.AspNetCore.Cors.netcoreapp3.0.cs b/src/Middleware/CORS/ref/Microsoft.AspNetCore.Cors.netcoreapp3.0.cs index f000293212..9baec2b608 100644 --- a/src/Middleware/CORS/ref/Microsoft.AspNetCore.Cors.netcoreapp3.0.cs +++ b/src/Middleware/CORS/ref/Microsoft.AspNetCore.Cors.netcoreapp3.0.cs @@ -5,8 +5,8 @@ namespace Microsoft.AspNetCore.Builder { public static partial class CorsEndpointConventionBuilderExtensions { - public static TBuilder WithCorsPolicy(this TBuilder builder, System.Action configurePolicy) where TBuilder : Microsoft.AspNetCore.Builder.IEndpointConventionBuilder { throw null; } - public static TBuilder WithCorsPolicy(this TBuilder builder, string policyName) where TBuilder : Microsoft.AspNetCore.Builder.IEndpointConventionBuilder { throw null; } + public static TBuilder RequireCors(this TBuilder builder, System.Action configurePolicy) where TBuilder : Microsoft.AspNetCore.Builder.IEndpointConventionBuilder { throw null; } + public static TBuilder RequireCors(this TBuilder builder, string policyName) where TBuilder : Microsoft.AspNetCore.Builder.IEndpointConventionBuilder { throw null; } } public static partial class CorsMiddlewareExtensions { diff --git a/src/Middleware/CORS/samples/SampleDestination/Startup.cs b/src/Middleware/CORS/samples/SampleDestination/Startup.cs index d2daaaa13c..9f3614c295 100644 --- a/src/Middleware/CORS/samples/SampleDestination/Startup.cs +++ b/src/Middleware/CORS/samples/SampleDestination/Startup.cs @@ -63,11 +63,11 @@ namespace SampleDestination app.UseEndpoints(endpoints => { - endpoints.Map("/allow-origin", HandleRequest).WithCorsPolicy("AllowOrigin"); - endpoints.Map("/allow-header-method", HandleRequest).WithCorsPolicy("AllowHeaderMethod"); - endpoints.Map("/allow-credentials", HandleRequest).WithCorsPolicy("AllowCredentials"); - endpoints.Map("/exposed-header", HandleRequest).WithCorsPolicy("ExposedHeader"); - endpoints.Map("/allow-all", HandleRequest).WithCorsPolicy("AllowAll"); + endpoints.Map("/allow-origin", HandleRequest).RequireCors("AllowOrigin"); + endpoints.Map("/allow-header-method", HandleRequest).RequireCors("AllowHeaderMethod"); + endpoints.Map("/allow-credentials", HandleRequest).RequireCors("AllowCredentials"); + endpoints.Map("/exposed-header", HandleRequest).RequireCors("ExposedHeader"); + endpoints.Map("/allow-all", HandleRequest).RequireCors("AllowAll"); }); app.Run(async (context) => diff --git a/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs b/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs index 1b82f90dfa..b22a1e5758 100644 --- a/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs +++ b/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs @@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Builder /// The endpoint convention builder. /// The CORS policy name. /// The original convention builder parameter. - public static TBuilder WithCorsPolicy(this TBuilder builder, string policyName) where TBuilder : IEndpointConventionBuilder + public static TBuilder RequireCors(this TBuilder builder, string policyName) where TBuilder : IEndpointConventionBuilder { if (builder == null) { @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.Builder /// The endpoint convention builder. /// A delegate which can use a policy builder to build a policy. /// The original convention builder parameter. - public static TBuilder WithCorsPolicy(this TBuilder builder, Action configurePolicy) where TBuilder : IEndpointConventionBuilder + public static TBuilder RequireCors(this TBuilder builder, Action configurePolicy) where TBuilder : IEndpointConventionBuilder { if (builder == null) { diff --git a/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs index 6b1ce451d0..e800e45813 100644 --- a/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs +++ b/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs @@ -13,13 +13,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure public class CorsEndpointConventionBuilderExtensionsTests { [Fact] - public void WithCorsPolicy_Name_MetadataAdded() + public void RequireCors_Name_MetadataAdded() { // Arrange var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - testConventionBuilder.WithCorsPolicy("TestPolicyName"); + testConventionBuilder.RequireCors("TestPolicyName"); // Assert var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions); @@ -34,13 +34,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure } [Fact] - public void WithCorsPolicy_Policy_MetadataAdded() + public void RequireCors_Policy_MetadataAdded() { // Arrange var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - testConventionBuilder.WithCorsPolicy(builder => builder.AllowAnyOrigin()); + testConventionBuilder.RequireCors(builder => builder.AllowAnyOrigin()); // Assert var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions); @@ -56,13 +56,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure } [Fact] - public void WithCorsPolicy_ChainedCall_ReturnedBuilderIsDerivedType() + public void RequireCors_ChainedCall_ReturnedBuilderIsDerivedType() { // Arrange var testConventionBuilder = new TestEndpointConventionBuilder(); // Act - var builder = testConventionBuilder.WithCorsPolicy("TestPolicyName"); + var builder = testConventionBuilder.RequireCors("TestPolicyName"); // Assert Assert.True(builder.TestProperty); diff --git a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs index 759d1a4c8f..1a470da10d 100644 --- a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs +++ b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs @@ -717,7 +717,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure } [Fact] - public async Task Invoke_HasEndpointWithCorsPolicyMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName() + public async Task Invoke_HasEndpointRequireCorsMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName() { // Arrange var defaultPolicy = new CorsPolicyBuilder().Build(); diff --git a/src/Security/Authorization/Core/src/Policy/AuthorizationEndpointConventionBuilderExtensions.cs b/src/Security/Authorization/Core/src/Policy/AuthorizationEndpointConventionBuilderExtensions.cs index cefb534c55..6b7a03b1a6 100644 --- a/src/Security/Authorization/Core/src/Policy/AuthorizationEndpointConventionBuilderExtensions.cs +++ b/src/Security/Authorization/Core/src/Policy/AuthorizationEndpointConventionBuilderExtensions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Linq; using Microsoft.AspNetCore.Authorization; @@ -13,38 +14,25 @@ namespace Microsoft.AspNetCore.Builder public static class AuthorizationEndpointConventionBuilderExtensions { /// - /// Adds authorization policies with the specified to the endpoint(s). + /// Adds the default authorization policy to the endpoint(s). /// /// The endpoint convention builder. - /// A collection of . /// The original convention builder parameter. - public static TBuilder RequireAuthorization(this TBuilder builder, params IAuthorizeData[] authorizeData) where TBuilder : IEndpointConventionBuilder + public static TBuilder RequireAuthorization(this TBuilder builder) where TBuilder : IEndpointConventionBuilder { if (builder == null) { throw new ArgumentNullException(nameof(builder)); } - if (authorizeData == null) - { - throw new ArgumentNullException(nameof(authorizeData)); - } - - builder.Add(endpointBuilder => - { - foreach (var data in authorizeData) - { - endpointBuilder.Metadata.Add(data); - } - }); - return builder; + return builder.RequireAuthorization(new AuthorizeAttribute()); } /// /// Adds authorization policies with the specified names to the endpoint(s). /// /// The endpoint convention builder. - /// A collection of policy names. + /// A collection of policy names. If empty, the default authorization policy will be used. /// The original convention builder parameter. public static TBuilder RequireAuthorization(this TBuilder builder, params string[] policyNames) where TBuilder : IEndpointConventionBuilder { @@ -62,18 +50,45 @@ namespace Microsoft.AspNetCore.Builder } /// - /// Adds the default authorization policy to the endpoint(s). + /// Adds authorization policies with the specified to the endpoint(s). /// /// The endpoint convention builder. + /// + /// A collection of . If empty, the default authorization policy will be used. + /// /// The original convention builder parameter. - public static TBuilder RequireAuthorization(this TBuilder builder) where TBuilder : IEndpointConventionBuilder + public static TBuilder RequireAuthorization(this TBuilder builder, params IAuthorizeData[] authorizeData) + where TBuilder : IEndpointConventionBuilder { if (builder == null) { throw new ArgumentNullException(nameof(builder)); } - return builder.RequireAuthorization(new AuthorizeAttribute()); + if (authorizeData == null) + { + throw new ArgumentNullException(nameof(authorizeData)); + } + + if (authorizeData.Length == 0) + { + authorizeData = new IAuthorizeData[] { new AuthorizeAttribute(), }; + } + + RequireAuthorizationCore(builder, authorizeData); + return builder; + } + + private static void RequireAuthorizationCore(TBuilder builder, IEnumerable authorizeData) + where TBuilder : IEndpointConventionBuilder + { + builder.Add(endpointBuilder => + { + foreach (var data in authorizeData) + { + endpointBuilder.Metadata.Add(data); + } + }); } } } diff --git a/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs b/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs index f535bf7e69..c31041f075 100644 --- a/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs +++ b/src/Security/Authorization/test/AuthorizationEndpointConventionBuilderExtensionsTests.cs @@ -32,6 +32,25 @@ namespace Microsoft.AspNetCore.Authorization.Test Assert.Equal(metadata, Assert.Single(endpointModel.Metadata)); } + [Fact] + public void RequireAuthorization_IAuthorizeData_Empty() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + + // Act + builder.RequireAuthorization(Array.Empty()); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + convention(endpointModel); + + var authMetadata = Assert.IsAssignableFrom(Assert.Single(endpointModel.Metadata)); + Assert.Null(authMetadata.Policy); + } + [Fact] public void RequireAuthorization_PolicyName() { @@ -51,6 +70,25 @@ namespace Microsoft.AspNetCore.Authorization.Test Assert.Equal("policy", authMetadata.Policy); } + [Fact] + public void RequireAuthorization_PolicyName_Empty() + { + // Arrange + var builder = new TestEndpointConventionBuilder(); + + // Act + builder.RequireAuthorization(Array.Empty()); + + // Assert + var convention = Assert.Single(builder.Conventions); + + var endpointModel = new RouteEndpointBuilder((context) => Task.CompletedTask, RoutePatternFactory.Parse("/"), 0); + convention(endpointModel); + + var authMetadata = Assert.IsAssignableFrom(Assert.Single(endpointModel.Metadata)); + Assert.Null(authMetadata.Policy); + } + [Fact] public void RequireAuthorization_Default() {