Add some caching to CORS (#9972)

* Add some caching to CORS
This commit is contained in:
Pranav K 2019-05-13 11:52:16 -07:00 committed by GitHub
commit 3d79b9aa58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 98 additions and 32 deletions

View File

@ -125,11 +125,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
return _next(context); return _next(context);
} }
return InvokeCore(context, corsPolicyProvider);
}
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
{
// CORS policy resolution rules: // CORS policy resolution rules:
// //
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run // 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
@ -157,11 +152,10 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
{ {
// If this is a preflight request, and we disallow CORS, complete the request // If this is a preflight request, and we disallow CORS, complete the request
context.Response.StatusCode = StatusCodes.Status204NoContent; context.Response.StatusCode = StatusCodes.Status204NoContent;
return; return Task.CompletedTask;
} }
await _next(context); return _next(context);
return;
} }
var corsPolicy = _policy; var corsPolicy = _policy;
@ -182,14 +176,30 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
if (corsPolicy == null) if (corsPolicy == null)
{ {
// Resolve policy by name if the local policy is not being used // Resolve policy by name if the local policy is not being used
corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName); var policyTask = corsPolicyProvider.GetPolicyAsync(context, policyName);
if (!policyTask.IsCompletedSuccessfully)
{
return InvokeCoreAwaited(context, policyTask);
}
corsPolicy = policyTask.GetAwaiter().GetResult();
} }
return EvaluateAndApplyPolicy(context, corsPolicy);
async Task InvokeCoreAwaited(HttpContext context, Task<CorsPolicy> policyTask)
{
var corsPolicy = await policyTask;
await EvaluateAndApplyPolicy(context, corsPolicy);
}
}
private Task EvaluateAndApplyPolicy(HttpContext context, CorsPolicy corsPolicy)
{
if (corsPolicy == null) if (corsPolicy == null)
{ {
Logger?.NoCorsPolicyFound(); Logger.NoCorsPolicyFound();
await _next(context); return _next(context);
return;
} }
var corsResult = CorsService.EvaluatePolicy(context, corsPolicy); var corsResult = CorsService.EvaluatePolicy(context, corsPolicy);
@ -200,12 +210,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
// Since there is a policy which was identified, // Since there is a policy which was identified,
// always respond to preflight requests. // always respond to preflight requests.
context.Response.StatusCode = StatusCodes.Status204NoContent; context.Response.StatusCode = StatusCodes.Status204NoContent;
return; return Task.CompletedTask;
} }
else else
{ {
context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult)); context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult));
await _next(context); return _next(context);
} }
} }
@ -218,7 +228,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
} }
catch (Exception exception) catch (Exception exception)
{ {
middleware.Logger?.FailedToSetCorsHeaders(exception); middleware.Logger.FailedToSetCorsHeaders(exception);
} }
return Task.CompletedTask; return Task.CompletedTask;
} }

View File

@ -3,6 +3,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Cors.Infrastructure namespace Microsoft.AspNetCore.Cors.Infrastructure
{ {
@ -12,22 +13,18 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
public class CorsOptions public class CorsOptions
{ {
private string _defaultPolicyName = "__DefaultCorsPolicy"; private string _defaultPolicyName = "__DefaultCorsPolicy";
private IDictionary<string, CorsPolicy> PolicyMap { get; } = new Dictionary<string, CorsPolicy>();
// DefaultCorsPolicyProvider returns a Task<CorsPolicy>. We'll cache the value to be returned alongside
// the actual policy instance to have a separate lookup.
internal IDictionary<string, (CorsPolicy policy, Task<CorsPolicy> policyTask)> PolicyMap { get; }
= new Dictionary<string, (CorsPolicy, Task<CorsPolicy>)>(StringComparer.Ordinal);
public string DefaultPolicyName public string DefaultPolicyName
{ {
get get => _defaultPolicyName;
{
return _defaultPolicyName;
}
set set
{ {
if (value == null) _defaultPolicyName = value ?? throw new ArgumentNullException(nameof(value));
{
throw new ArgumentNullException(nameof(value));
}
_defaultPolicyName = value;
} }
} }
@ -76,7 +73,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(policy)); throw new ArgumentNullException(nameof(policy));
} }
PolicyMap[name] = policy; PolicyMap[name] = (policy, Task.FromResult(policy));
} }
/// <summary> /// <summary>
@ -98,7 +95,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
var policyBuilder = new CorsPolicyBuilder(); var policyBuilder = new CorsPolicyBuilder();
configurePolicy(policyBuilder); configurePolicy(policyBuilder);
PolicyMap[name] = policyBuilder.Build(); var policy = policyBuilder.Build();
PolicyMap[name] = (policy, Task.FromResult(policy));
} }
/// <summary> /// <summary>
@ -113,7 +112,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(name)); throw new ArgumentNullException(nameof(name));
} }
return PolicyMap.ContainsKey(name) ? PolicyMap[name] : null; if (PolicyMap.TryGetValue(name, out var result))
{
return result.policy;
}
return null;
} }
} }
} }

View File

@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
/// <inheritdoc /> /// <inheritdoc />
public class DefaultCorsPolicyProvider : ICorsPolicyProvider public class DefaultCorsPolicyProvider : ICorsPolicyProvider
{ {
private static readonly Task<CorsPolicy> NullResult = Task.FromResult<CorsPolicy>(null);
private readonly CorsOptions _options; private readonly CorsOptions _options;
/// <summary> /// <summary>
@ -30,7 +31,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(context)); throw new ArgumentNullException(nameof(context));
} }
return Task.FromResult(_options.GetPolicy(policyName ?? _options.DefaultPolicyName)); policyName ??= _options.DefaultPolicyName;
if (_options.PolicyMap.TryGetValue(policyName, out var result))
{
return result.policyTask;
}
return NullResult;
} }
} }
} }

View File

@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Options;
using Moq; using Moq;
using Xunit; using Xunit;
@ -564,6 +565,50 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
} }
} }
[Fact]
public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works()
{
// Arrange
var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance);
var mockProvider = new Mock<ICorsPolicyProvider>();
var loggerFactory = NullLoggerFactory.Instance;
var policy = new CorsPolicyBuilder()
.WithOrigins(OriginUrl)
.WithHeaders("AllowedHeader")
.Build();
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
.ReturnsAsync(policy, TimeSpan.FromMilliseconds(10));
var middleware = new CorsMiddleware(
Mock.Of<RequestDelegate>(),
corsService,
loggerFactory,
"DefaultPolicyName");
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "OPTIONS";
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl });
httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" });
// Act
await middleware.Invoke(httpContext, mockProvider.Object);
// Assert
var response = httpContext.Response;
Assert.Collection(
response.Headers.OrderBy(o => o.Key),
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
},
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
});
}
[Fact] [Fact]
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors() public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
{ {