diff --git a/src/Microsoft.AspNetCore.Authentication.Abstractions/TokenExtensions.cs b/src/Microsoft.AspNetCore.Authentication.Abstractions/TokenExtensions.cs index 1af05aab0f..497acabc23 100644 --- a/src/Microsoft.AspNetCore.Authentication.Abstractions/TokenExtensions.cs +++ b/src/Microsoft.AspNetCore.Authentication.Abstractions/TokenExtensions.cs @@ -149,10 +149,6 @@ namespace Microsoft.AspNetCore.Authentication { throw new ArgumentNullException(nameof(auth)); } - if (scheme == null) - { - throw new ArgumentNullException(nameof(scheme)); - } if (tokenName == null) { throw new ArgumentNullException(nameof(tokenName)); diff --git a/test/Microsoft.AspNetCore.Authentication.Core.Test/Microsoft.AspNetCore.Authentication.Core.Test.csproj b/test/Microsoft.AspNetCore.Authentication.Core.Test/Microsoft.AspNetCore.Authentication.Core.Test.csproj index 925f5aa079..e176bcad4c 100644 --- a/test/Microsoft.AspNetCore.Authentication.Core.Test/Microsoft.AspNetCore.Authentication.Core.Test.csproj +++ b/test/Microsoft.AspNetCore.Authentication.Core.Test/Microsoft.AspNetCore.Authentication.Core.Test.csproj @@ -8,6 +8,7 @@ + diff --git a/test/Microsoft.AspNetCore.Authentication.Core.Test/TokenExtensionTests.cs b/test/Microsoft.AspNetCore.Authentication.Core.Test/TokenExtensionTests.cs index 3e3eb9c52b..d9e050fe82 100644 --- a/test/Microsoft.AspNetCore.Authentication.Core.Test/TokenExtensionTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Core.Test/TokenExtensionTests.cs @@ -1,8 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Linq; +using System.Security.Claims; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; using Xunit; namespace Microsoft.AspNetCore.Authentication @@ -117,7 +122,70 @@ namespace Microsoft.AspNetCore.Authentication Assert.Null(props.GetTokenValue("ONE")); Assert.Null(props.GetTokenValue("Jigglypuff")); Assert.Equal(3, props.GetTokens().Count()); - } + + [Fact] + public async Task GetTokenWorksWithDefaultAuthenticateScheme() + { + var context = new DefaultHttpContext(); + var services = new ServiceCollection().AddOptions() + .AddAuthenticationCore(o => o.AddScheme("simple", s => s.HandlerType = typeof(SimpleAuth))); + context.RequestServices = services.BuildServiceProvider(); + + Assert.Equal("1", await context.GetTokenAsync("One")); + Assert.Equal("2", await context.GetTokenAsync("Two")); + Assert.Equal("3", await context.GetTokenAsync("Three")); + } + + [Fact] + public async Task GetTokenWorksWithExplicitScheme() + { + var context = new DefaultHttpContext(); + var services = new ServiceCollection().AddOptions() + .AddAuthenticationCore(o => o.AddScheme("simple", s => s.HandlerType = typeof(SimpleAuth))); + context.RequestServices = services.BuildServiceProvider(); + + Assert.Equal("1", await context.GetTokenAsync("simple", "One")); + Assert.Equal("2", await context.GetTokenAsync("simple", "Two")); + Assert.Equal("3", await context.GetTokenAsync("simple", "Three")); + } + + private class SimpleAuth : IAuthenticationHandler + { + public Task AuthenticateAsync() + { + var props = new AuthenticationProperties(); + var tokens = new List(); + var tok1 = new AuthenticationToken { Name = "One", Value = "1" }; + var tok2 = new AuthenticationToken { Name = "Two", Value = "2" }; + var tok3 = new AuthenticationToken { Name = "Three", Value = "3" }; + tokens.Add(tok1); + tokens.Add(tok2); + tokens.Add(tok3); + props.StoreTokens(tokens); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(), props, "simple"))); + } + + public Task ChallengeAsync(ChallengeContext context) + { + return Task.FromResult(0); + } + + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + return Task.FromResult(0); + } + + public Task SignInAsync(SignInContext context) + { + return Task.FromResult(0); + } + + public Task SignOutAsync(SignOutContext context) + { + return Task.FromResult(0); + } + } + } }