diff --git a/src/Microsoft.AspNetCore.Routing/EndpointRoutingMiddleware.cs b/src/Microsoft.AspNetCore.Routing/EndpointRoutingMiddleware.cs index 0b82dc3438..60efa46586 100644 --- a/src/Microsoft.AspNetCore.Routing/EndpointRoutingMiddleware.cs +++ b/src/Microsoft.AspNetCore.Routing/EndpointRoutingMiddleware.cs @@ -95,23 +95,51 @@ namespace Microsoft.AspNetCore.Routing // blocking operation if you have a low core count and enough work to do. private Task InitializeAsync() { - if (_initializationTask != null) + var initializationTask = _initializationTask; + if (initializationTask != null) { - return _initializationTask; + return initializationTask; } - var initializationTask = new TaskCompletionSource(); - if (Interlocked.CompareExchange>( - ref _initializationTask, - initializationTask.Task, - null) == null) + return InitializeCoreAsync(); + } + + private Task InitializeCoreAsync() + { + var initialization = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var initializationTask = Interlocked.CompareExchange(ref _initializationTask, initialization.Task, null); + if (initializationTask != null) + { + // This thread lost the race, join the existing task. + return initializationTask; + } + + // This thread won the race, do the initialization. + try { - // This thread won the race, do the initialization. var matcher = _matcherFactory.CreateMatcher(_endpointDataSource); - initializationTask.SetResult(matcher); - } - return _initializationTask; + // Now replace the initialization task with one created with the default execution context. + // This is important because capturing the execution context will leak memory in ASP.NET Core. + using (ExecutionContext.SuppressFlow()) + { + _initializationTask = Task.FromResult(matcher); + } + + // Complete the task, this will unblock any requests that came in while initializing. + initialization.SetResult(matcher); + return initialization.Task; + } + catch (Exception ex) + { + // Allow initialization to occur again. Since DataSources can change, it's possible + // for the developer to correct the data causing the failure. + _initializationTask = null; + + // Complete the task, this will throw for any requests that came in while initializing. + initialization.SetException(ex); + return initialization.Task; + } } private static class Log diff --git a/test/Microsoft.AspNetCore.Routing.Tests/EndpointRoutingMiddlewareTest.cs b/test/Microsoft.AspNetCore.Routing.Tests/EndpointRoutingMiddlewareTest.cs index dad2fdf0ff..48861da09b 100644 --- a/test/Microsoft.AspNetCore.Routing.Tests/EndpointRoutingMiddlewareTest.cs +++ b/test/Microsoft.AspNetCore.Routing.Tests/EndpointRoutingMiddlewareTest.cs @@ -2,14 +2,17 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Routing.Matching; using Microsoft.AspNetCore.Routing.TestObjects; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Testing; using Microsoft.Extensions.Options; +using Moq; using Xunit; namespace Microsoft.AspNetCore.Routing @@ -103,6 +106,29 @@ namespace Microsoft.AspNetCore.Routing Assert.Equal("testValue", routeValuesFeature.RouteValues["testKey"]); } + [Fact] + public async Task Invoke_InitializationFailure_AllowsReinitialization() + { + // Arrange + var httpContext = CreateHttpContext(); + + var matcherFactory = new Mock(); + matcherFactory + .Setup(f => f.CreateMatcher(It.IsAny())) + .Throws(new InvalidTimeZoneException()) + .Verifiable(); + + var middleware = CreateMiddleware(matcherFactory: matcherFactory.Object); + + // Act + await Assert.ThrowsAsync(async () => await middleware.Invoke(httpContext)); + await Assert.ThrowsAsync(async () => await middleware.Invoke(httpContext)); + + // Assert + matcherFactory + .Verify(f => f.CreateMatcher(It.IsAny()), Times.Exactly(2)); + } + private HttpContext CreateHttpContext() { var context = new EndpointSelectorContext(); @@ -116,14 +142,16 @@ namespace Microsoft.AspNetCore.Routing return httpContext; } - private EndpointRoutingMiddleware CreateMiddleware(Logger logger = null) + private EndpointRoutingMiddleware CreateMiddleware( + Logger logger = null, + MatcherFactory matcherFactory = null) { RequestDelegate next = (c) => Task.FromResult(null); logger = logger ?? new Logger(NullLoggerFactory.Instance); + matcherFactory = matcherFactory ?? new TestMatcherFactory(true); var options = Options.Create(new EndpointOptions()); - var matcherFactory = new TestMatcherFactory(true); var middleware = new EndpointRoutingMiddleware( matcherFactory, new CompositeEndpointDataSource(Array.Empty()),