// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Entity; namespace Microsoft.AspNet.Identity.EntityFramework { public class UserStore : UserStore { public UserStore(DbContext context) : base(context) { } } public class UserStore : UserStore where TUser : IdentityUser, new() { public UserStore(DbContext context) : base(context) { } } public class UserStore : UserStore where TUser : IdentityUser, new() where TRole : IdentityRole, new() where TContext : DbContext { public UserStore(TContext context) : base(context) { } } public class UserStore : IUserLoginStore, IUserRoleStore, IUserClaimStore, IUserPasswordStore, IUserSecurityStampStore, IUserEmailStore, IUserLockoutStore, IUserPhoneNumberStore, IQueryableUserStore, IUserTwoFactorStore where TUser : IdentityUser where TRole : IdentityRole where TContext : DbContext where TKey : IEquatable { public UserStore(TContext context) { if (context == null) { throw new ArgumentNullException("context"); } Context = context; } private bool _disposed; public TContext Context { get; private set; } /// /// If true will call SaveChanges after CreateAsync/UpdateAsync/DeleteAsync /// public bool AutoSaveChanges { get; set; } = true; private Task SaveChanges(CancellationToken cancellationToken) { return AutoSaveChanges ? Context.SaveChangesAsync(cancellationToken) : Task.FromResult(0); } protected virtual Task GetUserAggregate(Expression> filter, CancellationToken cancellationToken = default(CancellationToken)) { return Users.FirstOrDefaultAsync(filter, cancellationToken); // TODO: .Include(u => u.Roles) //.Include(u => u.Claims) //.Include(u => u.Logins); } public Task GetUserIdAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(ConvertIdToString(user.Id)); } public Task GetUserNameAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.UserName); } public Task SetUserNameAsync(TUser user, string userName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.UserName = userName; return Task.FromResult(0); } public Task GetNormalizedUserNameAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.NormalizedUserName); } public Task SetNormalizedUserNameAsync(TUser user, string normalizedName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.NormalizedUserName = normalizedName; return Task.FromResult(0); } public async virtual Task CreateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } await Context.AddAsync(user, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task UpdateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } Context.Update(user); await SaveChanges(cancellationToken); } public async virtual Task DeleteAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } Context.Remove(user); await SaveChanges(cancellationToken); } /// /// Find a user by id /// /// /// /// public virtual Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var id = ConvertIdFromString(userId); return GetUserAggregate(u => u.Id.Equals(id), cancellationToken); } public virtual TKey ConvertIdFromString(string id) { if (id == null) { return default(TKey); } return (TKey)Convert.ChangeType(id, typeof(TKey)); } public virtual string ConvertIdToString(TKey id) { if (id.Equals(default(TKey))) { return null; } return id.ToString(); } /// /// Find a user by normalized name /// /// /// /// public virtual Task FindByNameAsync(string normalizedUserName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.NormalizedUserName == normalizedUserName, cancellationToken); } public IQueryable Users { get { return Context.Set(); } } /// /// Set the password hash for a user /// /// /// /// /// public virtual Task SetPasswordHashAsync(TUser user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PasswordHash = passwordHash; return Task.FromResult(0); } /// /// Get the password hash for a user /// /// /// /// public virtual Task GetPasswordHashAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PasswordHash); } /// /// Returns true if the user has a password set /// /// /// /// public virtual Task HasPasswordAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(user.PasswordHash != null); } /// /// Add a user to a role /// /// /// /// /// public async virtual Task AddToRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } var roleEntity = await Roles.SingleOrDefaultAsync(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity == null) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.RoleNotFound, roleName)); } var ur = new IdentityUserRole { UserId = user.Id, RoleId = roleEntity.Id }; // TODO: rely on fixup? await UserRoles.AddAsync(ur); user.Roles.Add(ur); roleEntity.Users.Add(ur); } /// /// Remove a user from a role /// /// /// /// /// public async virtual Task RemoveFromRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } var roleEntity = await Roles.SingleOrDefaultAsync(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity != null) { var userRole = await UserRoles.FirstOrDefaultAsync(r => roleEntity.Id.Equals(r.RoleId) && r.UserId.Equals(user.Id)); if (userRole != null) { UserRoles.Remove(userRole); user.Roles.Remove(userRole); } } } /// /// Get the names of the roles a user is a member of /// /// /// /// public virtual Task> GetRolesAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var userId = user.Id; // TODO: var query = from userRole in UserRoles var query = from userRole in user.Roles join role in Roles on userRole.RoleId equals role.Id select role.Name; //return await query.ToListAsync(); return Task.FromResult>(query.ToList()); } /// /// Returns true if the user is in the named role /// /// /// /// /// public virtual async Task IsInRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } var role = await Roles.SingleOrDefaultAsync(r => r.Name.ToUpper() == roleName.ToUpper()); if (role != null) { var userId = user.Id; var roleId = role.Id; return user.Roles.Any(ur => ur.RoleId.Equals(roleId)); //return await UserRoles.AnyAsync(ur => ur.RoleId.Equals(roleId) && ur.UserId.Equals(userId)); //return UserRoles.Any(ur => ur.RoleId.Equals(roleId) && ur.UserId.Equals(userId)); } return false; } private void ThrowIfDisposed() { if (_disposed) { throw new ObjectDisposedException(GetType().Name); } } /// /// Dispose the store /// public void Dispose() { _disposed = true; } private DbSet Roles { get { return Context.Set(); } } private DbSet> UserClaims { get { return Context.Set>(); } } private DbSet> UserRoles { get { return Context.Set>(); } } private DbSet> UserLogins { get { return Context.Set>(); } } public async Task> GetClaimsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return await UserClaims.Where(uc => uc.UserId.Equals(user.Id)).Select(c => new Claim(c.ClaimType, c.ClaimValue)).ToListAsync(); } public async Task AddClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claims == null) { throw new ArgumentNullException("claims"); } foreach (var claim in claims) { await UserClaims.AddAsync(new IdentityUserClaim { UserId = user.Id, ClaimType = claim.Type, ClaimValue = claim.Value }); } } public async Task ReplaceClaimAsync(TUser user, Claim claim, Claim newClaim, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claim == null) { throw new ArgumentNullException("claim"); } if (newClaim == null) { throw new ArgumentNullException("newClaim"); } var matchedClaims = await UserClaims.Where(uc => uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type).ToListAsync(); foreach(var matchedClaim in matchedClaims) { matchedClaim.ClaimValue = newClaim.Value; matchedClaim.ClaimType = newClaim.Type; } } public async Task RemoveClaimsAsync(TUser user, IEnumerable claims, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claims == null) { throw new ArgumentNullException("claims"); } foreach (var claim in claims) { var matchedClaims = await UserClaims.Where(uc => uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type).ToListAsync(); foreach (var c in matchedClaims) { UserClaims.Remove(c); } } } public virtual async Task AddLoginAsync(TUser user, UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (login == null) { throw new ArgumentNullException("login"); } var l = new IdentityUserLogin { UserId = user.Id, ProviderKey = login.ProviderKey, LoginProvider = login.LoginProvider, ProviderDisplayName = login.ProviderDisplayName }; // TODO: fixup so we don't have to update both await UserLogins.AddAsync(l); user.Logins.Add(l); } public virtual async Task RemoveLoginAsync(TUser user, string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var userId = user.Id; // todo: ensure logins loaded var entry = await UserLogins.SingleOrDefaultAsync(l => l.UserId.Equals(userId) && l.LoginProvider == loginProvider && l.ProviderKey == providerKey); if (entry != null) { UserLogins.Remove(entry); user.Logins.Remove(entry); } } public async virtual Task> GetLoginsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } // todo: ensure logins loaded //IList result = user.Logins // .Select(l => new UserLoginInfo(l.LoginProvider, l.ProviderKey, l.ProviderDisplayName)).ToList(); var userId = user.Id; return await UserLogins.Where(l => l.UserId.Equals(userId)) .Select(l => new UserLoginInfo(l.LoginProvider, l.ProviderKey, l.ProviderDisplayName)).ToListAsync(); } public async virtual Task FindByLoginAsync(string loginProvider, string providerKey, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); // todo: ensure logins loaded var userLogin = await UserLogins.FirstOrDefaultAsync(l => l.LoginProvider == loginProvider && l.ProviderKey == providerKey); if (userLogin != null) { return await GetUserAggregate(u => u.Id.Equals(userLogin.UserId), cancellationToken); } return null; } /// /// Returns whether the user email is confirmed /// /// /// /// public virtual Task GetEmailConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.EmailConfirmed); } /// /// Set IsConfirmed on the user /// /// /// /// /// public virtual Task SetEmailConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.EmailConfirmed = confirmed; return Task.FromResult(0); } /// /// Set the user email /// /// /// /// /// public virtual Task SetEmailAsync(TUser user, string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.Email = email; return Task.FromResult(0); } /// /// Get the user's email /// /// /// /// public virtual Task GetEmailAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.Email); } /// /// Find an user by email /// /// /// /// public virtual Task FindByEmailAsync(string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.Email == email, cancellationToken); // todo: ToUpper blows up with Null Ref //return GetUserAggregate(u => u.Email.ToUpper() == email.ToUpper(), cancellationToken); } /// /// Returns the DateTimeOffset that represents the end of a user's lockout, any time in the past should be considered /// not locked out. /// /// /// /// public virtual Task GetLockoutEndDateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnd); } /// /// Locks a user out until the specified end date (set to a past date, to unlock a user) /// /// /// /// /// public virtual Task SetLockoutEndDateAsync(TUser user, DateTimeOffset? lockoutEnd, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnd = lockoutEnd; return Task.FromResult(0); } /// /// Used to record when an attempt to access the user has failed /// /// /// /// public virtual Task IncrementAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount++; return Task.FromResult(user.AccessFailedCount); } /// /// Used to reset the account access count, typically after the account is successfully accessed /// /// /// /// public virtual Task ResetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount = 0; return Task.FromResult(0); } /// /// Returns the current number of failed access attempts. This number usually will be reset whenever the password is /// verified or the account is locked out. /// /// /// /// public virtual Task GetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.AccessFailedCount); } /// /// Returns whether the user can be locked out. /// /// /// /// public virtual Task GetLockoutEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnabled); } /// /// Sets whether the user can be locked out. /// /// /// /// /// public virtual Task SetLockoutEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnabled = enabled; return Task.FromResult(0); } /// /// Set the user's phone number /// /// /// /// /// public virtual Task SetPhoneNumberAsync(TUser user, string phoneNumber, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumber = phoneNumber; return Task.FromResult(0); } /// /// Get a user's phone number /// /// /// /// public virtual Task GetPhoneNumberAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumber); } /// /// Returns whether the user phoneNumber is confirmed /// /// /// /// public virtual Task GetPhoneNumberConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumberConfirmed); } /// /// Set PhoneNumberConfirmed on the user /// /// /// /// /// public virtual Task SetPhoneNumberConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumberConfirmed = confirmed; return Task.FromResult(0); } /// /// Set the security stamp for the user /// /// /// /// /// public virtual Task SetSecurityStampAsync(TUser user, string stamp, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.SecurityStamp = stamp; return Task.FromResult(0); } /// /// Get the security stamp for a user /// /// /// /// public virtual Task GetSecurityStampAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.SecurityStamp); } /// /// Set whether two factor authentication is enabled for the user /// /// /// /// /// public virtual Task SetTwoFactorEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.TwoFactorEnabled = enabled; return Task.FromResult(0); } /// /// Gets whether two factor authentication is enabled for the user /// /// /// /// public virtual Task GetTwoFactorEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.TwoFactorEnabled); } } }