Refactor CORS support out of MVC Core

This commit is contained in:
Javier Calvarro Nelson 2017-06-28 14:24:49 -07:00
parent 2ef26486dd
commit f2a8c1cea7
8 changed files with 475 additions and 21 deletions

View File

@ -15,10 +15,6 @@ namespace Microsoft.AspNetCore.Mvc.Internal
private readonly IReadOnlyList<string> _httpMethods;
private readonly string OriginHeader = "Origin";
private readonly string AccessControlRequestMethod = "Access-Control-Request-Method";
private readonly string PreflightHttpMethod = "OPTIONS";
// Empty collection means any method will be accepted.
public HttpMethodActionConstraint(IEnumerable<string> httpMethods)
{
@ -46,7 +42,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal
public int Order => HttpMethodConstraintOrder;
public bool Accept(ActionConstraintContext context)
public virtual bool Accept(ActionConstraintContext context)
{
if (context == null)
{
@ -61,18 +57,6 @@ namespace Microsoft.AspNetCore.Mvc.Internal
var request = context.RouteContext.HttpContext.Request;
var method = request.Method;
// Perf: Check http method before accessing the Headers collection.
if (string.Equals(method, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) &&
request.Headers.ContainsKey(OriginHeader))
{
// Update the http method if it is preflight request.
var accessControlRequestMethod = request.Headers[AccessControlRequestMethod];
if (!StringValues.IsNullOrEmpty(accessControlRequestMethod))
{
method = accessControlRequestMethod;
}
}
for (var i = 0; i < _httpMethods.Count; i++)
{
var supportedMethod = _httpMethods[i];

View File

@ -5,6 +5,8 @@ using System;
using System.Linq;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Mvc.ApplicationModels;
using Microsoft.AspNetCore.Mvc.Internal;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
@ -28,6 +30,9 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
throw new ArgumentNullException(nameof(context));
}
var isCorsEnabledGlobally = context.Result.Filters.OfType<ICorsAuthorizationFilter>().Any() ||
context.Result.Filters.OfType<CorsAuthorizationFilterFactory>().Any();
foreach (var controllerModel in context.Result.Controllers)
{
var enableCors = controllerModel.Attributes.OfType<IEnableCorsAttribute>().FirstOrDefault();
@ -42,6 +47,8 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
controllerModel.Filters.Add(new DisableCorsAuthorizationFilter());
}
var corsOnController = enableCors != null || disableCors != null || controllerModel.Filters.OfType<ICorsAuthorizationFilter>().Any();
foreach (var actionModel in controllerModel.Actions)
{
enableCors = actionModel.Attributes.OfType<IEnableCorsAttribute>().FirstOrDefault();
@ -55,6 +62,28 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
actionModel.Filters.Add(new DisableCorsAuthorizationFilter());
}
var corsOnAction = enableCors != null || disableCors != null || actionModel.Filters.OfType<ICorsAuthorizationFilter>().Any();
if (isCorsEnabledGlobally || corsOnController || corsOnAction)
{
UpdateHttpMethodActionConstraint(actionModel);
}
}
}
}
private static void UpdateHttpMethodActionConstraint(ActionModel actionModel)
{
for (var i = 0; i < actionModel.Selectors.Count; i++)
{
var selectorModel = actionModel.Selectors[i];
for (var j = 0; j < selectorModel.ActionConstraints.Count; j++)
{
if (selectorModel.ActionConstraints[j] is HttpMethodActionConstraint httpConstraint)
{
selectorModel.ActionConstraints[j] = new CorsHttpMethodActionConstraint(httpConstraint);
}
}
}
}

View File

@ -0,0 +1,58 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using Microsoft.AspNetCore.Mvc.ActionConstraints;
using Microsoft.Extensions.Primitives;
using Microsoft.AspNetCore.Mvc.Internal;
namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
public class CorsHttpMethodActionConstraint : HttpMethodActionConstraint
{
private readonly string OriginHeader = "Origin";
private readonly string AccessControlRequestMethod = "Access-Control-Request-Method";
private readonly string PreflightHttpMethod = "OPTIONS";
public CorsHttpMethodActionConstraint(HttpMethodActionConstraint constraint)
: base(constraint.HttpMethods)
{
}
public override bool Accept(ActionConstraintContext context)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
var methods = (ReadOnlyCollection<string>)HttpMethods;
if (methods.Count == 0)
{
return true;
}
var request = context.RouteContext.HttpContext.Request;
if (request.Headers.ContainsKey(OriginHeader) &&
string.Equals(request.Method, PreflightHttpMethod, StringComparison.OrdinalIgnoreCase) &&
request.Headers.TryGetValue(AccessControlRequestMethod, out var accessControlRequestMethod) &&
!StringValues.IsNullOrEmpty(accessControlRequestMethod))
{
for (var i = 0; i < methods.Count; i++)
{
var supportedMethod = methods[i];
if (string.Equals(supportedMethod, accessControlRequestMethod, StringComparison.OrdinalIgnoreCase))
{
return true;
}
}
return false;
}
return base.Accept(context);
}
}
}

View File

@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal
[Theory]
[MemberData(nameof(AcceptCaseInsensitiveData))]
public void HttpMethodActionConstraint_Accept_Preflight_CaseInsensitive(IEnumerable<string> httpMethods, string accessControlMethod)
public void HttpMethodActionConstraint_IgnoresPreflightRequests(IEnumerable<string> httpMethods, string accessControlMethod)
{
// Arrange
var constraint = new HttpMethodActionConstraint(httpMethods);
@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal
var result = constraint.Accept(context);
// Assert
Assert.True(result, "Request should have been accepted.");
Assert.False(result, "Request should have been rejected.");
}
[Theory]

View File

@ -1,18 +1,24 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Threading.Tasks;
using System.Reflection;
using Microsoft.AspNetCore.Cors;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Mvc.ApplicationModels;
using Microsoft.AspNetCore.Mvc.Cors.Internal;
using Microsoft.AspNetCore.Mvc.Filters;
using Microsoft.AspNetCore.Mvc.Internal;
using Microsoft.Extensions.Options;
using Moq;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
public class CorsApplicationModelProviderTest
{
[Fact]
public void CreateControllerModel_EnableCorsAttributeAddsCorsAuthorizationFilterFactory()
{
@ -29,6 +35,10 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
// Assert
var model = Assert.Single(context.Result.Controllers);
Assert.Single(model.Filters, f => f is CorsAuthorizationFilterFactory);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
@ -47,6 +57,31 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
// Assert
var model = Assert.Single(context.Result.Controllers);
Assert.Single(model.Filters, f => f is DisableCorsAuthorizationFilter);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void CreateControllerModel_CustomCorsFilter_ReplacesHttpConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(CustomCorsFilterController).GetTypeInfo() });
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var controller = Assert.Single(context.Result.Controllers);
var action = Assert.Single(controller.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
@ -66,6 +101,9 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
var controller = Assert.Single(context.Result.Controllers);
var action = Assert.Single(controller.Actions);
Assert.Single(action.Filters, f => f is CorsAuthorizationFilterFactory);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
@ -85,19 +123,133 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
var controller = Assert.Single(context.Result.Controllers);
var action = Assert.Single(controller.Actions);
Assert.True(action.Filters.Any(f => f is DisableCorsAuthorizationFilter));
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void BuildActionModel_CustomCorsAuthorizationFilterOnAction_ReplacesHttpConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(CustomCorsFilterOnActionController).GetTypeInfo() });
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var controller = Assert.Single(context.Result.Controllers);
var action = Assert.Single(controller.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void CreateControllerModel_EnableCorsGloballyReplacesHttpMethodConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(RegularController).GetTypeInfo() });
context.Result.Filters.Add(new CorsAuthorizationFilter(Mock.Of<ICorsService>(), Mock.Of<ICorsPolicyProvider>()));
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var model = Assert.Single(context.Result.Controllers);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void CreateControllerModel_DisableCorsGloballyReplacesHttpMethodConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(RegularController).GetTypeInfo() });
context.Result.Filters.Add(new DisableCorsAuthorizationFilter());
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var model = Assert.Single(context.Result.Controllers);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void CreateControllerModel_CustomCorsFilterGloballyReplacesHttpMethodConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(RegularController).GetTypeInfo() });
context.Result.Filters.Add(new CustomCorsFilterAttribute());
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var model = Assert.Single(context.Result.Controllers);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsType<CorsHttpMethodActionConstraint>(constraint);
}
[Fact]
public void CreateControllerModel_CorsNotInUseDoesNotOverrideHttpConstraints()
{
// Arrange
var corsProvider = new CorsApplicationModelProvider();
var defaultProvider = new DefaultApplicationModelProvider(new TestOptionsManager<MvcOptions>());
var context = new ApplicationModelProviderContext(new[] { typeof(RegularController).GetTypeInfo() });
defaultProvider.OnProvidersExecuting(context);
// Act
corsProvider.OnProvidersExecuting(context);
// Assert
var model = Assert.Single(context.Result.Controllers);
var action = Assert.Single(model.Actions);
var selector = Assert.Single(action.Selectors);
var constraint = Assert.Single(selector.ActionConstraints, c => c is HttpMethodActionConstraint);
Assert.IsNotType<CorsHttpMethodActionConstraint>(constraint);
}
private class EnableCorsController
{
[EnableCors("policy")]
public void Action()
[HttpGet]
public IActionResult Action()
{
return null;
}
}
private class DisableCorsActionController
{
[DisableCors]
[HttpGet]
public void Action()
{
}
@ -106,11 +258,60 @@ namespace Microsoft.AspNetCore.Mvc.Cors.Internal
[EnableCors("policy")]
public class CorsController
{
[HttpGet]
public IActionResult Action()
{
return null;
}
}
[DisableCors]
public class DisableCorsController
{
[HttpOptions]
public IActionResult Action()
{
return null;
}
}
public class RegularController
{
[HttpPost]
public IActionResult Action()
{
return null;
}
}
[CustomCorsFilter]
public class CustomCorsFilterController
{
[HttpPost]
public IActionResult Action()
{
return null;
}
}
public class CustomCorsFilterOnActionController
{
[HttpPost]
[CustomCorsFilter]
public IActionResult Action()
{
return null;
}
}
public class CustomCorsFilterAttribute : Attribute, ICorsAuthorizationFilter
{
public int Order { get; } = 1000;
public Task OnAuthorizationAsync(AuthorizationFilterContext context)
{
return Task.FromResult(0);
}
}
}
}

View File

@ -0,0 +1,108 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc.Abstractions;
using Microsoft.AspNetCore.Mvc.ActionConstraints;
using Microsoft.AspNetCore.Mvc.Internal;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.Primitives;
using Xunit;
namespace Microsoft.AspNetCore.Mvc.Cors.Internal
{
public class CorsHttpMethodActionConstraintTest
{
public static TheoryData AcceptCaseInsensitiveData =
new TheoryData<IEnumerable<string>, string>
{
{ new string[] { "get", "Get", "GET", "GEt"}, "gEt" },
{ new string[] { "POST", "PoSt", "GEt"}, "GET" },
{ new string[] { "get" }, "get" },
{ new string[] { "post" }, "POST" },
{ new string[] { "gEt" }, "get" },
{ new string[] { "get", "PoST" }, "pOSt" }
};
[Theory]
[MemberData(nameof(AcceptCaseInsensitiveData))]
public void HttpMethodActionConstraint_Accept_Preflight_CaseInsensitive(IEnumerable<string> httpMethods, string accessControlMethod)
{
// Arrange
var constraint = new CorsHttpMethodActionConstraint(new HttpMethodActionConstraint(httpMethods)) as IActionConstraint;
var context = CreateActionConstraintContext(constraint);
context.RouteContext = CreateRouteContext("oPtIoNs", accessControlMethod);
// Act
var result = constraint.Accept(context);
// Assert
Assert.True(result, "Request should have been accepted.");
}
[Fact]
public void HttpMethodActionConstraint_RejectsOptionsRequest_WithoutAccessControlMethod()
{
// Arrange
var constraint = new CorsHttpMethodActionConstraint(new HttpMethodActionConstraint(new[] { "GET", "Post" })) as IActionConstraint;
var context = CreateActionConstraintContext(constraint);
context.RouteContext = CreateRouteContext("oPtIoNs", accessControlMethod: "");
// Act
var result = constraint.Accept(context);
// Assert
Assert.False(result, "Request should have been rejected.");
}
[Theory]
[MemberData(nameof(AcceptCaseInsensitiveData))]
public void HttpMethodActionConstraint_Accept_CaseInsensitive(IEnumerable<string> httpMethods, string expectedMethod)
{
// Arrange
var constraint = new CorsHttpMethodActionConstraint(new HttpMethodActionConstraint(httpMethods)) as IActionConstraint;
var context = CreateActionConstraintContext(constraint);
context.RouteContext = CreateRouteContext(expectedMethod);
// Act
var result = constraint.Accept(context);
// Assert
Assert.True(result, "Request should have been accepted.");
}
private static ActionConstraintContext CreateActionConstraintContext(IActionConstraint constraint)
{
var context = new ActionConstraintContext();
var actionSelectorCandidate = new ActionSelectorCandidate(new ActionDescriptor(), new List<IActionConstraint> { constraint });
context.Candidates = new List<ActionSelectorCandidate> { actionSelectorCandidate };
context.CurrentCandidate = context.Candidates[0];
return context;
}
private static RouteContext CreateRouteContext(string requestedMethod, string accessControlMethod = null)
{
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = requestedMethod;
if (accessControlMethod != null)
{
httpContext.Request.Headers.Add("Origin", StringValues.Empty);
if (accessControlMethod != string.Empty)
{
httpContext.Request.Headers.Add("Access-Control-Request-Method", accessControlMethod);
}
}
var routeContext = new RouteContext(httpContext);
routeContext.RouteData = new RouteData();
return routeContext;
}
}
}

View File

@ -43,6 +43,55 @@ namespace Microsoft.AspNetCore.Mvc.FunctionalTests
Assert.Equal(new[] { "*" }, header.Value.ToArray());
}
[Fact]
public async Task OptionsRequest_NonPreflight_ExecutesOptionsAction()
{
// Arrange
var request = new HttpRequestMessage(new HttpMethod("OPTIONS"), "http://localhost/NonCors/GetOptions");
// Act
var response = await Client.SendAsync(request);
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Equal("[\"Create\",\"Update\",\"Delete\"]", content);
Assert.Empty(response.Headers);
}
[Fact]
public async Task PreflightRequestOnNonCorsEnabledController_ExecutesOptionsAction()
{
// Arrange
var request = new HttpRequestMessage(new HttpMethod("OPTIONS"), "http://localhost/NonCors/GetOptions");
request.Headers.Add(CorsConstants.Origin, "http://example.com");
request.Headers.Add(CorsConstants.AccessControlRequestMethod, "POST");
// Act
var response = await Client.SendAsync(request);
// Assert
Assert.Equal(HttpStatusCode.OK, response.StatusCode);
var content = await response.Content.ReadAsStringAsync();
Assert.Equal("[\"Create\",\"Update\",\"Delete\"]", content);
Assert.Empty(response.Headers);
}
[Fact]
public async Task PreflightRequestOnNonCorsEnabledController_DoesNotMatchTheAction()
{
// Arrange
var request = new HttpRequestMessage(new HttpMethod("OPTIONS"), "http://localhost/NonCors/Post");
request.Headers.Add(CorsConstants.Origin, "http://example.com");
request.Headers.Add(CorsConstants.AccessControlRequestMethod, "POST");
// Act
var response = await Client.SendAsync(request);
// Assert
Assert.Equal(HttpStatusCode.NotFound, response.StatusCode);
}
[Theory]
[InlineData("GET")]
[InlineData("HEAD")]

View File

@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using Microsoft.AspNetCore.Cors;
using Microsoft.AspNetCore.Mvc;
namespace CorsWebSite
{
[Route("NonCors/[action]")]
public class CustomerController : Controller
{
[HttpOptions]
public IEnumerable<string> GetOptions()
{
return new[] { "Create", "Update", "Delete" };
}
[HttpPost]
public IEnumerable<string> Post()
{
return new[] { "customer1", "customer2", "customer3" };
}
}
}