// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; using System; using System.Security.Claims; using System.Threading.Tasks; using Xunit; namespace Microsoft.AspNetCore.Authentication { public abstract class SharedAuthenticationTests where TOptions : AuthenticationSchemeOptions { protected TestClock Clock { get; } = new TestClock(); protected abstract string DefaultScheme { get; } protected virtual string DisplayName { get; } protected abstract Type HandlerType { get; } protected virtual bool SupportsSignIn { get => true; } protected virtual bool SupportsSignOut { get => true; } protected abstract void RegisterAuth(AuthenticationBuilder services, Action configure); [Fact] public async Task CanForwardDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); }); RegisterAuth(builder, o => o.ForwardDefault = "auth1"); var forwardDefault = new TestHandler(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); await context.AuthenticateAsync(); Assert.Equal(1, forwardDefault.AuthenticateCount); await context.ForbidAsync(); Assert.Equal(1, forwardDefault.ForbidCount); await context.ChallengeAsync(); Assert.Equal(1, forwardDefault.ChallengeCount); if (SupportsSignOut) { await context.SignOutAsync(); Assert.Equal(1, forwardDefault.SignOutCount); } else { await Assert.ThrowsAsync(() => context.SignOutAsync()); } if (SupportsSignIn) { await context.SignInAsync(new ClaimsPrincipal()); Assert.Equal(1, forwardDefault.SignInCount); } else { await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); } } [Fact] public async Task ForwardSignInWinsOverDefault() { if (SupportsSignIn) { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardSignIn = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.SignInAsync(new ClaimsPrincipal()); Assert.Equal(1, specific.SignInCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignOutCount); Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); } } [Fact] public async Task ForwardSignOutWinsOverDefault() { if (SupportsSignOut) { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardSignOut = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.SignOutAsync(); Assert.Equal(1, specific.SignOutCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); } } [Fact] public async Task ForwardForbidWinsOverDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardForbid = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.ForbidAsync(); Assert.Equal(0, specific.SignOutCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(1, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); } [Fact] public async Task ForwardAuthenticateWinsOverDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardAuthenticate = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.AuthenticateAsync(); Assert.Equal(0, specific.SignOutCount); Assert.Equal(1, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); } [Fact] public async Task ForwardChallengeWinsOverDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardChallenge = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.ChallengeAsync(); Assert.Equal(0, specific.SignOutCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(1, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); } [Fact] public async Task ForwardSelectorWinsOverDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("selector", "selector"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardDefaultSelector = _ => "selector"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var selector = new TestHandler3(); services.AddSingleton(selector); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.AuthenticateAsync(); Assert.Equal(1, selector.AuthenticateCount); await context.ForbidAsync(); Assert.Equal(1, selector.ForbidCount); await context.ChallengeAsync(); Assert.Equal(1, selector.ChallengeCount); if (SupportsSignOut) { await context.SignOutAsync(); Assert.Equal(1, selector.SignOutCount); } else { await Assert.ThrowsAsync(() => context.SignOutAsync()); } if (SupportsSignIn) { await context.SignInAsync(new ClaimsPrincipal()); Assert.Equal(1, selector.SignInCount); } else { await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); } Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, specific.SignOutCount); } [Fact] public async Task NullForwardSelectorUsesDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("selector", "selector"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardDefaultSelector = _ => null; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var selector = new TestHandler3(); services.AddSingleton(selector); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.AuthenticateAsync(); Assert.Equal(1, forwardDefault.AuthenticateCount); await context.ForbidAsync(); Assert.Equal(1, forwardDefault.ForbidCount); await context.ChallengeAsync(); Assert.Equal(1, forwardDefault.ChallengeCount); if (SupportsSignOut) { await context.SignOutAsync(); Assert.Equal(1, forwardDefault.SignOutCount); } else { await Assert.ThrowsAsync(() => context.SignOutAsync()); } if (SupportsSignIn) { await context.SignInAsync(new ClaimsPrincipal()); Assert.Equal(1, forwardDefault.SignInCount); } else { await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); } Assert.Equal(0, selector.AuthenticateCount); Assert.Equal(0, selector.ForbidCount); Assert.Equal(0, selector.ChallengeCount); Assert.Equal(0, selector.SignInCount); Assert.Equal(0, selector.SignOutCount); Assert.Equal(0, specific.AuthenticateCount); Assert.Equal(0, specific.ForbidCount); Assert.Equal(0, specific.ChallengeCount); Assert.Equal(0, specific.SignInCount); Assert.Equal(0, specific.SignOutCount); } [Fact] public async Task SpecificForwardWinsOverSelectorAndDefault() { var services = new ServiceCollection().AddLogging(); var builder = services.AddAuthentication(o => { o.DefaultScheme = DefaultScheme; o.AddScheme("auth1", "auth1"); o.AddScheme("selector", "selector"); o.AddScheme("specific", "specific"); }); RegisterAuth(builder, o => { o.ForwardDefault = "auth1"; o.ForwardDefaultSelector = _ => "selector"; o.ForwardAuthenticate = "specific"; o.ForwardChallenge = "specific"; o.ForwardSignIn = "specific"; o.ForwardSignOut = "specific"; o.ForwardForbid = "specific"; }); var specific = new TestHandler(); services.AddSingleton(specific); var forwardDefault = new TestHandler2(); services.AddSingleton(forwardDefault); var selector = new TestHandler3(); services.AddSingleton(selector); var sp = services.BuildServiceProvider(); var context = new DefaultHttpContext(); context.RequestServices = sp; await context.AuthenticateAsync(); Assert.Equal(1, specific.AuthenticateCount); await context.ForbidAsync(); Assert.Equal(1, specific.ForbidCount); await context.ChallengeAsync(); Assert.Equal(1, specific.ChallengeCount); if (SupportsSignOut) { await context.SignOutAsync(); Assert.Equal(1, specific.SignOutCount); } else { await Assert.ThrowsAsync(() => context.SignOutAsync()); } if (SupportsSignIn) { await context.SignInAsync(new ClaimsPrincipal()); Assert.Equal(1, specific.SignInCount); } else { await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); } Assert.Equal(0, forwardDefault.AuthenticateCount); Assert.Equal(0, forwardDefault.ForbidCount); Assert.Equal(0, forwardDefault.ChallengeCount); Assert.Equal(0, forwardDefault.SignInCount); Assert.Equal(0, forwardDefault.SignOutCount); Assert.Equal(0, selector.AuthenticateCount); Assert.Equal(0, selector.ForbidCount); Assert.Equal(0, selector.ChallengeCount); Assert.Equal(0, selector.SignInCount); Assert.Equal(0, selector.SignOutCount); } [Fact] public async Task VerifySchemeDefaults() { var services = new ServiceCollection(); var builder = services.AddAuthentication(); RegisterAuth(builder, o => { }); var sp = services.BuildServiceProvider(); var schemeProvider = sp.GetRequiredService(); var scheme = await schemeProvider.GetSchemeAsync(DefaultScheme); Assert.NotNull(scheme); Assert.Equal(HandlerType, scheme.HandlerType); Assert.Equal(DisplayName, scheme.DisplayName); } } }