From 32645e93c8f0ad0243c42fe55665592446f7b864 Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Sun, 20 Sep 2015 00:26:51 -0700 Subject: [PATCH] Add expression rewriting to Razor This change rewrites simple and safe Expression> expressions into accesses to readonly fields. This allows us to cache the actual expression and avoid repeatedly allocating and compiling it. The rewrite is limited to cases where we know that the expression doesn't capture, and where we support that kind of expression for evaluating viewdata. In practice this means 'indentity' and property accessors are allowed. --- .../Compilation/ExpressionRewriter.cs | 200 +++++++ .../Compilation/RoslynCompilationService.cs | 39 +- .../Compilation/ExpressionRewriterTest.cs | 491 ++++++++++++++++++ 3 files changed, 721 insertions(+), 9 deletions(-) create mode 100644 src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs create mode 100644 test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs diff --git a/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs b/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs new file mode 100644 index 0000000000..c4ae27161a --- /dev/null +++ b/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs @@ -0,0 +1,200 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq.Expressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNet.Mvc.Razor.Compilation +{ + /// + /// An expression rewriter which can hoist a simple expression lambda into a private field. + /// + public class ExpressionRewriter : CSharpSyntaxRewriter + { + private static readonly string FieldNameTemplate = "__h{0}"; + + public ExpressionRewriter(SemanticModel semanticModel) + { + SemanticModel = semanticModel; + + Expressions = new List>(); + } + + // We only want to rewrite expressions for the top-level class definition. + private bool IsInsideClass { get; set; } + + private SemanticModel SemanticModel { get; } + + private List> Expressions { get; } + + public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node) + { + if (IsInsideClass) + { + // Avoid recursing into nested classes. + return node; + } + + Expressions.Clear(); + + IsInsideClass = true; + + // Call base first to visit all the children and populate Expressions. + var classDeclaration = (ClassDeclarationSyntax)base.VisitClassDeclaration(node); + + IsInsideClass = false; + + var memberDeclarations = new List(); + foreach (var kvp in Expressions) + { + var expression = kvp.Key; + var memberName = kvp.Value.GetFirstToken(); + + var expressionType = SemanticModel.GetTypeInfo(expression).ConvertedType; + var declaration = SyntaxFactory.FieldDeclaration( + SyntaxFactory.List(), + SyntaxFactory.TokenList( + SyntaxFactory.Token(SyntaxKind.PrivateKeyword), + SyntaxFactory.Token(SyntaxKind.StaticKeyword), + SyntaxFactory.Token(SyntaxKind.ReadOnlyKeyword)), + SyntaxFactory.VariableDeclaration( + SyntaxFactory.ParseTypeName(expressionType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat)), + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.VariableDeclarator( + memberName, + SyntaxFactory.BracketedArgumentList(), + SyntaxFactory.EqualsValueClause(expression))))) + .WithTriviaFrom(expression); + memberDeclarations.Add(declaration); + } + + return classDeclaration.AddMembers(memberDeclarations.ToArray()); + } + + public override SyntaxNode VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node) + { + Debug.Assert(IsInsideClass); + + // If this lambda is an Expression and is suitable for hoisting, we rewrite this into a field access. + // + // Before: + // public Task ExecuteAsync(...) + // { + // ... + // Html.EditorFor(m => m.Price); + // ... + // } + // + // + // After: + // private static readonly Expression> __h0 = m => m.Price; + // public Task ExecuteAsync(...) + // { + // ... + // Html.EditorFor(__h0); + // ... + // } + // + var type = SemanticModel.GetTypeInfo(node); + if (type.ConvertedType.Name != typeof(Expression).Name && + type.ConvertedType.ContainingNamespace.Name != typeof(Expression).Namespace) + { + return node; + } + + if (!node.Parent.IsKind(SyntaxKind.Argument)) + { + return node; + } + + var parameter = node.Parameter; + if (IsValidForHoisting(parameter, node.Body)) + { + // Replace with a MemberAccess + var memberName = string.Format(FieldNameTemplate, Expressions.Count); + var memberAccess = PadMemberAccess(node, SyntaxFactory.IdentifierName(memberName)); + Expressions.Add(new KeyValuePair(node, memberAccess)); + return memberAccess; + } + + return node; + } + + private static IdentifierNameSyntax PadMemberAccess( + SimpleLambdaExpressionSyntax node, + IdentifierNameSyntax memberAccess) + { + // We want to make the new span + var originalSpan = node.GetLocation().GetMappedLineSpan(); + + // Start by collecting all the trivia 'inside' the expression - we need to tack that on the end, but + // if it ends with a newline, don't include that. + var innerTrivia = SyntaxFactory.TriviaList(node.DescendantTrivia(descendIntoChildren: n => true)); + if (innerTrivia.Count > 0 && innerTrivia[innerTrivia.Count - 1].IsKind(SyntaxKind.EndOfLineTrivia)) + { + innerTrivia = innerTrivia.RemoveAt(innerTrivia.Count - 1); + } + + memberAccess = memberAccess.WithTrailingTrivia(innerTrivia); + + // If everything is all on one line, then make sure the spans are the same, to compensate + // for the expression potentially being longer than the variable name. + var lineSpan = originalSpan.EndLinePosition.Line - originalSpan.StartLinePosition.Line; + if (lineSpan == 0) + { + var padding = node.Span.Length - memberAccess.FullSpan.Length; + var trailingTrivia = + SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia()) + .Add(SyntaxFactory.Whitespace(new string(' ', padding))) + .AddRange(node.GetTrailingTrivia()); + + return + memberAccess + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(trailingTrivia); + } + else + { + // If everything isn't on the same line, we need to pad out the last line. + var padding = + originalSpan.EndLinePosition.Character - + originalSpan.StartLinePosition.Character; + + var trailingTrivia = + SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia()) + .Add(SyntaxFactory.Whitespace(new string(' ', padding))) + .AddRange(node.GetTrailingTrivia()); + + return + memberAccess + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(trailingTrivia); + } + } + + private static bool IsValidForHoisting(ParameterSyntax parameter, CSharpSyntaxNode node) + { + if (node.IsKind(SyntaxKind.IdentifierName)) + { + var identifier = (IdentifierNameSyntax)node; + if (identifier.Identifier.Text == parameter.Identifier.Text) + { + return true; + } + } + else if (node.IsKind(SyntaxKind.SimpleMemberAccessExpression)) + { + var memberAccess = (MemberAccessExpressionSyntax)node; + var lhs = memberAccess.Expression; + return IsValidForHoisting(parameter, lhs); + } + + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs b/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs index 623860788b..3efbedd709 100644 --- a/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs +++ b/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs @@ -82,18 +82,24 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation var assemblyName = Path.GetRandomFileName(); var compilationSettings = _compilerOptionsProvider.GetCompilationSettings(_environment); - var syntaxTree = SyntaxTreeGenerator.Generate(compilationContent, - assemblyName, - compilationSettings); + var syntaxTree = SyntaxTreeGenerator.Generate( + compilationContent, + assemblyName, + compilationSettings); + var references = _applicationReferences.Value; - var compilationOptions = compilationSettings.CompilationOptions - .WithOutputKind(OutputKind.DynamicallyLinkedLibrary); + var compilationOptions = compilationSettings + .CompilationOptions + .WithOutputKind(OutputKind.DynamicallyLinkedLibrary); - var compilation = CSharpCompilation.Create(assemblyName, - options: compilationOptions, - syntaxTrees: new[] { syntaxTree }, - references: references); + var compilation = CSharpCompilation.Create( + assemblyName, + options: compilationOptions, + syntaxTrees: new[] { syntaxTree }, + references: references); + + compilation = Rewrite(compilation); using (var ms = new MemoryStream()) { @@ -140,6 +146,21 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation } } + private CSharpCompilation Rewrite(CSharpCompilation compilation) + { + var rewrittenTrees = new List(); + foreach (var tree in compilation.SyntaxTrees) + { + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + var rewriter = new ExpressionRewriter(semanticModel); + + var rewrittenTree = tree.WithRootAndOptions(rewriter.Visit(tree.GetRoot()), tree.Options); + rewrittenTrees.Add(rewrittenTree); + } + + return compilation.RemoveAllSyntaxTrees().AddSyntaxTrees(rewrittenTrees); + } + // Internal for unit testing internal CompilationResult GetCompilationFailedResult( string relativePath, diff --git a/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs b/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs new file mode 100644 index 0000000000..c8e3e6918e --- /dev/null +++ b/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs @@ -0,0 +1,491 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.Dnx.Compilation; +using Microsoft.Dnx.Compilation.CSharp; +using Microsoft.Dnx.Runtime; +using Microsoft.Dnx.Runtime.Infrastructure; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNet.Mvc.Razor.Compilation +{ + public class ExpressionRewriterTest + { + [Fact] + public void ExpressionRewriter_CanRewriteExpression_IdentityExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_MemberAccessExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x.Name", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_ChainedMemberAccessExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name.Length); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x.Name.Length", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_MethodCall() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.GetHashCode()); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_NonArgument() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + Expression> expr = x => x.GetHashCode(); + CalledWithExpression(expr); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_NestedClass() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + private class Nested + { + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + Expression> expr = x => x.GetHashCode(); + CalledWithExpression(expr); + } + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_AdditionalArguments() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(int x, Expression> expression, string name) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(5, x => x, ""Billy""); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + Assert.Equal(3, arguments.Arguments.Count); + var argument = Assert.IsType(arguments.Arguments[1].Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + // When we rewrite the expression, we want to maintain the original span as much as possible. + [Fact] + public void ExpressionRewriter_CanRewriteExpression_SimpleFormatting() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name.Length); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + + var originalArguments = FindArguments(tree.GetRoot()); + var originalSpan = originalArguments.GetLocation().GetMappedLineSpan(); + + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var arguments = FindArguments(result); + Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan()); + } + + // When we rewrite the expression, we want to maintain the original span as much as possible. + [Fact] + public void ExpressionRewriter_CanRewriteExpression_ComplexFormatting() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(int z, Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression( + 17, + x => + x.Name. + Length + ); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + + var originalArguments = FindArguments(tree.GetRoot()); + var originalSpan = originalArguments.GetLocation().GetMappedLineSpan(); + + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var arguments = FindArguments(result); + Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan()); + } + + public ArgumentListSyntax FindArguments(SyntaxNode node) + { + return node + .DescendantNodes(n => true) + .Where(n => n.IsKind(SyntaxKind.ArgumentList)) + .Cast() + .Single(); + } + + public IEnumerable FindFields(SyntaxNode node) + { + return node + .DescendantNodes(n => true) + .Where(n => n.IsKind(SyntaxKind.FieldDeclaration)) + .Cast(); + } + + private CSharpCompilation Compile(SyntaxTree tree) + { + var compilation = CSharpCompilation.Create( + "Test.Assembly", + new[] { tree }, + GetReferences()); + + var diagnostics = compilation.GetDiagnostics(); + if (diagnostics.Length > 0) + { + Assert.False(true, string.Join(Environment.NewLine, diagnostics)); + } + + return compilation; + } + + private IEnumerable GetReferences() + { + var services = CallContextServiceLocator.Locator.ServiceProvider; + var libraryExporter = services.GetRequiredService(); + var environment = services.GetRequiredService(); + + var references = new List(); + + var libraryExports = libraryExporter.GetAllExports(environment.ApplicationName); + foreach (var export in libraryExports.MetadataReferences) + { + references.Add(export.ConvertMetadataReference(MetadataReferenceExtensions.CreateAssemblyMetadata)); + } + + return references; + } + } +}