diff --git a/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs b/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs new file mode 100644 index 0000000000..c4ae27161a --- /dev/null +++ b/src/Microsoft.AspNet.Mvc.Razor/Compilation/ExpressionRewriter.cs @@ -0,0 +1,200 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Diagnostics; +using System.Linq.Expressions; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; + +namespace Microsoft.AspNet.Mvc.Razor.Compilation +{ + /// + /// An expression rewriter which can hoist a simple expression lambda into a private field. + /// + public class ExpressionRewriter : CSharpSyntaxRewriter + { + private static readonly string FieldNameTemplate = "__h{0}"; + + public ExpressionRewriter(SemanticModel semanticModel) + { + SemanticModel = semanticModel; + + Expressions = new List>(); + } + + // We only want to rewrite expressions for the top-level class definition. + private bool IsInsideClass { get; set; } + + private SemanticModel SemanticModel { get; } + + private List> Expressions { get; } + + public override SyntaxNode VisitClassDeclaration(ClassDeclarationSyntax node) + { + if (IsInsideClass) + { + // Avoid recursing into nested classes. + return node; + } + + Expressions.Clear(); + + IsInsideClass = true; + + // Call base first to visit all the children and populate Expressions. + var classDeclaration = (ClassDeclarationSyntax)base.VisitClassDeclaration(node); + + IsInsideClass = false; + + var memberDeclarations = new List(); + foreach (var kvp in Expressions) + { + var expression = kvp.Key; + var memberName = kvp.Value.GetFirstToken(); + + var expressionType = SemanticModel.GetTypeInfo(expression).ConvertedType; + var declaration = SyntaxFactory.FieldDeclaration( + SyntaxFactory.List(), + SyntaxFactory.TokenList( + SyntaxFactory.Token(SyntaxKind.PrivateKeyword), + SyntaxFactory.Token(SyntaxKind.StaticKeyword), + SyntaxFactory.Token(SyntaxKind.ReadOnlyKeyword)), + SyntaxFactory.VariableDeclaration( + SyntaxFactory.ParseTypeName(expressionType.ToDisplayString( + SymbolDisplayFormat.FullyQualifiedFormat)), + SyntaxFactory.SingletonSeparatedList( + SyntaxFactory.VariableDeclarator( + memberName, + SyntaxFactory.BracketedArgumentList(), + SyntaxFactory.EqualsValueClause(expression))))) + .WithTriviaFrom(expression); + memberDeclarations.Add(declaration); + } + + return classDeclaration.AddMembers(memberDeclarations.ToArray()); + } + + public override SyntaxNode VisitSimpleLambdaExpression(SimpleLambdaExpressionSyntax node) + { + Debug.Assert(IsInsideClass); + + // If this lambda is an Expression and is suitable for hoisting, we rewrite this into a field access. + // + // Before: + // public Task ExecuteAsync(...) + // { + // ... + // Html.EditorFor(m => m.Price); + // ... + // } + // + // + // After: + // private static readonly Expression> __h0 = m => m.Price; + // public Task ExecuteAsync(...) + // { + // ... + // Html.EditorFor(__h0); + // ... + // } + // + var type = SemanticModel.GetTypeInfo(node); + if (type.ConvertedType.Name != typeof(Expression).Name && + type.ConvertedType.ContainingNamespace.Name != typeof(Expression).Namespace) + { + return node; + } + + if (!node.Parent.IsKind(SyntaxKind.Argument)) + { + return node; + } + + var parameter = node.Parameter; + if (IsValidForHoisting(parameter, node.Body)) + { + // Replace with a MemberAccess + var memberName = string.Format(FieldNameTemplate, Expressions.Count); + var memberAccess = PadMemberAccess(node, SyntaxFactory.IdentifierName(memberName)); + Expressions.Add(new KeyValuePair(node, memberAccess)); + return memberAccess; + } + + return node; + } + + private static IdentifierNameSyntax PadMemberAccess( + SimpleLambdaExpressionSyntax node, + IdentifierNameSyntax memberAccess) + { + // We want to make the new span + var originalSpan = node.GetLocation().GetMappedLineSpan(); + + // Start by collecting all the trivia 'inside' the expression - we need to tack that on the end, but + // if it ends with a newline, don't include that. + var innerTrivia = SyntaxFactory.TriviaList(node.DescendantTrivia(descendIntoChildren: n => true)); + if (innerTrivia.Count > 0 && innerTrivia[innerTrivia.Count - 1].IsKind(SyntaxKind.EndOfLineTrivia)) + { + innerTrivia = innerTrivia.RemoveAt(innerTrivia.Count - 1); + } + + memberAccess = memberAccess.WithTrailingTrivia(innerTrivia); + + // If everything is all on one line, then make sure the spans are the same, to compensate + // for the expression potentially being longer than the variable name. + var lineSpan = originalSpan.EndLinePosition.Line - originalSpan.StartLinePosition.Line; + if (lineSpan == 0) + { + var padding = node.Span.Length - memberAccess.FullSpan.Length; + var trailingTrivia = + SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia()) + .Add(SyntaxFactory.Whitespace(new string(' ', padding))) + .AddRange(node.GetTrailingTrivia()); + + return + memberAccess + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(trailingTrivia); + } + else + { + // If everything isn't on the same line, we need to pad out the last line. + var padding = + originalSpan.EndLinePosition.Character - + originalSpan.StartLinePosition.Character; + + var trailingTrivia = + SyntaxFactory.TriviaList(memberAccess.GetTrailingTrivia()) + .Add(SyntaxFactory.Whitespace(new string(' ', padding))) + .AddRange(node.GetTrailingTrivia()); + + return + memberAccess + .WithLeadingTrivia(node.GetLeadingTrivia()) + .WithTrailingTrivia(trailingTrivia); + } + } + + private static bool IsValidForHoisting(ParameterSyntax parameter, CSharpSyntaxNode node) + { + if (node.IsKind(SyntaxKind.IdentifierName)) + { + var identifier = (IdentifierNameSyntax)node; + if (identifier.Identifier.Text == parameter.Identifier.Text) + { + return true; + } + } + else if (node.IsKind(SyntaxKind.SimpleMemberAccessExpression)) + { + var memberAccess = (MemberAccessExpressionSyntax)node; + var lhs = memberAccess.Expression; + return IsValidForHoisting(parameter, lhs); + } + + return false; + } + } +} diff --git a/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs b/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs index 623860788b..3efbedd709 100644 --- a/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs +++ b/src/Microsoft.AspNet.Mvc.Razor/Compilation/RoslynCompilationService.cs @@ -82,18 +82,24 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation var assemblyName = Path.GetRandomFileName(); var compilationSettings = _compilerOptionsProvider.GetCompilationSettings(_environment); - var syntaxTree = SyntaxTreeGenerator.Generate(compilationContent, - assemblyName, - compilationSettings); + var syntaxTree = SyntaxTreeGenerator.Generate( + compilationContent, + assemblyName, + compilationSettings); + var references = _applicationReferences.Value; - var compilationOptions = compilationSettings.CompilationOptions - .WithOutputKind(OutputKind.DynamicallyLinkedLibrary); + var compilationOptions = compilationSettings + .CompilationOptions + .WithOutputKind(OutputKind.DynamicallyLinkedLibrary); - var compilation = CSharpCompilation.Create(assemblyName, - options: compilationOptions, - syntaxTrees: new[] { syntaxTree }, - references: references); + var compilation = CSharpCompilation.Create( + assemblyName, + options: compilationOptions, + syntaxTrees: new[] { syntaxTree }, + references: references); + + compilation = Rewrite(compilation); using (var ms = new MemoryStream()) { @@ -140,6 +146,21 @@ namespace Microsoft.AspNet.Mvc.Razor.Compilation } } + private CSharpCompilation Rewrite(CSharpCompilation compilation) + { + var rewrittenTrees = new List(); + foreach (var tree in compilation.SyntaxTrees) + { + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + var rewriter = new ExpressionRewriter(semanticModel); + + var rewrittenTree = tree.WithRootAndOptions(rewriter.Visit(tree.GetRoot()), tree.Options); + rewrittenTrees.Add(rewrittenTree); + } + + return compilation.RemoveAllSyntaxTrees().AddSyntaxTrees(rewrittenTrees); + } + // Internal for unit testing internal CompilationResult GetCompilationFailedResult( string relativePath, diff --git a/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs b/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs new file mode 100644 index 0000000000..c8e3e6918e --- /dev/null +++ b/test/Microsoft.AspNet.Mvc.Razor.Test/Compilation/ExpressionRewriterTest.cs @@ -0,0 +1,491 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.Dnx.Compilation; +using Microsoft.Dnx.Compilation.CSharp; +using Microsoft.Dnx.Runtime; +using Microsoft.Dnx.Runtime.Infrastructure; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNet.Mvc.Razor.Compilation +{ + public class ExpressionRewriterTest + { + [Fact] + public void ExpressionRewriter_CanRewriteExpression_IdentityExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_MemberAccessExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x.Name", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_ChainedMemberAccessExpression() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name.Length); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x.Name.Length", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + var argument = Assert.IsType(Assert.Single(arguments.Arguments).Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_MethodCall() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.GetHashCode()); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_NonArgument() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + Expression> expr = x => x.GetHashCode(); + CalledWithExpression(expr); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CannotRewriteExpression_NestedClass() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + private class Nested + { + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + Expression> expr = x => x.GetHashCode(); + CalledWithExpression(expr); + } + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + Assert.Empty(FindFields(result)); + } + + [Fact] + public void ExpressionRewriter_CanRewriteExpression_AdditionalArguments() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(int x, Expression> expression, string name) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(5, x => x, ""Billy""); + } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var fields = FindFields(result); + + var field = Assert.Single(fields); + Assert.Collection( + field.Modifiers, + m => Assert.Equal("private", m.ToString()), + m => Assert.Equal("static", m.ToString()), + m => Assert.Equal("readonly", m.ToString())); + + var declaration = field.Declaration; + Assert.Equal( + "global::System.Linq.Expressions.Expression>", + declaration.Type.ToString()); + + var variable = Assert.Single(declaration.Variables); + Assert.Equal("__h0", variable.Identifier.ToString()); + Assert.Equal("x => x", variable.Initializer.Value.ToString()); + + var arguments = FindArguments(result); + Assert.Equal(3, arguments.Arguments.Count); + var argument = Assert.IsType(arguments.Arguments[1].Expression); + Assert.Equal("__h0", argument.Identifier.ToString()); + } + + // When we rewrite the expression, we want to maintain the original span as much as possible. + [Fact] + public void ExpressionRewriter_CanRewriteExpression_SimpleFormatting() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression(x => x.Name.Length); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + + var originalArguments = FindArguments(tree.GetRoot()); + var originalSpan = originalArguments.GetLocation().GetMappedLineSpan(); + + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var arguments = FindArguments(result); + Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan()); + } + + // When we rewrite the expression, we want to maintain the original span as much as possible. + [Fact] + public void ExpressionRewriter_CanRewriteExpression_ComplexFormatting() + { + // Arrange + var source = @" +using System; +using System.Linq.Expressions; +public class Program +{ + public static void CalledWithExpression(int z, Expression> expression) + { + } + + public static void Main(string[] args) + { + CalledWithExpression( + 17, + x => + x.Name. + Length + ); + } +} + +public class Person +{ + public string Name { get; set; } +} +"; + + var tree = CSharpSyntaxTree.ParseText(source); + + var originalArguments = FindArguments(tree.GetRoot()); + var originalSpan = originalArguments.GetLocation().GetMappedLineSpan(); + + var compilation = Compile(tree); + var semanticModel = compilation.GetSemanticModel(tree, ignoreAccessibility: true); + + var rewriter = new ExpressionRewriter(semanticModel); + + // Act + var result = rewriter.Visit(tree.GetRoot()); + + // Assert + var arguments = FindArguments(result); + Assert.Equal(originalSpan, arguments.GetLocation().GetMappedLineSpan()); + } + + public ArgumentListSyntax FindArguments(SyntaxNode node) + { + return node + .DescendantNodes(n => true) + .Where(n => n.IsKind(SyntaxKind.ArgumentList)) + .Cast() + .Single(); + } + + public IEnumerable FindFields(SyntaxNode node) + { + return node + .DescendantNodes(n => true) + .Where(n => n.IsKind(SyntaxKind.FieldDeclaration)) + .Cast(); + } + + private CSharpCompilation Compile(SyntaxTree tree) + { + var compilation = CSharpCompilation.Create( + "Test.Assembly", + new[] { tree }, + GetReferences()); + + var diagnostics = compilation.GetDiagnostics(); + if (diagnostics.Length > 0) + { + Assert.False(true, string.Join(Environment.NewLine, diagnostics)); + } + + return compilation; + } + + private IEnumerable GetReferences() + { + var services = CallContextServiceLocator.Locator.ServiceProvider; + var libraryExporter = services.GetRequiredService(); + var environment = services.GetRequiredService(); + + var references = new List(); + + var libraryExports = libraryExporter.GetAllExports(environment.ApplicationName); + foreach (var export in libraryExports.MetadataReferences) + { + references.Add(export.ConvertMetadataReference(MetadataReferenceExtensions.CreateAssemblyMetadata)); + } + + return references; + } + } +}