// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Linq.Expressions; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Entity; namespace Microsoft.AspNet.Identity.EntityFramework { public class UserStore : UserStore { public UserStore(DbContext context) : base(context) { } } public class UserStore : UserStore where TUser : User { public UserStore(DbContext context) : base(context) { } } public class UserStore : IUserLoginStore, IUserRoleStore, IUserClaimStore, IUserPasswordStore, IUserSecurityStampStore, IUserEmailStore, IUserLockoutStore, IUserPhoneNumberStore, IQueryableUserStore, IUserTwoFactorStore where TUser : User where TRole : IdentityRole where TContext : DbContext { private bool _disposed; public UserStore(TContext context) { if (context == null) { throw new ArgumentNullException("context"); } Context = context; AutoSaveChanges = true; } public TContext Context { get; private set; } /// /// If true will call SaveChanges after CreateAsync/UpdateAsync/DeleteAsync /// public bool AutoSaveChanges { get; set; } private Task SaveChanges(CancellationToken cancellationToken) { return AutoSaveChanges ? Context.SaveChangesAsync(cancellationToken) : Task.FromResult(0); } protected virtual Task GetUserAggregate(Expression> filter, CancellationToken cancellationToken = default(CancellationToken)) { return Task.FromResult(Users.FirstOrDefault(filter)); // TODO: return Users.FirstOrDefaultAsync(filter, cancellationToken); //Include(u => u.Roles) //.Include(u => u.Claims) //.Include(u => u.Logins) } public Task GetUserIdAsync(TUser user, CancellationToken cancellationToken = new CancellationToken()) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(Convert.ToString(user.Id, CultureInfo.InvariantCulture)); } public Task GetUserNameAsync(TUser user, CancellationToken cancellationToken = new CancellationToken()) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.UserName); } public Task SetUserNameAsync(TUser user, string userName, CancellationToken cancellationToken = new CancellationToken()) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.UserName = userName; return Task.FromResult(0); } public async virtual Task CreateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } await Context.AddAsync(user, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task UpdateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } await Context.UpdateAsync(user, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task DeleteAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } Context.Delete(user); await SaveChanges(cancellationToken); } /// /// Find a user by id /// /// /// /// public virtual Task FindByIdAsync(string userId, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.Id.Equals(userId), cancellationToken); } /// /// Find a user by name /// /// /// /// public virtual Task FindByNameAsync(string userName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.UserName.ToUpper() == userName.ToUpper(), cancellationToken); } public IQueryable Users { get { return Context.Set(); } } /// /// Set the password hash for a user /// /// /// /// /// public virtual Task SetPasswordHashAsync(TUser user, string passwordHash, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PasswordHash = passwordHash; return Task.FromResult(0); } /// /// Get the password hash for a user /// /// /// /// public virtual Task GetPasswordHashAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PasswordHash); } /// /// Returns true if the user has a password set /// /// /// /// public virtual Task HasPasswordAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); return Task.FromResult(user.PasswordHash != null); } /// /// Add a user to a role /// /// /// /// /// public virtual Task AddToRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } var roleEntity = Roles.SingleOrDefault(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity == null) { throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.RoleNotFound, roleName)); } var ur = new IdentityUserRole { UserId = user.Id, RoleId = roleEntity.Id }; // TODO: rely on fixup? UserRoles.Add(ur); user.Roles.Add(ur); roleEntity.Users.Add(ur); return Task.FromResult(0); } /// /// Remove a user from a role /// /// /// /// /// public virtual Task RemoveFromRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } var roleEntity = Roles.SingleOrDefault(r => r.Name.ToUpper() == roleName.ToUpper()); if (roleEntity != null) { var userRole = UserRoles.FirstOrDefault(r => roleEntity.Id.Equals(r.RoleId) && r.UserId == user.Id); if (userRole != null) { UserRoles.Remove(userRole); user.Roles.Remove(userRole); } } return Task.FromResult(0); } /// /// Get the names of the roles a user is a member of /// /// /// /// public virtual async Task> GetRolesAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } var userId = user.Id; // TODO: var query = from userRole in UserRoles var query = from userRole in user.Roles join role in Roles on userRole.RoleId equals role.Id select role.Name; //return await query.ToListAsync(); return query.ToList(); } /// /// Returns true if the user is in the named role /// /// /// /// /// public virtual async Task IsInRoleAsync(TUser user, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (String.IsNullOrWhiteSpace(roleName)) { throw new ArgumentException(Resources.ValueCannotBeNullOrEmpty, "roleName"); } //var role = await Roles.SingleOrDefaultAsync(r => r.Name.ToUpper() == roleName.ToUpper()); var role = Roles.SingleOrDefault(r => r.Name.ToUpper() == roleName.ToUpper()); if (role != null) { var userId = user.Id; var roleId = role.Id; return user.Roles.Any(ur => ur.RoleId.Equals(roleId)); //return await UserRoles.AnyAsync(ur => ur.RoleId.Equals(roleId) && ur.UserId.Equals(userId)); //return UserRoles.Any(ur => ur.RoleId.Equals(roleId) && ur.UserId.Equals(userId)); } return false; } private void ThrowIfDisposed() { if (_disposed) { throw new ObjectDisposedException(GetType().Name); } } /// /// Dispose the store /// public void Dispose() { _disposed = true; } private DbSet Roles { get { return Context.Set(); } } private DbSet UserClaims { get { return Context.Set(); } } private DbSet UserRoles { get { return Context.Set(); } } private DbSet UserLogins { get { return Context.Set(); } } public Task> GetClaimsAsync(TUser user, CancellationToken cancellationToken = new CancellationToken()) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } IList result = UserClaims.Where(uc => uc.UserId == user.Id).Select(c => new Claim(c.ClaimType, c.ClaimValue)).ToList(); return Task.FromResult(result); } public Task AddClaimAsync(TUser user, Claim claim, CancellationToken cancellationToken = new CancellationToken()) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claim == null) { throw new ArgumentNullException("claim"); } UserClaims.Add(new IdentityUserClaim { UserId = user.Id, ClaimType = claim.Type, ClaimValue = claim.Value }); return Task.FromResult(0); } public Task RemoveClaimAsync(TUser user, Claim claim, CancellationToken cancellationToken = new CancellationToken()) { ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (claim == null) { throw new ArgumentNullException("claim"); } var claims = UserClaims.Where(uc => uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type).ToList(); foreach (var c in claims) { UserClaims.Remove(c); } return Task.FromResult(0); } public async virtual Task AddLoginAsync(TUser user, UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (login == null) { throw new ArgumentNullException("login"); } var l = new IdentityUserLogin { UserId = user.Id, ProviderKey = login.ProviderKey, LoginProvider = login.LoginProvider }; // TODO: fixup so we don't have to update both UserLogins.Add(l); user.Logins.Add(l); } public virtual Task RemoveLoginAsync(TUser user, UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } if (login == null) { throw new ArgumentNullException("login"); } var provider = login.LoginProvider; var key = login.ProviderKey; var userId = user.Id; // todo: ensure logins loaded var entry = UserLogins.SingleOrDefault(l => l.UserId == userId && l.LoginProvider == provider && l.ProviderKey == key); if (entry != null) { UserLogins.Remove(entry); user.Logins.Remove(entry); } return Task.FromResult(0); } public virtual Task> GetLoginsAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } // todo: ensure logins loaded IList result = user.Logins.Select(l => new UserLoginInfo(l.LoginProvider, l.ProviderKey)).ToList(); return Task.FromResult(result); } public async virtual Task FindByLoginAsync(UserLoginInfo login, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (login == null) { throw new ArgumentNullException("login"); } // todo: ensure logins loaded var provider = login.LoginProvider; var key = login.ProviderKey; // TODO: use FirstOrDefaultAsync var userLogin = UserLogins.FirstOrDefault(l => l.LoginProvider == provider && l.ProviderKey == key); if (userLogin != null) { return await GetUserAggregate(u => u.Id.Equals(userLogin.UserId), cancellationToken); } return null; } /// /// Returns whether the user email is confirmed /// /// /// /// public virtual Task GetEmailConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.EmailConfirmed); } /// /// Set IsConfirmed on the user /// /// /// /// /// public virtual Task SetEmailConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.EmailConfirmed = confirmed; return Task.FromResult(0); } /// /// Set the user email /// /// /// /// /// public virtual Task SetEmailAsync(TUser user, string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.Email = email; return Task.FromResult(0); } /// /// Get the user's email /// /// /// /// public virtual Task GetEmailAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.Email); } /// /// Find an user by email /// /// /// /// public virtual Task FindByEmailAsync(string email, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetUserAggregate(u => u.Email == email, cancellationToken); // todo: ToUpper blows up with Null Ref //return GetUserAggregate(u => u.Email.ToUpper() == email.ToUpper(), cancellationToken); } /// /// Returns the DateTimeOffset that represents the end of a user's lockout, any time in the past should be considered /// not locked out. /// /// /// /// public virtual Task GetLockoutEndDateAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnd.HasValue ? new DateTimeOffset(DateTime.SpecifyKind(user.LockoutEnd.Value, DateTimeKind.Utc)) : new DateTimeOffset()); } /// /// Locks a user out until the specified end date (set to a past date, to unlock a user) /// /// /// /// /// public virtual Task SetLockoutEndDateAsync(TUser user, DateTimeOffset lockoutEnd, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnd = lockoutEnd == DateTimeOffset.MinValue ? (DateTime?)null : lockoutEnd.UtcDateTime; return Task.FromResult(0); } /// /// Used to record when an attempt to access the user has failed /// /// /// /// public virtual Task IncrementAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount++; return Task.FromResult(user.AccessFailedCount); } /// /// Used to reset the account access count, typically after the account is successfully accessed /// /// /// /// public virtual Task ResetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.AccessFailedCount = 0; return Task.FromResult(0); } /// /// Returns the current number of failed access attempts. This number usually will be reset whenever the password is /// verified or the account is locked out. /// /// /// /// public virtual Task GetAccessFailedCountAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.AccessFailedCount); } /// /// Returns whether the user can be locked out. /// /// /// /// public virtual Task GetLockoutEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.LockoutEnabled); } /// /// Sets whether the user can be locked out. /// /// /// /// /// public virtual Task SetLockoutEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.LockoutEnabled = enabled; return Task.FromResult(0); } /// /// Set the user's phone number /// /// /// /// /// public virtual Task SetPhoneNumberAsync(TUser user, string phoneNumber, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumber = phoneNumber; return Task.FromResult(0); } /// /// Get a user's phone number /// /// /// /// public virtual Task GetPhoneNumberAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumber); } /// /// Returns whether the user phoneNumber is confirmed /// /// /// /// public virtual Task GetPhoneNumberConfirmedAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.PhoneNumberConfirmed); } /// /// Set PhoneNumberConfirmed on the user /// /// /// /// /// public virtual Task SetPhoneNumberConfirmedAsync(TUser user, bool confirmed, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.PhoneNumberConfirmed = confirmed; return Task.FromResult(0); } /// /// Set the security stamp for the user /// /// /// /// /// public virtual Task SetSecurityStampAsync(TUser user, string stamp, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.SecurityStamp = stamp; return Task.FromResult(0); } /// /// Get the security stamp for a user /// /// /// /// public virtual Task GetSecurityStampAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.SecurityStamp); } /// /// Set whether two factor authentication is enabled for the user /// /// /// /// /// public virtual Task SetTwoFactorEnabledAsync(TUser user, bool enabled, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } user.TwoFactorEnabled = enabled; return Task.FromResult(0); } /// /// Gets whether two factor authentication is enabled for the user /// /// /// /// public virtual Task GetTwoFactorEnabledAsync(TUser user, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (user == null) { throw new ArgumentNullException("user"); } return Task.FromResult(user.TwoFactorEnabled); } } }