diff --git a/src/Microsoft.AspNetCore.Authorization/AuthorizationPolicyBuilder.cs b/src/Microsoft.AspNetCore.Authorization/AuthorizationPolicyBuilder.cs index 0cc0195e60..965ffe02ef 100644 --- a/src/Microsoft.AspNetCore.Authorization/AuthorizationPolicyBuilder.cs +++ b/src/Microsoft.AspNetCore.Authorization/AuthorizationPolicyBuilder.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization.Infrastructure; namespace Microsoft.AspNetCore.Authorization @@ -139,6 +140,22 @@ namespace Microsoft.AspNetCore.Authorization return this; } + /// + /// Requires that this Function returns true + /// + /// Function that must return true + /// + public AuthorizationPolicyBuilder RequireAssertion(Func> assert) + { + if (assert == null) + { + throw new ArgumentNullException(nameof(assert)); + } + + Requirements.Add(new AssertionRequirement(assert)); + return this; + } + public AuthorizationPolicy Build() { return new AuthorizationPolicy(Requirements, AuthenticationSchemes.Distinct()); diff --git a/src/Microsoft.AspNetCore.Authorization/Infrastructure/AssertionRequirement.cs b/src/Microsoft.AspNetCore.Authorization/Infrastructure/AssertionRequirement.cs index 0d3ab2bf28..0cc1751a49 100644 --- a/src/Microsoft.AspNetCore.Authorization/Infrastructure/AssertionRequirement.cs +++ b/src/Microsoft.AspNetCore.Authorization/Infrastructure/AssertionRequirement.cs @@ -2,15 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Authorization.Infrastructure { - public class AssertionRequirement : AuthorizationHandler, IAuthorizationRequirement + public class AssertionRequirement : IAuthorizationHandler, IAuthorizationRequirement { /// /// Function that is called to handle this requirement /// - public Func Handler { get; } + public Func> Handler { get; } public AssertionRequirement(Func assert) { @@ -19,14 +20,24 @@ namespace Microsoft.AspNetCore.Authorization.Infrastructure throw new ArgumentNullException(nameof(assert)); } + Handler = context => Task.FromResult(assert(context)); + } + + public AssertionRequirement(Func> assert) + { + if (assert == null) + { + throw new ArgumentNullException(nameof(assert)); + } + Handler = assert; } - protected override void Handle(AuthorizationContext context, AssertionRequirement requirement) + public async Task HandleAsync(AuthorizationContext context) { - if (Handler(context)) + if (await Handler(context)) { - context.Succeed(requirement); + context.Succeed(this); } } } diff --git a/test/Microsoft.AspNetCore.Authorization.Test/DefaultAuthorizationServiceTests.cs b/test/Microsoft.AspNetCore.Authorization.Test/DefaultAuthorizationServiceTests.cs index 60988448ba..7f11916940 100644 --- a/test/Microsoft.AspNetCore.Authorization.Test/DefaultAuthorizationServiceTests.cs +++ b/test/Microsoft.AspNetCore.Authorization.Test/DefaultAuthorizationServiceTests.cs @@ -923,6 +923,25 @@ namespace Microsoft.AspNetCore.Authorization.Test Assert.True(allowed); } + [Fact] + public async Task CanAuthorizeWithAsyncAssertionRequirement() + { + var authorizationService = BuildAuthorizationService(services => + { + services.AddAuthorization(options => + { + options.AddPolicy("Basic", policy => policy.RequireAssertion(context => Task.FromResult(true))); + }); + }); + var user = new ClaimsPrincipal(); + + // Act + var allowed = await authorizationService.AuthorizeAsync(user, "Basic"); + + // Assert + Assert.True(allowed); + } + public class StaticPolicyProvider : IAuthorizationPolicyProvider { public Task GetPolicyAsync(string policyName)