diff --git a/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs b/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs index c54e4e85c3..d376370cc6 100644 --- a/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs +++ b/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs @@ -125,11 +125,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure return _next(context); } - return InvokeCore(context, corsPolicyProvider); - } - - private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider) - { // CORS policy resolution rules: // // 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run @@ -157,11 +152,10 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure { // If this is a preflight request, and we disallow CORS, complete the request context.Response.StatusCode = StatusCodes.Status204NoContent; - return; + return Task.CompletedTask; } - await _next(context); - return; + return _next(context); } var corsPolicy = _policy; @@ -182,14 +176,30 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure if (corsPolicy == null) { // Resolve policy by name if the local policy is not being used - corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName); + var policyTask = corsPolicyProvider.GetPolicyAsync(context, policyName); + if (!policyTask.IsCompletedSuccessfully) + { + return InvokeCoreAwaited(context, policyTask); + } + + corsPolicy = policyTask.GetAwaiter().GetResult(); } + return EvaluateAndApplyPolicy(context, corsPolicy); + + async Task InvokeCoreAwaited(HttpContext context, Task policyTask) + { + var corsPolicy = await policyTask; + await EvaluateAndApplyPolicy(context, corsPolicy); + } + } + + private Task EvaluateAndApplyPolicy(HttpContext context, CorsPolicy corsPolicy) + { if (corsPolicy == null) { - Logger?.NoCorsPolicyFound(); - await _next(context); - return; + Logger.NoCorsPolicyFound(); + return _next(context); } var corsResult = CorsService.EvaluatePolicy(context, corsPolicy); @@ -200,12 +210,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure // Since there is a policy which was identified, // always respond to preflight requests. context.Response.StatusCode = StatusCodes.Status204NoContent; - return; + return Task.CompletedTask; } else { context.Response.OnStarting(OnResponseStartingDelegate, Tuple.Create(this, context, corsResult)); - await _next(context); + return _next(context); } } @@ -218,7 +228,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure } catch (Exception exception) { - middleware.Logger?.FailedToSetCorsHeaders(exception); + middleware.Logger.FailedToSetCorsHeaders(exception); } return Task.CompletedTask; } diff --git a/src/Middleware/CORS/src/Infrastructure/CorsOptions.cs b/src/Middleware/CORS/src/Infrastructure/CorsOptions.cs index 92e1f775ba..df32568b07 100644 --- a/src/Middleware/CORS/src/Infrastructure/CorsOptions.cs +++ b/src/Middleware/CORS/src/Infrastructure/CorsOptions.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Cors.Infrastructure { @@ -12,22 +13,18 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure public class CorsOptions { private string _defaultPolicyName = "__DefaultCorsPolicy"; - private IDictionary PolicyMap { get; } = new Dictionary(); + + // DefaultCorsPolicyProvider returns a Task. We'll cache the value to be returned alongside + // the actual policy instance to have a separate lookup. + internal IDictionary policyTask)> PolicyMap { get; } + = new Dictionary)>(StringComparer.Ordinal); public string DefaultPolicyName { - get - { - return _defaultPolicyName; - } + get => _defaultPolicyName; set { - if (value == null) - { - throw new ArgumentNullException(nameof(value)); - } - - _defaultPolicyName = value; + _defaultPolicyName = value ?? throw new ArgumentNullException(nameof(value)); } } @@ -76,7 +73,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure throw new ArgumentNullException(nameof(policy)); } - PolicyMap[name] = policy; + PolicyMap[name] = (policy, Task.FromResult(policy)); } /// @@ -98,7 +95,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure var policyBuilder = new CorsPolicyBuilder(); configurePolicy(policyBuilder); - PolicyMap[name] = policyBuilder.Build(); + var policy = policyBuilder.Build(); + + PolicyMap[name] = (policy, Task.FromResult(policy)); } /// @@ -113,7 +112,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure throw new ArgumentNullException(nameof(name)); } - return PolicyMap.ContainsKey(name) ? PolicyMap[name] : null; + if (PolicyMap.TryGetValue(name, out var result)) + { + return result.policy; + } + + return null; } } -} \ No newline at end of file +} diff --git a/src/Middleware/CORS/src/Infrastructure/DefaultCorsPolicyProvider.cs b/src/Middleware/CORS/src/Infrastructure/DefaultCorsPolicyProvider.cs index 4841d60f75..e9e5ca9154 100644 --- a/src/Middleware/CORS/src/Infrastructure/DefaultCorsPolicyProvider.cs +++ b/src/Middleware/CORS/src/Infrastructure/DefaultCorsPolicyProvider.cs @@ -11,6 +11,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// public class DefaultCorsPolicyProvider : ICorsPolicyProvider { + private static readonly Task NullResult = Task.FromResult(null); private readonly CorsOptions _options; /// @@ -30,7 +31,13 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure throw new ArgumentNullException(nameof(context)); } - return Task.FromResult(_options.GetPolicy(policyName ?? _options.DefaultPolicyName)); + policyName ??= _options.DefaultPolicyName; + if (_options.PolicyMap.TryGetValue(policyName, out var result)) + { + return result.policyTask; + } + + return NullResult; } } -} \ No newline at end of file +} diff --git a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs index 1a470da10d..ebd44aa2d0 100644 --- a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs +++ b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; using Moq; using Xunit; @@ -564,6 +565,50 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure } } + [Fact] + public async Task Invoke_WithCustomPolicyProviderThatReturnsAsynchronously_Works() + { + // Arrange + var corsService = new CorsService(Options.Create(new CorsOptions()), NullLoggerFactory.Instance); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + var policy = new CorsPolicyBuilder() + .WithOrigins(OriginUrl) + .WithHeaders("AllowedHeader") + .Build(); + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(policy, TimeSpan.FromMilliseconds(10)); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.Request.Method = "OPTIONS"; + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { OriginUrl }); + httpContext.Request.Headers.Add(CorsConstants.AccessControlRequestMethod, new[] { "PUT" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + var response = httpContext.Response; + Assert.Collection( + response.Headers.OrderBy(o => o.Key), + kvp => + { + Assert.Equal(CorsConstants.AccessControlAllowHeaders, kvp.Key); + Assert.Equal("AllowedHeader", Assert.Single(kvp.Value)); + }, + kvp => + { + Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key); + Assert.Equal(OriginUrl, Assert.Single(kvp.Value)); + }); + } + [Fact] public async Task Invoke_HasEndpointWithNoMetadata_RunsCors() {