diff --git a/src/Microsoft.AspNetCore.Authentication/AuthenticationBuilder.cs b/src/Microsoft.AspNetCore.Authentication/AuthenticationBuilder.cs index 3bce55ea10..7bf8fe96ee 100644 --- a/src/Microsoft.AspNetCore.Authentication/AuthenticationBuilder.cs +++ b/src/Microsoft.AspNetCore.Authentication/AuthenticationBuilder.cs @@ -25,18 +25,10 @@ namespace Microsoft.AspNetCore.Authentication /// public virtual IServiceCollection Services { get; } - /// - /// Adds a which can be used by . - /// - /// The type to configure the handler."/>. - /// The used to handle this scheme. - /// The name of this scheme. - /// The display name of this scheme. - /// Used to configure the scheme options. - /// The builder. - public virtual AuthenticationBuilder AddScheme(string authenticationScheme, string displayName, Action configureOptions) - where TOptions : AuthenticationSchemeOptions, new() - where THandler : AuthenticationHandler + + private AuthenticationBuilder AddSchemeHelper(string authenticationScheme, string displayName, Action configureOptions) + where TOptions : class, new() + where THandler : class, IAuthenticationHandler { Services.Configure(o => { @@ -53,6 +45,20 @@ namespace Microsoft.AspNetCore.Authentication return this; } + /// + /// Adds a which can be used by . + /// + /// The type to configure the handler."/>. + /// The used to handle this scheme. + /// The name of this scheme. + /// The display name of this scheme. + /// Used to configure the scheme options. + /// The builder. + public virtual AuthenticationBuilder AddScheme(string authenticationScheme, string displayName, Action configureOptions) + where TOptions : AuthenticationSchemeOptions, new() + where THandler : AuthenticationHandler + => AddSchemeHelper(authenticationScheme, displayName, configureOptions); + /// /// Adds a which can be used by . /// @@ -84,6 +90,17 @@ namespace Microsoft.AspNetCore.Authentication return AddScheme(authenticationScheme, displayName, configureOptions: configureOptions); } + /// + /// Adds a based authentication handler which can be used to + /// redirect to other authentication schemes. + /// + /// The name of this scheme. + /// The display name of this scheme. + /// Used to configure the scheme options. + /// The builder. + public virtual AuthenticationBuilder AddVirtualScheme(string authenticationScheme, string displayName, Action configureOptions) + => AddSchemeHelper(authenticationScheme, displayName, configureOptions); + // Used to ensure that there's always a default sign in scheme that's not itself private class EnsureSignInScheme : IPostConfigureOptions where TOptions : RemoteAuthenticationOptions { diff --git a/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs index 9728e5ff05..ef4292100a 100644 --- a/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs @@ -22,12 +22,12 @@ namespace Microsoft.AspNetCore.Authentication protected HttpRequest Request { - get { return Context.Request; } + get => Context.Request; } protected HttpResponse Response { - get { return Context.Response; } + get => Context.Response; } protected PathString OriginalPath => Context.Features.Get()?.OriginalPath ?? Request.Path; @@ -52,10 +52,7 @@ namespace Microsoft.AspNetCore.Authentication protected string CurrentUri { - get - { - return Request.Scheme + "://" + Request.Host + Request.PathBase + Request.Path + Request.QueryString; - } + get => Request.Scheme + "://" + Request.Host + Request.PathBase + Request.Path + Request.QueryString; } protected AuthenticationHandler(IOptionsMonitor options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock) @@ -116,15 +113,10 @@ namespace Microsoft.AspNetCore.Authentication /// Called after options/events have been initialized for the handler to finish initializing itself. /// /// A task - protected virtual Task InitializeHandlerAsync() - { - return Task.CompletedTask; - } + protected virtual Task InitializeHandlerAsync() => Task.CompletedTask; protected string BuildRedirectUri(string targetPath) - { - return Request.Scheme + "://" + Request.Host + OriginalPathBase + targetPath; - } + => Request.Scheme + "://" + Request.Host + OriginalPathBase + targetPath; public async Task AuthenticateAsync() { diff --git a/src/Microsoft.AspNetCore.Authentication/VirtualAuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication/VirtualAuthenticationHandler.cs new file mode 100644 index 0000000000..4a023bec2c --- /dev/null +++ b/src/Microsoft.AspNetCore.Authentication/VirtualAuthenticationHandler.cs @@ -0,0 +1,71 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Authentication +{ + /// + /// Forwards calls to another authentication scheme. + /// + public class VirtualAuthenticationHandler : IAuthenticationHandler, IAuthenticationSignInHandler + { + protected IOptionsMonitor OptionsMonitor { get; } + public AuthenticationScheme Scheme { get; private set; } + public VirtualSchemeOptions Options { get; private set; } + protected HttpContext Context { get; private set; } + + public VirtualAuthenticationHandler(IOptionsMonitor options) + { + OptionsMonitor = options; + } + + /// + /// Initialize the handler, resolve the options and validate them. + /// + /// + /// + /// A Task. + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + if (scheme == null) + { + throw new ArgumentNullException(nameof(scheme)); + } + if (context == null) + { + throw new ArgumentNullException(nameof(context)); + } + + Scheme = scheme; + Context = context; + + Options = OptionsMonitor.Get(Scheme.Name) ?? new VirtualSchemeOptions(); + Options.Validate(); + + return Task.CompletedTask; + } + + protected virtual string ResolveTarget(string scheme) + => scheme ?? Options.DefaultSelector?.Invoke(Context) ?? Options.Default; + + public virtual Task AuthenticateAsync() + => Context.AuthenticateAsync(ResolveTarget(Options.Authenticate)); + + public virtual Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) + => Context.SignInAsync(ResolveTarget(Options.SignIn), user, properties); + + public virtual Task SignOutAsync(AuthenticationProperties properties) + => Context.SignOutAsync(ResolveTarget(Options.SignOut), properties); + + public virtual Task ChallengeAsync(AuthenticationProperties properties) + => Context.ChallengeAsync(ResolveTarget(Options.Challenge), properties); + + public virtual Task ForbidAsync(AuthenticationProperties properties) + => Context.ForbidAsync(ResolveTarget(Options.Forbid), properties); + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Authentication/VirtualSchemeOptions.cs b/src/Microsoft.AspNetCore.Authentication/VirtualSchemeOptions.cs new file mode 100644 index 0000000000..38d819bf59 --- /dev/null +++ b/src/Microsoft.AspNetCore.Authentication/VirtualSchemeOptions.cs @@ -0,0 +1,33 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Authentication +{ + /// + /// Used to redirect authentication methods to another scheme + /// + public class VirtualSchemeOptions + { + public string Default { get; set; } + + public string Authenticate { get; set; } + public string Challenge { get; set; } + public string Forbid { get; set; } + public string SignIn { get; set; } + public string SignOut { get; set; } + + /// + /// Used to select a default scheme to target based on the request. + /// + public Func DefaultSelector { get; set; } + + + /// + /// Check that the options are valid. Should throw an exception if things are not ok. + /// + public virtual void Validate() { } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs index 51bc67cc38..944a4827c3 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs @@ -990,7 +990,7 @@ namespace Microsoft.AspNetCore.Authentication.Google var res = context.Response; if (req.Path == new PathString("/challenge")) { - await context.ChallengeAsync("Google"); + await context.ChallengeAsync(); } else if (req.Path == new PathString("/challengeFacebook")) { @@ -1061,19 +1061,19 @@ namespace Microsoft.AspNetCore.Authentication.Google .ConfigureServices(services => { services.AddTransient(); - services.AddAuthentication(o => - { - o.DefaultScheme = TestExtensions.CookieAuthenticationScheme; - o.DefaultChallengeScheme = GoogleDefaults.AuthenticationScheme; - }); - services.AddAuthentication() + services.AddAuthentication("Auth") + .AddVirtualScheme("Auth", "Auth", o => + { + o.Default = TestExtensions.CookieAuthenticationScheme; + o.Challenge = GoogleDefaults.AuthenticationScheme; + }) .AddCookie(TestExtensions.CookieAuthenticationScheme) .AddGoogle(configureOptions) .AddFacebook(o => - { - o.AppId = "Test AppId"; - o.AppSecret = "Test AppSecrent"; - }); + { + o.AppId = "Test AppId"; + o.AppSecret = "Test AppSecrent"; + }); }); return new TestServer(builder); } diff --git a/test/Microsoft.AspNetCore.Authentication.Test/VirtualHandlerTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/VirtualHandlerTests.cs new file mode 100644 index 0000000000..a43478c949 --- /dev/null +++ b/test/Microsoft.AspNetCore.Authentication.Test/VirtualHandlerTests.cs @@ -0,0 +1,525 @@ +// Copyright (c) .NET Foundation. All rights reserved. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Authentication +{ + public class VirtualHandlerTests + { + [Fact] + public async Task CanDispatch() + { + var server = CreateServer(services => + { + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + o.AddScheme("auth3", "auth3"); + }) + .AddVirtualScheme("policy1", "policy1", p => + { + p.Default = "auth1"; + }) + .AddVirtualScheme("policy2", "policy2", p => + { + p.Authenticate = "auth2"; + }); + }); + + var transaction = await server.SendAsync("http://example.com/auth/policy1"); + Assert.Equal("auth1", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth1")); + + transaction = await server.SendAsync("http://example.com/auth/auth1"); + Assert.Equal("auth1", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth1")); + + transaction = await server.SendAsync("http://example.com/auth/auth2"); + Assert.Equal("auth2", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth2")); + + transaction = await server.SendAsync("http://example.com/auth/auth3"); + Assert.Equal("auth3", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth3")); + + transaction = await server.SendAsync("http://example.com/auth/policy2"); + Assert.Equal("auth2", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth2")); + } + + [Fact] + public async Task DefaultTargetSelectorWinsOverDefaultTarget() + { + var services = new ServiceCollection().AddOptions(); + + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + }) + .AddVirtualScheme("forward", "forward", p => { + p.Default = "auth2"; + p.DefaultSelector = ctx => "auth1"; + }); + + var handler1 = new TestHandler(); + services.AddSingleton(handler1); + var handler2 = new TestHandler2(); + services.AddSingleton(handler2); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, handler1.AuthenticateCount); + Assert.Equal(0, handler1.ForbidCount); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(0, handler1.SignOutCount); + Assert.Equal(0, handler2.AuthenticateCount); + Assert.Equal(0, handler2.ForbidCount); + Assert.Equal(0, handler2.ChallengeCount); + Assert.Equal(0, handler2.SignInCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.AuthenticateAsync("forward"); + Assert.Equal(1, handler1.AuthenticateCount); + Assert.Equal(0, handler2.AuthenticateCount); + + await context.ForbidAsync("forward"); + Assert.Equal(1, handler1.ForbidCount); + Assert.Equal(0, handler2.ForbidCount); + + await context.ChallengeAsync("forward"); + Assert.Equal(1, handler1.ChallengeCount); + Assert.Equal(0, handler2.ChallengeCount); + + await context.SignOutAsync("forward"); + Assert.Equal(1, handler1.SignOutCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.SignInAsync("forward", new ClaimsPrincipal()); + Assert.Equal(1, handler1.SignInCount); + Assert.Equal(0, handler2.SignInCount); + } + + [Fact] + public async Task NullDefaultTargetSelectorFallsBacktoDefaultTarget() + { + var services = new ServiceCollection().AddOptions(); + + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + }) + .AddVirtualScheme("forward", "forward", p => { + p.Default = "auth1"; + p.DefaultSelector = ctx => null; + }); + + var handler1 = new TestHandler(); + services.AddSingleton(handler1); + var handler2 = new TestHandler2(); + services.AddSingleton(handler2); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, handler1.AuthenticateCount); + Assert.Equal(0, handler1.ForbidCount); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(0, handler1.SignOutCount); + Assert.Equal(0, handler2.AuthenticateCount); + Assert.Equal(0, handler2.ForbidCount); + Assert.Equal(0, handler2.ChallengeCount); + Assert.Equal(0, handler2.SignInCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.AuthenticateAsync("forward"); + Assert.Equal(1, handler1.AuthenticateCount); + Assert.Equal(0, handler2.AuthenticateCount); + + await context.ForbidAsync("forward"); + Assert.Equal(1, handler1.ForbidCount); + Assert.Equal(0, handler2.ForbidCount); + + await context.ChallengeAsync("forward"); + Assert.Equal(1, handler1.ChallengeCount); + Assert.Equal(0, handler2.ChallengeCount); + + await context.SignOutAsync("forward"); + Assert.Equal(1, handler1.SignOutCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.SignInAsync("forward", new ClaimsPrincipal()); + Assert.Equal(1, handler1.SignInCount); + Assert.Equal(0, handler2.SignInCount); + } + + [Fact] + public async Task SpecificTargetAlwaysWinsOverDefaultTarget() + { + var services = new ServiceCollection().AddOptions(); + + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + }) + .AddVirtualScheme("forward", "forward", p => { + p.Default = "auth2"; + p.DefaultSelector = ctx => "auth2"; + p.Authenticate = "auth1"; + p.SignIn = "auth1"; + p.SignOut = "auth1"; + p.Forbid = "auth1"; + p.Challenge = "auth1"; + }); + + var handler1 = new TestHandler(); + services.AddSingleton(handler1); + var handler2 = new TestHandler2(); + services.AddSingleton(handler2); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, handler1.AuthenticateCount); + Assert.Equal(0, handler1.ForbidCount); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(0, handler1.SignOutCount); + Assert.Equal(0, handler2.AuthenticateCount); + Assert.Equal(0, handler2.ForbidCount); + Assert.Equal(0, handler2.ChallengeCount); + Assert.Equal(0, handler2.SignInCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.AuthenticateAsync("forward"); + Assert.Equal(1, handler1.AuthenticateCount); + Assert.Equal(0, handler2.AuthenticateCount); + + await context.ForbidAsync("forward"); + Assert.Equal(1, handler1.ForbidCount); + Assert.Equal(0, handler2.ForbidCount); + + await context.ChallengeAsync("forward"); + Assert.Equal(1, handler1.ChallengeCount); + Assert.Equal(0, handler2.ChallengeCount); + + await context.SignOutAsync("forward"); + Assert.Equal(1, handler1.SignOutCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.SignInAsync("forward", new ClaimsPrincipal()); + Assert.Equal(1, handler1.SignInCount); + Assert.Equal(0, handler2.SignInCount); + } + + [Fact] + public async Task VirtualSchemeTargetsForwardWithDefaultTarget() + { + var services = new ServiceCollection().AddOptions(); + + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + }) + .AddVirtualScheme("forward", "forward", p => p.Default = "auth1"); + + var handler1 = new TestHandler(); + services.AddSingleton(handler1); + var handler2 = new TestHandler2(); + services.AddSingleton(handler2); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, handler1.AuthenticateCount); + Assert.Equal(0, handler1.ForbidCount); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(0, handler1.SignOutCount); + Assert.Equal(0, handler2.AuthenticateCount); + Assert.Equal(0, handler2.ForbidCount); + Assert.Equal(0, handler2.ChallengeCount); + Assert.Equal(0, handler2.SignInCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.AuthenticateAsync("forward"); + Assert.Equal(1, handler1.AuthenticateCount); + Assert.Equal(0, handler2.AuthenticateCount); + + await context.ForbidAsync("forward"); + Assert.Equal(1, handler1.ForbidCount); + Assert.Equal(0, handler2.ForbidCount); + + await context.ChallengeAsync("forward"); + Assert.Equal(1, handler1.ChallengeCount); + Assert.Equal(0, handler2.ChallengeCount); + + await context.SignOutAsync("forward"); + Assert.Equal(1, handler1.SignOutCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.SignInAsync("forward", new ClaimsPrincipal()); + Assert.Equal(1, handler1.SignInCount); + Assert.Equal(0, handler2.SignInCount); + } + + [Fact] + public async Task VirtualSchemeTargetsOverrideDefaultTarget() + { + var services = new ServiceCollection().AddOptions(); + + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + }) + .AddVirtualScheme("forward", "forward", p => + { + p.Default = "auth1"; + p.Challenge = "auth2"; + p.SignIn = "auth2"; + }); + + var handler1 = new TestHandler(); + services.AddSingleton(handler1); + var handler2 = new TestHandler2(); + services.AddSingleton(handler2); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, handler1.AuthenticateCount); + Assert.Equal(0, handler1.ForbidCount); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(0, handler1.SignOutCount); + Assert.Equal(0, handler2.AuthenticateCount); + Assert.Equal(0, handler2.ForbidCount); + Assert.Equal(0, handler2.ChallengeCount); + Assert.Equal(0, handler2.SignInCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.AuthenticateAsync("forward"); + Assert.Equal(1, handler1.AuthenticateCount); + Assert.Equal(0, handler2.AuthenticateCount); + + await context.ForbidAsync("forward"); + Assert.Equal(1, handler1.ForbidCount); + Assert.Equal(0, handler2.ForbidCount); + + await context.ChallengeAsync("forward"); + Assert.Equal(0, handler1.ChallengeCount); + Assert.Equal(1, handler2.ChallengeCount); + + await context.SignOutAsync("forward"); + Assert.Equal(1, handler1.SignOutCount); + Assert.Equal(0, handler2.SignOutCount); + + await context.SignInAsync("forward", new ClaimsPrincipal()); + Assert.Equal(0, handler1.SignInCount); + Assert.Equal(1, handler2.SignInCount); + } + + [Fact] + public async Task CanDynamicTargetBasedOnQueryString() + { + var server = CreateServer(services => + { + services.AddAuthentication(o => + { + o.AddScheme("auth1", "auth1"); + o.AddScheme("auth2", "auth2"); + o.AddScheme("auth3", "auth3"); + }) + .AddVirtualScheme("dynamic", "dynamic", p => + { + p.DefaultSelector = c => c.Request.QueryString.Value.Substring(1); + }); + }); + + var transaction = await server.SendAsync("http://example.com/auth/dynamic?auth1"); + Assert.Equal("auth1", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth1")); + transaction = await server.SendAsync("http://example.com/auth/dynamic?auth2"); + Assert.Equal("auth2", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth2")); + transaction = await server.SendAsync("http://example.com/auth/dynamic?auth3"); + Assert.Equal("auth3", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "auth3")); + } + + [Fact] + public async Task TargetsDefaultSchemeByDefault() + { + var server = CreateServer(services => + { + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("default", "default"); + }) + .AddVirtualScheme("virtual", "virtual", p => { }); + }); + + var transaction = await server.SendAsync("http://example.com/auth/virtual"); + Assert.Equal("default", transaction.FindClaimValue(ClaimTypes.NameIdentifier, "default")); + } + + [Fact] + public async Task TargetsDefaultSchemeThrowsWithNoDefault() + { + var server = CreateServer(services => + { + services.AddAuthentication(o => + { + o.AddScheme("default", "default"); + }) + .AddVirtualScheme("virtual", "virtual", p => { }); + }); + + var error = await Assert.ThrowsAsync(() => server.SendAsync("http://example.com/auth/virtual")); + Assert.Contains("No authenticationScheme was specified", error.Message); + } + + private class TestHandler : IAuthenticationSignInHandler + { + public AuthenticationScheme Scheme { get; set; } + public int SignInCount { get; set; } + public int SignOutCount { get; set; } + public int ForbidCount { get; set; } + public int ChallengeCount { get; set; } + public int AuthenticateCount { get; set; } + + public Task AuthenticateAsync() + { + AuthenticateCount++; + var principal = new ClaimsPrincipal(); + var id = new ClaimsIdentity(); + id.AddClaim(new Claim(ClaimTypes.NameIdentifier, Scheme.Name, ClaimValueTypes.String, Scheme.Name)); + principal.AddIdentity(id); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(principal, new AuthenticationProperties(), Scheme.Name))); + } + + public Task ChallengeAsync(AuthenticationProperties properties) + { + ChallengeCount++; + return Task.CompletedTask; + } + + public Task ForbidAsync(AuthenticationProperties properties) + { + ForbidCount++; + return Task.CompletedTask; + } + + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + Scheme = scheme; + return Task.CompletedTask; + } + + public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) + { + SignInCount++; + return Task.CompletedTask; + } + + public Task SignOutAsync(AuthenticationProperties properties) + { + SignOutCount++; + return Task.CompletedTask; + } + } + + private class TestHandler2 : IAuthenticationSignInHandler + { + public AuthenticationScheme Scheme { get; set; } + public int SignInCount { get; set; } + public int SignOutCount { get; set; } + public int ForbidCount { get; set; } + public int ChallengeCount { get; set; } + public int AuthenticateCount { get; set; } + + public Task AuthenticateAsync() + { + AuthenticateCount++; + var principal = new ClaimsPrincipal(); + var id = new ClaimsIdentity(); + id.AddClaim(new Claim(ClaimTypes.NameIdentifier, Scheme.Name, ClaimValueTypes.String, Scheme.Name)); + principal.AddIdentity(id); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(principal, new AuthenticationProperties(), Scheme.Name))); + } + + public Task ChallengeAsync(AuthenticationProperties properties) + { + ChallengeCount++; + return Task.CompletedTask; + } + + public Task ForbidAsync(AuthenticationProperties properties) + { + ForbidCount++; + return Task.CompletedTask; + } + + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + Scheme = scheme; + return Task.CompletedTask; + } + + public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) + { + SignInCount++; + return Task.CompletedTask; + } + + public Task SignOutAsync(AuthenticationProperties properties) + { + SignOutCount++; + return Task.CompletedTask; + } + } + + private static TestServer CreateServer(Action configure = null, string defaultScheme = null) + { + var builder = new WebHostBuilder() + .Configure(app => + { + app.UseAuthentication(); + app.Use(async (context, next) => + { + var req = context.Request; + var res = context.Response; + if (req.Path.StartsWithSegments(new PathString("/auth"), out var remainder)) + { + var name = (remainder.Value.Length > 0) ? remainder.Value.Substring(1) : null; + var result = await context.AuthenticateAsync(name); + res.Describe(result?.Ticket?.Principal); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => + { + configure?.Invoke(services); + }); + return new TestServer(builder); + } + } +}