// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Entity; namespace Microsoft.AspNet.Identity.SqlServer.InMemory.Test { public class InMemoryUserStore : InMemoryUserStore { public InMemoryUserStore(InMemoryContext context) : base(context) { } } public class InMemoryUserStore : InMemoryUserStore where TUser : IdentityUser { public InMemoryUserStore(InMemoryContext context) : base(context) { } } public class InMemoryUserStore : InMemoryUserStore where TUser:IdentityUser where TContext : DbContext { public InMemoryUserStore(TContext context) : base(context) { } } public class InMemoryUserStore : IUserLoginStore, IUserClaimStore, IUserRoleStore, IUserPasswordStore, IUserSecurityStampStore, IQueryableUserStore, IUserEmailStore, IUserPhoneNumberStore, IUserTwoFactorStore, IUserLockoutStore where TKey : IEquatable where TUser : IdentityUser where TRole : IdentityRole where TUserLogin : IdentityUserLogin, new() where TUserRole : IdentityUserRole, new() where TUserClaim : IdentityUserClaim, new() where TContext : DbContext { private bool _disposed; public InMemoryUserStore(TContext context) { if (context == null) { throw new ArgumentNullException("context"); } Context = context; AutoSaveChanges = true; } public TContext Context { get; private set; } /// /// If true will call SaveChanges after CreateAsync/UpdateAsync/DeleteAsync /// public bool AutoSaveChanges { get; set; } private Task SaveChanges(CancellationToken cancellationToken) { return AutoSaveChanges ? Context.SaveChangesAsync(cancellationToken) : Task.FromResult(0); } protected virtual Task GetUserAggregate(Expression> filter, CancellationToken cancellationToken = default(CancellationToken)) { return Task.FromResult(Users.SingleOrDefault(filter)); // TODO: return Users.SingleOrDefaultAsync(filter, cancellationToken); //Include(u => u.Roles) //.Include(u => u.Claims) //.Include(u => u.Logins) } public Task GetUserIdAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(Convert.ToString(user.Id, CultureInfo.InvariantCulture)); } public Task GetUserNameAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.UserName); } public Task SetUserNameAsync(TUser user, string userName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.UserName = userName; return Task.FromResult(0); } public Task GetNormalizedUserNameAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.NormalizedUserName); } public Task SetNormalizedUserNameAsync(TUser user, string userName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.NormalizedUserName = userName; return Task.FromResult(0); } public async virtual Task CreateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } await Context.AddAsync(user, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task UpdateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } await Context.UpdateAsync(user, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task DeleteAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } Context.Delete(user); await SaveChanges(cancellationToken); } public virtual TKey ConvertUserId(string userId) { return (TKey)Convert.ChangeType(userId, typeof(TKey)); } /// /// Find a user by id /// /// /// /// public virtual Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var id = ConvertUserId(userId); return GetUserAggregate(u => u.Id.Equals(id), cancellationToken); } /// /// Find a user by name /// /// /// /// public virtual Task FindByNameAsync(string userName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.UserName.ToUpper() == userName.ToUpper(), cancellationToken); } public IQueryable Users { get { return Context.Set(); } } public async virtual Task AddLoginAsync(TUser user, UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var l = new TUserLogin { UserId = user.Id, ProviderKey = login.ProviderKey, LoginProvider = login.LoginProvider, ProviderDisplayName = login.ProviderDisplayName }; await Context.Set().AddAsync(l, cancellationToken); user.Logins.Add(l); } public virtual Task RemoveLoginAsync(TUser user, string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var entry = user.Logins.SingleOrDefault(l => l.LoginProvider == loginProvider && l.ProviderKey == providerKey); if (entry != null) { user.Logins.Remove(entry); Context.Set>().Remove(entry); } return Task.FromResult(0); } public virtual Task> GetLoginsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } IList result = user.Logins.Select( l => new UserLoginInfo(l.LoginProvider, l.ProviderKey, l.ProviderDisplayName)) .ToList(); return Task.FromResult(result); } public async virtual Task FindByLoginAsync(string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var userLogin = await Context.Set() .FirstOrDefaultAsync(l => l.LoginProvider == loginProvider && l.ProviderKey == providerKey); if (userLogin != null) { return await GetUserAggregate(u => u.Id.Equals(userLogin.UserId), cancellationToken); } return null; } /// /// Set the password hash for a user /// /// /// /// /// public virtual Task SetPasswordHashAsync(TUser user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PasswordHash = passwordHash; return Task.FromResult(0); } /// /// Get the password hash for a user /// /// /// /// public virtual Task GetPasswordHashAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PasswordHash); } /// /// Returns true if the user has a password set /// /// /// /// public virtual Task HasPasswordAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(user.PasswordHash != null); } /// /// Return the claims for a user /// /// /// /// public virtual Task> GetClaimsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } IList result = user.Claims.Select(c => new Claim(c.ClaimType, c.ClaimValue)).ToList(); return Task.FromResult(result); } /// /// Add claims to a user /// /// /// /// /// public virtual Task AddClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claims == null) { throw new ArgumentNullException("claims"); } foreach (var claim in claims) { user.Claims.Add(new TUserClaim { UserId = user.Id, ClaimType = claim.Type, ClaimValue = claim.Value }); } return Task.FromResult(0); } /// /// Remove claims from a user /// /// /// /// /// public virtual Task RemoveClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claims == null) { throw new ArgumentNullException("claims"); } foreach (var claim in claims) { var matchingClaims = user.Claims.Where(uc => uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type).ToList(); foreach (var c in matchingClaims) { user.Claims.Remove(c); } } // TODO:these claims might not exist in the dbset //var query = // _userClaims.Where( // uc => uc.UserId.Equals(user.Id) && uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type); //foreach (var c in query) //{ // _userClaims.Remove(c); //} return Task.FromResult(0); } /// /// Returns whether the user email is confirmed /// /// /// /// public virtual Task GetEmailConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.EmailConfirmed); } /// /// Set IsConfirmed on the user /// /// /// /// /// public virtual Task SetEmailConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.EmailConfirmed = confirmed; return Task.FromResult(0); } /// /// Set the user email /// /// /// /// /// public virtual Task SetEmailAsync(TUser user, string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.Email = email; return Task.FromResult(0); } /// /// Get the user's email /// /// /// /// public virtual Task GetEmailAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.Email); } /// /// FindByLoginAsync a user by email /// /// /// /// public virtual Task FindByEmailAsync(string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return Task.FromResult(Users.SingleOrDefault(u => u.Email.ToUpper() == email.ToUpper())); //return GetUserAggregate(u => u.Email.ToUpper() == email.ToUpper(), cancellationToken); } /// /// Returns the DateTimeOffset that represents the end of a user's lockout, any time in the past should be considered /// not locked out. /// /// /// /// public virtual Task GetLockoutEndDateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnd); } /// /// Locks a user out until the specified end date (set to a past date, to unlock a user) /// /// /// /// /// public virtual Task SetLockoutEndDateAsync(TUser user, DateTimeOffset lockoutEnd, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnd = lockoutEnd; return Task.FromResult(0); } /// /// Used to record when an attempt to access the user has failed /// /// /// /// public virtual Task IncrementAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount++; return Task.FromResult(user.AccessFailedCount); } /// /// Used to reset the account access count, typically after the account is successfully accessed /// /// /// /// public virtual Task ResetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount = 0; return Task.FromResult(0); } /// /// Returns the current number of failed access attempts. This number usually will be reset whenever the password is /// verified or the account is locked out. /// /// /// /// public virtual Task GetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.AccessFailedCount); } /// /// Returns whether the user can be locked out. /// /// /// /// public virtual Task GetLockoutEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnabled); } /// /// Sets whether the user can be locked out. /// /// /// /// /// public virtual Task SetLockoutEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnabled = enabled; return Task.FromResult(0); } /// /// Set the user's phone number /// /// /// /// /// public virtual Task SetPhoneNumberAsync(TUser user, string phoneNumber, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumber = phoneNumber; return Task.FromResult(0); } /// /// Get a user's phone number /// /// /// /// public virtual Task GetPhoneNumberAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumber); } /// /// Returns whether the user phoneNumber is confirmed /// /// /// /// public virtual Task GetPhoneNumberConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumberConfirmed); } /// /// Set PhoneNumberConfirmed on the user /// /// /// /// /// public virtual Task SetPhoneNumberConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumberConfirmed = confirmed; return Task.FromResult(0); } /// /// Add a user to a role /// /// /// /// /// public virtual Task AddToRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } // TODO: //if (String.IsNullOrWhiteSpace(roleName)) //{ // throw new ArgumentException(IdentityResources.ValueCannotBeNullOrEmpty, "roleName"); //} var roleEntity = Context.Set().SingleOrDefault(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity == null) { throw new InvalidOperationException("Role Not Found"); //TODO: String.Format(CultureInfo.CurrentCulture, IdentityResources.RoleNotFound, roleName)); } var ur = new TUserRole { UserId = user.Id, RoleId = roleEntity.Id }; user.Roles.Add(ur); roleEntity.Users.Add(ur); return Task.FromResult(0); } /// /// Remove a user from a role /// /// /// /// /// public virtual Task RemoveFromRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } //if (String.IsNullOrWhiteSpace(roleName)) //{ // throw new ArgumentException(IdentityResources.ValueCannotBeNullOrEmpty, "roleName"); //} var roleEntity = Context.Set().SingleOrDefault(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity != null) { var userRole = user.Roles.FirstOrDefault(r => roleEntity.Id.Equals(r.RoleId)); if (userRole != null) { user.Roles.Remove(userRole); roleEntity.Users.Remove(userRole); } } return Task.FromResult(0); } /// /// Get the names of the roles a user is a member of /// /// /// /// public virtual Task> GetRolesAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var query = from userRoles in user.Roles join roles in Context.Set() on userRoles.RoleId equals roles.Id select roles.Name; return Task.FromResult>(query.ToList()); } /// /// Returns true if the user is in the named role /// /// /// /// /// public virtual Task IsInRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } //if (String.IsNullOrWhiteSpace(roleName)) //{ // throw new ArgumentException(IdentityResources.ValueCannotBeNullOrEmpty, "roleName"); //} var any = Context.Set().Where(r => r.Name.ToUpper() == roleName.ToUpper()) .Where(r => r.Users.Any(ur => ur.UserId.Equals(user.Id))) .Count() > 0; return Task.FromResult(any); } /// /// Set the security stamp for the user /// /// /// /// /// public virtual Task SetSecurityStampAsync(TUser user, string stamp, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.SecurityStamp = stamp; return Task.FromResult(0); } /// /// Get the security stamp for a user /// /// /// /// public virtual Task GetSecurityStampAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.SecurityStamp); } /// /// Set whether two factor authentication is enabled for the user /// /// /// /// /// public virtual Task SetTwoFactorEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.TwoFactorEnabled = enabled; return Task.FromResult(0); } /// /// Gets whether two factor authentication is enabled for the user /// /// /// /// public virtual Task GetTwoFactorEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.TwoFactorEnabled); } private void ThrowIfDisposed() { if (_disposed) { throw new ObjectDisposedException(GetType().Name); } } /// /// Dispose the store /// public void Dispose() { _disposed = true; } } }