From c0b8be58ba0ccdb3a4cd38825c30536f60ab99ce Mon Sep 17 00:00:00 2001 From: Hao Kung Date: Thu, 1 Feb 2018 14:40:56 -0800 Subject: [PATCH] Add scheme forwarding (authN policies) (#1625) --- .../CookieAuthenticationHandler.cs | 14 + .../OpenIdConnectHandler.cs | 7 + .../AuthenticationHandler.cs | 30 ++ .../AuthenticationSchemeOptions.cs | 53 +++ .../CookieTests.cs | 412 ++++++++++++++++- .../FacebookTests.cs | 397 ++++++++++++++++ .../GoogleTests.cs | 409 ++++++++++++++++- .../JwtBearerTests.cs | 399 ++++++++++++++++- .../MicrosoftAccountTests.cs | 395 ++++++++++++++++ .../OAuthTests.cs | 422 +++++++++++++++++- .../OpenIdConnectConfigurationTests.cs | 418 +++++++++++++++++ .../TestHandlers.cs | 115 +++++ .../TwitterTests.cs | 396 ++++++++++++++++ 13 files changed, 3441 insertions(+), 26 deletions(-) create mode 100644 test/Microsoft.AspNetCore.Authentication.Test/TestHandlers.cs diff --git a/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs index 5d0afba46b..9a2fbfbc74 100644 --- a/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.Cookies/CookieAuthenticationHandler.cs @@ -240,6 +240,13 @@ namespace Microsoft.AspNetCore.Authentication.Cookies throw new ArgumentNullException(nameof(user)); } + var target = ResolveTarget(Options.ForwardSignIn); + if (target != null) + { + await Context.SignInAsync(target, user, properties); + return; + } + properties = properties ?? new AuthenticationProperties(); _signInCalled = true; @@ -322,6 +329,13 @@ namespace Microsoft.AspNetCore.Authentication.Cookies public async virtual Task SignOutAsync(AuthenticationProperties properties) { + var target = ResolveTarget(Options.ForwardSignOut); + if (target != null) + { + await Context.SignOutAsync(target, properties); + return; + } + properties = properties ?? new AuthenticationProperties(); _signOutCalled = true; diff --git a/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs b/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs index 4f722323dc..ce7494fb4a 100644 --- a/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.OpenIdConnect/OpenIdConnectHandler.cs @@ -155,6 +155,13 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect /// A task executing the sign out procedure public async virtual Task SignOutAsync(AuthenticationProperties properties) { + var target = ResolveTarget(Options.ForwardSignOut); + if (target != null) + { + await Context.SignOutAsync(target, properties); + return; + } + properties = properties ?? new AuthenticationProperties(); Logger.EnteringOpenIdAuthenticationHandlerHandleSignOutAsync(GetType().FullName); diff --git a/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs index ef4292100a..4399ce5f74 100644 --- a/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication/AuthenticationHandler.cs @@ -118,8 +118,24 @@ namespace Microsoft.AspNetCore.Authentication protected string BuildRedirectUri(string targetPath) => Request.Scheme + "://" + Request.Host + OriginalPathBase + targetPath; + protected virtual string ResolveTarget(string scheme) + { + var target = scheme ?? Options.ForwardDefaultSelector?.Invoke(Context) ?? Options.ForwardDefault; + + // Prevent self targetting + return string.Equals(target, Scheme.Name, StringComparison.Ordinal) + ? null + : target; + } + public async Task AuthenticateAsync() { + var target = ResolveTarget(Options.ForwardAuthenticate); + if (target != null) + { + return await Context.AuthenticateAsync(target); + } + // Calling Authenticate more than once should always return the original value. var result = await HandleAuthenticateOnceAsync(); if (result?.Failure == null) @@ -200,6 +216,13 @@ namespace Microsoft.AspNetCore.Authentication public async Task ChallengeAsync(AuthenticationProperties properties) { + var target = ResolveTarget(Options.ForwardChallenge); + if (target != null) + { + await Context.ChallengeAsync(target, properties); + return; + } + properties = properties ?? new AuthenticationProperties(); await HandleChallengeAsync(properties); Logger.AuthenticationSchemeChallenged(Scheme.Name); @@ -207,6 +230,13 @@ namespace Microsoft.AspNetCore.Authentication public async Task ForbidAsync(AuthenticationProperties properties) { + var target = ResolveTarget(Options.ForwardForbid); + if (target != null) + { + await Context.ForbidAsync(target, properties); + return; + } + properties = properties ?? new AuthenticationProperties(); await HandleForbiddenAsync(properties); Logger.AuthenticationSchemeForbidden(Scheme.Name); diff --git a/src/Microsoft.AspNetCore.Authentication/AuthenticationSchemeOptions.cs b/src/Microsoft.AspNetCore.Authentication/AuthenticationSchemeOptions.cs index 18d4c97881..a547d203b4 100644 --- a/src/Microsoft.AspNetCore.Authentication/AuthenticationSchemeOptions.cs +++ b/src/Microsoft.AspNetCore.Authentication/AuthenticationSchemeOptions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Authentication { @@ -36,5 +37,57 @@ namespace Microsoft.AspNetCore.Authentication /// If set, will be used as the service type to get the Events instance instead of the property. /// public Type EventsType { get; set; } + + /// + /// If set, this specifies a default scheme that authentication handlers should forward all authentication operations to + /// by default. The default forwarding logic will check the most specific ForwardAuthenticate/Challenge/Forbid/SignIn/SignOut + /// setting first, followed by checking the ForwardDefaultSelector, followed by ForwardDefault. The first non null result + /// will be used as the target scheme to forward to. + /// + public string ForwardDefault { get; set; } + + /// + /// If set, this specifies the target scheme that this scheme should forward AuthenticateAsync calls to. + /// For example Context.AuthenticateAsync("ThisScheme") => Context.AuthenticateAsync("ForwardAuthenticateValue"); + /// Set the target to the current scheme to disable forwarding and allow normal processing. + /// + public string ForwardAuthenticate { get; set; } + + /// + /// If set, this specifies the target scheme that this scheme should forward ChallengeAsync calls to. + /// For example Context.ChallengeAsync("ThisScheme") => Context.ChallengeAsync("ForwardChallengeValue"); + /// Set the target to the current scheme to disable forwarding and allow normal processing. + /// + public string ForwardChallenge { get; set; } + + /// + /// If set, this specifies the target scheme that this scheme should forward ForbidAsync calls to. + /// For example Context.ForbidAsync("ThisScheme") => Context.ForbidAsync("ForwardForbidValue"); + /// Set the target to the current scheme to disable forwarding and allow normal processing. + /// + public string ForwardForbid { get; set; } + + /// + /// If set, this specifies the target scheme that this scheme should forward SignInAsync calls to. + /// For example Context.SignInAsync("ThisScheme") => Context.SignInAsync("ForwardSignInValue"); + /// Set the target to the current scheme to disable forwarding and allow normal processing. + /// + public string ForwardSignIn { get; set; } + + /// + /// If set, this specifies the target scheme that this scheme should forward SignOutAsync calls to. + /// For example Context.SignOutAsync("ThisScheme") => Context.SignInAsync("ForwardSignOutValue"); + /// Set the target to the current scheme to disable forwarding and allow normal processing. + /// + public string ForwardSignOut { get; set; } + + /// + /// Used to select a default scheme for the current request that authentication handlers should forward all authentication operations to + /// by default. The default forwarding logic will check the most specific ForwardAuthenticate/Challenge/Forbid/SignIn/SignOut + /// setting first, followed by checking the ForwardDefaultSelector, followed by ForwardDefault. The first non null result + /// will be used as the target scheme to forward to. + /// + public Func ForwardDefaultSelector { get; set; } + } } diff --git a/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs index 789f5ede9c..b2726bac8c 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/CookieTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Diagnostics; using System.Linq; using System.Net; using System.Net.Http; @@ -11,6 +10,7 @@ using System.Security.Principal; using System.Text; using System.Threading.Tasks; using System.Xml.Linq; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Hosting; @@ -26,6 +26,416 @@ namespace Microsoft.AspNetCore.Authentication.Cookies { private TestClock _clock = new TestClock(); + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddCookie(o => o.ForwardDefault = "auth1"); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, forwardDefault.SignOutCount); + + await context.SignInAsync(new ClaimsPrincipal()); + Assert.Equal(1, forwardDefault.SignInCount); + } + + [Fact] + public async Task ForwardSignInWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardSignIn = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.SignInAsync(new ClaimsPrincipal()); + Assert.Equal(1, specific.SignInCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignOutCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSignOutWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.SignOutAsync(); + Assert.Equal(1, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, selector.SignOutCount); + + await context.SignInAsync(new ClaimsPrincipal()); + Assert.Equal(1, selector.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, forwardDefault.SignOutCount); + + await context.SignInAsync(new ClaimsPrincipal()); + Assert.Equal(1, forwardDefault.SignInCount); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddCookie(o => + { + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, specific.SignOutCount); + + await context.SignInAsync(new ClaimsPrincipal()); + Assert.Equal(1, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySchemeDefaults() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/FacebookTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/FacebookTests.cs index 2314b6b3c9..684482ed5b 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/FacebookTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/FacebookTests.cs @@ -5,11 +5,13 @@ using System; using System.Linq; using System.Net; using System.Net.Http; +using System.Security.Claims; using System.Text; using System.Text.Encodings.Web; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.Cookies; using Microsoft.AspNetCore.Authentication.OAuth; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Hosting; @@ -24,6 +26,401 @@ namespace Microsoft.AspNetCore.Authentication.Facebook { public class FacebookTests { + private void ConfigureDefaults(FacebookOptions o) + { + o.AppId = "whatever"; + o.AppSecret = "whatever"; + o.SignInScheme = "auth1"; + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = FacebookDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddFacebook(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs index 944a4827c3..8bfbaacde8 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/GoogleTests.cs @@ -10,6 +10,7 @@ using System.Text; using System.Text.Encodings.Web; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.OAuth; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.Hosting; @@ -24,6 +25,401 @@ namespace Microsoft.AspNetCore.Authentication.Google { public class GoogleTests { + private void ConfigureDefaults(GoogleOptions o) + { + o.ClientId = "whatever"; + o.ClientSecret = "whatever"; + o.SignInScheme = "auth1"; + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = GoogleDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddGoogle(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() { @@ -1061,18 +1457,13 @@ namespace Microsoft.AspNetCore.Authentication.Google .ConfigureServices(services => { services.AddTransient(); - services.AddAuthentication("Auth") - .AddVirtualScheme("Auth", "Auth", o => - { - o.Default = TestExtensions.CookieAuthenticationScheme; - o.Challenge = GoogleDefaults.AuthenticationScheme; - }) - .AddCookie(TestExtensions.CookieAuthenticationScheme) + services.AddAuthentication(TestExtensions.CookieAuthenticationScheme) + .AddCookie(TestExtensions.CookieAuthenticationScheme, o => o.ForwardChallenge = GoogleDefaults.AuthenticationScheme) .AddGoogle(configureOptions) .AddFacebook(o => { - o.AppId = "Test AppId"; - o.AppSecret = "Test AppSecrent"; + o.ClientId = "Test ClientId"; + o.ClientSecret = "Test AppSecrent"; }); }); return new TestServer(builder); diff --git a/test/Microsoft.AspNetCore.Authentication.Test/JwtBearerTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/JwtBearerTests.cs index 97adb21054..b472a4162d 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/JwtBearerTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/JwtBearerTests.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.Linq; using System.Net; using System.Net.Http; @@ -11,14 +10,13 @@ using System.Security.Claims; using System.Text; using System.Threading.Tasks; using System.Xml.Linq; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.Testing.xunit; -using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Options; using Microsoft.IdentityModel.Tokens; using Xunit; @@ -26,6 +24,401 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer { public class JwtBearerTests { + private void ConfigureDefaults(JwtBearerOptions o) + { + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = JwtBearerDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddJwtBearer(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySchemeDefaults() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/MicrosoftAccountTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/MicrosoftAccountTests.cs index b2854e344e..480241d35b 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/MicrosoftAccountTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/MicrosoftAccountTests.cs @@ -27,6 +27,401 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount { public class MicrosoftAccountTests { + private void ConfigureDefaults(MicrosoftAccountOptions o) + { + o.ClientId = "whatever"; + o.ClientSecret = "whatever"; + o.SignInScheme = "auth1"; + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = MicrosoftAccountDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddMicrosoftAccount(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs index 65d865b941..9279f145b9 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/OAuthTests.cs @@ -4,20 +4,416 @@ using System; using System.Collections.Generic; using System.Net; +using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.Cookies; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Net.Http.Headers; using Xunit; namespace Microsoft.AspNetCore.Authentication.OAuth { public class OAuthTests { + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.SignInScheme = "auth1"; + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.SignInScheme = "auth1"; + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.SignInScheme = "auth1"; + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.DefaultSignInScheme = "auth1"; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = "default"; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOAuth("default", o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + + [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() { @@ -131,12 +527,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth "Weblie", opt => { - opt.ClientId = "Test Id"; - opt.ClientSecret = "secret"; - opt.SignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - opt.AuthorizationEndpoint = "https://example.com/provider/login"; - opt.TokenEndpoint = "https://example.com/provider/token"; - opt.CallbackPath = "/oauth-callback"; + ConfigureDefaults(opt); }), async ctx => { @@ -162,12 +553,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth "Weblie", opt => { - opt.ClientId = "Test Id"; - opt.ClientSecret = "secret"; - opt.SignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - opt.AuthorizationEndpoint = "https://example.com/provider/login"; - opt.TokenEndpoint = "https://example.com/provider/token"; - opt.CallbackPath = "/oauth-callback"; + ConfigureDefaults(opt); opt.CorrelationCookie.Path = "/"; }), async ctx => @@ -186,6 +572,16 @@ namespace Microsoft.AspNetCore.Authentication.OAuth Assert.Contains("path=/", correlation); } + private void ConfigureDefaults(OAuthOptions o) + { + o.ClientId = "Test Id"; + o.ClientSecret = "secret"; + o.SignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + o.AuthorizationEndpoint = "https://example.com/provider/login"; + o.TokenEndpoint = "https://example.com/provider/token"; + o.CallbackPath = "/oauth-callback"; + } + [Fact] public async Task RemoteAuthenticationFailed_OAuthError_IncludesProperties() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectConfigurationTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectConfigurationTests.cs index 69ba758292..ed368c1ef7 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectConfigurationTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/OpenIdConnect/OpenIdConnectConfigurationTests.cs @@ -3,10 +3,13 @@ using System; using System.Net; +using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.OpenIdConnect; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -15,6 +18,421 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect { public class OpenIdConnectConfigurationTests { + private void ConfigureDefaults(OpenIdConnectOptions o) + { + o.Authority = TestServerBuilder.DefaultAuthority; + o.ClientId = "Test Id"; + o.ClientSecret = "Test Secret"; + o.SignInScheme = "auth1"; + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, forwardDefault.SignOutCount); + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.SignOutAsync(); + Assert.Equal(1, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, selector.SignOutCount); + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, forwardDefault.SignOutCount); + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = OpenIdConnectDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddOpenIdConnect(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await context.SignOutAsync(); + Assert.Equal(1, specific.SignOutCount); + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task MetadataAddressIsGeneratedFromAuthorityWhenMissing() { diff --git a/test/Microsoft.AspNetCore.Authentication.Test/TestHandlers.cs b/test/Microsoft.AspNetCore.Authentication.Test/TestHandlers.cs new file mode 100644 index 0000000000..cd9fe9fb1a --- /dev/null +++ b/test/Microsoft.AspNetCore.Authentication.Test/TestHandlers.cs @@ -0,0 +1,115 @@ +// Copyright (c) .NET Foundation. All rights reserved. See License.txt in the project root for license information. + +using System.Security.Claims; +using System.Text.Encodings.Web; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace Microsoft.AspNetCore.Authentication.Tests +{ + public class TestAuthHandler : AuthenticationHandler, IAuthenticationSignInHandler + { + public TestAuthHandler(IOptionsMonitor options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock) : base(options, logger, encoder, clock) + { } + + public int SignInCount { get; set; } + public int SignOutCount { get; set; } + public int ForbidCount { get; set; } + public int ChallengeCount { get; set; } + public int AuthenticateCount { get; set; } + + protected override Task HandleChallengeAsync(AuthenticationProperties properties) + { + ChallengeCount++; + return Task.CompletedTask; + } + + protected override Task HandleForbiddenAsync(AuthenticationProperties properties) + { + ForbidCount++; + return Task.CompletedTask; + } + + protected override Task HandleAuthenticateAsync() + { + AuthenticateCount++; + var principal = new ClaimsPrincipal(); + var id = new ClaimsIdentity(); + id.AddClaim(new Claim(ClaimTypes.NameIdentifier, Scheme.Name, ClaimValueTypes.String, Scheme.Name)); + principal.AddIdentity(id); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(principal, new AuthenticationProperties(), Scheme.Name))); + } + + public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) + { + SignInCount++; + return Task.CompletedTask; + } + + public Task SignOutAsync(AuthenticationProperties properties) + { + SignOutCount++; + return Task.CompletedTask; + } + } + + public class TestHandler : IAuthenticationSignInHandler + { + public AuthenticationScheme Scheme { get; set; } + public int SignInCount { get; set; } + public int SignOutCount { get; set; } + public int ForbidCount { get; set; } + public int ChallengeCount { get; set; } + public int AuthenticateCount { get; set; } + + public Task AuthenticateAsync() + { + AuthenticateCount++; + var principal = new ClaimsPrincipal(); + var id = new ClaimsIdentity(); + id.AddClaim(new Claim(ClaimTypes.NameIdentifier, Scheme.Name, ClaimValueTypes.String, Scheme.Name)); + principal.AddIdentity(id); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(principal, new AuthenticationProperties(), Scheme.Name))); + } + + public Task ChallengeAsync(AuthenticationProperties properties) + { + ChallengeCount++; + return Task.CompletedTask; + } + + public Task ForbidAsync(AuthenticationProperties properties) + { + ForbidCount++; + return Task.CompletedTask; + } + + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + Scheme = scheme; + return Task.CompletedTask; + } + + public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) + { + SignInCount++; + return Task.CompletedTask; + } + + public Task SignOutAsync(AuthenticationProperties properties) + { + SignOutCount++; + return Task.CompletedTask; + } + } + + public class TestHandler2 : TestHandler + { + } + + public class TestHandler3 : TestHandler + { + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs b/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs index 2a63757b9a..c1937d136c 100644 --- a/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Test/TwitterTests.cs @@ -7,6 +7,7 @@ using System.Net.Http; using System.Security.Claims; using System.Text; using System.Threading.Tasks; +using Microsoft.AspNetCore.Authentication.Tests; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; @@ -19,6 +20,401 @@ namespace Microsoft.AspNetCore.Authentication.Twitter { public class TwitterTests { + private void ConfigureDefaults(TwitterOptions o) + { + o.ConsumerKey = "whatever"; + o.ConsumerSecret = "whatever"; + o.SignInScheme = "auth1"; + } + + [Fact] + public async Task CanForwardDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + }); + + var forwardDefault = new TestHandler(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignInThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + } + + [Fact] + public async Task ForwardSignOutThrows() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardSignOut = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + } + + [Fact] + public async Task ForwardForbidWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ForbidAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(1, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardAuthenticateWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardAuthenticate = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(1, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardChallengeWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("specific", "specific"); + o.AddScheme("auth1", "auth1"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardChallenge = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.ChallengeAsync(); + Assert.Equal(0, specific.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(1, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + } + + [Fact] + public async Task ForwardSelectorWinsOverDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, selector.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, selector.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, selector.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task NullForwardSelectorUsesDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => null; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, forwardDefault.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, forwardDefault.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, forwardDefault.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + Assert.Equal(0, specific.AuthenticateCount); + Assert.Equal(0, specific.ForbidCount); + Assert.Equal(0, specific.ChallengeCount); + Assert.Equal(0, specific.SignInCount); + Assert.Equal(0, specific.SignOutCount); + } + + [Fact] + public async Task SpecificForwardWinsOverSelectorAndDefault() + { + var services = new ServiceCollection().AddLogging(); + services.AddAuthentication(o => + { + o.DefaultScheme = TwitterDefaults.AuthenticationScheme; + o.AddScheme("auth1", "auth1"); + o.AddScheme("selector", "selector"); + o.AddScheme("specific", "specific"); + }) + .AddTwitter(o => + { + ConfigureDefaults(o); + o.ForwardDefault = "auth1"; + o.ForwardDefaultSelector = _ => "selector"; + o.ForwardAuthenticate = "specific"; + o.ForwardChallenge = "specific"; + o.ForwardSignIn = "specific"; + o.ForwardSignOut = "specific"; + o.ForwardForbid = "specific"; + }); + + var specific = new TestHandler(); + services.AddSingleton(specific); + var forwardDefault = new TestHandler2(); + services.AddSingleton(forwardDefault); + var selector = new TestHandler3(); + services.AddSingleton(selector); + + var sp = services.BuildServiceProvider(); + var context = new DefaultHttpContext(); + context.RequestServices = sp; + + await context.AuthenticateAsync(); + Assert.Equal(1, specific.AuthenticateCount); + + await context.ForbidAsync(); + Assert.Equal(1, specific.ForbidCount); + + await context.ChallengeAsync(); + Assert.Equal(1, specific.ChallengeCount); + + await Assert.ThrowsAsync(() => context.SignOutAsync()); + await Assert.ThrowsAsync(() => context.SignInAsync(new ClaimsPrincipal())); + + Assert.Equal(0, forwardDefault.AuthenticateCount); + Assert.Equal(0, forwardDefault.ForbidCount); + Assert.Equal(0, forwardDefault.ChallengeCount); + Assert.Equal(0, forwardDefault.SignInCount); + Assert.Equal(0, forwardDefault.SignOutCount); + Assert.Equal(0, selector.AuthenticateCount); + Assert.Equal(0, selector.ForbidCount); + Assert.Equal(0, selector.ChallengeCount); + Assert.Equal(0, selector.SignInCount); + Assert.Equal(0, selector.SignOutCount); + } + [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() {