// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
namespace Microsoft.AspNetCore.Authentication
{
///
/// Extension methods for storing authentication tokens in .
///
public static class AuthenticationTokenExtensions
{
private static string TokenNamesKey = ".TokenNames";
private static string TokenKeyPrefix = ".Token.";
///
/// Stores a set of authentication tokens, after removing any old tokens.
///
/// The properties.
/// The tokens to store.
public static void StoreTokens(this AuthenticationProperties properties, IEnumerable tokens)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
if (tokens == null)
{
throw new ArgumentNullException(nameof(tokens));
}
// Clear old tokens first
var oldTokens = properties.GetTokens();
foreach (var t in oldTokens)
{
properties.Items.Remove(TokenKeyPrefix + t.Name);
}
properties.Items.Remove(TokenNamesKey);
var tokenNames = new List();
foreach (var token in tokens)
{
// REVIEW: should probably check that there are no ; in the token name and throw or encode
tokenNames.Add(token.Name);
properties.Items[TokenKeyPrefix+token.Name] = token.Value;
}
if (tokenNames.Count > 0)
{
properties.Items[TokenNamesKey] = string.Join(";", tokenNames.ToArray());
}
}
///
/// Returns the value of a token.
///
/// The properties.
/// The token name.
/// The token value.
public static string GetTokenValue(this AuthenticationProperties properties, string tokenName)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
if (tokenName == null)
{
throw new ArgumentNullException(nameof(tokenName));
}
var tokenKey = TokenKeyPrefix + tokenName;
return properties.Items.ContainsKey(tokenKey)
? properties.Items[tokenKey]
: null;
}
public static bool UpdateTokenValue(this AuthenticationProperties properties, string tokenName, string tokenValue)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
if (tokenName == null)
{
throw new ArgumentNullException(nameof(tokenName));
}
var tokenKey = TokenKeyPrefix + tokenName;
if (!properties.Items.ContainsKey(tokenKey))
{
return false;
}
properties.Items[tokenKey] = tokenValue;
return true;
}
///
/// Returns all of the AuthenticationTokens contained in the properties.
///
/// The properties.
/// The authentication toekns.
public static IEnumerable GetTokens(this AuthenticationProperties properties)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
var tokens = new List();
if (properties.Items.ContainsKey(TokenNamesKey))
{
var tokenNames = properties.Items[TokenNamesKey].Split(';');
foreach (var name in tokenNames)
{
var token = properties.GetTokenValue(name);
if (token != null)
{
tokens.Add(new AuthenticationToken { Name = name, Value = token });
}
}
}
return tokens;
}
///
/// Extension method for getting the value of an authentication token.
///
/// The .
/// The context.
/// The name of the authentication scheme.
/// The name of the token.
/// The value of the token.
public static async Task GetTokenAsync(this IAuthenticationService auth, HttpContext context, string scheme, string tokenName)
{
if (auth == null)
{
throw new ArgumentNullException(nameof(auth));
}
if (scheme == null)
{
throw new ArgumentNullException(nameof(scheme));
}
if (tokenName == null)
{
throw new ArgumentNullException(nameof(tokenName));
}
var result = await auth.AuthenticateAsync(context, scheme);
return result?.Properties?.GetTokenValue(tokenName);
}
}
}