Add some caching to CORS (#9972)

* Add some caching to CORS
This commit is contained in:
Pranav K 2019-05-13 11:52:16 -07:00 committed by GitHub
commit 3d79b9aa58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 32 deletions

View File

@ -125,11 +125,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
return _next(context);
}
return InvokeCore(context, corsPolicyProvider);
}
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
{
// CORS policy resolution rules:
//
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
@ -157,11 +152,10 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
{
// If this is a preflight request, and we disallow CORS, complete the request
context.Response.StatusCode = StatusCodes.Status204NoContent;
return;
return Task.CompletedTask;
}
await _next(context);
return;
return _next(context);
}
var corsPolicy = _policy;
@ -182,14 +176,30 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
if (corsPolicy == null)
{
// Resolve policy by name if the local policy is not being used
corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName);
var policyTask = corsPolicyProvider.GetPolicyAsync(context, policyName);
if (!policyTask.IsCompletedSuccessfully)
{
return InvokeCoreAwaited(context, policyTask);
}
corsPolicy = policyTask.GetAwaiter().GetResult();
}
return EvaluateAndApplyPolicy(context, corsPolicy);
async Task InvokeCoreAwaited(HttpContext context, Task<CorsPolicy> policyTask)
{
var corsPolicy = await policyTask;
await EvaluateAndApplyPolicy(context, corsPolicy);
}
}
private Task EvaluateAndApplyPolicy(HttpContext context, CorsPolicy corsPolicy)
{
if (corsPolicy == null)
{
Logger?.NoCorsPolicyFound();
await _next(context);
return;
Logger.NoCorsPolicyFound();
return _next(context);
}
var corsResult = CorsService.EvaluatePolicy(context, corsPolicy);
@ -200,12 +210,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
// Since there is a policy which was identified,
// always respond to preflight requests.
context.Response.StatusCode = StatusCodes.Status204NoContent;
return;
return Task.CompletedTask;
}
else
{
context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult));
await _next(context);
return _next(context);
}
}
@ -218,7 +228,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
}
catch (Exception exception)
{
middleware.Logger?.FailedToSetCorsHeaders(exception);
middleware.Logger.FailedToSetCorsHeaders(exception);
}
return Task.CompletedTask;
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Cors.Infrastructure
{
@ -12,22 +13,18 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
public class CorsOptions
{
private string _defaultPolicyName = "__DefaultCorsPolicy";
private IDictionary<string, CorsPolicy> PolicyMap { get; } = new Dictionary<string, CorsPolicy>();
// DefaultCorsPolicyProvider returns a Task<CorsPolicy>. We'll cache the value to be returned alongside
// the actual policy instance to have a separate lookup.
internal IDictionary<string, (CorsPolicy policy, Task<CorsPolicy> policyTask)> PolicyMap { get; }
= new Dictionary<string, (CorsPolicy, Task<CorsPolicy>)>(StringComparer.Ordinal);
public string DefaultPolicyName
{
get
{
return _defaultPolicyName;
}
get => _defaultPolicyName;
set
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}
_defaultPolicyName = value;
_defaultPolicyName = value ?? throw new ArgumentNullException(nameof(value));
}
}
@ -76,7 +73,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(policy));
}
PolicyMap[name] = policy;
PolicyMap[name] = (policy, Task.FromResult(policy));
}
/// <summary>
@ -98,7 +95,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
var policyBuilder = new CorsPolicyBuilder();
configurePolicy(policyBuilder);
PolicyMap[name] = policyBuilder.Build();
var policy = policyBuilder.Build();
PolicyMap[name] = (policy, Task.FromResult(policy));
}
/// <summary>
@ -113,7 +112,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(name));
}
return PolicyMap.ContainsKey(name) ? PolicyMap[name] : null;
if (PolicyMap.TryGetValue(name, out var result))
{
return result.policy;
}
return null;
}
}
}
}

View File

@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
/// <inheritdoc />
public class DefaultCorsPolicyProvider : ICorsPolicyProvider
{
private static readonly Task<CorsPolicy> NullResult = Task.FromResult<CorsPolicy>(null);
private readonly CorsOptions _options;
/// <summary>
@ -30,7 +31,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(context));
}
return Task.FromResult(_options.GetPolicy(policyName ?? _options.DefaultPolicyName));
policyName ??= _options.DefaultPolicyName;
if (_options.PolicyMap.TryGetValue(policyName, out var result))
{
return result.policyTask;
}
return NullResult;
}
}
}
}

View File

@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq;
using Xunit;
@ -564,6 +565,50 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
}
}
[Fact]
public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works()
{
// Arrange
var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance);
var mockProvider = new Mock<ICorsPolicyProvider>();
var loggerFactory = NullLoggerFactory.Instance;
var policy = new CorsPolicyBuilder()
.WithOrigins(OriginUrl)
.WithHeaders("AllowedHeader")
.Build();
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
.ReturnsAsync(policy, TimeSpan.FromMilliseconds(10));
var middleware = new CorsMiddleware(
Mock.Of<RequestDelegate>(),
corsService,
loggerFactory,
"DefaultPolicyName");
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "OPTIONS";
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl });
httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" });
// Act
await middleware.Invoke(httpContext, mockProvider.Object);
// Assert
var response = httpContext.Response;
Assert.Collection(
response.Headers.OrderBy(o => o.Key),
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
},
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
});
}
[Fact]
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
{