No CORS headers sent if Exception is thrown

- Normally headers are added however if a controller throws an exception then CORS headers not be present.

Addresses aspnet/Home#3220
This commit is contained in:
Daniel Little 2018-06-13 08:03:03 +10:00 committed by Pranav K
parent 32ad46006e
commit 554855cab3
3 changed files with 232 additions and 9 deletions

View File

@ -3,7 +3,10 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Cors.Internal;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.Cors.Infrastructure
@ -18,6 +21,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
private readonly ICorsPolicyProvider _corsPolicyProvider;
private readonly CorsPolicy _policy;
private readonly string _corsPolicyName;
private readonly ILogger _logger;
/// <summary>
/// Instantiates a new <see cref="CorsMiddleware"/>.
@ -25,11 +29,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policyProvider">A policy provider which can get an <see cref="CorsPolicy"/>.</param>
[Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")]
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
ICorsPolicyProvider policyProvider)
: this(next, corsService, policyProvider, policyName: null)
: this(next, corsService, policyProvider, NullLoggerFactory.Instance, policyName: null)
{
}
@ -40,11 +45,61 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policyProvider">A policy provider which can get an <see cref="CorsPolicy"/>.</param>
/// <param name="policyName">An optional name of the policy to be fetched.</param>
[Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")]
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
ICorsPolicyProvider policyProvider,
string policyName)
: this(next, corsService, policyProvider, NullLoggerFactory.Instance, policyName)
{
}
/// <summary>
/// Instantiates a new <see cref="CorsMiddleware"/>.
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policy">An instance of the <see cref="CorsPolicy"/> which can be applied.</param>
[Obsolete("This constructor has been replaced with an equivalent constructor which requires an ILoggerFactory")]
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
CorsPolicy policy)
: this(next, corsService, policy, NullLoggerFactory.Instance)
{
}
/// <summary>
/// Instantiates a new <see cref="CorsMiddleware"/>.
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policyProvider">A policy provider which can get an <see cref="CorsPolicy"/>.</param>
/// <param name="loggerFactory">An instance of <see cref="ILoggerFactory"/>.</param>
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
ICorsPolicyProvider policyProvider,
ILoggerFactory loggerFactory)
: this(next, corsService, policyProvider, loggerFactory, policyName: null)
{
}
/// <summary>
/// Instantiates a new <see cref="CorsMiddleware"/>.
/// </summary>
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policyProvider">A policy provider which can get an <see cref="CorsPolicy"/>.</param>
/// <param name="loggerFactory">An instance of <see cref="ILoggerFactory"/>.</param>
/// <param name="policyName">An optional name of the policy to be fetched.</param>
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
ICorsPolicyProvider policyProvider,
ILoggerFactory loggerFactory,
string policyName)
{
if (next == null)
{
@ -61,10 +116,16 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(policyProvider));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
}
_next = next;
_corsService = corsService;
_corsPolicyProvider = policyProvider;
_corsPolicyName = policyName;
_logger = loggerFactory.CreateLogger<CorsMiddleware>();
}
/// <summary>
@ -73,10 +134,12 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
/// <param name="next">The next middleware in the pipeline.</param>
/// <param name="corsService">An instance of <see cref="ICorsService"/>.</param>
/// <param name="policy">An instance of the <see cref="CorsPolicy"/> which can be applied.</param>
/// <param name="loggerFactory">An instance of <see cref="ILoggerFactory"/>.</param>
public CorsMiddleware(
RequestDelegate next,
ICorsService corsService,
CorsPolicy policy)
RequestDelegate next,
ICorsService corsService,
CorsPolicy policy,
ILoggerFactory loggerFactory)
{
if (next == null)
{
@ -93,9 +156,15 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
throw new ArgumentNullException(nameof(policy));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
}
_next = next;
_corsService = corsService;
_policy = policy;
_logger = loggerFactory.CreateLogger<CorsMiddleware>();
}
/// <inheritdoc />
@ -106,9 +175,6 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
var corsPolicy = _policy ?? await _corsPolicyProvider?.GetPolicyAsync(context, _corsPolicyName);
if (corsPolicy != null)
{
var corsResult = _corsService.EvaluatePolicy(context, corsPolicy);
_corsService.ApplyResult(corsResult, context.Response);
var accessControlRequestMethod =
context.Request.Headers[CorsConstants.AccessControlRequestMethod];
if (string.Equals(
@ -117,15 +183,39 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
StringComparison.OrdinalIgnoreCase) &&
!StringValues.IsNullOrEmpty(accessControlRequestMethod))
{
ApplyCorsHeaders(context, corsPolicy);
// Since there is a policy which was identified,
// always respond to preflight requests.
context.Response.StatusCode = StatusCodes.Status204NoContent;
return;
}
else
{
context.Response.OnStarting(state =>
{
var (httpContext, policy) = (Tuple<HttpContext, CorsPolicy>)state;
try
{
ApplyCorsHeaders(httpContext, policy);
}
catch (Exception exception)
{
_logger.FailedToSetCorsHeaders(exception);
}
return Task.CompletedTask;
}, Tuple.Create(context, corsPolicy));
}
}
}
await _next(context);
}
private void ApplyCorsHeaders(HttpContext context, CorsPolicy corsPolicy)
{
var corsResult = _corsService.EvaluatePolicy(context, corsPolicy);
_corsService.ApplyResult(corsResult, context.Response);
}
}
}
}

View File

@ -16,6 +16,7 @@ namespace Microsoft.AspNetCore.Cors.Internal
private static readonly Action<ILogger, string, Exception> _originNotAllowed;
private static readonly Action<ILogger, string, Exception> _accessControlMethodNotAllowed;
private static readonly Action<ILogger, string, Exception> _requestHeaderNotAllowed;
private static readonly Action<ILogger, Exception> _failedToSetCorsHeaders;
static CORSLoggerExtensions()
{
@ -58,6 +59,11 @@ namespace Microsoft.AspNetCore.Cors.Internal
LogLevel.Information,
8,
"Request header '{requestHeader}' not allowed in CORS policy.");
_failedToSetCorsHeaders = LoggerMessage.Define(
LogLevel.Warning,
9,
"Failed to apply CORS Response headers.");
}
public static void IsPreflightRequest(this ILogger logger)
@ -99,5 +105,10 @@ namespace Microsoft.AspNetCore.Cors.Internal
{
_requestHeaderNotAllowed(logger, requestHeader, null);
}
public static void FailedToSetCorsHeaders(this ILogger logger, Exception exception)
{
_failedToSetCorsHeaders(logger, exception);
}
}
}

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Net;
using System.Threading.Tasks;
@ -9,6 +10,7 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Moq;
using Xunit;
@ -246,6 +248,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
// Arrange
var corsService = Mock.Of<ICorsService>();
var mockProvider = new Mock<ICorsPolicyProvider>();
var loggerFactory = Mock.Of<ILoggerFactory>();
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
.Returns(Task.FromResult<CorsPolicy>(null))
.Verifiable();
@ -254,6 +257,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
Mock.Of<RequestDelegate>(),
corsService,
mockProvider.Object,
loggerFactory,
policyName: null);
var httpContext = new DefaultHttpContext();
@ -274,6 +278,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
// Arrange
var corsService = Mock.Of<ICorsService>();
var mockProvider = new Mock<ICorsPolicyProvider>();
var loggerFactory = Mock.Of<ILoggerFactory>();
mockProvider.Setup(o => o.GetPolicyAsync(It.IsAny<HttpContext>(), It.IsAny<string>()))
.Returns(Task.FromResult<CorsPolicy>(null))
.Verifiable();
@ -282,6 +287,7 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
Mock.Of<RequestDelegate>(),
corsService,
mockProvider.Object,
loggerFactory,
policyName: null);
var httpContext = new DefaultHttpContext();
@ -349,5 +355,121 @@ namespace Microsoft.AspNetCore.Cors.Infrastructure
Assert.Equal("PUT", response.Headers.GetValues(CorsConstants.AccessControlAllowMethods).FirstOrDefault());
}
}
[Fact]
public async Task CorsRequest_SetsResponseHeaders()
{
// Arrange
var hostBuilder = new WebHostBuilder()
.Configure(app =>
{
app.UseCors(builder =>
builder.WithOrigins("http://localhost:5001")
.WithMethods("PUT")
.WithHeaders("Header1")
.WithExposedHeaders("AllowedHeader"));
app.Run(async context =>
{
context.Response.Headers.Add("Test", "Should-Appear");
await context.Response.WriteAsync("Cross origin response");
});
})
.ConfigureServices(services => services.AddCors());
using (var server = new TestServer(hostBuilder))
{
// Act
// Actual request.
var response = await server.CreateRequest("/")
.AddHeader(CorsConstants.Origin, "http://localhost:5001")
.SendAsync("PUT");
// Assert
response.EnsureSuccessStatusCode();
Assert.Collection(
response.Headers.OrderBy(o => o.Key),
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
Assert.Equal("http://localhost:5001", Assert.Single(kvp.Value));
},
kvp =>
{
Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key);
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
},
kvp =>
{
Assert.Equal("Test", kvp.Key);
Assert.Equal("Should-Appear", Assert.Single(kvp.Value));
});
Assert.Equal("Cross origin response", await response.Content.ReadAsStringAsync());
}
}
[Fact]
public async Task CorsRequest_SetsResponseHeader_IfExceptionHandlerClearsResponse()
{
// Arrange
var exceptionSeen = true;
var hostBuilder = new WebHostBuilder()
.Configure(app =>
{
// Simulate ExceptionHandler middleware
app.Use(async (context, next) =>
{
try
{
await next();
}
catch (Exception)
{
exceptionSeen = true;
context.Response.Clear();
context.Response.StatusCode = 500;
}
});
app.UseCors(builder =>
builder.WithOrigins("http://localhost:5001")
.WithMethods("PUT")
.WithHeaders("Header1")
.WithExposedHeaders("AllowedHeader"));
app.Run(context =>
{
context.Response.Headers.Add("Test", "Should-Not-Exist");
throw new Exception("Runtime error");
});
})
.ConfigureServices(services => services.AddCors());
using (var server = new TestServer(hostBuilder))
{
// Act
// Actual request.
var response = await server.CreateRequest("/")
.AddHeader(CorsConstants.Origin, "http://localhost:5001")
.SendAsync("PUT");
// Assert
Assert.Equal(HttpStatusCode.InternalServerError, response.StatusCode);
Assert.True(exceptionSeen, "We expect exception middleware to have executed");
Assert.Collection(
response.Headers.OrderBy(o => o.Key),
kvp =>
{
Assert.Equal(CorsConstants.AccessControlAllowOrigin, kvp.Key);
Assert.Equal("http://localhost:5001", Assert.Single(kvp.Value));
},
kvp =>
{
Assert.Equal(CorsConstants.AccessControlExposeHeaders, kvp.Key);
Assert.Equal("AllowedHeader", Assert.Single(kvp.Value));
});
}
}
}
}
}