diff --git a/src/Http/Routing/src/EndpointMiddleware.cs b/src/Http/Routing/src/EndpointMiddleware.cs index 00b36cef50..97e56a04e7 100644 --- a/src/Http/Routing/src/EndpointMiddleware.cs +++ b/src/Http/Routing/src/EndpointMiddleware.cs @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Routing _routeOptions = routeOptions?.Value ?? throw new ArgumentNullException(nameof(routeOptions)); } - public async Task Invoke(HttpContext httpContext) + public Task Invoke(HttpContext httpContext) { var endpoint = httpContext.Features.Get()?.Endpoint; if (endpoint?.RequestDelegate != null) @@ -39,7 +39,7 @@ namespace Microsoft.AspNetCore.Routing if (_routeOptions.SuppressCheckForUnhandledSecurityMetadata) { // User opted out of this check. - return; + return Task.CompletedTask; } if (endpoint.Metadata.GetMetadata() != null && @@ -58,17 +58,35 @@ namespace Microsoft.AspNetCore.Routing try { - await endpoint.RequestDelegate(httpContext); + var requestTask = endpoint.RequestDelegate(httpContext); + if (!requestTask.IsCompletedSuccessfully) + { + return AwaitRequestTask(endpoint, requestTask, _logger); + } + } + catch (Exception exception) + { + Log.ExecutedEndpoint(_logger, endpoint); + return Task.FromException(exception); + } + + Log.ExecutedEndpoint(_logger, endpoint); + return Task.CompletedTask; + } + + return _next(httpContext); + + static async Task AwaitRequestTask(Endpoint endpoint, Task requestTask, ILogger logger) + { + try + { + await requestTask; } finally { - Log.ExecutedEndpoint(_logger, endpoint); + Log.ExecutedEndpoint(logger, endpoint); } - - return; } - - await _next(httpContext); } private static void ThrowMissingAuthMiddlewareException(Endpoint endpoint) diff --git a/src/Http/Routing/src/EndpointRoutingMiddleware.cs b/src/Http/Routing/src/EndpointRoutingMiddleware.cs index 162ad8a21b..02b089dbca 100644 --- a/src/Http/Routing/src/EndpointRoutingMiddleware.cs +++ b/src/Http/Routing/src/EndpointRoutingMiddleware.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; @@ -53,15 +54,45 @@ namespace Microsoft.AspNetCore.Routing _endpointDataSource = new CompositeEndpointDataSource(endpointRouteBuilder.DataSources); } - public async Task Invoke(HttpContext httpContext) + public Task Invoke(HttpContext httpContext) { var feature = new EndpointSelectorContext(); // There's an inherent race condition between waiting for init and accessing the matcher // this is OK because once `_matcher` is initialized, it will not be set to null again. - var matcher = await InitializeAsync(); + var matcherTask = InitializeAsync(); + if (!matcherTask.IsCompletedSuccessfully) + { + return AwaitMatcher(this, httpContext, feature, matcherTask); + } - await matcher.MatchAsync(httpContext, feature); + var matchTask = matcherTask.Result.MatchAsync(httpContext, feature); + if (!matchTask.IsCompletedSuccessfully) + { + return AwaitMatch(this, httpContext, feature, matchTask); + } + + return SetRoutingAndContinue(httpContext, feature); + + // Awaited fallbacks for when the Tasks do not synchronously complete + static async Task AwaitMatcher(EndpointRoutingMiddleware middleware, HttpContext httpContext, EndpointSelectorContext feature, Task matcherTask) + { + var matcher = await matcherTask; + await matcher.MatchAsync(httpContext, feature); + await middleware.SetRoutingAndContinue(httpContext, feature); + } + + static async Task AwaitMatch(EndpointRoutingMiddleware middleware, HttpContext httpContext, EndpointSelectorContext feature, Task matchTask) + { + await matchTask; + await middleware.SetRoutingAndContinue(httpContext, feature); + } + + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private Task SetRoutingAndContinue(HttpContext httpContext, EndpointSelectorContext feature) + { if (feature.Endpoint != null) { // Set the endpoint feature only on success. This means we won't overwrite any @@ -75,7 +106,7 @@ namespace Microsoft.AspNetCore.Routing Log.MatchFailure(_logger); } - await _next(httpContext); + return _next(httpContext); } private static void SetFeatures(HttpContext httpContext, EndpointSelectorContext context)