diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionHelper.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionHelper.cs index 90df635a9c..4de9146bb2 100644 --- a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionHelper.cs +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionHelper.cs @@ -242,7 +242,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal try { - func = CachedExpressionCompiler.Process(lambda); + func = CachedExpressionCompiler.Process(lambda) ?? lambda.Compile(); } catch (InvalidOperationException ex) { @@ -259,8 +259,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal public static bool IsSingleArgumentIndexer(Expression expression) { - var methodExpression = expression as MethodCallExpression; - if (methodExpression == null || methodExpression.Arguments.Count != 1) + if (!(expression is MethodCallExpression methodExpression) || methodExpression.Arguments.Count != 1) { return false; } diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionMetadataProvider.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionMetadataProvider.cs index 56d7774ad7..0f8b11eacb 100644 --- a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionMetadataProvider.cs +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/ExpressionMetadataProvider.cs @@ -82,11 +82,17 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal object modelAccessor(object container) { - var compiledExpression = CachedExpressionCompiler.Process(expression); - Debug.Assert(compiledExpression != null); + var model = (TModel)container; + var cachedFunc = CachedExpressionCompiler.Process(expression); + if (cachedFunc != null) + { + return cachedFunc(model); + } + + var func = expression.Compile(); try { - return compiledExpression((TModel)container); + return func(model); } catch (NullReferenceException) { diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/CachedExpressionCompiler.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/CachedExpressionCompiler.cs index f41448d8f3..03c0b31da9 100644 --- a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/CachedExpressionCompiler.cs +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/CachedExpressionCompiler.cs @@ -11,12 +11,16 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures { internal static class CachedExpressionCompiler { - // This is the entry point to the cached expression compilation system. The system - // will try to turn the expression into an actual delegate as quickly as possible, - // relying on cache lookups and other techniques to save time if appropriate. - // If the provided expression is particularly obscure and the system doesn't know - // how to handle it, we'll just compile the expression as normal. - public static Func Process( + private static readonly Expression NullExpression = Expression.Constant(value: null); + + /// + /// This is the entry point to the expression compilation system. The system + /// a) Will rewrite the expression to avoid null refs when any part of the expression tree is evaluated to null + /// b) Attempt to cache the result, or an intermediate part of the result. + /// If the provided expression is particularly obscure and the system doesn't know how to handle it, it will + /// return null. + /// + public static Func Process( Expression> expression) { if (expression == null) @@ -29,15 +33,18 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures private static class Compiler { - private static Func _identityFunc; + private static Func _identityFunc; - private static readonly ConcurrentDictionary> _simpleMemberAccessCache = - new ConcurrentDictionary>(); + private static readonly ConcurrentDictionary> _simpleMemberAccessCache = + new ConcurrentDictionary>(); + + private static readonly ConcurrentDictionary> _chainedMemberAccessCache = + new ConcurrentDictionary>(MemberExpressionCacheKeyComparer.Instance); private static readonly ConcurrentDictionary> _constMemberAccessCache = new ConcurrentDictionary>(); - public static Func Compile(Expression> expression) + public static Func Compile(Expression> expression) { Debug.Assert(expression != null); @@ -55,77 +62,120 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression: return CompileCapturedConstant(memberExpression, constantExpression); - // model => StaticMember + // model => ModelType.StaticMember case MemberExpression memberExpression when memberExpression.Expression == null: return CompileFromStaticMemberAccess(expression, memberExpression); // model => model.Member case MemberExpression memberExpression when memberExpression.Expression == expression.Parameters[0]: - return CompileFromMemberAccess(expression, memberExpression); + return CompileFromSimpleMemberAccess(expression, memberExpression); + + // model => model.Member1.Member2 + case MemberExpression memberExpression when IsChainedPropertyAccessor(memberExpression): + return CompileForChainedMemberAccess(expression, memberExpression); default: - return CompileSlow(expression); + return null; + } + + bool IsChainedPropertyAccessor(MemberExpression memberExpression) + { + while (memberExpression.Expression != null) + { + if (memberExpression.Expression is MemberExpression leftExpression) + { + memberExpression = leftExpression; + continue; + } + else if (memberExpression.Expression == expression.Parameters[0]) + { + return true; + } + + break; + } + + return false; } } - private static Func CompileFromConstLookup( + private static Func CompileFromConstLookup( ConstantExpression constantExpression) { // model => {const} - var constantValue = (TResult)constantExpression.Value; + var constantValue = constantExpression.Value; return _ => constantValue; } - private static Func CompileFromIdentityFunc( + private static Func CompileFromIdentityFunc( Expression> expression) { // model => model - // Don't need to lock, as all identity funcs are identical. if (_identityFunc == null) { - _identityFunc = expression.Compile(); + var identityFuncCore = expression.Compile(); + _identityFunc = model => identityFuncCore(model); } return _identityFunc; } - private static Func CompileFromMemberAccess( + private static Func CompileFromStaticMemberAccess( Expression> expression, MemberExpression memberExpression) { - // model => model.Member + // model => ModelType.StaticMember if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result)) { return result; } - result = expression.Compile(); + var func = expression.Compile(); + result = model => func(model); result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result); + return result; } - private static Func CompileFromStaticMemberAccess( + private static Func CompileFromSimpleMemberAccess( Expression> expression, MemberExpression memberExpression) { - // model => model.StaticMember + // Input: () => m.Member + // Output: () => (m == null) ? null : m.Member if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result)) { return result; } - result = expression.Compile(); - result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result); + result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, Rewrite(expression, memberExpression)); return result; } - private static Func CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression) + private static Func CompileForChainedMemberAccess( + Expression> expression, + MemberExpression memberExpression) { - // model => {const}.Member (captured local variable) + // Input: () => m.Member1.Member2 + // Output: () => (m == null || m.Member1 == null) ? null : m.Member1.Member2 + var key = new MemberExpressionCacheKey(typeof(TModel), memberExpression); + if (_chainedMemberAccessCache.TryGetValue(key, out var result)) + { + return result; + } + + var cacheableKey = key.MakeCacheable(); + result = _chainedMemberAccessCache.GetOrAdd(cacheableKey, Rewrite(expression, memberExpression)); + return result; + } + + private static Func CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression) + { + // model => {const} (captured local variable) if (!_constMemberAccessCache.TryGetValue(memberExpression.Member, out var result)) { - // rewrite as capturedLocal => ((TDeclaringType)capturedLocal).Member + // rewrite as capturedLocal => ((TDeclaringType)capturedLocal) var parameterExpression = Expression.Parameter(typeof(object), "capturedLocal"); var castExpression = Expression.Convert(parameterExpression, memberExpression.Member.DeclaringType); @@ -142,10 +192,74 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures return _ => result(capturedLocal); } - private static Func CompileSlow(Expression> expression) + private static Func Rewrite( + Expression> expression, + MemberExpression memberExpression) { - // fallback compilation system - just compile the expression directly - return expression.Compile(); + Expression combinedNullTest = null; + var currentExpression = memberExpression; + + while (currentExpression != null) + { + AddNullCheck(currentExpression.Expression, ref combinedNullTest); + + if (currentExpression.Expression is MemberExpression leftExpression) + { + currentExpression = leftExpression; + } + else + { + break; + } + } + + var body = expression.Body; + + // Cast the entire expression to object in case Member is a value type. This is required for us to be able to + // express the null conditional statement m == null ? null : (object)m.IntValue + if (body.Type.IsValueType) + { + body = Expression.Convert(body, typeof(object)); + } + + if (combinedNullTest != null) + { + Debug.Assert(combinedNullTest.Type == typeof(bool)); + body = Expression.Condition( + combinedNullTest, + Expression.Constant(value: null, body.Type), + body); + } + + var rewrittenExpression = Expression.Lambda>(body, expression.Parameters); + return rewrittenExpression.Compile(); + } + + private static void AddNullCheck(Expression invokingExpression, ref Expression combinedNullTest) + { + var type = invokingExpression.Type; + var isNullableValueType = type.IsValueType && Nullable.GetUnderlyingType(type) != null; + if (type.IsValueType && !isNullableValueType) + { + // struct.Member where struct is not nullable. Do nothing. + return; + } + + // NullableStruct.Member or Class.Member + // type is Nullable ? (value == null) : object.ReferenceEquals(value, null) + var nullTest = isNullableValueType ? + Expression.Equal(invokingExpression, NullExpression) : + Expression.ReferenceEqual(invokingExpression, NullExpression); + + if (combinedNullTest == null) + { + combinedNullTest = nullTest; + } + else + { + // m == null || m.Member == null + combinedNullTest = Expression.OrElse(nullTest, combinedNullTest); + } } } } diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKey.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKey.cs new file mode 100644 index 0000000000..1cbc80c38a --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKey.cs @@ -0,0 +1,90 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; + +namespace Microsoft.AspNetCore.Mvc.ViewFeatures +{ + internal struct MemberExpressionCacheKey + { + public MemberExpressionCacheKey(Type modelType, MemberExpression memberExpression) + { + ModelType = modelType; + MemberExpression = memberExpression; + Members = null; + } + + public MemberExpressionCacheKey(Type modelType, MemberInfo[] members) + { + ModelType = modelType; + Members = members; + MemberExpression = null; + } + + // We want to avoid caching a MemberExpression since it has references to other instances in the expression tree. + // We instead store it as a series of MemberInfo items that comprise of the MemberExpression going from right-most + // expression to left. + public MemberExpressionCacheKey MakeCacheable() + { + var members = new List(); + foreach (var member in this) + { + members.Add(member); + } + + return new MemberExpressionCacheKey(ModelType, members.ToArray()); + } + + public MemberExpression MemberExpression { get; } + + public Type ModelType { get; } + + public MemberInfo[] Members { get; } + + public Enumerator GetEnumerator() => new Enumerator(ref this); + + public struct Enumerator + { + private readonly MemberInfo[] _members; + private int _index; + private MemberExpression _memberExpression; + + public Enumerator(ref MemberExpressionCacheKey key) + { + Current = null; + _members = key.Members; + _memberExpression = key.MemberExpression; + _index = -1; + } + + public MemberInfo Current { get; private set; } + + public bool MoveNext() + { + if (_members != null) + { + _index++; + if (_index >= _members.Length) + { + return false; + } + + Current = _members[_index]; + return true; + } + + if (_memberExpression == null) + { + return false; + } + + Current = _memberExpression.Member; + _memberExpression = _memberExpression.Expression as MemberExpression; + return true; + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKeyComparer.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKeyComparer.cs new file mode 100644 index 0000000000..5911611c89 --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/MemberExpressionCacheKeyComparer.cs @@ -0,0 +1,53 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using Microsoft.Extensions.Internal; + +namespace Microsoft.AspNetCore.Mvc.ViewFeatures +{ + internal class MemberExpressionCacheKeyComparer : IEqualityComparer + { + public static readonly MemberExpressionCacheKeyComparer Instance = new MemberExpressionCacheKeyComparer(); + + public bool Equals(MemberExpressionCacheKey x, MemberExpressionCacheKey y) + { + if (x.ModelType != y.ModelType) + { + return false; + } + + var xEnumerator = x.GetEnumerator(); + var yEnumerator = y.GetEnumerator(); + + while (xEnumerator.MoveNext()) + { + if (!yEnumerator.MoveNext()) + { + return false; + } + + // Current is a MemberInfo instance which has a good comparer. + if (xEnumerator.Current != yEnumerator.Current) + { + return false; + } + } + + return !yEnumerator.MoveNext(); + } + + public int GetHashCode(MemberExpressionCacheKey obj) + { + var hashCodeCombiner = new HashCodeCombiner(); + hashCodeCombiner.Add(obj.ModelType); + + foreach (var member in obj) + { + hashCodeCombiner.Add(member); + } + + return hashCodeCombiner.CombinedHash; + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyComparerTest.cs b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyComparerTest.cs new file mode 100644 index 0000000000..164b899151 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyComparerTest.cs @@ -0,0 +1,216 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq.Expressions; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.ViewFeatures +{ + public class MemberExpressionCacheKeyComparerTest + { + private readonly MemberExpressionCacheKeyComparer Comparer = MemberExpressionCacheKeyComparer.Instance; + + [Fact] + public void Equals_ReturnsTrue_ForTheSameExpression() + { + // Arrange + var key = GetKey(m => m.Value); + + // Act & Assert + VerifyEquals(key, key); + } + + [Fact] + public void Equals_ReturnsTrue_ForDifferentInstances_OfSameExpression() + { + // Arrange + var key1 = GetKey(m => m.Value); + var key2 = GetKey(m => m.Value); + + // Act & Assert + VerifyEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithReferenceTypes() + { + // Arrange + var key1 = GetKey(m => m.TestModel2.Name); + var key2 = GetKey(m => m.TestModel2.Name); + + // Act & Assert + VerifyEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithNullableValueTypes() + { + // Arrange + var key1 = GetKey(m => m.NullableDateTime.Value.TimeOfDay); + var key2 = GetKey(m => m.NullableDateTime.Value.TimeOfDay); + + // Act & Assert + VerifyEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithValueTypes() + { + // Arrange + var key1 = GetKey(m => m.DateTime.Year); + var key2 = GetKey(m => m.DateTime.Year); + + // Act & Assert + VerifyEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_ForDifferentExpression() + { + // Arrange + var key1 = GetKey(m => m.Value); + var key2 = GetKey(m => m.TestModel2.Name); + + // Act & Assert + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_ForChainedExpressions() + { + // Arrange + var key1 = GetKey(m => m.TestModel2.Id); + var key2 = GetKey(m => m.TestModel2.Name); + + // Act & Assert + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_ForChainedExpressions_WithValueTypes() + { + // Arrange + var key1 = GetKey(m => m.DateTime.Ticks); + var key2 = GetKey(m => m.DateTime.Year); + + // Act & Assert + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_ForChainedExpressions_DifferingByNullable() + { + // Arrange + var key1 = GetKey(m => m.DateTime.Ticks); + var key2 = GetKey(m => m.NullableDateTime.Value.Ticks); + + // Act & Assert + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_WhenOneExpressionIsSubsetOfOther() + { + // Arrange + var key1 = GetKey(m => m.TestModel2); + var key2 = GetKey(m => m.TestModel2.Name); + + // Act & Assert + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughNullableProperty() + { + // Arrange + var key1 = GetKey(m => m.NullableDateTime.Value.Year); + var key2 = GetKey(m => m.DateTime.Year); + + // Act + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughDifferentModels() + { + // Arrange + var key1 = GetKey(m => m.Id); + var key2 = GetKey(m => m.TestModel2.Id); + + // Act + VerifyNotEquals(key1, key2); + } + + [Fact] + public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughConstantExpression() + { + // Arrange + var testModel = new TestModel2 { Id = 1 }; + var key1 = GetKey(m => testModel.Id); + var key2 = GetKey(m => m.Id); + + // Act + VerifyNotEquals(key1, key2); + } + + private void VerifyEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2) + { + Assert.Equal(key1, key2, Comparer); + + var hashCode1 = Comparer.GetHashCode(key1); + var hashCode2 = Comparer.GetHashCode(key2); + Assert.Equal(hashCode1, hashCode2); + + var cachedKey1 = key1.MakeCacheable(); + + Assert.Equal(key1, cachedKey1, Comparer); + Assert.Equal(cachedKey1, key1, Comparer); + + var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1); + Assert.Equal(hashCode1, cachedKeyHashCode1); + } + + private void VerifyNotEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2) + { + var hashCode1 = Comparer.GetHashCode(key1); + var hashCode2 = Comparer.GetHashCode(key2); + + Assert.NotEqual(hashCode1, hashCode2); + Assert.NotEqual(key1, key2, Comparer); + + var cachedKey1 = key1.MakeCacheable(); + Assert.NotEqual(key2, cachedKey1, Comparer); + + var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1); + Assert.NotEqual(cachedKeyHashCode1, hashCode2); + } + + private static MemberExpressionCacheKey GetKey(Expression> expresssion) + => GetKey(expresssion); + + private static MemberExpressionCacheKey GetKey(Expression> expresssion) + { + var memberExpression = Assert.IsAssignableFrom(expresssion.Body); + return new MemberExpressionCacheKey(typeof(TModel), memberExpression); + } + + public class TestModel + { + public string Value { get; set; } + + public TestModel2 TestModel2 { get; set; } + + public DateTime DateTime { get; set; } + + public DateTime? NullableDateTime { get; set; } + } + + public class TestModel2 + { + public string Name { get; set; } + + public int Id { get; set; } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyTest.cs b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyTest.cs new file mode 100644 index 0000000000..5de4c308c3 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/MemberExpressionCacheKeyTest.cs @@ -0,0 +1,111 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq.Expressions; +using System.Reflection; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.ViewFeatures +{ + public class MemberExpressionCacheKeyTest + { + [Fact] + public void GetEnumerator_ReturnsMembers() + { + // Arrange + var expected = new[] + { + typeof(TestModel3).GetProperty(nameof(TestModel3.Value)), + typeof(TestModel2).GetProperty(nameof(TestModel2.TestModel3)), + typeof(TestModel).GetProperty(nameof(TestModel.TestModel2)), + }; + + var key = GetKey(m => m.TestModel2.TestModel3.Value); + + // Act + var actual = GetMembers(key); + + // Assert + Assert.Equal(expected, actual); + } + + [Fact] + public void GetEnumerator_WithNullableType_ReturnsMembers() + { + // Arrange + var expected = new[] + { + typeof(DateTime).GetProperty(nameof(DateTime.Ticks)), + typeof(DateTime?).GetProperty(nameof(Nullable.Value)), + typeof(TestModel).GetProperty(nameof(TestModel.NullableDateTime)), + }; + + var key = GetKey(m => m.NullableDateTime.Value.Ticks); + + // Act + var actual = GetMembers(key); + + // Assert + Assert.Equal(expected, actual); + } + + [Fact] + public void GetEnumerator_WithValueType_ReturnsMembers() + { + // Arrange + var expected = new[] + { + typeof(DateTime).GetProperty(nameof(DateTime.Ticks)), + typeof(TestModel).GetProperty(nameof(TestModel.DateTime)), + }; + + var key = GetKey(m => m.DateTime.Ticks); + + // Act + var actual = GetMembers(key); + + // Assert + Assert.Equal(expected, actual); + } + + private static MemberExpressionCacheKey GetKey(Expression> expresssion) + { + var memberExpression = Assert.IsAssignableFrom(expresssion.Body); + return new MemberExpressionCacheKey(typeof(TestModel), memberExpression); + } + + private static IList GetMembers(MemberExpressionCacheKey key) + { + var members = new List(); + foreach (var member in key) + { + members.Add(member); + } + + return members; + } + + public class TestModel + { + public TestModel2 TestModel2 { get; set; } + + public DateTime DateTime { get; set; } + + public DateTime? NullableDateTime { get; set; } + } + + public class TestModel2 + { + public string Name { get; set; } + + public TestModel3 TestModel3 { get; set; } + } + + public class TestModel3 + { + public string Value { get; set; } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/CachedExpressionCompilerTest.cs b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/CachedExpressionCompilerTest.cs new file mode 100644 index 0000000000..a798d5594e --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/CachedExpressionCompilerTest.cs @@ -0,0 +1,1012 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq.Expressions; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.ViewFeatures +{ + public class CachedExpressionCompilerTest + { + [Fact] + public void Process_IdentityExpression() + { + // Arrange + var model = new TestModel(); + var expression = GetTestModelExpression(m => m); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Same(model, result); + } + + [Fact] + public void Process_CachesIdentityExpression() + { + // Arrange + var expression1 = GetTestModelExpression(m => m); + var expression2 = GetTestModelExpression(m => m); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_ConstLookup() + { + // Arrange + var model = new TestModel(); + var differentModel = new DifferentModel(); + var expression = GetTestModelExpression(m => differentModel); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Same(differentModel, result); + } + + [Fact] + public void Process_ConstLookup_ReturningNull() + { + // Arrange + var model = new TestModel(); + var expression = GetTestModelExpression(m => (DifferentModel)null); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ConstLookup_WithNullModel() + { + // Arrange + var differentModel = new DifferentModel(); + var expression = GetTestModelExpression(m => differentModel); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(null); + Assert.Same(differentModel, result); + } + + [Fact] + public void Process_ConstLookup_UsingCachedValue() + { + // Arrange + var model = new TestModel(); + var differentModel = new DifferentModel(); + var expression1 = GetTestModelExpression(m => differentModel); + var expression2 = GetTestModelExpression(m => differentModel); + + // Act - 1 + var func1 = CachedExpressionCompiler.Process(expression1); + + // Assert - 1 + var result1 = func1(null); + Assert.Same(differentModel, result1); + + // Act - 2 + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert - 2 + var result2 = func1(null); + Assert.Same(differentModel, result2); + } + + [Fact] + public void Process_ConstLookup_WhenCapturedLocalChanges() + { + // Arrange + var model = new TestModel(); + var differentModel = new DifferentModel(); + var expression = GetTestModelExpression(m => differentModel); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert - 1 + var result1 = func(null); + Assert.Same(differentModel, result1); + + // Act - 2 + differentModel = new DifferentModel(); + + // Assert - 2 + var result2 = func(null); + Assert.NotSame(differentModel, result1); + Assert.Same(differentModel, result2); + } + + [Fact] + public void Process_ConstLookup_WithPrimitiveConstant() + { + // Arrange + var model = new TestModel(); + var expression = GetTestModelExpression(m => 10); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(10, result); + } + + [Fact] + public void Process_StaticFieldAccess() + { + // Arrange + var model = new TestModel(); + var expression = GetTestModelExpression(m => TestModel.StaticField); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal("StaticValue", result); + } + + [Fact] + public void Process_CachesStaticFieldAccess() + { + // Arrange + var expression1 = GetTestModelExpression(m => TestModel.StaticField); + var expression2 = GetTestModelExpression(m => TestModel.StaticField); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_StaticPropertyAccess() + { + // Arrange + var expected = "TestValue"; + TestModel.StaticProperty = expected; + var model = new TestModel(); + var expression = GetTestModelExpression(m => TestModel.StaticProperty); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_CachesStaticPropertyAccess() + { + // Arrange + var expression1 = GetTestModelExpression(m => TestModel.StaticProperty); + var expression2 = GetTestModelExpression(m => TestModel.StaticProperty); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_StaticPropertyAccess_WithNullModel() + { + // Arrange + var expected = "TestValue"; + TestModel.StaticProperty = expected; + var expression = GetTestModelExpression(m => TestModel.StaticProperty); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(null); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ConstFieldLookup() + { + // Arrange + var model = new TestModel(); + var expression = GetTestModelExpression(m => DifferentModel.Constant); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(10, result); + } + + [Fact] + public void Process_ConstFieldLookup_WthNullModel() + { + // Arrange + var expression = GetTestModelExpression(m => DifferentModel.Constant); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(null); + Assert.Equal(10, result); + } + + [Fact] + public void Process_SimpleMemberAccess() + { + // Arrange + var model = new TestModel { Name = "Test" }; + var expression = GetTestModelExpression(m => m.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal("Test", result); + } + + [Fact] + public void Process_CachesSimpleMemberAccess() + { + // Arrange + var expression1 = GetTestModelExpression(m => m.Name); + var expression2 = GetTestModelExpression(m => m.Name); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_SimpleMemberAccess_ToPrimitive() + { + // Arrange + var model = new TestModel { Age = 12 }; + var expression = GetTestModelExpression(m => m.Age); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(12, result); + } + + [Fact] + public void Process_SimpleMemberAccess_WithNullModel() + { + // Arrange + var model = (TestModel)null; + var expression = GetTestModelExpression(m => m.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_SimpleMemberAccess_ToPrimitive_WithNullModel() + { + // Arrange + var model = (TestModel)null; + var expression = GetTestModelExpression(m => m.Age); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnTypeWithBadEqualityComparer() + { + // Arrange + var model = new BadEqualityModel { Id = 7 }; + var expression = GetExpression(m => m.Id); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(7, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnTypeWithBadEqualityComparer_WithNullModel() + { + // Arrange + var model = (BadEqualityModel)null; + var expression = GetExpression(m => m.Id); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnValueTypeWithBadEqualityComparer() + { + // Arrange + var model = new BadEqualityValueTypeModel { Id = 7 }; + var expression = GetExpression(m => m.Id); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(7, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnTypeWithBadEqualityComparer_WithDefaultValue() + { + // Arrange + var model = (BadEqualityValueTypeModel)default; + var expression = GetExpression(m => m.Id); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(model.Id, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnValueType() + { + // Arrange + var model = new DateTime(2000, 1, 1); + var expression = GetExpression(m => m.Year); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(2000, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnValueType_WithDefaultValue() + { + // Arrange + var model = default(DateTime); + var expression = GetExpression(m => m.Year); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(1, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnNullableValueType() + { + // Arrange + var model = new DateTime(2000, 1, 1); + var nullableModel = (DateTime?)model; + var expression = GetExpression(m => m.Value); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(nullableModel); + Assert.Equal(model, result); + } + + [Fact] + public void Process_SimpleMemberAccess_OnNullableValueType_WithNullValue() + { + // Arrange + var nullableModel = (DateTime?)null; + var expression = GetExpression(m => m.Value); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(nullableModel); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_ToValueType() + { + // Arrange + var dateTime = new DateTime(2000, 1, 1); + var model = new TestModel { Date = dateTime }; + var expression = GetTestModelExpression(m => m.Date.Year); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(2000, result); + } + + [Fact] + public void Process_ChainedMemberAccess_ToValueType_WithNullModel() + { + // Arrange + var model = (TestModel)null; + var expression = GetTestModelExpression(m => m.Date.Year); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_ToReferenceType() + { + // Arrange + var expected = "Test1"; + var model = new TestModel { DifferentModel = new DifferentModel { Name = expected } }; + var expression = GetTestModelExpression(m => m.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_CachesChainedMemberAccess() + { + // Arrange + var expression1 = GetTestModelExpression(m => m.DifferentModel.Name); + var expression2 = GetTestModelExpression(m => m.DifferentModel.Name); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_CachesChainedMemberAccess_ToValueType() + { + // Arrange + var expression1 = GetTestModelExpression(m => m.Date.Year); + var expression2 = GetTestModelExpression(m => m.Date.Year); + + // Act + var func1 = CachedExpressionCompiler.Process(expression1); + var func2 = CachedExpressionCompiler.Process(expression2); + + // Assert + Assert.NotNull(func1); + Assert.Same(func1, func2); + } + + [Fact] + public void Process_ChainedMemberAccess_LongChain_WithReferenceType() + { + // Arrange + var expected = "TestVal"; + var model = new Chain0Model + { + Chain1 = new Chain1Model + { + ValueTypeModel = new ValueType1 + { + TestModel = new TestModel { DifferentModel = new DifferentModel { Name = expected } } + } + } + }; + + var expression = GetExpression(m => m.Chain1.ValueTypeModel.TestModel.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ChainedMemberAccess_LongChain_WithNullIntermediary() + { + // Arrange + var model = new Chain0Model + { + Chain1 = new Chain1Model + { + ValueTypeModel = new ValueType1 { TestModel = null }, + } + }; + + var expression = GetExpression(m => m.Chain1.ValueTypeModel.TestModel.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_LongChain_WithNullValueTypeAccessor() + { + // Arrange + // Chain2 is a value type + var model = new Chain0Model + { + Chain1 = null + }; + + var expression = GetExpression(m => m.Chain1.ValueTypeModel.TestModel.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_LongChain_WithNullableValueType() + { + // Arrange + var expected = "TestVal"; + var model = new Chain0Model + { + Chain1 = new Chain1Model + { + NullableValueTypeModel = new ValueType1 + { + TestModel = new TestModel { DifferentModel = new DifferentModel { Name = expected } } + } + } + }; + + var expression = GetExpression(m => m.Chain1.NullableValueTypeModel.Value.TestModel.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ChainedMemberAccess_LongChain_WithNullValuedNullableValueType() + { + // Arrange + var model = new Chain0Model + { + Chain1 = new Chain1Model + { + NullableValueTypeModel = null + } + }; + + var expression = GetExpression(m => m.Chain1.NullableValueTypeModel.Value.TestModel.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_ToReferenceType_WithNullIntermediary() + { + // Arrange + var model = new TestModel { DifferentModel = null }; + var expression = GetTestModelExpression(m => m.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_ToReferenceType_WithNullModel() + { + // Arrange + var model = (TestModel)null; + var expression = GetTestModelExpression(m => m.DifferentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_OfValueTypes_ReturningReferenceTypeMember() + { + // Arrange + var expected = "TestName"; + var model = new ValueType1 + { + ValueType2 = new ValueType2 { Name = expected }, + }; + var expression = GetExpression(m => m.ValueType2.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ChainedMemberAccess_OfValueTypes_ReturningValueType() + { + // Arrange + var expected = new DateTime(2001, 1, 1); + var model = new ValueType1 + { + ValueType2 = new ValueType2 { Date = expected }, + }; + var expression = GetExpression(m => m.ValueType2.Date); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ChainedMemberAccess_OfValueTypes_IncludingNullableType() + { + // Arrange + var expected = "TestName"; + var model = new ValueType1 + { + NullableValueType2 = new ValueType2 { Name = expected }, + }; + var expression = GetExpression(m => m.NullableValueType2.Value.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Equal(expected, result); + } + + [Fact] + public void Process_ChainedMemberAccess_OfValueTypes_WithNullValuedNullable() + { + // Arrange + var model = new ValueType1 { NullableValueType2 = null }; + var expression = GetExpression(m => m.NullableValueType2.Value.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_ChainedMemberAccess_OfValueTypes_WithNullValuedNullable_ReturningValueType() + { + // Arrange + var model = new ValueType1 { NullableValueType2 = null }; + var expression = GetExpression(m => m.NullableValueType2.Value.Date); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Null(result); + } + + [Fact] + public void Process_MemberAccessOnCapturedVariable_ReturnsNull() + { + // Arrange + var differentModel = new DifferentModel { Name = "Test" }; + var expression = GetTestModelExpression(m => differentModel.Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.Null(func); + } + + [Fact] + public void Process_CapturedVariable() + { + // Arrange + var differentModel = new DifferentModel(); + var model = new TestModel(); + var expression = GetTestModelExpression(m => differentModel); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Same(differentModel, result); + } + + [Fact] + public void Process_CapturedVariable_WithNullModel() + { + // Arrange + var differentModel = new DifferentModel(); + var model = (TestModel)null; + var expression = GetTestModelExpression(m => differentModel); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.NotNull(func); + var result = func(model); + Assert.Same(differentModel, result); + } + + + [Fact] + public void Process_MemberAccess_OnCapturedVariable_ReturnsNull() + { + // Arrange + var differentModel = "Hello world"; + var expression = GetTestModelExpression(m => differentModel.Length); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.Null(func); + } + + [Fact] + public void Process_ComplexChainedMemberAccess_ReturnsNull() + { + // Arrange + var expected = "SomeName"; + var model = new TestModel { DifferentModels = new[] { new DifferentModel { Name = expected } } }; + var expression = GetTestModelExpression(m => m.DifferentModels[0].Name); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.Null(func); + } + + [Fact] + public void Process_ArrayMemberAccess_ReturnsNull() + { + // Arrange + var expression = GetTestModelExpression(m => m.Sizes[1]); + + // Act + var func = CachedExpressionCompiler.Process(expression); + + // Assert + Assert.Null(func); + } + + private static Expression> GetExpression(Expression> expression) + => expression; + + private static Expression> GetTestModelExpression(Expression> expression) + => GetExpression(expression); + + public class TestModel + { + public static readonly string StaticField = "StaticValue"; + + public static string StaticProperty { get; set; } + + public int Age { get; set; } + + public string Name { get; set; } + + public DateTime Date { get; set; } + + public DifferentModel DifferentModel { get; set; } + + public int[] Sizes { get; set; } + + public DifferentModel[] DifferentModels { get; set; } + } + + public class DifferentModel + { + public const int Constant = 10; + + public string Name { get; set; } + } + + public class Chain0Model + { + public Chain1Model Chain1 { get; set; } + } + + public class Chain1Model + { + public ValueType1 ValueTypeModel { get; set; } + + public ValueType1? NullableValueTypeModel { get; set; } + } + + public struct ValueType1 + { + public TestModel TestModel { get; set; } + + public ValueType2 ValueType2 { get; set; } + + public ValueType2? NullableValueType2 { get; set; } + } + + public struct ValueType2 + { + public string Name { get; set; } + + public DateTime Date { get; set; } + } + + public class BadEqualityModel + { + public int Id { get; set; } + + public override bool Equals(object obj) + { + return this == obj; + } + + public static bool operator ==(BadEqualityModel a, object b) + { + if (a is null || b is null) + { + throw new TimeZoneNotFoundException(); + } + + return true; + } + + public static bool operator !=(BadEqualityModel a, object b) + { + return !(a == b); + } + + public override int GetHashCode() => 0; + } + + public struct BadEqualityValueTypeModel + { + public int Id { get; set; } + + public override bool Equals(object obj) + { + return this == obj; + } + + public static bool operator ==(BadEqualityValueTypeModel a, object b) + { + if (b is null) + { + throw new TimeZoneNotFoundException(); + } + + return true; + } + + public static bool operator !=(BadEqualityValueTypeModel a, object b) + { + return !(a == b); + } + + public override int GetHashCode() => 0; + } + } +}