// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Collections.ObjectModel; using Microsoft.AspNetCore.Mvc.ActionConstraints; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Mvc.Internal { public class HttpMethodActionConstraint : IActionConstraint { public static readonly int HttpMethodConstraintOrder = 100; private readonly IReadOnlyList _httpMethods; private readonly string OriginHeader = "Origin"; private readonly string AccessControlRequestMethod = "Access-Control-Request-Method"; private readonly string PreflightHttpMethod = "OPTIONS"; // Empty collection means any method will be accepted. public HttpMethodActionConstraint(IEnumerable httpMethods) { if (httpMethods == null) { throw new ArgumentNullException(nameof(httpMethods)); } var methods = new List(); foreach (var method in httpMethods) { if (string.IsNullOrEmpty(method)) { throw new ArgumentException("httpMethod cannot be null or empty"); } methods.Add(method); } _httpMethods = new ReadOnlyCollection(methods); } public IEnumerable HttpMethods => _httpMethods; public int Order => HttpMethodConstraintOrder; public bool Accept(ActionConstraintContext context) { if (context == null) { throw new ArgumentNullException(nameof(context)); } if (_httpMethods.Count == 0) { return true; } var request = context.RouteContext.HttpContext.Request; var method = request.Method; // Perf: Check http method before accessing the Headers collection. if (string.Equals(method, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) && request.Headers.ContainsKey(OriginHeader)) { // Update the http method if it is preflight request. var accessControlRequestMethod = request.Headers[AccessControlRequestMethod]; if (!StringValues.IsNullOrEmpty(accessControlRequestMethod)) { method = accessControlRequestMethod; } } for (var i = 0; i < _httpMethods.Count; i++) { var supportedMethod = _httpMethods[i]; if (string.Equals(supportedMethod, method, StringComparison.OrdinalIgnoreCase)) { return true; } } return false; } } }