diff --git a/src/Microsoft.AspNetCore.Authentication.Core/AuthenticationSchemeProvider.cs b/src/Microsoft.AspNetCore.Authentication.Core/AuthenticationSchemeProvider.cs index f5ec8e1598..050118d3c4 100644 --- a/src/Microsoft.AspNetCore.Authentication.Core/AuthenticationSchemeProvider.cs +++ b/src/Microsoft.AspNetCore.Authentication.Core/AuthenticationSchemeProvider.cs @@ -15,13 +15,28 @@ namespace Microsoft.AspNetCore.Authentication public class AuthenticationSchemeProvider : IAuthenticationSchemeProvider { /// - /// Constructor. + /// Creates an instance of + /// using the specified , /// /// The options. public AuthenticationSchemeProvider(IOptions options) + : this(options, new Dictionary(StringComparer.Ordinal)) + { + } + + /// + /// Creates an instance of + /// using the specified and . + /// + /// The options. + /// The dictionary used to store authentication schemes. + protected AuthenticationSchemeProvider(IOptions options, IDictionary schemes) { _options = options.Value; + _schemes = schemes ?? throw new ArgumentNullException(nameof(schemes)); + _requestHandlers = new List(); + foreach (var builder in _options.Schemes) { var scheme = builder.Build(); @@ -32,8 +47,8 @@ namespace Microsoft.AspNetCore.Authentication private readonly AuthenticationOptions _options; private readonly object _lock = new object(); - private IDictionary _map = new Dictionary(StringComparer.Ordinal); - private List _requestHandlers = new List(); + private readonly IDictionary _schemes; + private readonly List _requestHandlers; private Task GetDefaultSchemeAsync() => _options.DefaultScheme != null @@ -101,7 +116,7 @@ namespace Microsoft.AspNetCore.Authentication /// The name of the authenticationScheme. /// The scheme or null if not found. public virtual Task GetSchemeAsync(string name) - => Task.FromResult(_map.ContainsKey(name) ? _map[name] : null); + => Task.FromResult(_schemes.ContainsKey(name) ? _schemes[name] : null); /// /// Returns the schemes in priority order for request handling. @@ -116,13 +131,13 @@ namespace Microsoft.AspNetCore.Authentication /// The scheme. public virtual void AddScheme(AuthenticationScheme scheme) { - if (_map.ContainsKey(scheme.Name)) + if (_schemes.ContainsKey(scheme.Name)) { throw new InvalidOperationException("Scheme already exists: " + scheme.Name); } lock (_lock) { - if (_map.ContainsKey(scheme.Name)) + if (_schemes.ContainsKey(scheme.Name)) { throw new InvalidOperationException("Scheme already exists: " + scheme.Name); } @@ -130,7 +145,7 @@ namespace Microsoft.AspNetCore.Authentication { _requestHandlers.Add(scheme); } - _map[scheme.Name] = scheme; + _schemes[scheme.Name] = scheme; } } @@ -140,22 +155,22 @@ namespace Microsoft.AspNetCore.Authentication /// The name of the authenticationScheme being removed. public virtual void RemoveScheme(string name) { - if (!_map.ContainsKey(name)) + if (!_schemes.ContainsKey(name)) { return; } lock (_lock) { - if (_map.ContainsKey(name)) + if (_schemes.ContainsKey(name)) { - var scheme = _map[name]; + var scheme = _schemes[name]; _requestHandlers.Remove(scheme); - _map.Remove(name); + _schemes.Remove(name); } } } public virtual Task> GetAllSchemesAsync() - => Task.FromResult>(_map.Values); + => Task.FromResult>(_schemes.Values); } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Authentication.Core.Test/AuthenticationSchemeProviderTests.cs b/test/Microsoft.AspNetCore.Authentication.Core.Test/AuthenticationSchemeProviderTests.cs index 4fa0ea8782..82602000aa 100644 --- a/test/Microsoft.AspNetCore.Authentication.Core.Test/AuthenticationSchemeProviderTests.cs +++ b/test/Microsoft.AspNetCore.Authentication.Core.Test/AuthenticationSchemeProviderTests.cs @@ -3,10 +3,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Security.Claims; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.AspNetCore.Authentication @@ -117,6 +119,39 @@ namespace Microsoft.AspNetCore.Authentication Assert.NotNull(await provider.GetDefaultSignOutSchemeAsync()); } + [Fact] + public void SchemeRegistrationIsCaseSensitive() + { + var services = new ServiceCollection().AddOptions().AddAuthenticationCore(o => + { + o.AddScheme("signin", "whatever"); + o.AddScheme("signin", "whatever"); + }).BuildServiceProvider(); + + var error = Assert.Throws(() => services.GetRequiredService()); + + Assert.Contains("Scheme already exists: signin", error.Message); + } + + [Fact] + public async Task LookupUsesProvidedStringComparer() + { + var services = new ServiceCollection().AddOptions() + .AddSingleton() + .AddAuthenticationCore(o => o.AddScheme("signin", "whatever")) + .BuildServiceProvider(); + + var provider = services.GetRequiredService(); + + var a = await provider.GetSchemeAsync("signin"); + var b = await provider.GetSchemeAsync("SignIn"); + var c = await provider.GetSchemeAsync("SIGNIN"); + + Assert.NotNull(a); + Assert.Same(a, b); + Assert.Same(b, c); + } + private class Handler : IAuthenticationHandler { public Task AuthenticateAsync() @@ -160,5 +195,13 @@ namespace Microsoft.AspNetCore.Authentication throw new NotImplementedException(); } } + + private class IgnoreCaseSchemeProvider : AuthenticationSchemeProvider + { + public IgnoreCaseSchemeProvider(IOptions options) + : base(options, new Dictionary(StringComparer.OrdinalIgnoreCase)) + { + } + } } }