// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; namespace Microsoft.AspNetCore.Authentication { /// /// Extension methods for storing authentication tokens in . /// public static class AuthenticationTokenExtensions { private static string TokenNamesKey = ".TokenNames"; private static string TokenKeyPrefix = ".Token."; /// /// Stores a set of authentication tokens, after removing any old tokens. /// /// The properties. /// The tokens to store. public static void StoreTokens(this AuthenticationProperties properties, IEnumerable tokens) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } if (tokens == null) { throw new ArgumentNullException(nameof(tokens)); } // Clear old tokens first var oldTokens = properties.GetTokens(); foreach (var t in oldTokens) { properties.Items.Remove(TokenKeyPrefix + t.Name); } properties.Items.Remove(TokenNamesKey); var tokenNames = new List(); foreach (var token in tokens) { // REVIEW: should probably check that there are no ; in the token name and throw or encode tokenNames.Add(token.Name); properties.Items[TokenKeyPrefix+token.Name] = token.Value; } if (tokenNames.Count > 0) { properties.Items[TokenNamesKey] = string.Join(";", tokenNames.ToArray()); } } /// /// Returns the value of a token. /// /// The properties. /// The token name. /// The token value. public static string GetTokenValue(this AuthenticationProperties properties, string tokenName) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } if (tokenName == null) { throw new ArgumentNullException(nameof(tokenName)); } var tokenKey = TokenKeyPrefix + tokenName; return properties.Items.ContainsKey(tokenKey) ? properties.Items[tokenKey] : null; } public static bool UpdateTokenValue(this AuthenticationProperties properties, string tokenName, string tokenValue) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } if (tokenName == null) { throw new ArgumentNullException(nameof(tokenName)); } var tokenKey = TokenKeyPrefix + tokenName; if (!properties.Items.ContainsKey(tokenKey)) { return false; } properties.Items[tokenKey] = tokenValue; return true; } /// /// Returns all of the AuthenticationTokens contained in the properties. /// /// The properties. /// The authentication toekns. public static IEnumerable GetTokens(this AuthenticationProperties properties) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } var tokens = new List(); if (properties.Items.ContainsKey(TokenNamesKey)) { var tokenNames = properties.Items[TokenNamesKey].Split(';'); foreach (var name in tokenNames) { var token = properties.GetTokenValue(name); if (token != null) { tokens.Add(new AuthenticationToken { Name = name, Value = token }); } } } return tokens; } /// /// Extension method for getting the value of an authentication token. /// /// The . /// The context. /// The name of the authentication scheme. /// The name of the token. /// The value of the token. public static async Task GetTokenAsync(this IAuthenticationService auth, HttpContext context, string scheme, string tokenName) { if (auth == null) { throw new ArgumentNullException(nameof(auth)); } if (scheme == null) { throw new ArgumentNullException(nameof(scheme)); } if (tokenName == null) { throw new ArgumentNullException(nameof(tokenName)); } var result = await auth.AuthenticateAsync(context, scheme); return result?.Properties?.GetTokenValue(tokenName); } } }