Prevent null refs in some simple cases in CachedExpressionCompiler

Fixes #6928
This commit is contained in:
Pranav K 2018-05-04 09:05:16 -07:00
parent dacbb41478
commit 8e31319215
No known key found for this signature in database
GPG Key ID: 1963DA6D96C3057A
8 changed files with 1638 additions and 37 deletions

View File

@ -242,7 +242,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
try
{
func = CachedExpressionCompiler.Process(lambda);
func = CachedExpressionCompiler.Process(lambda) ?? lambda.Compile();
}
catch (InvalidOperationException ex)
{
@ -259,8 +259,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
public static bool IsSingleArgumentIndexer(Expression expression)
{
var methodExpression = expression as MethodCallExpression;
if (methodExpression == null || methodExpression.Arguments.Count != 1)
if (!(expression is MethodCallExpression methodExpression) || methodExpression.Arguments.Count != 1)
{
return false;
}

View File

@ -82,11 +82,17 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal
object modelAccessor(object container)
{
var compiledExpression = CachedExpressionCompiler.Process(expression);
Debug.Assert(compiledExpression != null);
var model = (TModel)container;
var cachedFunc = CachedExpressionCompiler.Process(expression);
if (cachedFunc != null)
{
return cachedFunc(model);
}
var func = expression.Compile();
try
{
return compiledExpression((TModel)container);
return func(model);
}
catch (NullReferenceException)
{

View File

@ -11,12 +11,16 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
internal static class CachedExpressionCompiler
{
// This is the entry point to the cached expression compilation system. The system
// will try to turn the expression into an actual delegate as quickly as possible,
// relying on cache lookups and other techniques to save time if appropriate.
// If the provided expression is particularly obscure and the system doesn't know
// how to handle it, we'll just compile the expression as normal.
public static Func<TModel, TResult> Process<TModel, TResult>(
private static readonly Expression NullExpression = Expression.Constant(value: null);
/// <remarks>
/// This is the entry point to the expression compilation system. The system
/// a) Will rewrite the expression to avoid null refs when any part of the expression tree is evaluated to null
/// b) Attempt to cache the result, or an intermediate part of the result.
/// If the provided expression is particularly obscure and the system doesn't know how to handle it, it will
/// return null.
/// </remarks>
public static Func<TModel, object> Process<TModel, TResult>(
Expression<Func<TModel, TResult>> expression)
{
if (expression == null)
@ -29,15 +33,18 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
private static class Compiler<TModel, TResult>
{
private static Func<TModel, TResult> _identityFunc;
private static Func<TModel, object> _identityFunc;
private static readonly ConcurrentDictionary<MemberInfo, Func<TModel, TResult>> _simpleMemberAccessCache =
new ConcurrentDictionary<MemberInfo, Func<TModel, TResult>>();
private static readonly ConcurrentDictionary<MemberInfo, Func<TModel, object>> _simpleMemberAccessCache =
new ConcurrentDictionary<MemberInfo, Func<TModel, object>>();
private static readonly ConcurrentDictionary<MemberExpressionCacheKey, Func<TModel, object>> _chainedMemberAccessCache =
new ConcurrentDictionary<MemberExpressionCacheKey, Func<TModel, object>>(MemberExpressionCacheKeyComparer.Instance);
private static readonly ConcurrentDictionary<MemberInfo, Func<object, TResult>> _constMemberAccessCache =
new ConcurrentDictionary<MemberInfo, Func<object, TResult>>();
public static Func<TModel, TResult> Compile(Expression<Func<TModel, TResult>> expression)
public static Func<TModel, object> Compile(Expression<Func<TModel, TResult>> expression)
{
Debug.Assert(expression != null);
@ -55,77 +62,120 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
case MemberExpression memberExpression when memberExpression.Expression is ConstantExpression constantExpression:
return CompileCapturedConstant(memberExpression, constantExpression);
// model => StaticMember
// model => ModelType.StaticMember
case MemberExpression memberExpression when memberExpression.Expression == null:
return CompileFromStaticMemberAccess(expression, memberExpression);
// model => model.Member
case MemberExpression memberExpression when memberExpression.Expression == expression.Parameters[0]:
return CompileFromMemberAccess(expression, memberExpression);
return CompileFromSimpleMemberAccess(expression, memberExpression);
// model => model.Member1.Member2
case MemberExpression memberExpression when IsChainedPropertyAccessor(memberExpression):
return CompileForChainedMemberAccess(expression, memberExpression);
default:
return CompileSlow(expression);
return null;
}
bool IsChainedPropertyAccessor(MemberExpression memberExpression)
{
while (memberExpression.Expression != null)
{
if (memberExpression.Expression is MemberExpression leftExpression)
{
memberExpression = leftExpression;
continue;
}
else if (memberExpression.Expression == expression.Parameters[0])
{
return true;
}
break;
}
return false;
}
}
private static Func<TModel, TResult> CompileFromConstLookup(
private static Func<TModel, object> CompileFromConstLookup(
ConstantExpression constantExpression)
{
// model => {const}
var constantValue = (TResult)constantExpression.Value;
var constantValue = constantExpression.Value;
return _ => constantValue;
}
private static Func<TModel, TResult> CompileFromIdentityFunc(
private static Func<TModel, object> CompileFromIdentityFunc(
Expression<Func<TModel, TResult>> expression)
{
// model => model
// Don't need to lock, as all identity funcs are identical.
if (_identityFunc == null)
{
_identityFunc = expression.Compile();
var identityFuncCore = expression.Compile();
_identityFunc = model => identityFuncCore(model);
}
return _identityFunc;
}
private static Func<TModel, TResult> CompileFromMemberAccess(
private static Func<TModel, object> CompileFromStaticMemberAccess(
Expression<Func<TModel, TResult>> expression,
MemberExpression memberExpression)
{
// model => model.Member
// model => ModelType.StaticMember
if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
{
return result;
}
result = expression.Compile();
var func = expression.Compile();
result = model => func(model);
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result);
return result;
}
private static Func<TModel, TResult> CompileFromStaticMemberAccess(
private static Func<TModel, object> CompileFromSimpleMemberAccess(
Expression<Func<TModel, TResult>> expression,
MemberExpression memberExpression)
{
// model => model.StaticMember
// Input: () => m.Member
// Output: () => (m == null) ? null : m.Member
if (_simpleMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
{
return result;
}
result = expression.Compile();
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, result);
result = _simpleMemberAccessCache.GetOrAdd(memberExpression.Member, Rewrite(expression, memberExpression));
return result;
}
private static Func<TModel, TResult> CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression)
private static Func<TModel, object> CompileForChainedMemberAccess(
Expression<Func<TModel, TResult>> expression,
MemberExpression memberExpression)
{
// model => {const}.Member (captured local variable)
// Input: () => m.Member1.Member2
// Output: () => (m == null || m.Member1 == null) ? null : m.Member1.Member2
var key = new MemberExpressionCacheKey(typeof(TModel), memberExpression);
if (_chainedMemberAccessCache.TryGetValue(key, out var result))
{
return result;
}
var cacheableKey = key.MakeCacheable();
result = _chainedMemberAccessCache.GetOrAdd(cacheableKey, Rewrite(expression, memberExpression));
return result;
}
private static Func<TModel, object> CompileCapturedConstant(MemberExpression memberExpression, ConstantExpression constantExpression)
{
// model => {const} (captured local variable)
if (!_constMemberAccessCache.TryGetValue(memberExpression.Member, out var result))
{
// rewrite as capturedLocal => ((TDeclaringType)capturedLocal).Member
// rewrite as capturedLocal => ((TDeclaringType)capturedLocal)
var parameterExpression = Expression.Parameter(typeof(object), "capturedLocal");
var castExpression =
Expression.Convert(parameterExpression, memberExpression.Member.DeclaringType);
@ -142,10 +192,74 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures
return _ => result(capturedLocal);
}
private static Func<TModel, TResult> CompileSlow(Expression<Func<TModel, TResult>> expression)
private static Func<TModel, object> Rewrite(
Expression<Func<TModel, TResult>> expression,
MemberExpression memberExpression)
{
// fallback compilation system - just compile the expression directly
return expression.Compile();
Expression combinedNullTest = null;
var currentExpression = memberExpression;
while (currentExpression != null)
{
AddNullCheck(currentExpression.Expression, ref combinedNullTest);
if (currentExpression.Expression is MemberExpression leftExpression)
{
currentExpression = leftExpression;
}
else
{
break;
}
}
var body = expression.Body;
// Cast the entire expression to object in case Member is a value type. This is required for us to be able to
// express the null conditional statement m == null ? null : (object)m.IntValue
if (body.Type.IsValueType)
{
body = Expression.Convert(body, typeof(object));
}
if (combinedNullTest != null)
{
Debug.Assert(combinedNullTest.Type == typeof(bool));
body = Expression.Condition(
combinedNullTest,
Expression.Constant(value: null, body.Type),
body);
}
var rewrittenExpression = Expression.Lambda<Func<TModel, object>>(body, expression.Parameters);
return rewrittenExpression.Compile();
}
private static void AddNullCheck(Expression invokingExpression, ref Expression combinedNullTest)
{
var type = invokingExpression.Type;
var isNullableValueType = type.IsValueType && Nullable.GetUnderlyingType(type) != null;
if (type.IsValueType && !isNullableValueType)
{
// struct.Member where struct is not nullable. Do nothing.
return;
}
// NullableStruct.Member or Class.Member
// type is Nullable ? (value == null) : object.ReferenceEquals(value, null)
var nullTest = isNullableValueType ?
Expression.Equal(invokingExpression, NullExpression) :
Expression.ReferenceEqual(invokingExpression, NullExpression);
if (combinedNullTest == null)
{
combinedNullTest = nullTest;
}
else
{
// m == null || m.Member == null
combinedNullTest = Expression.OrElse(nullTest, combinedNullTest);
}
}
}
}

View File

@ -0,0 +1,90 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
internal struct MemberExpressionCacheKey
{
public MemberExpressionCacheKey(Type modelType, MemberExpression memberExpression)
{
ModelType = modelType;
MemberExpression = memberExpression;
Members = null;
}
public MemberExpressionCacheKey(Type modelType, MemberInfo[] members)
{
ModelType = modelType;
Members = members;
MemberExpression = null;
}
// We want to avoid caching a MemberExpression since it has references to other instances in the expression tree.
// We instead store it as a series of MemberInfo items that comprise of the MemberExpression going from right-most
// expression to left.
public MemberExpressionCacheKey MakeCacheable()
{
var members = new List<MemberInfo>();
foreach (var member in this)
{
members.Add(member);
}
return new MemberExpressionCacheKey(ModelType, members.ToArray());
}
public MemberExpression MemberExpression { get; }
public Type ModelType { get; }
public MemberInfo[] Members { get; }
public Enumerator GetEnumerator() => new Enumerator(ref this);
public struct Enumerator
{
private readonly MemberInfo[] _members;
private int _index;
private MemberExpression _memberExpression;
public Enumerator(ref MemberExpressionCacheKey key)
{
Current = null;
_members = key.Members;
_memberExpression = key.MemberExpression;
_index = -1;
}
public MemberInfo Current { get; private set; }
public bool MoveNext()
{
if (_members != null)
{
_index++;
if (_index >= _members.Length)
{
return false;
}
Current = _members[_index];
return true;
}
if (_memberExpression == null)
{
return false;
}
Current = _memberExpression.Member;
_memberExpression = _memberExpression.Expression as MemberExpression;
return true;
}
}
}
}

View File

@ -0,0 +1,53 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
internal class MemberExpressionCacheKeyComparer : IEqualityComparer<MemberExpressionCacheKey>
{
public static readonly MemberExpressionCacheKeyComparer Instance = new MemberExpressionCacheKeyComparer();
public bool Equals(MemberExpressionCacheKey x, MemberExpressionCacheKey y)
{
if (x.ModelType != y.ModelType)
{
return false;
}
var xEnumerator = x.GetEnumerator();
var yEnumerator = y.GetEnumerator();
while (xEnumerator.MoveNext())
{
if (!yEnumerator.MoveNext())
{
return false;
}
// Current is a MemberInfo instance which has a good comparer.
if (xEnumerator.Current != yEnumerator.Current)
{
return false;
}
}
return !yEnumerator.MoveNext();
}
public int GetHashCode(MemberExpressionCacheKey obj)
{
var hashCodeCombiner = new HashCodeCombiner();
hashCodeCombiner.Add(obj.ModelType);
foreach (var member in obj)
{
hashCodeCombiner.Add(member);
}
return hashCodeCombiner.CombinedHash;
}
}
}

View File

@ -0,0 +1,216 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq.Expressions;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
public class MemberExpressionCacheKeyComparerTest
{
private readonly MemberExpressionCacheKeyComparer Comparer = MemberExpressionCacheKeyComparer.Instance;
[Fact]
public void Equals_ReturnsTrue_ForTheSameExpression()
{
// Arrange
var key = GetKey(m => m.Value);
// Act & Assert
VerifyEquals(key, key);
}
[Fact]
public void Equals_ReturnsTrue_ForDifferentInstances_OfSameExpression()
{
// Arrange
var key1 = GetKey(m => m.Value);
var key2 = GetKey(m => m.Value);
// Act & Assert
VerifyEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithReferenceTypes()
{
// Arrange
var key1 = GetKey(m => m.TestModel2.Name);
var key2 = GetKey(m => m.TestModel2.Name);
// Act & Assert
VerifyEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithNullableValueTypes()
{
// Arrange
var key1 = GetKey(m => m.NullableDateTime.Value.TimeOfDay);
var key2 = GetKey(m => m.NullableDateTime.Value.TimeOfDay);
// Act & Assert
VerifyEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsTrue_ForChainedMemberAccessExpressionsWithValueTypes()
{
// Arrange
var key1 = GetKey(m => m.DateTime.Year);
var key2 = GetKey(m => m.DateTime.Year);
// Act & Assert
VerifyEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_ForDifferentExpression()
{
// Arrange
var key1 = GetKey(m => m.Value);
var key2 = GetKey(m => m.TestModel2.Name);
// Act & Assert
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_ForChainedExpressions()
{
// Arrange
var key1 = GetKey(m => m.TestModel2.Id);
var key2 = GetKey(m => m.TestModel2.Name);
// Act & Assert
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_ForChainedExpressions_WithValueTypes()
{
// Arrange
var key1 = GetKey(m => m.DateTime.Ticks);
var key2 = GetKey(m => m.DateTime.Year);
// Act & Assert
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_ForChainedExpressions_DifferingByNullable()
{
// Arrange
var key1 = GetKey(m => m.DateTime.Ticks);
var key2 = GetKey(m => m.NullableDateTime.Value.Ticks);
// Act & Assert
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_WhenOneExpressionIsSubsetOfOther()
{
// Arrange
var key1 = GetKey(m => m.TestModel2);
var key2 = GetKey(m => m.TestModel2.Name);
// Act & Assert
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughNullableProperty()
{
// Arrange
var key1 = GetKey(m => m.NullableDateTime.Value.Year);
var key2 = GetKey(m => m.DateTime.Year);
// Act
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughDifferentModels()
{
// Arrange
var key1 = GetKey<TestModel2, int>(m => m.Id);
var key2 = GetKey(m => m.TestModel2.Id);
// Act
VerifyNotEquals(key1, key2);
}
[Fact]
public void Equals_ReturnsFalse_WhenMemberIsAccessedThroughConstantExpression()
{
// Arrange
var testModel = new TestModel2 { Id = 1 };
var key1 = GetKey(m => testModel.Id);
var key2 = GetKey<TestModel2, int>(m => m.Id);
// Act
VerifyNotEquals(key1, key2);
}
private void VerifyEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2)
{
Assert.Equal(key1, key2, Comparer);
var hashCode1 = Comparer.GetHashCode(key1);
var hashCode2 = Comparer.GetHashCode(key2);
Assert.Equal(hashCode1, hashCode2);
var cachedKey1 = key1.MakeCacheable();
Assert.Equal(key1, cachedKey1, Comparer);
Assert.Equal(cachedKey1, key1, Comparer);
var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1);
Assert.Equal(hashCode1, cachedKeyHashCode1);
}
private void VerifyNotEquals(MemberExpressionCacheKey key1, MemberExpressionCacheKey key2)
{
var hashCode1 = Comparer.GetHashCode(key1);
var hashCode2 = Comparer.GetHashCode(key2);
Assert.NotEqual(hashCode1, hashCode2);
Assert.NotEqual(key1, key2, Comparer);
var cachedKey1 = key1.MakeCacheable();
Assert.NotEqual(key2, cachedKey1, Comparer);
var cachedKeyHashCode1 = Comparer.GetHashCode(cachedKey1);
Assert.NotEqual(cachedKeyHashCode1, hashCode2);
}
private static MemberExpressionCacheKey GetKey<TResult>(Expression<Func<TestModel, TResult>> expresssion)
=> GetKey<TestModel, TResult>(expresssion);
private static MemberExpressionCacheKey GetKey<TModel, TResult>(Expression<Func<TModel, TResult>> expresssion)
{
var memberExpression = Assert.IsAssignableFrom<MemberExpression>(expresssion.Body);
return new MemberExpressionCacheKey(typeof(TModel), memberExpression);
}
public class TestModel
{
public string Value { get; set; }
public TestModel2 TestModel2 { get; set; }
public DateTime DateTime { get; set; }
public DateTime? NullableDateTime { get; set; }
}
public class TestModel2
{
public string Name { get; set; }
public int Id { get; set; }
}
}
}

View File

@ -0,0 +1,111 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.ViewFeatures
{
public class MemberExpressionCacheKeyTest
{
[Fact]
public void GetEnumerator_ReturnsMembers()
{
// Arrange
var expected = new[]
{
typeof(TestModel3).GetProperty(nameof(TestModel3.Value)),
typeof(TestModel2).GetProperty(nameof(TestModel2.TestModel3)),
typeof(TestModel).GetProperty(nameof(TestModel.TestModel2)),
};
var key = GetKey(m => m.TestModel2.TestModel3.Value);
// Act
var actual = GetMembers(key);
// Assert
Assert.Equal(expected, actual);
}
[Fact]
public void GetEnumerator_WithNullableType_ReturnsMembers()
{
// Arrange
var expected = new[]
{
typeof(DateTime).GetProperty(nameof(DateTime.Ticks)),
typeof(DateTime?).GetProperty(nameof(Nullable<DateTime>.Value)),
typeof(TestModel).GetProperty(nameof(TestModel.NullableDateTime)),
};
var key = GetKey(m => m.NullableDateTime.Value.Ticks);
// Act
var actual = GetMembers(key);
// Assert
Assert.Equal(expected, actual);
}
[Fact]
public void GetEnumerator_WithValueType_ReturnsMembers()
{
// Arrange
var expected = new[]
{
typeof(DateTime).GetProperty(nameof(DateTime.Ticks)),
typeof(TestModel).GetProperty(nameof(TestModel.DateTime)),
};
var key = GetKey(m => m.DateTime.Ticks);
// Act
var actual = GetMembers(key);
// Assert
Assert.Equal(expected, actual);
}
private static MemberExpressionCacheKey GetKey<TResult>(Expression<Func<TestModel, TResult>> expresssion)
{
var memberExpression = Assert.IsAssignableFrom<MemberExpression>(expresssion.Body);
return new MemberExpressionCacheKey(typeof(TestModel), memberExpression);
}
private static IList<MemberInfo> GetMembers(MemberExpressionCacheKey key)
{
var members = new List<MemberInfo>();
foreach (var member in key)
{
members.Add(member);
}
return members;
}
public class TestModel
{
public TestModel2 TestModel2 { get; set; }
public DateTime DateTime { get; set; }
public DateTime? NullableDateTime { get; set; }
}
public class TestModel2
{
public string Name { get; set; }
public TestModel3 TestModel3 { get; set; }
}
public class TestModel3
{
public string Value { get; set; }
}
}
}