commit
3d79b9aa58
|
|
@ -125,11 +125,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
return _next(context);
|
return _next(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
return InvokeCore(context, corsPolicyProvider);
|
|
||||||
}
|
|
||||||
|
|
||||||
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
|
|
||||||
{
|
|
||||||
// CORS policy resolution rules:
|
// CORS policy resolution rules:
|
||||||
//
|
//
|
||||||
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
|
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
|
||||||
|
|
@ -157,11 +152,10 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
{
|
{
|
||||||
// If this is a preflight request, and we disallow CORS, complete the request
|
// If this is a preflight request, and we disallow CORS, complete the request
|
||||||
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
||||||
return;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
||||||
await _next(context);
|
return _next(context);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var corsPolicy = _policy;
|
var corsPolicy = _policy;
|
||||||
|
|
@ -182,14 +176,30 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
if (corsPolicy == null)
|
if (corsPolicy == null)
|
||||||
{
|
{
|
||||||
// Resolve policy by name if the local policy is not being used
|
// Resolve policy by name if the local policy is not being used
|
||||||
corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName);
|
var policyTask = corsPolicyProvider.GetPolicyAsync(context, policyName);
|
||||||
|
if (!policyTask.IsCompletedSuccessfully)
|
||||||
|
{
|
||||||
|
return InvokeCoreAwaited(context, policyTask);
|
||||||
|
}
|
||||||
|
|
||||||
|
corsPolicy = policyTask.GetAwaiter().GetResult();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return EvaluateAndApplyPolicy(context, corsPolicy);
|
||||||
|
|
||||||
|
async Task InvokeCoreAwaited(HttpContext context, Task<CorsPolicy> policyTask)
|
||||||
|
{
|
||||||
|
var corsPolicy = await policyTask;
|
||||||
|
await EvaluateAndApplyPolicy(context, corsPolicy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private Task EvaluateAndApplyPolicy(HttpContext context, CorsPolicy corsPolicy)
|
||||||
|
{
|
||||||
if (corsPolicy == null)
|
if (corsPolicy == null)
|
||||||
{
|
{
|
||||||
Logger?.NoCorsPolicyFound();
|
Logger.NoCorsPolicyFound();
|
||||||
await _next(context);
|
return _next(context);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var corsResult = CorsService.EvaluatePolicy(context, corsPolicy);
|
var corsResult = CorsService.EvaluatePolicy(context, corsPolicy);
|
||||||
|
|
@ -200,12 +210,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
// Since there is a policy which was identified,
|
// Since there is a policy which was identified,
|
||||||
// always respond to preflight requests.
|
// always respond to preflight requests.
|
||||||
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
||||||
return;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult));
|
context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult));
|
||||||
await _next(context);
|
return _next(context);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -218,7 +228,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
}
|
}
|
||||||
catch (Exception exception)
|
catch (Exception exception)
|
||||||
{
|
{
|
||||||
middleware.Logger?.FailedToSetCorsHeaders(exception);
|
middleware.Logger.FailedToSetCorsHeaders(exception);
|
||||||
}
|
}
|
||||||
return Task.CompletedTask;
|
return Task.CompletedTask;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
|
|
||||||
using System;
|
using System;
|
||||||
using System.Collections.Generic;
|
using System.Collections.Generic;
|
||||||
|
using System.Threading.Tasks;
|
||||||
|
|
||||||
namespace Microsoft.AspNetCore.Cors.Infrastructure
|
namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
{
|
{
|
||||||
|
|
@ -12,22 +13,18 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
public class CorsOptions
|
public class CorsOptions
|
||||||
{
|
{
|
||||||
private string _defaultPolicyName = "__DefaultCorsPolicy";
|
private string _defaultPolicyName = "__DefaultCorsPolicy";
|
||||||
private IDictionary<string, CorsPolicy> PolicyMap { get; } = new Dictionary<string, CorsPolicy>();
|
|
||||||
|
// DefaultCorsPolicyProvider returns a Task<CorsPolicy>. We'll cache the value to be returned alongside
|
||||||
|
// the actual policy instance to have a separate lookup.
|
||||||
|
internal IDictionary<string, (CorsPolicy policy, Task<CorsPolicy> policyTask)> PolicyMap { get; }
|
||||||
|
= new Dictionary<string, (CorsPolicy, Task<CorsPolicy>)>(StringComparer.Ordinal);
|
||||||
|
|
||||||
public string DefaultPolicyName
|
public string DefaultPolicyName
|
||||||
{
|
{
|
||||||
get
|
get => _defaultPolicyName;
|
||||||
{
|
|
||||||
return _defaultPolicyName;
|
|
||||||
}
|
|
||||||
set
|
set
|
||||||
{
|
{
|
||||||
if (value == null)
|
_defaultPolicyName = value ?? throw new ArgumentNullException(nameof(value));
|
||||||
{
|
|
||||||
throw new ArgumentNullException(nameof(value));
|
|
||||||
}
|
|
||||||
|
|
||||||
_defaultPolicyName = value;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -76,7 +73,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
throw new ArgumentNullException(nameof(policy));
|
throw new ArgumentNullException(nameof(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
PolicyMap[name] = policy;
|
PolicyMap[name] = (policy, Task.FromResult(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -98,7 +95,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
|
|
||||||
var policyBuilder = new CorsPolicyBuilder();
|
var policyBuilder = new CorsPolicyBuilder();
|
||||||
configurePolicy(policyBuilder);
|
configurePolicy(policyBuilder);
|
||||||
PolicyMap[name] = policyBuilder.Build();
|
var policy = policyBuilder.Build();
|
||||||
|
|
||||||
|
PolicyMap[name] = (policy, Task.FromResult(policy));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -113,7 +112,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
throw new ArgumentNullException(nameof(name));
|
throw new ArgumentNullException(nameof(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
return PolicyMap.ContainsKey(name) ? PolicyMap[name] : null;
|
if (PolicyMap.TryGetValue(name, out var result))
|
||||||
|
{
|
||||||
|
return result.policy;
|
||||||
|
}
|
||||||
|
|
||||||
|
return null;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
/// <inheritdoc />
|
/// <inheritdoc />
|
||||||
public class DefaultCorsPolicyProvider : ICorsPolicyProvider
|
public class DefaultCorsPolicyProvider : ICorsPolicyProvider
|
||||||
{
|
{
|
||||||
|
private static readonly Task<CorsPolicy> NullResult = Task.FromResult<CorsPolicy>(null);
|
||||||
private readonly CorsOptions _options;
|
private readonly CorsOptions _options;
|
||||||
|
|
||||||
/// <summary>
|
/// <summary>
|
||||||
|
|
@ -30,7 +31,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
throw new ArgumentNullException(nameof(context));
|
throw new ArgumentNullException(nameof(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
return Task.FromResult(_options.GetPolicy(policyName ?? _options.DefaultPolicyName));
|
policyName ??= _options.DefaultPolicyName;
|
||||||
|
if (_options.PolicyMap.TryGetValue(policyName, out var result))
|
||||||
|
{
|
||||||
|
return result.policyTask;
|
||||||
|
}
|
||||||
|
|
||||||
|
return NullResult;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http;
|
||||||
using Microsoft.AspNetCore.TestHost;
|
using Microsoft.AspNetCore.TestHost;
|
||||||
using Microsoft.Extensions.DependencyInjection;
|
using Microsoft.Extensions.DependencyInjection;
|
||||||
using Microsoft.Extensions.Logging.Abstractions;
|
using Microsoft.Extensions.Logging.Abstractions;
|
||||||
|
using Microsoft.Extensions.Options;
|
||||||
using Moq;
|
using Moq;
|
||||||
using Xunit;
|
using Xunit;
|
||||||
|
|
||||||
|
|
@ -564,6 +565,50 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[Fact]
|
||||||
|
public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works()
|
||||||
|
{
|
||||||
|
// Arrange
|
||||||
|
var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance);
|
||||||
|
var mockProvider = new Mock<ICorsPolicyProvider>();
|
||||||
|
var loggerFactory = NullLoggerFactory.Instance;
|
||||||
|
var policy = new CorsPolicyBuilder()
|
||||||
|
.WithOrigins(OriginUrl)
|
||||||
|
.WithHeaders("AllowedHeader")
|
||||||
|
.Build();
|
||||||
|
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
|
||||||
|
.ReturnsAsync(policy, TimeSpan.FromMilliseconds(10));
|
||||||
|
|
||||||
|
var middleware = new CorsMiddleware(
|
||||||
|
Mock.Of<RequestDelegate>(),
|
||||||
|
corsService,
|
||||||
|
loggerFactory,
|
||||||
|
"DefaultPolicyName");
|
||||||
|
|
||||||
|
var httpContext = new DefaultHttpContext();
|
||||||
|
httpContext.Request.Method = "OPTIONS";
|
||||||
|
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl });
|
||||||
|
httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" });
|
||||||
|
|
||||||
|
// Act
|
||||||
|
await middleware.Invoke(httpContext, mockProvider.Object);
|
||||||
|
|
||||||
|
// Assert
|
||||||
|
var response = httpContext.Response;
|
||||||
|
Assert.Collection(
|
||||||
|
response.Headers.OrderBy(o => o.Key),
|
||||||
|
kvp =>
|
||||||
|
{
|
||||||
|
Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
|
||||||
|
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
|
||||||
|
},
|
||||||
|
kvp =>
|
||||||
|
{
|
||||||
|
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
|
||||||
|
Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
[Fact]
|
[Fact]
|
||||||
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
|
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue