diff --git a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs index 112685a9e5..5d280b5f43 100644 --- a/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs +++ b/src/Middleware/Diagnostics/src/ExceptionHandler/ExceptionHandlerMiddleware.cs @@ -7,6 +7,7 @@ using System.Runtime.ExceptionServices; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Microsoft.Net.Http.Headers; @@ -103,7 +104,8 @@ namespace Microsoft.AspNetCore.Diagnostics } try { - context.Response.Clear(); + ClearHttpContext(context); + var exceptionHandlerFeature = new ExceptionHandlerFeature() { Error = edi.SourceException, @@ -137,6 +139,17 @@ namespace Microsoft.AspNetCore.Diagnostics edi.Throw(); // Re-throw the original if we couldn't handle it } + private static void ClearHttpContext(HttpContext context) + { + context.Response.Clear(); + + // An endpoint may have already been set. Since we're going to re-invoke the middleware pipeline we need to reset + // the endpoint and route values to ensure things are re-calculated. + context.SetEndpoint(endpoint: null); + var routeValuesFeature = context.Features.Get(); + routeValuesFeature?.RouteValues?.Clear(); + } + private static Task ClearCacheHeaders(object state) { var headers = ((HttpResponse)state).Headers; diff --git a/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs new file mode 100644 index 0000000000..6d08be5d3f --- /dev/null +++ b/src/Middleware/Diagnostics/test/UnitTests/ExceptionHandlerMiddlewareTest.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Diagnostics +{ + public class ExceptionHandlerMiddlewareTest + { + [Fact] + public async Task Invoke_ExceptionThrownResultsInClearedRouteValuesAndEndpoint() + { + // Arrange + var httpContext = CreateHttpContext(); + httpContext.SetEndpoint(new Endpoint((_) => Task.CompletedTask, new EndpointMetadataCollection(), "Test")); + httpContext.Request.RouteValues["John"] = "Doe"; + + var optionsAccessor = CreateOptionsAccessor( + exceptionHandler: context => + { + Assert.Empty(context.Request.RouteValues); + Assert.Null(context.GetEndpoint()); + return Task.CompletedTask; + }); + var middleware = CreateMiddleware(_ => throw new InvalidOperationException(), optionsAccessor); + + // Act & Assert + await middleware.Invoke(httpContext); + } + + private HttpContext CreateHttpContext() + { + var httpContext = new DefaultHttpContext + { + RequestServices = new TestServiceProvider() + }; + + return httpContext; + } + + private IOptions CreateOptionsAccessor( + RequestDelegate exceptionHandler = null, + string exceptionHandlingPath = null) + { + exceptionHandler ??= c => Task.CompletedTask; + var options = new ExceptionHandlerOptions() + { + ExceptionHandler = exceptionHandler, + ExceptionHandlingPath = exceptionHandlingPath, + }; + var optionsAccessor = Mock.Of>(o => o.Value == options); + return optionsAccessor; + } + + private ExceptionHandlerMiddleware CreateMiddleware( + RequestDelegate next, + IOptions options) + { + next ??= c => Task.CompletedTask; + var listener = new DiagnosticListener("Microsoft.AspNetCore"); + + var middleware = new ExceptionHandlerMiddleware( + next, + NullLoggerFactory.Instance, + options, + listener); + + return middleware; + } + + private class TestServiceProvider : IServiceProvider + { + public object GetService(Type serviceType) + { + throw new NotImplementedException(); + } + } + } +}