// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.Data.Entity; namespace Microsoft.AspNet.Identity.EntityFramework { public class RoleStore : RoleStore where TRole : IdentityRole { public RoleStore(DbContext context) : base(context) { } } public class RoleStore : RoleStore where TRole : IdentityRole where TContext : DbContext { public RoleStore(TContext context) : base(context) { } } public class RoleStore : IQueryableRoleStore, IRoleClaimStore where TRole : IdentityRole where TKey : IEquatable where TContext : DbContext { public RoleStore(TContext context) { if (context == null) { throw new ArgumentNullException("context"); } Context = context; } private bool _disposed; public TContext Context { get; private set; } /// /// If true will call SaveChanges after CreateAsync/UpdateAsync/DeleteAsync /// public bool AutoSaveChanges { get; set; } = true; private async Task SaveChanges(CancellationToken cancellationToken) { if (AutoSaveChanges) { await Context.SaveChangesAsync(cancellationToken); } } public virtual Task GetRoleAggregate(Expression> filter, CancellationToken cancellationToken = default(CancellationToken)) { return Roles.FirstOrDefaultAsync(filter); } public async virtual Task CreateAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } await Context.AddAsync(role, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task UpdateAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } await Context.UpdateAsync(role, cancellationToken); await SaveChanges(cancellationToken); } public async virtual Task DeleteAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } Context.Delete(role); await SaveChanges(cancellationToken); } public Task GetRoleIdAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } return Task.FromResult(ConvertIdToString(role.Id)); } public Task GetRoleNameAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } return Task.FromResult(role.Name); } public Task SetRoleNameAsync(TRole role, string roleName, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } role.Name = roleName; return Task.FromResult(0); } public virtual TKey ConvertIdFromString(string id) { if (id == null) { return default(TKey); } return (TKey)Convert.ChangeType(id, typeof(TKey)); } public virtual string ConvertIdToString(TKey id) { if (id.Equals(default(TKey))) { return null; } return id.ToString(); } /// /// Find a role by id /// /// /// /// public virtual Task FindByIdAsync(string id, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); var roleId = ConvertIdFromString(id); return GetRoleAggregate(u => u.Id.Equals(roleId), cancellationToken); } /// /// Find a role by name /// /// /// /// public virtual Task FindByNameAsync(string name, CancellationToken cancellationToken = default(CancellationToken)) { cancellationToken.ThrowIfCancellationRequested(); ThrowIfDisposed(); return GetRoleAggregate(u => u.Name.ToUpper() == name.ToUpper(), cancellationToken); } private void ThrowIfDisposed() { if (_disposed) { throw new ObjectDisposedException(GetType().Name); } } /// /// Dispose the store /// public void Dispose() { _disposed = true; } public async Task> GetClaimsAsync(TRole role, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } return await RoleClaims.Where(rc => rc.RoleId.Equals(role.Id)).Select(c => new Claim(c.ClaimType, c.ClaimValue)).ToListAsync(); } public Task AddClaimAsync(TRole role, Claim claim, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } if (claim == null) { throw new ArgumentNullException("claim"); } return RoleClaims.AddAsync(new IdentityRoleClaim { RoleId = role.Id, ClaimType = claim.Type, ClaimValue = claim.Value }); } public async Task RemoveClaimAsync(TRole role, Claim claim, CancellationToken cancellationToken = default(CancellationToken)) { ThrowIfDisposed(); if (role == null) { throw new ArgumentNullException("role"); } if (claim == null) { throw new ArgumentNullException("claim"); } var claims = await RoleClaims.Where(uc => uc.ClaimValue == claim.Value && uc.ClaimType == claim.Type).ToListAsync(); foreach (var c in claims) { RoleClaims.Remove(c); } } public IQueryable Roles { get { return Context.Set(); } } private DbSet> RoleClaims { get { return Context.Set>(); } } } }