Enable and fix Async/Cancellation tests

This commit is contained in:
Hao Kung 2014-10-08 16:27:29 -07:00
parent cb3948b86f
commit 2c9f43a160
10 changed files with 207 additions and 34 deletions

View File

@ -3,10 +3,6 @@
using System; using System;
using Microsoft.AspNet.Identity; using Microsoft.AspNet.Identity;
using Microsoft.Framework.OptionsModel;
using Microsoft.Framework.DependencyInjection;
using Microsoft.AspNet.Security.Cookies;
using Microsoft.Framework.ConfigurationModel;
namespace Microsoft.AspNet.Builder namespace Microsoft.AspNet.Builder
{ {

View File

@ -1,11 +1,11 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.Framework.OptionsModel;
using System; using System;
using System.Security.Claims; using System.Security.Claims;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.Framework.OptionsModel;
namespace Microsoft.AspNet.Identity namespace Microsoft.AspNet.Identity
{ {

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Security.Claims; using System.Security.Claims;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNet.Security.Cookies; using Microsoft.AspNet.Security.Cookies;
@ -9,6 +10,7 @@ namespace Microsoft.AspNet.Identity
{ {
public interface ISecurityStampValidator public interface ISecurityStampValidator
{ {
Task Validate(CookieValidateIdentityContext context, ClaimsIdentity identity); Task ValidateAsync(CookieValidateIdentityContext context, ClaimsIdentity identity,
CancellationToken cancellationToken = default(CancellationToken));
} }
} }

View File

@ -2,12 +2,12 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using Microsoft.AspNet.Identity;
using Microsoft.Framework.ConfigurationModel;
using Microsoft.AspNet.Security.DataProtection;
using Microsoft.AspNet.Security.Cookies;
using Microsoft.AspNet.Http; using Microsoft.AspNet.Http;
using Microsoft.AspNet.Identity;
using Microsoft.AspNet.Security; using Microsoft.AspNet.Security;
using Microsoft.AspNet.Security.Cookies;
using Microsoft.AspNet.Security.DataProtection;
using Microsoft.Framework.ConfigurationModel;
namespace Microsoft.Framework.DependencyInjection namespace Microsoft.Framework.DependencyInjection
{ {

View File

@ -4,6 +4,7 @@
using System; using System;
using System.Security.Claims; using System.Security.Claims;
using System.Security.Principal; using System.Security.Principal;
using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNet.Security.Cookies; using Microsoft.AspNet.Security.Cookies;
using Microsoft.Framework.DependencyInjection; using Microsoft.Framework.DependencyInjection;
@ -18,11 +19,12 @@ namespace Microsoft.AspNet.Identity
/// ClaimsIdentity /// ClaimsIdentity
/// </summary> /// </summary>
/// <returns></returns> /// <returns></returns>
public virtual async Task Validate(CookieValidateIdentityContext context, ClaimsIdentity identity) public virtual async Task ValidateAsync(CookieValidateIdentityContext context, ClaimsIdentity identity,
CancellationToken cancellationToken = default(CancellationToken))
{ {
var manager = context.HttpContext.RequestServices.GetRequiredService<SignInManager<TUser>>(); var manager = context.HttpContext.RequestServices.GetRequiredService<SignInManager<TUser>>();
var userId = identity.GetUserId(); var userId = identity.GetUserId();
var user = await manager.ValidateSecurityStampAsync(identity, userId); var user = await manager.ValidateSecurityStampAsync(identity, userId, cancellationToken);
if (user != null) if (user != null)
{ {
var isPersistent = false; var isPersistent = false;
@ -30,7 +32,7 @@ namespace Microsoft.AspNet.Identity
{ {
isPersistent = context.Properties.IsPersistent; isPersistent = context.Properties.IsPersistent;
} }
await manager.SignInAsync(user, isPersistent); await manager.SignInAsync(user, isPersistent, authenticationMethod: null, cancellationToken: cancellationToken);
} }
else else
{ {
@ -71,7 +73,7 @@ namespace Microsoft.AspNet.Identity
if (validate) if (validate)
{ {
var validator = context.HttpContext.RequestServices.GetRequiredService<ISecurityStampValidator>(); var validator = context.HttpContext.RequestServices.GetRequiredService<ISecurityStampValidator>();
return validator.Validate(context, context.Identity); return validator.ValidateAsync(context, context.Identity);
} }
return Task.FromResult(0); return Task.FromResult(0);
} }

View File

@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Claims; using System.Security.Claims;
using System.Security.Principal; using System.Security.Principal;
using System.Threading; using System.Threading;
@ -10,8 +12,6 @@ using Microsoft.AspNet.Http;
using Microsoft.AspNet.Http.Security; using Microsoft.AspNet.Http.Security;
using Microsoft.Framework.DependencyInjection; using Microsoft.Framework.DependencyInjection;
using Microsoft.Framework.OptionsModel; using Microsoft.Framework.OptionsModel;
using System.Collections.Generic;
using System.Linq;
namespace Microsoft.AspNet.Identity namespace Microsoft.AspNet.Identity
{ {
@ -214,7 +214,7 @@ namespace Microsoft.AspNet.Identity
Context.Response.SignIn(new AuthenticationProperties { IsPersistent = true }, rememberBrowserIdentity); Context.Response.SignIn(new AuthenticationProperties { IsPersistent = true }, rememberBrowserIdentity);
} }
public virtual Task ForgetTwoFactorClientAsync() public virtual Task ForgetTwoFactorClientAsync(CancellationToken cancellationToken = default(CancellationToken))
{ {
Context.Response.SignOut(IdentityOptions.TwoFactorRememberMeCookieAuthenticationType); Context.Response.SignOut(IdentityOptions.TwoFactorRememberMeCookieAuthenticationType);
return Task.FromResult(0); return Task.FromResult(0);

View File

@ -42,7 +42,7 @@ namespace Microsoft.AspNet.Identity
{ {
throw new ArgumentNullException("manager"); throw new ArgumentNullException("manager");
} }
var token = await manager.CreateSecurityTokenAsync(user); var token = await manager.CreateSecurityTokenAsync(user, cancellationToken);
var modifier = await GetUserModifierAsync(purpose, manager, user); var modifier = await GetUserModifierAsync(purpose, manager, user);
return Rfc6238AuthenticationService.GenerateCode(token, modifier).ToString("D6", CultureInfo.InvariantCulture); return Rfc6238AuthenticationService.GenerateCode(token, modifier).ToString("D6", CultureInfo.InvariantCulture);
} }
@ -63,11 +63,11 @@ namespace Microsoft.AspNet.Identity
throw new ArgumentNullException("manager"); throw new ArgumentNullException("manager");
} }
int code; int code;
if (!Int32.TryParse(token, out code)) if (!int.TryParse(token, out code))
{ {
return false; return false;
} }
var securityToken = await manager.CreateSecurityTokenAsync(user); var securityToken = await manager.CreateSecurityTokenAsync(user, cancellationToken);
var modifier = await GetUserModifierAsync(purpose, manager, user); var modifier = await GetUserModifierAsync(purpose, manager, user);
return securityToken != null && Rfc6238AuthenticationService.ValidateCode(securityToken, code, modifier); return securityToken != null && Rfc6238AuthenticationService.ValidateCode(securityToken, code, modifier);
} }

View File

@ -61,7 +61,8 @@ namespace Microsoft.AspNet.Identity
UserNameNormalizer = userNameNormalizer; UserNameNormalizer = userNameNormalizer;
// TODO: Email/Sms/Token services // TODO: Email/Sms/Token services
if (tokenProviders != null) { if (tokenProviders != null)
{
foreach (var tokenProvider in tokenProviders) foreach (var tokenProvider in tokenProviders)
{ {
RegisterTokenProvider(tokenProvider); RegisterTokenProvider(tokenProvider);
@ -310,7 +311,7 @@ namespace Microsoft.AspNet.Identity
{ {
await GetUserLockoutStore().SetLockoutEnabledAsync(user, true, cancellationToken); await GetUserLockoutStore().SetLockoutEnabledAsync(user, true, cancellationToken);
} }
await UpdateNormalizedUserName(user, cancellationToken); await UpdateNormalizedUserNameAsync(user, cancellationToken);
await Store.CreateAsync(user, cancellationToken); await Store.CreateAsync(user, cancellationToken);
return IdentityResult.Success; return IdentityResult.Success;
} }
@ -334,7 +335,7 @@ namespace Microsoft.AspNet.Identity
{ {
return result; return result;
} }
await UpdateNormalizedUserName(user, cancellationToken); await UpdateNormalizedUserNameAsync(user, cancellationToken);
await Store.UpdateAsync(user, cancellationToken); await Store.UpdateAsync(user, cancellationToken);
return IdentityResult.Success; return IdentityResult.Success;
} }
@ -443,7 +444,7 @@ namespace Microsoft.AspNet.Identity
/// <param name="user"></param> /// <param name="user"></param>
/// <param name="cancellationToken"></param> /// <param name="cancellationToken"></param>
/// <returns></returns> /// <returns></returns>
public virtual async Task UpdateNormalizedUserName(TUser user, public virtual async Task UpdateNormalizedUserNameAsync(TUser user,
CancellationToken cancellationToken = default(CancellationToken)) CancellationToken cancellationToken = default(CancellationToken))
{ {
var userName = await GetUserNameAsync(user, cancellationToken); var userName = await GetUserNameAsync(user, cancellationToken);
@ -489,7 +490,7 @@ namespace Microsoft.AspNet.Identity
private async Task UpdateUserName(TUser user, string userName, CancellationToken cancellationToken) private async Task UpdateUserName(TUser user, string userName, CancellationToken cancellationToken)
{ {
await Store.SetUserNameAsync(user, userName, cancellationToken); await Store.SetUserNameAsync(user, userName, cancellationToken);
await UpdateNormalizedUserName(user, cancellationToken); await UpdateNormalizedUserNameAsync(user, cancellationToken);
} }
/// <summary> /// <summary>
@ -1362,7 +1363,7 @@ namespace Microsoft.AspNet.Identity
{ {
throw new ArgumentNullException("user"); throw new ArgumentNullException("user");
} }
if (!await VerifyChangePhoneNumberTokenAsync(user, token, phoneNumber)) if (!await VerifyChangePhoneNumberTokenAsync(user, token, phoneNumber, cancellationToken))
{ {
return IdentityResult.Failed(Resources.InvalidToken); return IdentityResult.Failed(Resources.InvalidToken);
} }
@ -1392,10 +1393,10 @@ namespace Microsoft.AspNet.Identity
// Two factor APIS // Two factor APIS
internal async Task<SecurityToken> CreateSecurityTokenAsync(TUser user) internal async Task<SecurityToken> CreateSecurityTokenAsync(TUser user, CancellationToken cancellationToken)
{ {
return return
new SecurityToken(Encoding.Unicode.GetBytes(await GetSecurityStampAsync(user))); new SecurityToken(Encoding.Unicode.GetBytes(await GetSecurityStampAsync(user, cancellationToken)));
} }
/// <summary> /// <summary>
@ -1404,12 +1405,13 @@ namespace Microsoft.AspNet.Identity
/// <param name="user"></param> /// <param name="user"></param>
/// <param name="phoneNumber"></param> /// <param name="phoneNumber"></param>
/// <returns></returns> /// <returns></returns>
public virtual async Task<string> GenerateChangePhoneNumberTokenAsync(TUser user, string phoneNumber) public virtual async Task<string> GenerateChangePhoneNumberTokenAsync(TUser user, string phoneNumber,
CancellationToken cancellationToken = default(CancellationToken))
{ {
ThrowIfDisposed(); ThrowIfDisposed();
return return Rfc6238AuthenticationService.GenerateCode(
Rfc6238AuthenticationService.GenerateCode(await CreateSecurityTokenAsync(user), phoneNumber) await CreateSecurityTokenAsync(user, cancellationToken), phoneNumber)
.ToString(CultureInfo.InvariantCulture); .ToString(CultureInfo.InvariantCulture);
} }
/// <summary> /// <summary>
@ -1419,10 +1421,11 @@ namespace Microsoft.AspNet.Identity
/// <param name="token"></param> /// <param name="token"></param>
/// <param name="phoneNumber"></param> /// <param name="phoneNumber"></param>
/// <returns></returns> /// <returns></returns>
public virtual async Task<bool> VerifyChangePhoneNumberTokenAsync(TUser user, string token, string phoneNumber) public virtual async Task<bool> VerifyChangePhoneNumberTokenAsync(TUser user, string token, string phoneNumber,
CancellationToken cancellationToken = default(CancellationToken))
{ {
ThrowIfDisposed(); ThrowIfDisposed();
var securityToken = await CreateSecurityTokenAsync(user); var securityToken = await CreateSecurityTokenAsync(user, cancellationToken);
int code; int code;
if (securityToken != null && Int32.TryParse(token, out code)) if (securityToken != null && Int32.TryParse(token, out code))
{ {

View File

@ -0,0 +1,17 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Linq;
using System.Reflection;
using Xunit;
namespace Microsoft.AspNet.Identity.Test
{
public class ApiConsistencyTest : ApiConsistencyTestBase
{
protected override Assembly TargetAssembly
{
get { return typeof(IdentityOptions).Assembly; }
}
}
}

View File

@ -0,0 +1,153 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNet.Identity.Test
{
public abstract class ApiConsistencyTestBase
{
protected const BindingFlags PublicInstance
= BindingFlags.Instance | BindingFlags.Public;
//protected const BindingFlags AnyInstance
// = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
//[Fact]
//public void Public_inheritable_apis_should_be_virtual()
//{
// var nonVirtualMethods
// = (from type in GetAllTypes(TargetAssembly.GetTypes())
// where type.IsVisible
// && !type.IsSealed
// && type.GetConstructors(AnyInstance).Any(c => c.IsPublic || c.IsFamily || c.IsFamilyOrAssembly)
// && type.Namespace != null
// && !type.Namespace.EndsWith(".Compiled")
// from method in type.GetMethods(PublicInstance)
// where GetBasestTypeInAssembly(method.DeclaringType) == type
// && !(method.IsVirtual && !method.IsFinal)
// select type.Name + "." + method.Name)
// .ToList();
// Assert.False(
// nonVirtualMethods.Any(),
// "\r\n-- Missing virtual APIs --\r\n" + string.Join("\r\n", nonVirtualMethods));
//}
//[Fact]
//public void Public_api_arguments_should_have_not_null_annotation()
//{
// var parametersMissingAttribute
// = (from type in GetAllTypes(TargetAssembly.GetTypes())
// where type.IsVisible && !typeof(Delegate).IsAssignableFrom(type)
// let interfaceMappings = type.GetInterfaces().Select(type.GetInterfaceMap)
// let events = type.GetEvents()
// from method in type.GetMethods(PublicInstance | BindingFlags.Static)
// .Concat<MethodBase>(type.GetConstructors())
// where GetBasestTypeInAssembly(method.DeclaringType) == type
// where type.IsInterface || !interfaceMappings.Any(im => im.TargetMethods.Contains(method))
// where !events.Any(e => e.AddMethod == method || e.RemoveMethod == method)
// from parameter in method.GetParameters()
// where !parameter.ParameterType.IsValueType
// && !parameter.GetCustomAttributes()
// .Any(
// a => a.GetType().Name == "NotNullAttribute"
// || a.GetType().Name == "CanBeNullAttribute")
// select type.Name + "." + method.Name + "[" + parameter.Name + "]")
// .ToList();
// Assert.False(
// parametersMissingAttribute.Any(),
// "\r\n-- Missing NotNull annotations --\r\n" + string.Join("\r\n", parametersMissingAttribute));
//}
[Fact]
public void Async_methods_should_have_overload_with_cancellation_token_and_end_with_async_suffix()
{
var asyncMethods
= (from type in GetAllTypes(TargetAssembly.GetTypes())
where type.IsVisible
from method in type.GetMethods(PublicInstance/* | BindingFlags.Static*/)
where GetBasestTypeInAssembly(method.DeclaringType) == type
where typeof(Task).IsAssignableFrom(method.ReturnType)
select method).ToList();
var asyncMethodsWithToken
= (from method in asyncMethods
where method.GetParameters().Any(pi => pi.ParameterType == typeof(CancellationToken))
select method).ToList();
var asyncMethodsWithoutToken
= (from method in asyncMethods
where method.GetParameters().All(pi => pi.ParameterType != typeof(CancellationToken))
select method).ToList();
var missingOverloads
= (from methodWithoutToken in asyncMethodsWithoutToken
where !asyncMethodsWithToken
.Any(methodWithToken => methodWithoutToken.Name == methodWithToken.Name
&& methodWithoutToken.ReflectedType == methodWithToken.ReflectedType)
// ReSharper disable once PossibleNullReferenceException
select methodWithoutToken.DeclaringType.Name + "." + methodWithoutToken.Name)
.Except(GetCancellationTokenExceptions())
.ToList();
Assert.False(
missingOverloads.Any(),
"\r\n-- Missing async overloads --\r\n" + string.Join("\r\n", missingOverloads));
var missingSuffixMethods
= asyncMethods
.Where(method => !method.Name.EndsWith("Async"))
.Select(method => method.DeclaringType.Name + "." + method.Name)
.Except(GetAsyncSuffixExceptions())
.ToList();
Assert.False(
missingSuffixMethods.Any(),
"\r\n-- Missing async suffix --\r\n" + string.Join("\r\n", missingSuffixMethods));
}
protected virtual IEnumerable<string> GetCancellationTokenExceptions()
{
return Enumerable.Empty<string>();
}
protected virtual IEnumerable<string> GetAsyncSuffixExceptions()
{
return Enumerable.Empty<string>();
}
protected abstract Assembly TargetAssembly { get; }
protected virtual IEnumerable<Type> GetAllTypes(IEnumerable<Type> types)
{
foreach (var type in types)
{
yield return type;
foreach (var nestedType in GetAllTypes(type.GetNestedTypes()))
{
yield return nestedType;
}
}
}
protected Type GetBasestTypeInAssembly(Type type)
{
while (type.BaseType != null
&& type.BaseType.Assembly == type.Assembly)
{
type = type.BaseType;
}
return type;
}
}
}