commit
3d79b9aa58
|
|
@ -125,11 +125,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
return _next(context);
|
||||
}
|
||||
|
||||
return InvokeCore(context, corsPolicyProvider);
|
||||
}
|
||||
|
||||
private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider)
|
||||
{
|
||||
// CORS policy resolution rules:
|
||||
//
|
||||
// 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run
|
||||
|
|
@ -157,11 +152,10 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
{
|
||||
// If this is a preflight request, and we disallow CORS, complete the request
|
||||
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
||||
return;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
await _next(context);
|
||||
return;
|
||||
return _next(context);
|
||||
}
|
||||
|
||||
var corsPolicy = _policy;
|
||||
|
|
@ -182,14 +176,30 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
if (corsPolicy == null)
|
||||
{
|
||||
// Resolve policy by name if the local policy is not being used
|
||||
corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName);
|
||||
var policyTask = corsPolicyProvider.GetPolicyAsync(context, policyName);
|
||||
if (!policyTask.IsCompletedSuccessfully)
|
||||
{
|
||||
return InvokeCoreAwaited(context, policyTask);
|
||||
}
|
||||
|
||||
corsPolicy = policyTask.GetAwaiter().GetResult();
|
||||
}
|
||||
|
||||
return EvaluateAndApplyPolicy(context, corsPolicy);
|
||||
|
||||
async Task InvokeCoreAwaited(HttpContext context, Task<CorsPolicy> policyTask)
|
||||
{
|
||||
var corsPolicy = await policyTask;
|
||||
await EvaluateAndApplyPolicy(context, corsPolicy);
|
||||
}
|
||||
}
|
||||
|
||||
private Task EvaluateAndApplyPolicy(HttpContext context, CorsPolicy corsPolicy)
|
||||
{
|
||||
if (corsPolicy == null)
|
||||
{
|
||||
Logger?.NoCorsPolicyFound();
|
||||
await _next(context);
|
||||
return;
|
||||
Logger.NoCorsPolicyFound();
|
||||
return _next(context);
|
||||
}
|
||||
|
||||
var corsResult = CorsService.EvaluatePolicy(context, corsPolicy);
|
||||
|
|
@ -200,12 +210,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
// Since there is a policy which was identified,
|
||||
// always respond to preflight requests.
|
||||
context.Response.StatusCode = StatusCodes.Status204NoContent;
|
||||
return;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
else
|
||||
{
|
||||
context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult));
|
||||
await _next(context);
|
||||
return _next(context);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -218,7 +228,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
}
|
||||
catch (Exception exception)
|
||||
{
|
||||
middleware.Logger?.FailedToSetCorsHeaders(exception);
|
||||
middleware.Logger.FailedToSetCorsHeaders(exception);
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace Microsoft.AspNetCore.Cors.Infrastructure
|
||||
{
|
||||
|
|
@ -12,22 +13,18 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
public class CorsOptions
|
||||
{
|
||||
private string _defaultPolicyName = "__DefaultCorsPolicy";
|
||||
private IDictionary<string, CorsPolicy> PolicyMap { get; } = new Dictionary<string, CorsPolicy>();
|
||||
|
||||
// DefaultCorsPolicyProvider returns a Task<CorsPolicy>. We'll cache the value to be returned alongside
|
||||
// the actual policy instance to have a separate lookup.
|
||||
internal IDictionary<string, (CorsPolicy policy, Task<CorsPolicy> policyTask)> PolicyMap { get; }
|
||||
= new Dictionary<string, (CorsPolicy, Task<CorsPolicy>)>(StringComparer.Ordinal);
|
||||
|
||||
public string DefaultPolicyName
|
||||
{
|
||||
get
|
||||
{
|
||||
return _defaultPolicyName;
|
||||
}
|
||||
get => _defaultPolicyName;
|
||||
set
|
||||
{
|
||||
if (value == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(value));
|
||||
}
|
||||
|
||||
_defaultPolicyName = value;
|
||||
_defaultPolicyName = value ?? throw new ArgumentNullException(nameof(value));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -76,7 +73,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
throw new ArgumentNullException(nameof(policy));
|
||||
}
|
||||
|
||||
PolicyMap[name] = policy;
|
||||
PolicyMap[name] = (policy, Task.FromResult(policy));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -98,7 +95,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
|
||||
var policyBuilder = new CorsPolicyBuilder();
|
||||
configurePolicy(policyBuilder);
|
||||
PolicyMap[name] = policyBuilder.Build();
|
||||
var policy = policyBuilder.Build();
|
||||
|
||||
PolicyMap[name] = (policy, Task.FromResult(policy));
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -113,7 +112,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
throw new ArgumentNullException(nameof(name));
|
||||
}
|
||||
|
||||
return PolicyMap.ContainsKey(name) ? PolicyMap[name] : null;
|
||||
if (PolicyMap.TryGetValue(name, out var result))
|
||||
{
|
||||
return result.policy;
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
/// <inheritdoc />
|
||||
public class DefaultCorsPolicyProvider : ICorsPolicyProvider
|
||||
{
|
||||
private static readonly Task<CorsPolicy> NullResult = Task.FromResult<CorsPolicy>(null);
|
||||
private readonly CorsOptions _options;
|
||||
|
||||
/// <summary>
|
||||
|
|
@ -30,7 +31,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
throw new ArgumentNullException(nameof(context));
|
||||
}
|
||||
|
||||
return Task.FromResult(_options.GetPolicy(policyName ?? _options.DefaultPolicyName));
|
||||
policyName ??= _options.DefaultPolicyName;
|
||||
if (_options.PolicyMap.TryGetValue(policyName, out var result))
|
||||
{
|
||||
return result.policyTask;
|
||||
}
|
||||
|
||||
return NullResult;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http;
|
|||
using Microsoft.AspNetCore.TestHost;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging.Abstractions;
|
||||
using Microsoft.Extensions.Options;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
|
||||
|
|
@ -564,6 +565,50 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works()
|
||||
{
|
||||
// Arrange
|
||||
var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance);
|
||||
var mockProvider = new Mock<ICorsPolicyProvider>();
|
||||
var loggerFactory = NullLoggerFactory.Instance;
|
||||
var policy = new CorsPolicyBuilder()
|
||||
.WithOrigins(OriginUrl)
|
||||
.WithHeaders("AllowedHeader")
|
||||
.Build();
|
||||
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
|
||||
.ReturnsAsync(policy, TimeSpan.FromMilliseconds(10));
|
||||
|
||||
var middleware = new CorsMiddleware(
|
||||
Mock.Of<RequestDelegate>(),
|
||||
corsService,
|
||||
loggerFactory,
|
||||
"DefaultPolicyName");
|
||||
|
||||
var httpContext = new DefaultHttpContext();
|
||||
httpContext.Request.Method = "OPTIONS";
|
||||
httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl });
|
||||
httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" });
|
||||
|
||||
// Act
|
||||
await middleware.Invoke(httpContext, mockProvider.Object);
|
||||
|
||||
// Assert
|
||||
var response = httpContext.Response;
|
||||
Assert.Collection(
|
||||
response.Headers.OrderBy(o => o.Key),
|
||||
kvp =>
|
||||
{
|
||||
Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key);
|
||||
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
|
||||
},
|
||||
kvp =>
|
||||
{
|
||||
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
|
||||
Assert.Equal(OriginUrl, Assert.Single(kvp.Value));
|
||||
});
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task Invoke_HasEndpointWithNoMetadata_RunsCors()
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue