diff --git a/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs b/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs index eb773f5abf..fb3df360e6 100644 --- a/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs +++ b/src/Microsoft.AspNetCore.Cors/Infrastructure/CorsMiddleware.cs @@ -3,7 +3,10 @@ using System; using System.Threading.Tasks; +using Microsoft.AspNetCore.Cors.Internal; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Primitives; namespace Microsoft.AspNetCore.Cors.Infrastructure @@ -18,6 +21,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure private readonly ICorsPolicyProvider _corsPolicyProvider; private readonly CorsPolicy _policy; private readonly string _corsPolicyName; + private readonly ILogger _logger; /// /// Instantiates a new . @@ -25,11 +29,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// The next middleware in the pipeline. /// An instance of . /// A policy provider which can get an . + [Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")] public CorsMiddleware( RequestDelegate next, ICorsService corsService, ICorsPolicyProvider policyProvider) - : this(next, corsService, policyProvider, policyName: null) + : this(next, corsService, policyProvider, NullLoggerFactory.Instance, policyName: null) { } @@ -40,11 +45,61 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// An instance of . /// A policy provider which can get an . /// An optional name of the policy to be fetched. + [Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")] public CorsMiddleware( RequestDelegate next, ICorsService corsService, ICorsPolicyProvider policyProvider, string policyName) + : this(next, corsService, policyProvider, NullLoggerFactory.Instance, policyName) + { + } + + /// + /// Instantiates a new . + /// + /// The next middleware in the pipeline. + /// An instance of . + /// An instance of the which can be applied. + [Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")] + public CorsMiddleware( + RequestDelegate next, + ICorsService corsService, + CorsPolicy policy) + : this(next, corsService, policy, NullLoggerFactory.Instance) + { + } + + /// + /// Instantiates a new . + /// + /// The next middleware in the pipeline. + /// An instance of . + /// A policy provider which can get an . + /// An instance of . + public CorsMiddleware( + RequestDelegate next, + ICorsService corsService, + ICorsPolicyProvider policyProvider, + ILoggerFactory loggerFactory) + : this(next, corsService, policyProvider, loggerFactory, policyName: null) + { + } + + /// + /// Instantiates a new . + /// + /// The next middleware in the pipeline. + /// An instance of . + /// A policy provider which can get an . + /// An instance of . + /// An optional name of the policy to be fetched. + public CorsMiddleware( + RequestDelegate next, + ICorsService corsService, + ICorsPolicyProvider policyProvider, + ILoggerFactory loggerFactory, + string policyName) { if (next == null) { @@ -61,10 +116,16 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure throw new ArgumentNullException(nameof(policyProvider)); } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + _next = next; _corsService = corsService; _corsPolicyProvider = policyProvider; _corsPolicyName = policyName; + _logger = loggerFactory.CreateLogger(); } /// @@ -73,10 +134,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure /// The next middleware in the pipeline. /// An instance of . /// An instance of the which can be applied. + /// An instance of . public CorsMiddleware( - RequestDelegate next, - ICorsService corsService, - CorsPolicy policy) + RequestDelegate next, + ICorsService corsService, + CorsPolicy policy, + ILoggerFactory loggerFactory) { if (next == null) { @@ -93,9 +156,15 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure throw new ArgumentNullException(nameof(policy)); } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + _next = next; _corsService = corsService; _policy = policy; + _logger = loggerFactory.CreateLogger(); } /// @@ -106,9 +175,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure var corsPolicy = _policy ?? await _corsPolicyProvider?.GetPolicyAsync(context, _corsPolicyName); if (corsPolicy != null) { - var corsResult = _corsService.EvaluatePolicy(context, corsPolicy); - _corsService.ApplyResult(corsResult, context.Response); - var accessControlRequestMethod = context.Request.Headers[CorsConstants.AccessControlRequestMethod]; if (string.Equals( @@ -117,15 +183,39 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure StringComparison.OrdinalIgnoreCase) && !StringValues.IsNullOrEmpty(accessControlRequestMethod)) { + ApplyCorsHeaders(context, corsPolicy); + // Since there is a policy which was identified, // always respond to preflight requests. context.Response.StatusCode = StatusCodes.Status204NoContent; return; } + else + { + context.Response.OnStarting(state => + { + var (httpContext, policy) = (Tuple)state; + try + { + ApplyCorsHeaders(httpContext, policy); + } + catch (Exception exception) + { + _logger.FailedToSetCorsHeaders(exception); + } + return Task.CompletedTask; + }, Tuple.Create(context, corsPolicy)); + } } } await _next(context); } + + private void ApplyCorsHeaders(HttpContext context, CorsPolicy corsPolicy) + { + var corsResult = _corsService.EvaluatePolicy(context, corsPolicy); + _corsService.ApplyResult(corsResult, context.Response); + } } -} \ No newline at end of file +} diff --git a/src/Microsoft.AspNetCore.Cors/Internal/CORSLoggerExtensions.cs b/src/Microsoft.AspNetCore.Cors/Internal/CORSLoggerExtensions.cs index 727d19a4ea..2cf30d7525 100644 --- a/src/Microsoft.AspNetCore.Cors/Internal/CORSLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Cors/Internal/CORSLoggerExtensions.cs @@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.Cors.Internal private static readonly Action _originNotAllowed; private static readonly Action _accessControlMethodNotAllowed; private static readonly Action _requestHeaderNotAllowed; + private static readonly Action _failedToSetCorsHeaders; static CORSLoggerExtensions() { @@ -58,6 +59,11 @@ namespace Microsoft.AspNetCore.Cors.Internal LogLevel.Information, 8, "Request header '{requestHeader}' not allowed in CORS policy."); + + _failedToSetCorsHeaders = LoggerMessage.Define( + LogLevel.Warning, + 9, + "Failed to apply CORS Response headers."); } public static void IsPreflightRequest(this ILogger logger) @@ -99,5 +105,10 @@ namespace Microsoft.AspNetCore.Cors.Internal { _requestHeaderNotAllowed(logger, requestHeader, null); } + + public static void FailedToSetCorsHeaders(this ILogger logger, Exception exception) + { + _failedToSetCorsHeaders(logger, exception); + } } } diff --git a/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs b/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs index d74c020eae..44c6d5c1be 100644 --- a/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs +++ b/test/Microsoft.AspNetCore.Cors.Test/CorsMiddlewareTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Linq; using System.Net; using System.Threading.Tasks; @@ -9,6 +10,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Moq; using Xunit; @@ -246,6 +248,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure // Arrange var corsService = Mock.Of(); var mockProvider = new Mock(); + var loggerFactory = Mock.Of(); mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(null)) .Verifiable(); @@ -254,6 +257,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure Mock.Of(), corsService, mockProvider.Object, + loggerFactory, policyName: null); var httpContext = new DefaultHttpContext(); @@ -274,6 +278,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure // Arrange var corsService = Mock.Of(); var mockProvider = new Mock(); + var loggerFactory = Mock.Of(); mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny(), It.IsAny())) .Returns(Task.FromResult(null)) .Verifiable(); @@ -282,6 +287,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure Mock.Of(), corsService, mockProvider.Object, + loggerFactory, policyName: null); var httpContext = new DefaultHttpContext(); @@ -349,5 +355,121 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure Assert.Equal("PUT", response.Headers.GetValues(CorsConstants.AccessControlAllowMethods).FirstOrDefault()); } } + + [Fact] + public async Task CorsRequest_SetsResponseHeaders() + { + // Arrange + var hostBuilder = new WebHostBuilder() + .Configure(app => + { + app.UseCors(builder => + builder.WithOrigins("http://localhost:5001") + .WithMethods("PUT") + .WithHeaders("Header1") + .WithExposedHeaders("AllowedHeader")); + app.Run(async context => + { + context.Response.Headers.Add("Test", "Should-Appear"); + await context.Response.WriteAsync("Cross origin response"); + }); + }) + .ConfigureServices(services => services.AddCors()); + + using (var server = new TestServer(hostBuilder)) + { + // Act + // Actual request. + var response = await server.CreateRequest("/") + .AddHeader(CorsConstants.Origin, "http://localhost:5001") + .SendAsync("PUT"); + + // Assert + response.EnsureSuccessStatusCode(); + Assert.Collection( + response.Headers.OrderBy(o => o.Key), + kvp => + { + Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key); + Assert.Equal("http://localhost:5001", Assert.Single(kvp.Value)); + }, + kvp => + { + Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key); + Assert.Equal("AllowedHeader", Assert.Single(kvp.Value)); + }, + kvp => + { + Assert.Equal("Test", kvp.Key); + Assert.Equal("Should-Appear", Assert.Single(kvp.Value)); + }); + + Assert.Equal("Cross origin response", await response.Content.ReadAsStringAsync()); + } + } + + [Fact] + public async Task CorsRequest_SetsResponseHeader_IfExceptionHandlerClearsResponse() + { + // Arrange + var exceptionSeen = true; + var hostBuilder = new WebHostBuilder() + .Configure(app => + { + // Simulate ExceptionHandler middleware + app.Use(async (context, next) => + { + try + { + await next(); + } + catch (Exception) + { + exceptionSeen = true; + context.Response.Clear(); + context.Response.StatusCode = 500; + } + }); + + app.UseCors(builder => + builder.WithOrigins("http://localhost:5001") + .WithMethods("PUT") + .WithHeaders("Header1") + .WithExposedHeaders("AllowedHeader")); + + app.Run(context => + { + context.Response.Headers.Add("Test", "Should-Not-Exist"); + throw new Exception("Runtime error"); + }); + }) + .ConfigureServices(services => services.AddCors()); + + using (var server = new TestServer(hostBuilder)) + { + // Act + // Actual request. + var response = await server.CreateRequest("/") + .AddHeader(CorsConstants.Origin, "http://localhost:5001") + .SendAsync("PUT"); + + // Assert + Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode); + Assert.True(exceptionSeen, "We expect exception middleware to have executed"); + + Assert.Collection( + response.Headers.OrderBy(o => o.Key), + kvp => + { + Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key); + Assert.Equal("http://localhost:5001", Assert.Single(kvp.Value)); + }, + kvp => + { + Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key); + Assert.Equal("AllowedHeader", Assert.Single(kvp.Value)); + }); + } + } } -} \ No newline at end of file +}