Only run claims transformation once per ClaimsPrincipal instance by default (#12028)

This commit is contained in:
Hao Kung 2019-10-29 15:48:06 -07:00 committed by GitHub
parent b0d6b0edf9
commit 814a37548b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 100 additions and 8 deletions

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Collections.Generic;
using System.Linq; using System.Linq;
using System.Security.Claims; using System.Security.Claims;
using System.Threading.Tasks; using System.Threading.Tasks;
@ -15,6 +16,8 @@ namespace Microsoft.AspNetCore.Authentication
/// </summary> /// </summary>
public class AuthenticationService : IAuthenticationService public class AuthenticationService : IAuthenticationService
{ {
private HashSet<ClaimsPrincipal> _transformCache;
/// <summary> /// <summary>
/// Constructor. /// Constructor.
/// </summary> /// </summary>
@ -77,8 +80,20 @@ namespace Microsoft.AspNetCore.Authentication
var result = await handler.AuthenticateAsync(); var result = await handler.AuthenticateAsync();
if (result != null && result.Succeeded) if (result != null && result.Succeeded)
{ {
var transformed = await Transform.TransformAsync(result.Principal); var principal = result.Principal;
return AuthenticateResult.Success(new AuthenticationTicket(transformed, result.Properties, result.Ticket.AuthenticationScheme)); var doTransform = true;
_transformCache ??= new HashSet<ClaimsPrincipal>();
if (_transformCache.Contains(principal))
{
doTransform = false;
}
if (doTransform)
{
principal = await Transform.TransformAsync(principal);
_transformCache.Add(principal);
}
return AuthenticateResult.Success(new AuthenticationTicket(principal, result.Properties, result.Ticket.AuthenticationScheme));
} }
return result; return result;
} }

View File

@ -1,4 +1,4 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Globalization; using System.Globalization;
using System.Linq; using System.Linq;
@ -298,4 +298,4 @@ namespace Microsoft.AspNetCore.Authentication.Core.Test
} }
} }
} }
} }

View File

@ -11,7 +11,7 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.Authentication namespace Microsoft.AspNetCore.Authentication.Core.Test
{ {
public class AuthenticationSchemeProviderTests public class AuthenticationSchemeProviderTests
{ {

View File

@ -8,7 +8,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.Authentication namespace Microsoft.AspNetCore.Authentication.Core.Test
{ {
public class AuthenticationServiceTests public class AuthenticationServiceTests
{ {
@ -27,6 +27,30 @@ namespace Microsoft.AspNetCore.Authentication
Assert.Contains("base", ex.Message); Assert.Contains("base", ex.Message);
} }
[Fact]
public async Task CustomHandlersAuthenticateRunsClaimsTransformationEveryTime()
{
var transform = new RunOnce();
var services = new ServiceCollection().AddOptions().AddAuthenticationCore(o =>
{
o.AddScheme<BaseHandler>("base", "whatever");
})
.AddSingleton<IClaimsTransformation>(transform)
.BuildServiceProvider();
var context = new DefaultHttpContext();
context.RequestServices = services;
// Because base handler returns a different principal per call, its run multiple times
await context.AuthenticateAsync("base");
Assert.Equal(1, transform.Ran);
await context.AuthenticateAsync("base");
Assert.Equal(2, transform.Ran);
await context.AuthenticateAsync("base");
Assert.Equal(3, transform.Ran);
}
[Fact] [Fact]
public async Task ChallengeThrowsForSchemeMismatch() public async Task ChallengeThrowsForSchemeMismatch()
{ {
@ -219,12 +243,25 @@ namespace Microsoft.AspNetCore.Authentication
await context.ForbidAsync(); await context.ForbidAsync();
} }
private class RunOnce : IClaimsTransformation
{
public int Ran = 0;
public Task<ClaimsPrincipal> TransformAsync(ClaimsPrincipal principal)
{
Ran++;
return Task.FromResult(new ClaimsPrincipal());
}
}
private class BaseHandler : IAuthenticationHandler private class BaseHandler : IAuthenticationHandler
{ {
public Task<AuthenticateResult> AuthenticateAsync() public Task<AuthenticateResult> AuthenticateAsync()
{ {
return Task.FromResult(AuthenticateResult.NoResult()); return Task.FromResult(AuthenticateResult.Success(
new AuthenticationTicket(
new ClaimsPrincipal(new ClaimsIdentity("whatever")),
new AuthenticationProperties(),
"whatever")));
} }
public Task ChallengeAsync(AuthenticationProperties properties) public Task ChallengeAsync(AuthenticationProperties properties)

View File

@ -10,7 +10,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.Authentication namespace Microsoft.AspNetCore.Authentication.Core.Test
{ {
public class TokenExtensionTests public class TokenExtensionTests
{ {

View File

@ -203,6 +203,46 @@ namespace Microsoft.AspNetCore.Authentication
Assert.Equal(0, forwardDefault.SignOutCount); Assert.Equal(0, forwardDefault.SignOutCount);
} }
private class RunOnce : IClaimsTransformation
{
public int Ran = 0;
public Task<ClaimsPrincipal> TransformAsync(ClaimsPrincipal principal)
{
Ran++;
return Task.FromResult(new ClaimsPrincipal());
}
}
[Fact]
public async Task ForwardAuthenticateOnlyRunsTransformOnceByDefault()
{
var services = new ServiceCollection().AddLogging();
var transform = new RunOnce();
var builder = services.AddSingleton<IClaimsTransformation>(transform).AddAuthentication(o =>
{
o.DefaultScheme = DefaultScheme;
o.AddScheme<TestHandler2>("auth1", "auth1");
o.AddScheme<TestHandler>("specific", "specific");
});
RegisterAuth(builder, o =>
{
o.ForwardDefault = "auth1";
o.ForwardAuthenticate = "specific";
});
var specific = new TestHandler();
services.AddSingleton(specific);
var forwardDefault = new TestHandler2();
services.AddSingleton(forwardDefault);
var sp = services.BuildServiceProvider();
var context = new DefaultHttpContext();
context.RequestServices = sp;
await context.AuthenticateAsync();
Assert.Equal(1, transform.Ran);
}
[Fact] [Fact]
public async Task ForwardAuthenticateWinsOverDefault() public async Task ForwardAuthenticateWinsOverDefault()
{ {