diff --git a/src/Middleware/CORS/samples/SampleDestination/Program.cs b/src/Middleware/CORS/samples/SampleDestination/Program.cs index 8cef67c306..5a3e1dd7cc 100644 --- a/src/Middleware/CORS/samples/SampleDestination/Program.cs +++ b/src/Middleware/CORS/samples/SampleDestination/Program.cs @@ -1,6 +1,7 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; using Microsoft.AspNetCore.Hosting; using Microsoft.Extensions.Logging; @@ -16,10 +17,31 @@ namespace SampleDestination .UseUrls("http://+:9000") .UseContentRoot(Directory.GetCurrentDirectory()) .ConfigureLogging(factory => factory.AddConsole()) - .UseStartup() + .UseStartup(GetStartupType()) .Build(); host.Run(); } + + private static Type GetStartupType() + { + var startup = Environment.GetEnvironmentVariable("CORS_STARTUP"); + if (startup == null) + { + return typeof(Startup); + } + else + { + switch (startup) + { + case "Startup": + return typeof(Startup); + case "StartupWithoutEndpointRouting": + return typeof(StartupWithoutEndpointRouting); + } + } + + throw new InvalidOperationException("Could not resolve the startup type. Unexpected CORS_STARTUP environment variable."); + } } } diff --git a/src/Middleware/CORS/samples/SampleDestination/Startup.cs b/src/Middleware/CORS/samples/SampleDestination/Startup.cs index af93d605c9..cf306c6eeb 100644 --- a/src/Middleware/CORS/samples/SampleDestination/Startup.cs +++ b/src/Middleware/CORS/samples/SampleDestination/Startup.cs @@ -1,7 +1,6 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Net; using System.Text; using System.Threading.Tasks; @@ -26,66 +25,70 @@ namespace SampleDestination public void ConfigureServices(IServiceCollection services) { - services.AddCors(); - } - - public void Configure(IApplicationBuilder app, IHostingEnvironment env) - { - app.Map("/allow-origin", innerBuilder => + services.AddCors(options => { - innerBuilder.UseCors(policy => policy + options.AddPolicy("AllowOrigin", policy => policy .WithOrigins(DefaultAllowedOrigin) .AllowAnyMethod() .AllowAnyHeader()); - innerBuilder.UseMiddleware(); - }); - - app.Map("/allow-header-method", innerBuilder => - { - innerBuilder.UseCors(policy => policy + options.AddPolicy("AllowHeaderMethod", policy => policy .WithOrigins(DefaultAllowedOrigin) .WithHeaders("X-Test", "Content-Type") .WithMethods("PUT")); - innerBuilder.UseMiddleware(); - }); - - app.Map("/allow-credentials", innerBuilder => - { - innerBuilder.UseCors(policy => policy + options.AddPolicy("AllowCredentials", policy => policy .WithOrigins(DefaultAllowedOrigin) .AllowAnyHeader() .WithMethods("GET", "PUT") .AllowCredentials()); - innerBuilder.UseMiddleware(); - }); - - app.Map("/exposed-header", innerBuilder => - { - innerBuilder.UseCors(policy => policy + options.AddPolicy("ExposedHeader", policy => policy .WithOrigins(DefaultAllowedOrigin) .WithExposedHeaders("X-AllowedHeader", "Content-Length")); - innerBuilder.UseMiddleware(); - }); - - app.Map("/allow-all", innerBuilder => - { - innerBuilder.UseCors(policy => policy + options.AddPolicy("AllowAll", policy => policy .AllowAnyOrigin() .AllowAnyMethod() .AllowAnyHeader() .AllowCredentials()); - - innerBuilder.UseMiddleware(); }); + services.AddRouting(); + } + + public void Configure(IApplicationBuilder app, IHostingEnvironment env) + { + app.UseRouting(routing => + { + routing.Map("/allow-origin", HandleRequest).WithCorsPolicy("AllowOrigin"); + routing.Map("/allow-header-method", HandleRequest).WithCorsPolicy("AllowHeaderMethod"); + routing.Map("/allow-credentials", HandleRequest).WithCorsPolicy("AllowCredentials"); + routing.Map("/exposed-header", HandleRequest).WithCorsPolicy("ExposedHeader"); + routing.Map("/allow-all", HandleRequest).WithCorsPolicy("AllowAll"); + }); + + app.UseCors(); + + app.UseEndpoint(); app.Run(async (context) => { await context.Response.WriteAsync("Hello World!"); }); } + + private Task HandleRequest(HttpContext context) + { + var content = Encoding.UTF8.GetBytes("Hello world"); + + context.Response.Headers["X-AllowedHeader"] = "Test-Value"; + context.Response.Headers["X-DisallowedHeader"] = "Test-Value"; + + context.Response.ContentType = "text/plain; charset=utf-8"; + context.Response.ContentLength = content.Length; + context.Response.Body.Write(content, 0, content.Length); + + return Task.CompletedTask; + } } } diff --git a/src/Middleware/CORS/samples/SampleDestination/StartupWithoutEndpointRouting.cs b/src/Middleware/CORS/samples/SampleDestination/StartupWithoutEndpointRouting.cs new file mode 100644 index 0000000000..fcc817585e --- /dev/null +++ b/src/Middleware/CORS/samples/SampleDestination/StartupWithoutEndpointRouting.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; + +namespace SampleDestination +{ + public class StartupWithoutEndpointRouting + { + private static readonly string DefaultAllowedOrigin = $"http://{Dns.GetHostName()}:9001"; + private readonly ILogger _logger; + + public StartupWithoutEndpointRouting(ILoggerFactory loggerFactory) + { + _logger = loggerFactory.CreateLogger(); + _logger.LogInformation($"Setting up CORS middleware to allow clients on {DefaultAllowedOrigin}"); + } + + public void ConfigureServices(IServiceCollection services) + { + services.AddCors(); + } + + public void Configure(IApplicationBuilder app, IHostingEnvironment env) + { + app.Map("/allow-origin", innerBuilder => + { + innerBuilder.UseCors(policy => policy + .WithOrigins(DefaultAllowedOrigin) + .AllowAnyMethod() + .AllowAnyHeader()); + + innerBuilder.UseMiddleware(); + }); + + app.Map("/allow-header-method", innerBuilder => + { + innerBuilder.UseCors(policy => policy + .WithOrigins(DefaultAllowedOrigin) + .WithHeaders("X-Test", "Content-Type") + .WithMethods("PUT")); + + innerBuilder.UseMiddleware(); + }); + + app.Map("/allow-credentials", innerBuilder => + { + innerBuilder.UseCors(policy => policy + .WithOrigins(DefaultAllowedOrigin) + .AllowAnyHeader() + .WithMethods("GET", "PUT") + .AllowCredentials()); + + innerBuilder.UseMiddleware(); + }); + + app.Map("/exposed-header", innerBuilder => + { + innerBuilder.UseCors(policy => policy + .WithOrigins(DefaultAllowedOrigin) + .WithExposedHeaders("X-AllowedHeader", "Content-Length")); + + innerBuilder.UseMiddleware(); + }); + + app.Map("/allow-all", innerBuilder => + { + innerBuilder.UseCors(policy => policy + .AllowAnyOrigin() + .AllowAnyMethod() + .AllowAnyHeader() + .AllowCredentials()); + + innerBuilder.UseMiddleware(); + }); + + app.Run(async (context) => + { + await context.Response.WriteAsync("Hello World!"); + }); + } + } +} diff --git a/src/Middleware/CORS/samples/SampleOrigin/Program.cs b/src/Middleware/CORS/samples/SampleOrigin/Program.cs index 576e25f0e3..0d7fb8ac92 100644 --- a/src/Middleware/CORS/samples/SampleOrigin/Program.cs +++ b/src/Middleware/CORS/samples/SampleOrigin/Program.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.IO; diff --git a/src/Middleware/CORS/src/CorsPolicyMetadata.cs b/src/Middleware/CORS/src/CorsPolicyMetadata.cs new file mode 100644 index 0000000000..365ff45dbd --- /dev/null +++ b/src/Middleware/CORS/src/CorsPolicyMetadata.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Cors.Infrastructure; + +namespace Microsoft.AspNetCore.Cors +{ + /// + /// Metadata that provides a CORS policy. + /// + public class CorsPolicyMetadata : ICorsPolicyMetadata + { + public CorsPolicyMetadata(CorsPolicy policy) + { + Policy = policy; + } + + /// + /// The policy which needs to be applied. + /// + public CorsPolicy Policy { get; } + } +} diff --git a/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs b/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs new file mode 100644 index 0000000000..d56807d4ee --- /dev/null +++ b/src/Middleware/CORS/src/Infrastructure/CorsEndpointConventionBuilderExtensions.cs @@ -0,0 +1,65 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Cors; +using Microsoft.AspNetCore.Cors.Infrastructure; +using Microsoft.AspNetCore.Routing; + +namespace Microsoft.AspNetCore.Builder +{ + /// + /// CORS extension methods for . + /// + public static class CorsEndpointConventionBuilderExtensions + { + /// + /// Adds a CORS policy with the specified name to the endpoint(s). + /// + /// The endpoint convention builder. + /// The CORS policy name. + /// The original convention builder parameter. + public static IEndpointConventionBuilder WithCorsPolicy(this IEndpointConventionBuilder builder, string policyName) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + builder.Apply(endpointBuilder => + { + endpointBuilder.Metadata.Add(new EnableCorsAttribute(policyName)); + }); + return builder; + } + + /// + /// Adds the specified CORS policy to the endpoint(s). + /// + /// The endpoint convention builder. + /// A delegate which can use a policy builder to build a policy. + /// The original convention builder parameter. + public static IEndpointConventionBuilder WithCorsPolicy(this IEndpointConventionBuilder builder, Action configurePolicy) + { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + + if (configurePolicy == null) + { + throw new ArgumentNullException(nameof(configurePolicy)); + } + + var policyBuilder = new CorsPolicyBuilder(); + configurePolicy(policyBuilder); + var policy = policyBuilder.Build(); + + builder.Apply(endpointBuilder => + { + endpointBuilder.Metadata.Add(new CorsPolicyMetadata(policy)); + }); + return builder; + } + } +} diff --git a/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs b/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs index 0fcd4b947b..64fa4c1138 100644 --- a/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs +++ b/src/Middleware/CORS/src/Infrastructure/CorsMiddleware.cs @@ -5,6 +5,7 @@ using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Cors.Internal; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Endpoints; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -15,6 +16,9 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// public class CorsMiddleware { + // Property key is used by MVC filters to check if CORS middleware has run + private const string CorsMiddlewareInvokedKey = "__CorsMiddlewareInvoked"; + private readonly Func OnResponseStartingDelegate = OnResponseStarting; private readonly RequestDelegate _next; private readonly CorsPolicy _policy; @@ -124,7 +128,49 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure private async Task InvokeCore(HttpContext context, ICorsPolicyProvider corsPolicyProvider) { - var corsPolicy = _policy ?? await corsPolicyProvider.GetPolicyAsync(context, _corsPolicyName); + // CORS policy resolution rules: + // + // 1. If there is an endpoint with IDisableCorsAttribute then CORS is not run + // 2. If there is an endpoint with ICorsPolicyMetadata then use its policy or if + // there is an endpoint with IEnableCorsAttribute that has a policy name then + // fetch policy by name, prioritizing it above policy on middleware + // 3. If there is no policy on middleware then use name on middleware + + // Flag to indicate to other systems, e.g. MVC, that CORS middleware was run for this request + context.Items[CorsMiddlewareInvokedKey] = true; + + var endpoint = context.GetEndpoint(); + + // Get the most significant CORS metadata for the endpoint + // For backwards compatibility reasons this is then downcast to Enable/Disable metadata + var corsMetadata = endpoint?.Metadata.GetMetadata(); + if (corsMetadata is IDisableCorsAttribute) + { + await _next(context); + return; + } + + var corsPolicy = _policy; + var policyName = _corsPolicyName; + if (corsMetadata is ICorsPolicyMetadata corsPolicyMetadata) + { + policyName = null; + corsPolicy = corsPolicyMetadata.Policy; + } + else if (corsMetadata is IEnableCorsAttribute enableCorsAttribute && + enableCorsAttribute.PolicyName != null) + { + // If a policy name has been provided on the endpoint metadata then prioritizing it above the static middleware policy + policyName = enableCorsAttribute.PolicyName; + corsPolicy = null; + } + + if (corsPolicy == null) + { + // Resolve policy by name if the local policy is not being used + corsPolicy = await corsPolicyProvider.GetPolicyAsync(context, policyName); + } + if (corsPolicy == null) { Logger?.NoCorsPolicyFound(); diff --git a/src/Middleware/CORS/src/Infrastructure/ICorsMetadata.cs b/src/Middleware/CORS/src/Infrastructure/ICorsMetadata.cs new file mode 100644 index 0000000000..c5a458093e --- /dev/null +++ b/src/Middleware/CORS/src/Infrastructure/ICorsMetadata.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Cors.Infrastructure +{ + /// + /// A marker interface which can be used to identify CORS metdata. + /// + public interface ICorsMetadata + { + } +} diff --git a/src/Middleware/CORS/src/Infrastructure/ICorsPolicyMetadata.cs b/src/Middleware/CORS/src/Infrastructure/ICorsPolicyMetadata.cs new file mode 100644 index 0000000000..aa5620c196 --- /dev/null +++ b/src/Middleware/CORS/src/Infrastructure/ICorsPolicyMetadata.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Cors.Infrastructure +{ + /// + /// An interface which can be used to identify a type which provides metadata needed for enabling CORS support. + /// + public interface ICorsPolicyMetadata : ICorsMetadata + { + /// + /// The policy which needs to be applied. + /// + CorsPolicy Policy { get; } + } +} diff --git a/src/Middleware/CORS/src/Infrastructure/IDisableCorsAttribute.cs b/src/Middleware/CORS/src/Infrastructure/IDisableCorsAttribute.cs index 1e69ba3da3..4793ea9589 100644 --- a/src/Middleware/CORS/src/Infrastructure/IDisableCorsAttribute.cs +++ b/src/Middleware/CORS/src/Infrastructure/IDisableCorsAttribute.cs @@ -6,7 +6,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// /// An interface which can be used to identify a type which provides metdata to disable cors for a resource. /// - public interface IDisableCorsAttribute + public interface IDisableCorsAttribute : ICorsMetadata { } -} \ No newline at end of file +} diff --git a/src/Middleware/CORS/src/Infrastructure/IEnableCorsAttribute.cs b/src/Middleware/CORS/src/Infrastructure/IEnableCorsAttribute.cs index c58e2a1d96..544a98ea27 100644 --- a/src/Middleware/CORS/src/Infrastructure/IEnableCorsAttribute.cs +++ b/src/Middleware/CORS/src/Infrastructure/IEnableCorsAttribute.cs @@ -6,11 +6,11 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// /// An interface which can be used to identify a type which provides metadata needed for enabling CORS support. /// - public interface IEnableCorsAttribute + public interface IEnableCorsAttribute : ICorsMetadata { /// /// The name of the policy which needs to be applied. /// string PolicyName { get; set; } } -} \ No newline at end of file +} diff --git a/src/Middleware/CORS/src/Microsoft.AspNetCore.Cors.csproj b/src/Middleware/CORS/src/Microsoft.AspNetCore.Cors.csproj index 441e1c5ba6..5fb397467c 100644 --- a/src/Middleware/CORS/src/Microsoft.AspNetCore.Cors.csproj +++ b/src/Middleware/CORS/src/Microsoft.AspNetCore.Cors.csproj @@ -1,4 +1,4 @@ - + CORS middleware and policy for ASP.NET Core to enable cross-origin resource sharing. @@ -13,6 +13,7 @@ Microsoft.AspNetCore.Cors.EnableCorsAttribute + diff --git a/src/Middleware/CORS/test/FunctionalTests/CorsMiddlewareFunctionalTest.cs b/src/Middleware/CORS/test/FunctionalTests/CorsMiddlewareFunctionalTest.cs index bd8a6afe68..55d07ee94f 100644 --- a/src/Middleware/CORS/test/FunctionalTests/CorsMiddlewareFunctionalTest.cs +++ b/src/Middleware/CORS/test/FunctionalTests/CorsMiddlewareFunctionalTest.cs @@ -25,11 +25,13 @@ namespace FunctionalTests public ITestOutputHelper Output { get; } - [Fact] - public async Task RunClientTests() + [Theory] + [InlineData("Startup")] + [InlineData("StartupWithoutEndpointRouting")] + public async Task RunClientTests(string startup) { using (StartLog(out var loggerFactory)) - using (var deploymentResult = await CreateDeployments(loggerFactory)) + using (var deploymentResult = await CreateDeployments(loggerFactory, startup)) { ProcessStartInfo processStartInfo; if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) @@ -58,7 +60,7 @@ namespace FunctionalTests } } - private static async Task CreateDeployments(ILoggerFactory loggerFactory) + private static async Task CreateDeployments(ILoggerFactory loggerFactory, string startup) { var solutionPath = TestPathUtilities.GetSolutionRootDirectory("Middleware"); @@ -78,6 +80,10 @@ namespace FunctionalTests PublishApplicationBeforeDeployment = false, ApplicationType = ApplicationType.Portable, Configuration = configuration, + EnvironmentVariables = + { + ["CORS_STARTUP"] = startup + } }; var destinationFactory = ApplicationDeployerFactory.Create(destinationParameters, loggerFactory); diff --git a/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs b/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs new file mode 100644 index 0000000000..e6692bf210 --- /dev/null +++ b/src/Middleware/CORS/test/UnitTests/CorsEndpointConventionBuilderExtensionsTests.cs @@ -0,0 +1,76 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Xunit; + +namespace Microsoft.AspNetCore.Cors.Infrastructure +{ + public class CorsEndpointConventionBuilderExtensionsTests + { + [Fact] + public void WithCorsPolicy_Name_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.WithCorsPolicy("TestPolicyName"); + + // Assert + var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointModel(); + addCorsPolicy(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.Equal("TestPolicyName", metadata.PolicyName); + } + + [Fact] + public void WithCorsPolicy_Policy_MetadataAdded() + { + // Arrange + var testConventionBuilder = new TestEndpointConventionBuilder(); + + // Act + testConventionBuilder.WithCorsPolicy(builder => builder.AllowAnyOrigin()); + + // Assert + var addCorsPolicy = Assert.Single(testConventionBuilder.Conventions); + + var endpointModel = new TestEndpointModel(); + addCorsPolicy(endpointModel); + var endpoint = endpointModel.Build(); + + var metadata = endpoint.Metadata.GetMetadata(); + Assert.NotNull(metadata); + Assert.NotNull(metadata.Policy); + Assert.True(metadata.Policy.AllowAnyOrigin); + } + + private class TestEndpointModel : EndpointModel + { + public override Endpoint Build() + { + return new Endpoint(RequestDelegate, new EndpointMetadataCollection(Metadata), DisplayName); + } + } + + private class TestEndpointConventionBuilder : IEndpointConventionBuilder + { + public IList> Conventions { get; } = new List>(); + + public void Apply(Action convention) + { + Conventions.Add(convention); + } + } + } +} diff --git a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs index 4aac15de72..f3c95c4fe2 100644 --- a/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs +++ b/src/Middleware/CORS/test/UnitTests/CorsMiddlewareTests.cs @@ -8,8 +8,10 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Endpoints; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Moq; using Xunit; @@ -563,5 +565,256 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure }); } } + + [Fact] + public async Task Invoke_HasEndpointWithNoMetadata_RunsCors() + { + // Arrange + var corsService = Mock.Of(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, EndpointMetadataCollection.Empty, "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), "DefaultPolicyName"), + Times.Once); + } + + [Fact] + public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicyName_RunsCorsWithPolicyName() + { + // Arrange + var corsService = Mock.Of(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), "MetadataPolicyName"), + Times.Once); + } + + [Fact] + public async Task Invoke_HasEndpointWithEnableMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName() + { + // Arrange + var policy = new CorsPolicyBuilder().Build(); + var corsService = Mock.Of(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + policy, + loggerFactory); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName")), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), "MetadataPolicyName"), + Times.Once); + } + + [Fact] + public async Task Invoke_HasEndpointWithCorsPolicyMetadata_MiddlewareHasPolicy_RunsCorsWithPolicyName() + { + // Arrange + var defaultPolicy = new CorsPolicyBuilder().Build(); + var metadataPolicy = new CorsPolicyBuilder().Build(); + var mockCorsService = new Mock(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + mockCorsService.Setup(o => o.EvaluatePolicy(It.IsAny(), It.IsAny())) + .Returns(new CorsResult()) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + mockCorsService.Object, + defaultPolicy, + loggerFactory); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new CorsPolicyMetadata(metadataPolicy)), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), It.IsAny()), + Times.Never); + mockCorsService.Verify( + o => o.EvaluatePolicy(It.IsAny(), metadataPolicy), + Times.Once); + } + + [Fact] + public async Task Invoke_HasEndpointWithEnableMetadataWithNoName_RunsCorsWithStaticPolicy() + { + // Arrange + var policy = new CorsPolicyBuilder().Build(); + var mockCorsService = new Mock(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + mockCorsService.Setup(o => o.EvaluatePolicy(It.IsAny(), It.IsAny())) + .Returns(new CorsResult()) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + mockCorsService.Object, + policy, + loggerFactory); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute()), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), It.IsAny()), + Times.Never); + mockCorsService.Verify( + o => o.EvaluatePolicy(It.IsAny(), policy), + Times.Once); + } + + [Fact] + public async Task Invoke_HasEndpointWithDisableMetadata_SkipCors() + { + // Arrange + var corsService = Mock.Of(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new DisableCorsAttribute()), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task Invoke_HasEndpointWithMutlipleMetadata_SkipCorsBecauseOfMetadataOrder() + { + // Arrange + var corsService = Mock.Of(); + var mockProvider = new Mock(); + var loggerFactory = NullLoggerFactory.Instance; + mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) + .Returns(Task.FromResult(null)) + .Verifiable(); + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider.Object); + + // Assert + mockProvider.Verify( + o => o.GetPolicyAsync(It.IsAny(), It.IsAny()), + Times.Never); + } + + [Fact] + public async Task Invoke_InvokeFlagSet() + { + // Arrange + var corsService = Mock.Of(); + var mockProvider = Mock.Of(); + var loggerFactory = NullLoggerFactory.Instance; + + var middleware = new CorsMiddleware( + Mock.Of(), + corsService, + loggerFactory, + "DefaultPolicyName"); + + var httpContext = new DefaultHttpContext(); + httpContext.SetEndpoint(new Endpoint(c => Task.CompletedTask, new EndpointMetadataCollection(new EnableCorsAttribute("MetadataPolicyName"), new DisableCorsAttribute()), "Test endpoint")); + httpContext.Request.Headers.Add(CorsConstants.Origin, new[] { "http://example.com" }); + + // Act + await middleware.Invoke(httpContext, mockProvider); + + // Assert + Assert.Contains(httpContext.Items, item => string.Equals(item.Key as string, "__CorsMiddlewareInvoked")); + } } } diff --git a/src/Middleware/Middleware.sln b/src/Middleware/Middleware.sln index 92742c634a..72b5f0ad70 100644 --- a/src/Middleware/Middleware.sln +++ b/src/Middleware/Middleware.sln @@ -1,7 +1,7 @@  Microsoft Visual Studio Solution File, Format Version 12.00 -# Visual Studio 15 -VisualStudioVersion = 15.0.26124.0 +# Visual Studio Version 16 +VisualStudioVersion = 16.0.28407.52 MinimumVisualStudioVersion = 15.0.26124.0 Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "WebSockets", "WebSockets", "{E0D9867D-C23D-43EB-8D9C-DE0398A25432}" EndProject @@ -245,6 +245,12 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.Crypto EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Microsoft.AspNetCore.DataProtection.Abstractions", "..\DataProtection\Abstractions\src\Microsoft.AspNetCore.DataProtection.Abstractions.csproj", "{7343B4E4-C5A2-49E2-B431-4D1E6A26E424}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "FunctionalTests", "CORS\test\FunctionalTests\FunctionalTests.csproj", "{E025D98E-BD85-474A-98A9-E7F44F392F8E}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SampleDestination", "CORS\samples\SampleDestination\SampleDestination.csproj", "{52CDD110-77DD-4C4D-8C72-4570F6EF20DD}" +EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "SampleOrigin", "CORS\samples\SampleOrigin\SampleOrigin.csproj", "{198FFE3B-0346-4856-A6C9-8752D51C4EB3}" +EndProject Global GlobalSection(SolutionConfigurationPlatforms) = preSolution Debug|Any CPU = Debug|Any CPU @@ -1335,6 +1341,42 @@ Global {7343B4E4-C5A2-49E2-B431-4D1E6A26E424}.Release|x64.Build.0 = Release|Any CPU {7343B4E4-C5A2-49E2-B431-4D1E6A26E424}.Release|x86.ActiveCfg = Release|Any CPU {7343B4E4-C5A2-49E2-B431-4D1E6A26E424}.Release|x86.Build.0 = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|Any CPU.Build.0 = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|x64.ActiveCfg = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|x64.Build.0 = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|x86.ActiveCfg = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Debug|x86.Build.0 = Debug|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|Any CPU.ActiveCfg = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|Any CPU.Build.0 = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|x64.ActiveCfg = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|x64.Build.0 = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|x86.ActiveCfg = Release|Any CPU + {E025D98E-BD85-474A-98A9-E7F44F392F8E}.Release|x86.Build.0 = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|Any CPU.Build.0 = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|x64.ActiveCfg = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|x64.Build.0 = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|x86.ActiveCfg = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Debug|x86.Build.0 = Debug|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|Any CPU.ActiveCfg = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|Any CPU.Build.0 = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|x64.ActiveCfg = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|x64.Build.0 = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|x86.ActiveCfg = Release|Any CPU + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD}.Release|x86.Build.0 = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|Any CPU.Build.0 = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|x64.ActiveCfg = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|x64.Build.0 = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|x86.ActiveCfg = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Debug|x86.Build.0 = Debug|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|Any CPU.ActiveCfg = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|Any CPU.Build.0 = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|x64.ActiveCfg = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|x64.Build.0 = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|x86.ActiveCfg = Release|Any CPU + {198FFE3B-0346-4856-A6C9-8752D51C4EB3}.Release|x86.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -1441,6 +1483,9 @@ Global {3AD5B221-C718-4F14-883A-4345DC90CF9C} = {ACA6DDB9-7592-47CE-A740-D15BF307E9E0} {227030D6-99AD-4C6A-AE70-1333BCBE8705} = {ACA6DDB9-7592-47CE-A740-D15BF307E9E0} {7343B4E4-C5A2-49E2-B431-4D1E6A26E424} = {ACA6DDB9-7592-47CE-A740-D15BF307E9E0} + {E025D98E-BD85-474A-98A9-E7F44F392F8E} = {4967DE1B-FEC2-4C2B-8F7F-6262D67C9434} + {52CDD110-77DD-4C4D-8C72-4570F6EF20DD} = {7CF63806-4C4F-4C48-8922-A75113975308} + {198FFE3B-0346-4856-A6C9-8752D51C4EB3} = {7CF63806-4C4F-4C48-8922-A75113975308} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {83786312-A93B-4BB4-AB06-7C6913A59AFA}