Prevent null refs in some simple cases in CachedExpressionCompiler
Fixes #6928
This commit is contained in:
parent
dacbb41478
commit
8e31319215
|
|
@ -242,7 +242,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
|
|||
|
||||
try
|
||||
{
|
||||
func = CachedExpressionCompiler.Process(lambda);
|
||||
func = CachedExpressionCompiler.Process(lambda) ?? lambda.Compile();
|
||||
}
|
||||
catch (InvalidOperationException ex)
|
||||
{
|
||||
|
|
@ -259,8 +259,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
|
|||
|
||||
public static bool IsSingleArgumentIndexer(Expression expression)
|
||||
{
|
||||
var methodExpression = expression as MethodCallExpression;
|
||||
if (methodExpression == null || methodExpression.Arguments.Count != 1)
|
||||
if (!(expression is MethodCallExpression methodExpression) || methodExpression.Arguments.Count != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -82,11 +82,17 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
|
|||
|
||||
object modelAccessor(object container)
|
||||
{
|
||||
var compiledExpression = CachedExpressionCompiler.Process(expression);
|
||||
Debug.Assert(compiledExpression != null);
|
||||
var model = (TModel)container;
|
||||
var cachedFunc = CachedExpressionCompiler.Process(expression);
|
||||
if (cachedFunc != null)
|
||||
{
|
||||
return cachedFunc(model);
|
||||
}
|
||||
|
||||
var func = expression.Compile();
|
||||
try
|
||||
{
|
||||
return compiledExpression((TModel)container);
|
||||
return func(model);
|
||||
}
|
||||
catch (NullReferenceException)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -11,12 +11,16 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
|||
{
|
||||
internal static class CachedExpressionCompiler
|
||||
{
|
||||
// This is the entry point to the cached expression compilation system. The system
|
||||
// will try to turn the expression into an actual delegate as quickly as possible,
|
||||
// relying on cache lookups and other techniques to save time if appropriate.
|
||||
// If the provided expression is particularly obscure and the system doesn't know
|
||||
// how to handle it, we'll just compile the expression as normal.
|
||||
public static Func<TModel, TResult> Process<TModel, TResult>(
|
||||
private static readonly Expression NullExpression = Expression.Constant(value: null);
|
||||
|
||||
/// <remarks>
|
||||
/// This is the entry point to the expression compilation system. The system
|
||||
/// a) Will rewrite the expression to avoid null refs when any part of the expression tree is evaluated to null
|
||||
/// b) Attempt to cache the result, or an intermediate part of the result.
|
||||
/// If the provided expression is particularly obscure and the system doesn't know how to handle it, it will
|
||||
/// return null.
|
||||
/// </remarks>
|
||||
public static Func<TModel, object> Process<TModel, TResult>(
|
||||
Expression<Func<TModel, TResult>> expression)
|
||||
{
|
||||
if (expression == null)
|
||||
|
|
@ -29,15 +33,18 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
|||
|
||||
private static class Compiler<TModel, TResult>
|
||||
{
|
||||
private static Func<TModel, TResult> _identityFunc;
|
||||
private static Func<TModel, object> _identityFunc;
|
||||
|
||||
private static readonly ConcurrentDictionary<MemberInfo, Func<TModel, TResult>> _simpleMemberAccessCache =
|
||||
new ConcurrentDictionary<MemberInfo, Func<TModel, TResult>>();
|
||||
private static readonly ConcurrentDictionary<MemberInfo, Func<TModel, object>> _simpleMemberAccessCache =
|
||||
new ConcurrentDictionary<MemberInfo, Func<TModel, object>>();
|
||||
|
||||
private static readonly ConcurrentDictionary<MemberExpressionCacheKey, Func<TModel, object>> _chainedMemberAccessCache =
|
||||
new ConcurrentDictionary<MemberExpressionCacheKey, Func<TModel, object>>(MemberExpressionCacheKeyComparer.Instance);
|
||||
|
||||
private static readonly ConcurrentDictionary<MemberInfo, Func<object, TResult>> _constMemberAccessCache =
|
||||
new ConcurrentDictionary<MemberInfo, Func<object, TResult>>();
|
||||
|
||||
public static Func<TModel, TResult> Compile(Expression<Func<TModel, TResult>> expression)
|
||||
public static Func<TModel, object> Compile(Expression<Func<TModel, TResult>> expression)
|
||||
{
|
||||
Debug.Assert(expression != null);
|
||||
|
||||
|
|
@ -55,77 +62,120 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
|||
case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression:
|
||||
return CompileCapturedConstant(memberExpression, constantExpression);
|
||||
|
||||
// model => StaticMember
|
||||
// model => ModelType.StaticMember
|
||||
case MemberExpression memberExpression when memberExpression.Expression == null:
|
||||
return CompileFromStaticMemberAccess(expression, memberExpression);
|
||||
|
||||
// model => model.Member
|
||||
case MemberExpression memberExpression when memberExpression.Expression == expression.Parameters[0]:
|
||||
return CompileFromMemberAccess(expression, memberExpression);
|
||||
return CompileFromSimpleMemberAccess(expression, memberExpression);
|
||||
|
||||
// model => model.Member1.Member2
|
||||
case MemberExpression memberExpression when IsChainedPropertyAccessor(memberExpression):
|
||||
return CompileForChainedMemberAccess(expression, memberExpression);
|
||||
|
||||
default:
|
||||
return CompileSlow(expression);
|
||||
return null;
|
||||
}
|
||||
|
||||
bool IsChainedPropertyAccessor(MemberExpression memberExpression)
|
||||
{
|
||||
while (memberExpression.Expression != null)
|
||||
{
|
||||
if (memberExpression.Expression is MemberExpression leftExpression)
|
||||
{
|
||||
memberExpression = leftExpression;
|
||||
continue;
|
||||
}
|
||||
else if (memberExpression.Expression == expression.Parameters[0])
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileFromConstLookup(
|
||||
private static Func<TModel, object> CompileFromConstLookup(
|
||||
ConstantExpression constantExpression)
|
||||
{
|
||||
// model => {const}
|
||||
var constantValue = (TResult)constantExpression.Value;
|
||||
var constantValue = constantExpression.Value;
|
||||
return _ => constantValue;
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileFromIdentityFunc(
|
||||
private static Func<TModel, object> CompileFromIdentityFunc(
|
||||
Expression<Func<TModel, TResult>> expression)
|
||||
{
|
||||
// model => model
|
||||
|
||||
// Don't need to lock, as all identity funcs are identical.
|
||||
if (_identityFunc == null)
|
||||
{
|
||||
_identityFunc = expression.Compile();
|
||||
var identityFuncCore = expression.Compile();
|
||||
_identityFunc = model => identityFuncCore(model);
|
||||
}
|
||||
|
||||
return _identityFunc;
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileFromMemberAccess(
|
||||
private static Func<TModel, object> CompileFromStaticMemberAccess(
|
||||
Expression<Func<TModel, TResult>> expression,
|
||||
MemberExpression memberExpression)
|
||||
{
|
||||
// model => model.Member
|
||||
// model => ModelType.StaticMember
|
||||
if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
result = expression.Compile();
|
||||
var func = expression.Compile();
|
||||
result = model => func(model);
|
||||
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileFromStaticMemberAccess(
|
||||
private static Func<TModel, object> CompileFromSimpleMemberAccess(
|
||||
Expression<Func<TModel, TResult>> expression,
|
||||
MemberExpression memberExpression)
|
||||
{
|
||||
// model => model.StaticMember
|
||||
// Input: () => m.Member
|
||||
// Output: () => (m == null) ? null : m.Member
|
||||
if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
result = expression.Compile();
|
||||
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result);
|
||||
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, Rewrite(expression, memberExpression));
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression)
|
||||
private static Func<TModel, object> CompileForChainedMemberAccess(
|
||||
Expression<Func<TModel, TResult>> expression,
|
||||
MemberExpression memberExpression)
|
||||
{
|
||||
// model => {const}.Member (captured local variable)
|
||||
// Input: () => m.Member1.Member2
|
||||
// Output: () => (m == null || m.Member1 == null) ? null : m.Member1.Member2
|
||||
var key = new MemberExpressionCacheKey(typeof(TModel), memberExpression);
|
||||
if (_chainedMemberAccessCache.TryGetValue(key, out var result))
|
||||
{
|
||||
return result;
|
||||
}
|
||||
|
||||
var cacheableKey = key.MakeCacheable();
|
||||
result = _chainedMemberAccessCache.GetOrAdd(cacheableKey, Rewrite(expression, memberExpression));
|
||||
return result;
|
||||
}
|
||||
|
||||
private static Func<TModel, object> CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression)
|
||||
{
|
||||
// model => {const} (captured local variable)
|
||||
if (!_constMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
|
||||
{
|
||||
// rewrite as capturedLocal => ((TDeclaringType)capturedLocal).Member
|
||||
// rewrite as capturedLocal => ((TDeclaringType)capturedLocal)
|
||||
var parameterExpression = Expression.Parameter(typeof(object), "capturedLocal");
|
||||
var castExpression =
|
||||
Expression.Convert(parameterExpression, memberExpression.Member.DeclaringType);
|
||||
|
|
@ -142,10 +192,74 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
|||
return _ => result(capturedLocal);
|
||||
}
|
||||
|
||||
private static Func<TModel, TResult> CompileSlow(Expression<Func<TModel, TResult>> expression)
|
||||
private static Func<TModel, object> Rewrite(
|
||||
Expression<Func<TModel, TResult>> expression,
|
||||
MemberExpression memberExpression)
|
||||
{
|
||||
// fallback compilation system - just compile the expression directly
|
||||
return expression.Compile();
|
||||
Expression combinedNullTest = null;
|
||||
var currentExpression = memberExpression;
|
||||
|
||||
while (currentExpression != null)
|
||||
{
|
||||
AddNullCheck(currentExpression.Expression, ref combinedNullTest);
|
||||
|
||||
if (currentExpression.Expression is MemberExpression leftExpression)
|
||||
{
|
||||
currentExpression = leftExpression;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
var body = expression.Body;
|
||||
|
||||
// Cast the entire expression to object in case Member is a value type. This is required for us to be able to
|
||||
// express the null conditional statement m == null ? null : (object)m.IntValue
|
||||
if (body.Type.IsValueType)
|
||||
{
|
||||
body = Expression.Convert(body, typeof(object));
|
||||
}
|
||||
|
||||
if (combinedNullTest != null)
|
||||
{
|
||||
Debug.Assert(combinedNullTest.Type == typeof(bool));
|
||||
body = Expression.Condition(
|
||||
combinedNullTest,
|
||||
Expression.Constant(value: null, body.Type),
|
||||
body);
|
||||
}
|
||||
|
||||
var rewrittenExpression = Expression.Lambda<Func<TModel, object>>(body, expression.Parameters);
|
||||
return rewrittenExpression.Compile();
|
||||
}
|
||||
|
||||
private static void AddNullCheck(Expression invokingExpression, ref Expression combinedNullTest)
|
||||
{
|
||||
var type = invokingExpression.Type;
|
||||
var isNullableValueType = type.IsValueType && Nullable.GetUnderlyingType(type) != null;
|
||||
if (type.IsValueType && !isNullableValueType)
|
||||
{
|
||||
// struct.Member where struct is not nullable. Do nothing.
|
||||
return;
|
||||
}
|
||||
|
||||
// NullableStruct.Member or Class.Member
|
||||
// type is Nullable ? (value == null) : object.ReferenceEquals(value, null)
|
||||
var nullTest = isNullableValueType ?
|
||||
Expression.Equal(invokingExpression, NullExpression) :
|
||||
Expression.ReferenceEqual(invokingExpression, NullExpression);
|
||||
|
||||
if (combinedNullTest == null)
|
||||
{
|
||||
combinedNullTest = nullTest;
|
||||
}
|
||||
else
|
||||
{
|
||||
// m == null || m.Member == null
|
||||
combinedNullTest = Expression.OrElse(nullTest, combinedNullTest);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,90 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq.Expressions;
|
||||
using System.Reflection;
|
||||
|
||||
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
||||
{
|
||||
internal struct MemberExpressionCacheKey
|
||||
{
|
||||
public MemberExpressionCacheKey(Type modelType, MemberExpression memberExpression)
|
||||
{
|
||||
ModelType = modelType;
|
||||
MemberExpression = memberExpression;
|
||||
Members = null;
|
||||
}
|
||||
|
||||
public MemberExpressionCacheKey(Type modelType, MemberInfo[] members)
|
||||
{
|
||||
ModelType = modelType;
|
||||
Members = members;
|
||||
MemberExpression = null;
|
||||
}
|
||||
|
||||
// We want to avoid caching a MemberExpression since it has references to other instances in the expression tree.
|
||||
// We instead store it as a series of MemberInfo items that comprise of the MemberExpression going from right-most
|
||||
// expression to left.
|
||||
public MemberExpressionCacheKey MakeCacheable()
|
||||
{
|
||||
var members = new List<MemberInfo>();
|
||||
foreach (var member in this)
|
||||
{
|
||||
members.Add(member);
|
||||
}
|
||||
|
||||
return new MemberExpressionCacheKey(ModelType, members.ToArray());
|
||||
}
|
||||
|
||||
public MemberExpression MemberExpression { get; }
|
||||
|
||||
public Type ModelType { get; }
|
||||
|
||||
public MemberInfo[] Members { get; }
|
||||
|
||||
public Enumerator GetEnumerator() => new Enumerator(ref this);
|
||||
|
||||
public struct Enumerator
|
||||
{
|
||||
private readonly MemberInfo[] _members;
|
||||
private int _index;
|
||||
private MemberExpression _memberExpression;
|
||||
|
||||
public Enumerator(ref MemberExpressionCacheKey key)
|
||||
{
|
||||
Current = null;
|
||||
_members = key.Members;
|
||||
_memberExpression = key.MemberExpression;
|
||||
_index = -1;
|
||||
}
|
||||
|
||||
public MemberInfo Current { get; private set; }
|
||||
|
||||
public bool MoveNext()
|
||||
{
|
||||
if (_members != null)
|
||||
{
|
||||
_index++;
|
||||
if (_index >= _members.Length)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Current = _members[_index];
|
||||
return true;
|
||||
}
|
||||
|
||||
if (_memberExpression == null)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Current = _memberExpression.Member;
|
||||
_memberExpression = _memberExpression.Expression as MemberExpression;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
||||
{
|
||||
internal class MemberExpressionCacheKeyComparer : IEqualityComparer<MemberExpressionCacheKey>
|
||||
{
|
||||
public static readonly MemberExpressionCacheKeyComparer Instance = new MemberExpressionCacheKeyComparer();
|
||||
|
||||
public bool Equals(MemberExpressionCacheKey x, MemberExpressionCacheKey y)
|
||||
{
|
||||
if (x.ModelType != y.ModelType)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var xEnumerator = x.GetEnumerator();
|
||||
var yEnumerator = y.GetEnumerator();
|
||||
|
||||
while (xEnumerator.MoveNext())
|
||||
{
|
||||
if (!yEnumerator.MoveNext())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Current is a MemberInfo instance which has a good comparer.
|
||||
if (xEnumerator.Current != yEnumerator.Current)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return !yEnumerator.MoveNext();
|
||||
}
|
||||
|
||||
public int GetHashCode(MemberExpressionCacheKey obj)
|
||||
{
|
||||
var hashCodeCombiner = new HashCodeCombiner();
|
||||
hashCodeCombiner.Add(obj.ModelType);
|
||||
|
||||
foreach (var member in obj)
|
||||
{
|
||||
hashCodeCombiner.Add(member);
|
||||
}
|
||||
|
||||
return hashCodeCombiner.CombinedHash;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,216 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Linq.Expressions;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
||||
{
|
||||
public class MemberExpressionCacheKeyComparerTest
|
||||
{
|
||||
private readonly MemberExpressionCacheKeyComparer Comparer = MemberExpressionCacheKeyComparer.Instance;
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsTrue_ForTheSameExpression()
|
||||
{
|
||||
// Arrange
|
||||
var key = GetKey(m => m.Value);
|
||||
|
||||
// Act & Assert
|
||||
VerifyEquals(key, key);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsTrue_ForDifferentInstances_OfSameExpression()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.Value);
|
||||
var key2 = GetKey(m => m.Value);
|
||||
|
||||
// Act & Assert
|
||||
VerifyEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithReferenceTypes()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.TestModel2.Name);
|
||||
var key2 = GetKey(m => m.TestModel2.Name);
|
||||
|
||||
// Act & Assert
|
||||
VerifyEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithNullableValueTypes()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.NullableDateTime.Value.TimeOfDay);
|
||||
var key2 = GetKey(m => m.NullableDateTime.Value.TimeOfDay);
|
||||
|
||||
// Act & Assert
|
||||
VerifyEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithValueTypes()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.DateTime.Year);
|
||||
var key2 = GetKey(m => m.DateTime.Year);
|
||||
|
||||
// Act & Assert
|
||||
VerifyEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_ForDifferentExpression()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.Value);
|
||||
var key2 = GetKey(m => m.TestModel2.Name);
|
||||
|
||||
// Act & Assert
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_ForChainedExpressions()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.TestModel2.Id);
|
||||
var key2 = GetKey(m => m.TestModel2.Name);
|
||||
|
||||
// Act & Assert
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_ForChainedExpressions_WithValueTypes()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.DateTime.Ticks);
|
||||
var key2 = GetKey(m => m.DateTime.Year);
|
||||
|
||||
// Act & Assert
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_ForChainedExpressions_DifferingByNullable()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.DateTime.Ticks);
|
||||
var key2 = GetKey(m => m.NullableDateTime.Value.Ticks);
|
||||
|
||||
// Act & Assert
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_WhenOneExpressionIsSubsetOfOther()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.TestModel2);
|
||||
var key2 = GetKey(m => m.TestModel2.Name);
|
||||
|
||||
// Act & Assert
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughNullableProperty()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey(m => m.NullableDateTime.Value.Year);
|
||||
var key2 = GetKey(m => m.DateTime.Year);
|
||||
|
||||
// Act
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughDifferentModels()
|
||||
{
|
||||
// Arrange
|
||||
var key1 = GetKey<TestModel2, int>(m => m.Id);
|
||||
var key2 = GetKey(m => m.TestModel2.Id);
|
||||
|
||||
// Act
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughConstantExpression()
|
||||
{
|
||||
// Arrange
|
||||
var testModel = new TestModel2 { Id = 1 };
|
||||
var key1 = GetKey(m => testModel.Id);
|
||||
var key2 = GetKey<TestModel2, int>(m => m.Id);
|
||||
|
||||
// Act
|
||||
VerifyNotEquals(key1, key2);
|
||||
}
|
||||
|
||||
private void VerifyEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2)
|
||||
{
|
||||
Assert.Equal(key1, key2, Comparer);
|
||||
|
||||
var hashCode1 = Comparer.GetHashCode(key1);
|
||||
var hashCode2 = Comparer.GetHashCode(key2);
|
||||
Assert.Equal(hashCode1, hashCode2);
|
||||
|
||||
var cachedKey1 = key1.MakeCacheable();
|
||||
|
||||
Assert.Equal(key1, cachedKey1, Comparer);
|
||||
Assert.Equal(cachedKey1, key1, Comparer);
|
||||
|
||||
var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1);
|
||||
Assert.Equal(hashCode1, cachedKeyHashCode1);
|
||||
}
|
||||
|
||||
private void VerifyNotEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2)
|
||||
{
|
||||
var hashCode1 = Comparer.GetHashCode(key1);
|
||||
var hashCode2 = Comparer.GetHashCode(key2);
|
||||
|
||||
Assert.NotEqual(hashCode1, hashCode2);
|
||||
Assert.NotEqual(key1, key2, Comparer);
|
||||
|
||||
var cachedKey1 = key1.MakeCacheable();
|
||||
Assert.NotEqual(key2, cachedKey1, Comparer);
|
||||
|
||||
var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1);
|
||||
Assert.NotEqual(cachedKeyHashCode1, hashCode2);
|
||||
}
|
||||
|
||||
private static MemberExpressionCacheKey GetKey<TResult>(Expression<Func<TestModel, TResult>> expresssion)
|
||||
=> GetKey<TestModel, TResult>(expresssion);
|
||||
|
||||
private static MemberExpressionCacheKey GetKey<TModel, TResult>(Expression<Func<TModel, TResult>> expresssion)
|
||||
{
|
||||
var memberExpression = Assert.IsAssignableFrom<MemberExpression>(expresssion.Body);
|
||||
return new MemberExpressionCacheKey(typeof(TModel), memberExpression);
|
||||
}
|
||||
|
||||
public class TestModel
|
||||
{
|
||||
public string Value { get; set; }
|
||||
|
||||
public TestModel2 TestModel2 { get; set; }
|
||||
|
||||
public DateTime DateTime { get; set; }
|
||||
|
||||
public DateTime? NullableDateTime { get; set; }
|
||||
}
|
||||
|
||||
public class TestModel2
|
||||
{
|
||||
public string Name { get; set; }
|
||||
|
||||
public int Id { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq.Expressions;
|
||||
using System.Reflection;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
|
||||
{
|
||||
public class MemberExpressionCacheKeyTest
|
||||
{
|
||||
[Fact]
|
||||
public void GetEnumerator_ReturnsMembers()
|
||||
{
|
||||
// Arrange
|
||||
var expected = new[]
|
||||
{
|
||||
typeof(TestModel3).GetProperty(nameof(TestModel3.Value)),
|
||||
typeof(TestModel2).GetProperty(nameof(TestModel2.TestModel3)),
|
||||
typeof(TestModel).GetProperty(nameof(TestModel.TestModel2)),
|
||||
};
|
||||
|
||||
var key = GetKey(m => m.TestModel2.TestModel3.Value);
|
||||
|
||||
// Act
|
||||
var actual = GetMembers(key);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(expected, actual);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetEnumerator_WithNullableType_ReturnsMembers()
|
||||
{
|
||||
// Arrange
|
||||
var expected = new[]
|
||||
{
|
||||
typeof(DateTime).GetProperty(nameof(DateTime.Ticks)),
|
||||
typeof(DateTime?).GetProperty(nameof(Nullable<DateTime>.Value)),
|
||||
typeof(TestModel).GetProperty(nameof(TestModel.NullableDateTime)),
|
||||
};
|
||||
|
||||
var key = GetKey(m => m.NullableDateTime.Value.Ticks);
|
||||
|
||||
// Act
|
||||
var actual = GetMembers(key);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(expected, actual);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void GetEnumerator_WithValueType_ReturnsMembers()
|
||||
{
|
||||
// Arrange
|
||||
var expected = new[]
|
||||
{
|
||||
typeof(DateTime).GetProperty(nameof(DateTime.Ticks)),
|
||||
typeof(TestModel).GetProperty(nameof(TestModel.DateTime)),
|
||||
};
|
||||
|
||||
var key = GetKey(m => m.DateTime.Ticks);
|
||||
|
||||
// Act
|
||||
var actual = GetMembers(key);
|
||||
|
||||
// Assert
|
||||
Assert.Equal(expected, actual);
|
||||
}
|
||||
|
||||
private static MemberExpressionCacheKey GetKey<TResult>(Expression<Func<TestModel, TResult>> expresssion)
|
||||
{
|
||||
var memberExpression = Assert.IsAssignableFrom<MemberExpression>(expresssion.Body);
|
||||
return new MemberExpressionCacheKey(typeof(TestModel), memberExpression);
|
||||
}
|
||||
|
||||
private static IList<MemberInfo> GetMembers(MemberExpressionCacheKey key)
|
||||
{
|
||||
var members = new List<MemberInfo>();
|
||||
foreach (var member in key)
|
||||
{
|
||||
members.Add(member);
|
||||
}
|
||||
|
||||
return members;
|
||||
}
|
||||
|
||||
public class TestModel
|
||||
{
|
||||
public TestModel2 TestModel2 { get; set; }
|
||||
|
||||
public DateTime DateTime { get; set; }
|
||||
|
||||
public DateTime? NullableDateTime { get; set; }
|
||||
}
|
||||
|
||||
public class TestModel2
|
||||
{
|
||||
public string Name { get; set; }
|
||||
|
||||
public TestModel3 TestModel3 { get; set; }
|
||||
}
|
||||
|
||||
public class TestModel3
|
||||
{
|
||||
public string Value { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue