Add expression rewriting to Razor

This change rewrites simple and safe Expression<Func<T, U>> expressions
into accesses to readonly fields. This allows us to cache the actual
expression and avoid repeatedly allocating and compiling it.

The rewrite is limited to cases where we know that the expression doesn't
capture, and where we support that kind of expression for evaluating
viewdata. In practice this means 'indentity' and property accessors are
allowed.
This commit is contained in:
Ryan Nowak 2015-09-20 00:26:51 -07:00
parent 6ef2fe44ca
commit 32645e93c8
3 changed files with 721 additions and 9 deletions

View File

@ -0,0 +1,200 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq.Expressions;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
namespace Microsoft.AspNet.Mvc.Razor.Compilation
{
/// <summary>
/// An expression rewriter which can hoist a simple expression lambda into a private field.
/// </summary>
public class ExpressionRewriter : CSharpSyntaxRewriter
{
private static readonly string FieldNameTemplate = "__h{0}";
public ExpressionRewriter(SemanticModel semanticModel)
{
SemanticModel = semanticModel;
Expressions = new List<KeyValuePair<SimpleLambdaExpressionSyntax, IdentifierNameSyntax>>();
}
// We only want to rewrite expressions for the top-level class definition.
private bool IsInsideClass { get; set; }
private SemanticModel SemanticModel { get; }
private List<KeyValuePair<SimpleLambdaExpressionSyntax, IdentifierNameSyntax>> Expressions { get; }
public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node)
{
if (IsInsideClass)
{
// Avoid recursing into nested classes.
return node;
}
Expressions.Clear();
IsInsideClass = true;
// Call base first to visit all the children and populate Expressions.
var classDeclaration = (ClassDeclarationSyntax)base.VisitClassDeclaration(node);
IsInsideClass = false;
var memberDeclarations = new List<MemberDeclarationSyntax>();
foreach (var kvp in Expressions)
{
var expression = kvp.Key;
var memberName = kvp.Value.GetFirstToken();
var expressionType = SemanticModel.GetTypeInfo(expression).ConvertedType;
var declaration = SyntaxFactory.FieldDeclaration(
SyntaxFactory.List<AttributeListSyntax>(),
SyntaxFactory.TokenList(
SyntaxFactory.Token(SyntaxKind.PrivateKeyword),
SyntaxFactory.Token(SyntaxKind.StaticKeyword),
SyntaxFactory.Token(SyntaxKind.ReadOnlyKeyword)),
SyntaxFactory.VariableDeclaration(
SyntaxFactory.ParseTypeName(expressionType.ToDisplayString(
SymbolDisplayFormat.FullyQualifiedFormat)),
SyntaxFactory.SingletonSeparatedList(
SyntaxFactory.VariableDeclarator(
memberName,
SyntaxFactory.BracketedArgumentList(),
SyntaxFactory.EqualsValueClause(expression)))))
.WithTriviaFrom(expression);
memberDeclarations.Add(declaration);
}
return classDeclaration.AddMembers(memberDeclarations.ToArray());
}
public override SyntaxNode VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node)
{
Debug.Assert(IsInsideClass);
// If this lambda is an Expression and is suitable for hoisting, we rewrite this into a field access.
//
// Before:
// public Task ExecuteAsync(...)
// {
// ...
// Html.EditorFor(m => m.Price);
// ...
// }
//
//
// After:
// private static readonly Expression<Func<Product, decimal>> __h0 = m => m.Price;
// public Task ExecuteAsync(...)
// {
// ...
// Html.EditorFor(__h0);
// ...
// }
//
var type = SemanticModel.GetTypeInfo(node);
if (type.ConvertedType.Name != typeof(Expression).Name &&
type.ConvertedType.ContainingNamespace.Name != typeof(Expression).Namespace)
{
return node;
}
if (!node.Parent.IsKind(SyntaxKind.Argument))
{
return node;
}
var parameter = node.Parameter;
if (IsValidForHoisting(parameter, node.Body))
{
// Replace with a MemberAccess
var memberName = string.Format(FieldNameTemplate, Expressions.Count);
var memberAccess = PadMemberAccess(node, SyntaxFactory.IdentifierName(memberName));
Expressions.Add(new KeyValuePair<SimpleLambdaExpressionSyntax, IdentifierNameSyntax>(node, memberAccess));
return memberAccess;
}
return node;
}
private static IdentifierNameSyntax PadMemberAccess(
SimpleLambdaExpressionSyntax node,
IdentifierNameSyntax memberAccess)
{
// We want to make the new span
var originalSpan = node.GetLocation().GetMappedLineSpan();
// Start by collecting all the trivia 'inside' the expression - we need to tack that on the end, but
// if it ends with a newline, don't include that.
var innerTrivia = SyntaxFactory.TriviaList(node.DescendantTrivia(descendIntoChildren: n => true));
if (innerTrivia.Count > 0 && innerTrivia[innerTrivia.Count - 1].IsKind(SyntaxKind.EndOfLineTrivia))
{
innerTrivia = innerTrivia.RemoveAt(innerTrivia.Count - 1);
}
memberAccess = memberAccess.WithTrailingTrivia(innerTrivia);
// If everything is all on one line, then make sure the spans are the same, to compensate
// for the expression potentially being longer than the variable name.
var lineSpan = originalSpan.EndLinePosition.Line - originalSpan.StartLinePosition.Line;
if (lineSpan == 0)
{
var padding = node.Span.Length - memberAccess.FullSpan.Length;
var trailingTrivia =
SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia())
.Add(SyntaxFactory.Whitespace(new string(' ', padding)))
.AddRange(node.GetTrailingTrivia());
return
memberAccess
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(trailingTrivia);
}
else
{
// If everything isn't on the same line, we need to pad out the last line.
var padding =
originalSpan.EndLinePosition.Character -
originalSpan.StartLinePosition.Character;
var trailingTrivia =
SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia())
.Add(SyntaxFactory.Whitespace(new string(' ', padding)))
.AddRange(node.GetTrailingTrivia());
return
memberAccess
.WithLeadingTrivia(node.GetLeadingTrivia())
.WithTrailingTrivia(trailingTrivia);
}
}
private static bool IsValidForHoisting(ParameterSyntax parameter, CSharpSyntaxNode node)
{
if (node.IsKind(SyntaxKind.IdentifierName))
{
var identifier = (IdentifierNameSyntax)node;
if (identifier.Identifier.Text == parameter.Identifier.Text)
{
return true;
}
}
else if (node.IsKind(SyntaxKind.SimpleMemberAccessExpression))
{
var memberAccess = (MemberAccessExpressionSyntax)node;
var lhs = memberAccess.Expression;
return IsValidForHoisting(parameter, lhs);
}
return false;
}
}
}

View File

@ -82,18 +82,24 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation
var assemblyName = Path.GetRandomFileName();
var compilationSettings = _compilerOptionsProvider.GetCompilationSettings(_environment);
var syntaxTree = SyntaxTreeGenerator.Generate(compilationContent,
assemblyName,
compilationSettings);
var syntaxTree = SyntaxTreeGenerator.Generate(
compilationContent,
assemblyName,
compilationSettings);
var references = _applicationReferences.Value;
var compilationOptions = compilationSettings.CompilationOptions
.WithOutputKind(OutputKind.DynamicallyLinkedLibrary);
var compilationOptions = compilationSettings
.CompilationOptions
.WithOutputKind(OutputKind.DynamicallyLinkedLibrary);
var compilation = CSharpCompilation.Create(assemblyName,
options: compilationOptions,
syntaxTrees: new[] { syntaxTree },
references: references);
var compilation = CSharpCompilation.Create(
assemblyName,
options: compilationOptions,
syntaxTrees: new[] { syntaxTree },
references: references);
compilation = Rewrite(compilation);
using (var ms = new MemoryStream())
{
@ -140,6 +146,21 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation
}
}
private CSharpCompilation Rewrite(CSharpCompilation compilation)
{
var rewrittenTrees = new List<SyntaxTree>();
foreach (var tree in compilation.SyntaxTrees)
{
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
var rewrittenTree = tree.WithRootAndOptions(rewriter.Visit(tree.GetRoot()), tree.Options);
rewrittenTrees.Add(rewrittenTree);
}
return compilation.RemoveAllSyntaxTrees().AddSyntaxTrees(rewrittenTrees);
}
// Internal for unit testing
internal CompilationResult GetCompilationFailedResult(
string relativePath,

View File

@ -0,0 +1,491 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.Dnx.Compilation;
using Microsoft.Dnx.Compilation.CSharp;
using Microsoft.Dnx.Runtime;
using Microsoft.Dnx.Runtime.Infrastructure;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.AspNet.Mvc.Razor.Compilation
{
public class ExpressionRewriterTest
{
[Fact]
public void ExpressionRewriter_CanRewriteExpression_IdentityExpression()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<object, object>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(x => x);
}
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var fields = FindFields(result);
var field = Assert.Single(fields);
Assert.Collection(
field.Modifiers,
m => Assert.Equal("private", m.ToString()),
m => Assert.Equal("static", m.ToString()),
m => Assert.Equal("readonly", m.ToString()));
var declaration = field.Declaration;
Assert.Equal(
"global::System.Linq.Expressions.Expression<global::System.Func<object, object>>",
declaration.Type.ToString());
var variable = Assert.Single(declaration.Variables);
Assert.Equal("__h0", variable.Identifier.ToString());
Assert.Equal("x => x", variable.Initializer.Value.ToString());
var arguments = FindArguments(result);
var argument = Assert.IsType<IdentifierNameSyntax>(Assert.Single(arguments.Arguments).Expression);
Assert.Equal("__h0", argument.Identifier.ToString());
}
[Fact]
public void ExpressionRewriter_CanRewriteExpression_MemberAccessExpression()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<Person, object>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(x => x.Name);
}
}
public class Person
{
public string Name { get; set; }
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var fields = FindFields(result);
var field = Assert.Single(fields);
Assert.Collection(
field.Modifiers,
m => Assert.Equal("private", m.ToString()),
m => Assert.Equal("static", m.ToString()),
m => Assert.Equal("readonly", m.ToString()));
var declaration = field.Declaration;
Assert.Equal(
"global::System.Linq.Expressions.Expression<global::System.Func<global::Person, object>>",
declaration.Type.ToString());
var variable = Assert.Single(declaration.Variables);
Assert.Equal("__h0", variable.Identifier.ToString());
Assert.Equal("x => x.Name", variable.Initializer.Value.ToString());
var arguments = FindArguments(result);
var argument = Assert.IsType<IdentifierNameSyntax>(Assert.Single(arguments.Arguments).Expression);
Assert.Equal("__h0", argument.Identifier.ToString());
}
[Fact]
public void ExpressionRewriter_CanRewriteExpression_ChainedMemberAccessExpression()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<Person, int>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(x => x.Name.Length);
}
}
public class Person
{
public string Name { get; set; }
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var fields = FindFields(result);
var field = Assert.Single(fields);
Assert.Collection(
field.Modifiers,
m => Assert.Equal("private", m.ToString()),
m => Assert.Equal("static", m.ToString()),
m => Assert.Equal("readonly", m.ToString()));
var declaration = field.Declaration;
Assert.Equal(
"global::System.Linq.Expressions.Expression<global::System.Func<global::Person, int>>",
declaration.Type.ToString());
var variable = Assert.Single(declaration.Variables);
Assert.Equal("__h0", variable.Identifier.ToString());
Assert.Equal("x => x.Name.Length", variable.Initializer.Value.ToString());
var arguments = FindArguments(result);
var argument = Assert.IsType<IdentifierNameSyntax>(Assert.Single(arguments.Arguments).Expression);
Assert.Equal("__h0", argument.Identifier.ToString());
}
[Fact]
public void ExpressionRewriter_CannotRewriteExpression_MethodCall()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<object, int>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(x => x.GetHashCode());
}
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
Assert.Empty(FindFields(result));
}
[Fact]
public void ExpressionRewriter_CannotRewriteExpression_NonArgument()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<object, int>> expression)
{
}
public static void Main(string[] args)
{
Expression<Func<object, int>> expr = x => x.GetHashCode();
CalledWithExpression(expr);
}
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
Assert.Empty(FindFields(result));
}
[Fact]
public void ExpressionRewriter_CannotRewriteExpression_NestedClass()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
private class Nested
{
public static void CalledWithExpression(Expression<Func<object, int>> expression)
{
}
public static void Main(string[] args)
{
Expression<Func<object, int>> expr = x => x.GetHashCode();
CalledWithExpression(expr);
}
}
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
Assert.Empty(FindFields(result));
}
[Fact]
public void ExpressionRewriter_CanRewriteExpression_AdditionalArguments()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(int x, Expression<Func<object, object>> expression, string name)
{
}
public static void Main(string[] args)
{
CalledWithExpression(5, x => x, ""Billy"");
}
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var fields = FindFields(result);
var field = Assert.Single(fields);
Assert.Collection(
field.Modifiers,
m => Assert.Equal("private", m.ToString()),
m => Assert.Equal("static", m.ToString()),
m => Assert.Equal("readonly", m.ToString()));
var declaration = field.Declaration;
Assert.Equal(
"global::System.Linq.Expressions.Expression<global::System.Func<object, object>>",
declaration.Type.ToString());
var variable = Assert.Single(declaration.Variables);
Assert.Equal("__h0", variable.Identifier.ToString());
Assert.Equal("x => x", variable.Initializer.Value.ToString());
var arguments = FindArguments(result);
Assert.Equal(3, arguments.Arguments.Count);
var argument = Assert.IsType<IdentifierNameSyntax>(arguments.Arguments[1].Expression);
Assert.Equal("__h0", argument.Identifier.ToString());
}
// When we rewrite the expression, we want to maintain the original span as much as possible.
[Fact]
public void ExpressionRewriter_CanRewriteExpression_SimpleFormatting()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(Expression<Func<Person, int>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(x => x.Name.Length);
}
}
public class Person
{
public string Name { get; set; }
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var originalArguments = FindArguments(tree.GetRoot());
var originalSpan = originalArguments.GetLocation().GetMappedLineSpan();
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var arguments = FindArguments(result);
Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan());
}
// When we rewrite the expression, we want to maintain the original span as much as possible.
[Fact]
public void ExpressionRewriter_CanRewriteExpression_ComplexFormatting()
{
// Arrange
var source = @"
using System;
using System.Linq.Expressions;
public class Program
{
public static void CalledWithExpression(int z, Expression<Func<Person, int>> expression)
{
}
public static void Main(string[] args)
{
CalledWithExpression(
17,
x =>
x.Name.
Length
);
}
}
public class Person
{
public string Name { get; set; }
}
";
var tree = CSharpSyntaxTree.ParseText(source);
var originalArguments = FindArguments(tree.GetRoot());
var originalSpan = originalArguments.GetLocation().GetMappedLineSpan();
var compilation = Compile(tree);
var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true);
var rewriter = new ExpressionRewriter(semanticModel);
// Act
var result = rewriter.Visit(tree.GetRoot());
// Assert
var arguments = FindArguments(result);
Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan());
}
public ArgumentListSyntax FindArguments(SyntaxNode node)
{
return node
.DescendantNodes(n => true)
.Where(n => n.IsKind(SyntaxKind.ArgumentList))
.Cast<ArgumentListSyntax>()
.Single();
}
public IEnumerable<FieldDeclarationSyntax> FindFields(SyntaxNode node)
{
return node
.DescendantNodes(n => true)
.Where(n => n.IsKind(SyntaxKind.FieldDeclaration))
.Cast<FieldDeclarationSyntax>();
}
private CSharpCompilation Compile(SyntaxTree tree)
{
var compilation = CSharpCompilation.Create(
"Test.Assembly",
new[] { tree },
GetReferences());
var diagnostics = compilation.GetDiagnostics();
if (diagnostics.Length > 0)
{
Assert.False(true, string.Join(Environment.NewLine, diagnostics));
}
return compilation;
}
private IEnumerable<MetadataReference> GetReferences()
{
var services = CallContextServiceLocator.Locator.ServiceProvider;
var libraryExporter = services.GetRequiredService<ILibraryExporter>();
var environment = services.GetRequiredService<IApplicationEnvironment>();
var references = new List<MetadataReference>();
var libraryExports = libraryExporter.GetAllExports(environment.ApplicationName);
foreach (var export in libraryExports.MetadataReferences)
{
references.Add(export.ConvertMetadataReference(MetadataReferenceExtensions.CreateAssemblyMetadata));
}
return references;
}
}
}