// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; namespace Microsoft.AspNetCore.Mvc.Api.Analyzers { internal static class SymbolApiResponseMetadataProvider { private const string StatusCodeProperty = "StatusCode"; private const string StatusCodeConstructorParameter = "statusCode"; private static readonly IList DefaultResponseMetadatas = new[] { DeclaredApiResponseMetadata.ImplicitResponse, }; public static IList GetDeclaredResponseMetadata( in ApiControllerSymbolCache symbolCache, IMethodSymbol method) { var metadataItems = GetResponseMetadataFromMethodAttributes(symbolCache, method); if (metadataItems.Count != 0) { return metadataItems; } var conventionTypeAttributes = GetConventionTypes(symbolCache, method); metadataItems = GetResponseMetadataFromConventions(symbolCache, method, conventionTypeAttributes); if (metadataItems.Count == 0) { // If no metadata can be gleaned either through explicit attributes on the method or via a convention, // declare an implicit 200 status code. metadataItems = DefaultResponseMetadatas; } return metadataItems; } public static ITypeSymbol GetErrorResponseType( in ApiControllerSymbolCache symbolCache, IMethodSymbol method) { var errorTypeAttribute = method.GetAttributes(symbolCache.ProducesErrorResponseTypeAttribute).FirstOrDefault() ?? method.ContainingType.GetAttributes(symbolCache.ProducesErrorResponseTypeAttribute).FirstOrDefault() ?? method.ContainingAssembly.GetAttributes(symbolCache.ProducesErrorResponseTypeAttribute).FirstOrDefault(); ITypeSymbol errorType = symbolCache.ProblemDetails; if (errorTypeAttribute != null && errorTypeAttribute.ConstructorArguments.Length == 1 && errorTypeAttribute.ConstructorArguments[0].Kind == TypedConstantKind.Type && errorTypeAttribute.ConstructorArguments[0].Value is ITypeSymbol typeSymbol) { errorType = typeSymbol; } return errorType; } private static IList GetResponseMetadataFromConventions( in ApiControllerSymbolCache symbolCache, IMethodSymbol method, IReadOnlyList conventionTypes) { var conventionMethod = GetMethodFromConventionMethodAttribute(symbolCache, method); if (conventionMethod == null) { conventionMethod = MatchConventionMethod(symbolCache, method, conventionTypes); } if (conventionMethod != null) { return GetResponseMetadataFromMethodAttributes(symbolCache, conventionMethod); } return Array.Empty(); } private static IMethodSymbol GetMethodFromConventionMethodAttribute(in ApiControllerSymbolCache symbolCache, IMethodSymbol method) { var attribute = method.GetAttributes(symbolCache.ApiConventionMethodAttribute, inherit: true) .FirstOrDefault(); if (attribute == null) { return null; } if (attribute.ConstructorArguments.Length != 2) { return null; } if (attribute.ConstructorArguments[0].Kind != TypedConstantKind.Type || !(attribute.ConstructorArguments[0].Value is ITypeSymbol conventionType)) { return null; } if (attribute.ConstructorArguments[1].Kind != TypedConstantKind.Primitive || !(attribute.ConstructorArguments[1].Value is string conventionMethodName)) { return null; } var conventionMethod = conventionType.GetMembers(conventionMethodName) .FirstOrDefault(m => m.Kind == SymbolKind.Method && m.IsStatic && m.DeclaredAccessibility == Accessibility.Public); return (IMethodSymbol)conventionMethod; } private static IMethodSymbol MatchConventionMethod( in ApiControllerSymbolCache symbolCache, IMethodSymbol method, IReadOnlyList conventionTypes) { foreach (var conventionType in conventionTypes) { foreach (var conventionMethod in conventionType.GetMembers().OfType()) { if (!conventionMethod.IsStatic || conventionMethod.DeclaredAccessibility != Accessibility.Public) { continue; } if (SymbolApiConventionMatcher.IsMatch(symbolCache, method, conventionMethod)) { return conventionMethod; } } } return null; } private static IList GetResponseMetadataFromMethodAttributes(in ApiControllerSymbolCache symbolCache, IMethodSymbol methodSymbol) { var metadataItems = new List(); var responseMetadataAttributes = methodSymbol.GetAttributes(symbolCache.ProducesResponseTypeAttribute, inherit: true); foreach (var attribute in responseMetadataAttributes) { var statusCode = GetStatusCode(attribute); var metadata = DeclaredApiResponseMetadata.ForProducesResponseType(statusCode, attribute, attributeSource: methodSymbol); metadataItems.Add(metadata); } var producesDefaultResponse = methodSymbol.GetAttributes(symbolCache.ProducesDefaultResponseTypeAttribute, inherit: true).FirstOrDefault(); if (producesDefaultResponse != null) { metadataItems.Add(DeclaredApiResponseMetadata.ForProducesDefaultResponse(producesDefaultResponse, methodSymbol)); } return metadataItems; } internal static IReadOnlyList GetConventionTypes(in ApiControllerSymbolCache symbolCache, IMethodSymbol method) { var attributes = method.ContainingType.GetAttributes(symbolCache.ApiConventionTypeAttribute).ToArray(); if (attributes.Length == 0) { attributes = method.ContainingAssembly.GetAttributes(symbolCache.ApiConventionTypeAttribute).ToArray(); } var conventionTypes = new List(); for (var i = 0; i < attributes.Length; i++) { var attribute = attributes[i]; if (attribute.ConstructorArguments.Length != 1 || attribute.ConstructorArguments[0].Kind != TypedConstantKind.Type || !(attribute.ConstructorArguments[0].Value is ITypeSymbol conventionType)) { continue; } conventionTypes.Add(conventionType); } return conventionTypes; } internal static int GetStatusCode(AttributeData attribute) { const int DefaultStatusCode = 200; for (var i = 0; i < attribute.NamedArguments.Length; i++) { var namedArgument = attribute.NamedArguments[i]; var namedArgumentValue = namedArgument.Value; if (string.Equals(namedArgument.Key, StatusCodeProperty, StringComparison.Ordinal) && namedArgumentValue.Kind == TypedConstantKind.Primitive && (namedArgumentValue.Type.SpecialType & SpecialType.System_Int32) == SpecialType.System_Int32 && namedArgumentValue.Value is int statusCode) { return statusCode; } } if (attribute.AttributeConstructor == null) { return DefaultStatusCode; } var constructorParameters = attribute.AttributeConstructor.Parameters; for (var i = 0; i < constructorParameters.Length; i++) { var parameter = constructorParameters[i]; if (string.Equals(parameter.Name, StatusCodeConstructorParameter, StringComparison.Ordinal) && (parameter.Type.SpecialType & SpecialType.System_Int32) == SpecialType.System_Int32) { if (attribute.ConstructorArguments.Length < i) { return DefaultStatusCode; } var argument = attribute.ConstructorArguments[i]; if (argument.Kind == TypedConstantKind.Primitive && argument.Value is int statusCode) { return statusCode; } } } return DefaultStatusCode; } } }