diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcEndpointDataSource.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcEndpointDataSource.cs index 2d181e9754..a76dbe8c93 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcEndpointDataSource.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MvcEndpointDataSource.cs @@ -19,15 +19,13 @@ namespace Microsoft.AspNetCore.Mvc.Internal { internal class MvcEndpointDataSource : EndpointDataSource { + private readonly object _lock = new object(); private readonly IActionDescriptorCollectionProvider _actions; private readonly MvcEndpointInvokerFactory _invokerFactory; private readonly IServiceProvider _serviceProvider; private readonly IActionDescriptorChangeProvider[] _actionDescriptorChangeProviders; - private readonly List _endpoints; - private readonly object _lock = new object(); - private IChangeToken _changeToken; - private bool _initialized; + private List _endpoints; public MvcEndpointDataSource( IActionDescriptorCollectionProvider actions, @@ -60,12 +58,17 @@ namespace Microsoft.AspNetCore.Mvc.Internal _serviceProvider = serviceProvider; _actionDescriptorChangeProviders = actionDescriptorChangeProviders.ToArray(); - _endpoints = new List(); ConventionalEndpointInfos = new List(); + + Extensions.Primitives.ChangeToken.OnChange( + GetCompositeChangeToken, + UpdateEndpoints); } - private void InitializeEndpoints() + private List CreateEndpoints() { + List endpoints = new List(); + foreach (var action in _actions.ActionDescriptors.Items) { if (action.AttributeRouteInfo == null) @@ -117,7 +120,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal endpointInfo.Defaults, ++conventionalRouteOrder, endpointInfo); - _endpoints.Add(subEndpoint); + endpoints.Add(subEndpoint); } var segment = newEndpointTemplate.Segments[i]; @@ -142,7 +145,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal endpointInfo.Defaults, ++conventionalRouteOrder, endpointInfo); - _endpoints.Add(endpoint); + endpoints.Add(endpoint); } } } @@ -155,9 +158,11 @@ namespace Microsoft.AspNetCore.Mvc.Internal nonInlineDefaults: null, action.AttributeRouteInfo.Order, action.AttributeRouteInfo); - _endpoints.Add(endpoint); + endpoints.Add(endpoint); } } + + return endpoints; } private bool IsMvcParameter(string name) @@ -392,36 +397,36 @@ namespace Microsoft.AspNetCore.Mvc.Internal return new CompositeChangeToken(changeTokens); } - public override IChangeToken ChangeToken - { - get - { - if (_changeToken == null) - { - _changeToken = GetCompositeChangeToken(); - } - - return _changeToken; - } - } + public override IChangeToken ChangeToken => GetCompositeChangeToken(); public override IReadOnlyList Endpoints { get { - if (!_initialized) + // Want to initialize endpoints once and then cache while ensuring a null collection is never returned + // Local copy for thread safety + double check locking + var localEndpoints = _endpoints; + if (localEndpoints == null) { lock (_lock) { - if (!_initialized) + localEndpoints = _endpoints; + if (localEndpoints == null) { - InitializeEndpoints(); - _initialized = true; + _endpoints = localEndpoints = CreateEndpoints(); } } } - return _endpoints; + return localEndpoints; + } + } + + private void UpdateEndpoints() + { + lock (_lock) + { + _endpoints = CreateEndpoints(); } } diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MvcEndpointDataSourceTests.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MvcEndpointDataSourceTests.cs index 635169ee32..983af123c6 100644 --- a/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MvcEndpointDataSourceTests.cs +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MvcEndpointDataSourceTests.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Http; @@ -183,7 +184,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal [InlineData("{controller}/{action=TestAction}/{id?}/{*catchAll}", new[] { "TestController", "TestController/TestAction/{id?}/{*catchAll}" })] //[InlineData("{controller}/{action}.{ext?}", new[] { "TestController/TestAction.{ext?}" })] //[InlineData("{controller}/{action=TestAction}.{ext?}", new[] { "TestController", "TestController/TestAction.{ext?}" })] - public void InitializeEndpoints_SingleAction(string endpointInfoRoute, string[] finalEndpointTemplates) + public void Endpoints_SingleAction(string endpointInfoRoute, string[] finalEndpointTemplates) { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -210,7 +211,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal [InlineData("{area=TestArea}/{controller}/{action=TestAction}/{id?}", new[] { "TestArea/TestController", "TestArea/TestController/TestAction/{id?}" })] [InlineData("{area=TestArea}/{controller=TestController}/{action=TestAction}/{id?}", new[] { "", "TestArea", "TestArea/TestController", "TestArea/TestController/TestAction/{id?}" })] [InlineData("{area:exists}/{controller}/{action}/{id?}", new[] { "TestArea/TestController/TestAction/{id?}" })] - public void InitializeEndpoints_AreaSingleAction(string endpointInfoRoute, string[] finalEndpointTemplates) + public void Endpoints_AreaSingleAction(string endpointInfoRoute, string[] finalEndpointTemplates) { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -231,7 +232,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal } [Fact] - public void InitializeEndpoints_SingleAction_WithActionDefault() + public void Endpoints_SingleAction_WithActionDefault() { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -252,7 +253,93 @@ namespace Microsoft.AspNetCore.Mvc.Internal } [Fact] - public void InitializeEndpoints_MultipleActions_WithActionConstraint() + public void Endpoints_CalledMultipleTimes_ReturnsSameInstance() + { + // Arrange + var actionDescriptorCollectionProviderMock = new Mock(); + actionDescriptorCollectionProviderMock + .Setup(m => m.ActionDescriptors) + .Returns(new ActionDescriptorCollection(new[] + { + CreateActionDescriptor(new { controller = "TestController", action = "TestAction" }) + }, version: 0)); + + var dataSource = CreateMvcEndpointDataSource(actionDescriptorCollectionProviderMock.Object); + dataSource.ConventionalEndpointInfos.Add(CreateEndpointInfo( + string.Empty, + "{controller}/{action}", + new RouteValueDictionary(new { action = "TestAction" }))); + + // Act + var endpoints1 = dataSource.Endpoints; + var endpoints2 = dataSource.Endpoints; + + // Assert + Assert.Collection(endpoints1, + (e) => Assert.Equal("TestController", Assert.IsType(e).Template), + (e) => Assert.Equal("TestController/TestAction", Assert.IsType(e).Template)); + Assert.Same(endpoints1, endpoints2); + + actionDescriptorCollectionProviderMock.VerifyGet(m => m.ActionDescriptors, Times.Once); + } + + [Fact] + public void Endpoints_ChangeTokenTriggered_EndpointsRecreated() + { + // Arrange + var actionDescriptorCollectionProviderMock = new Mock(); + actionDescriptorCollectionProviderMock + .Setup(m => m.ActionDescriptors) + .Returns(new ActionDescriptorCollection(new[] + { + CreateActionDescriptor(new { controller = "TestController", action = "TestAction" }) + }, version: 0)); + + CancellationTokenSource cts = null; + + var changeProviderMock = new Mock(); + changeProviderMock.Setup(m => m.GetChangeToken()).Returns(() => + { + cts = new CancellationTokenSource(); + var changeToken = new CancellationChangeToken(cts.Token); + + return changeToken; + }); + + var dataSource = CreateMvcEndpointDataSource( + actionDescriptorCollectionProviderMock.Object, + actionDescriptorChangeProviders: new[] { changeProviderMock.Object }); + dataSource.ConventionalEndpointInfos.Add(CreateEndpointInfo( + string.Empty, + "{controller}/{action}", + new RouteValueDictionary(new { action = "TestAction" }))); + + // Act + var endpoints = dataSource.Endpoints; + + Assert.Collection(endpoints, + (e) => Assert.Equal("TestController", Assert.IsType(e).Template), + (e) => Assert.Equal("TestController/TestAction", Assert.IsType(e).Template)); + + actionDescriptorCollectionProviderMock + .Setup(m => m.ActionDescriptors) + .Returns(new ActionDescriptorCollection(new[] + { + CreateActionDescriptor(new { controller = "NewTestController", action = "NewTestAction" }) + }, version: 1)); + + cts.Cancel(); + + // Assert + var newEndpoints = dataSource.Endpoints; + + Assert.NotSame(endpoints, newEndpoints); + Assert.Collection(newEndpoints, + (e) => Assert.Equal("NewTestController/NewTestAction", Assert.IsType(e).Template)); + } + + [Fact] + public void Endpoints_MultipleActions_WithActionConstraint() { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -277,7 +364,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal [Theory] [InlineData("{controller}/{action}", new[] { "TestController1/TestAction1", "TestController1/TestAction2", "TestController1/TestAction3", "TestController2/TestAction1" })] [InlineData("{controller}/{action:regex((TestAction1|TestAction2))}", new[] { "TestController1/TestAction1", "TestController1/TestAction2", "TestController2/TestAction1" })] - public void InitializeEndpoints_MultipleActions(string endpointInfoRoute, string[] finalEndpointTemplates) + public void Endpoints_MultipleActions(string endpointInfoRoute, string[] finalEndpointTemplates) { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -302,7 +389,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal } [Fact] - public void ConventionalRoute_WithNoRouteName_DoesNotAddRouteNameMetadata() + public void Endpoints_ConventionalRoute_WithNoRouteName_DoesNotAddRouteNameMetadata() { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -322,7 +409,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal } [Fact] - public void CanCreateMultipleEndpoints_WithSameRouteName() + public void Endpoints_CanCreateMultipleEndpoints_WithSameRouteName() { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -357,7 +444,7 @@ namespace Microsoft.AspNetCore.Mvc.Internal } [Fact] - public void InitializeEndpoints_ConventionalRoutes_StaticallyDefinedOrder_IsMaintained() + public void Endpoints_ConventionalRoutes_StaticallyDefinedOrder_IsMaintained() { // Arrange var actionDescriptorCollection = GetActionDescriptorCollection( @@ -619,11 +706,11 @@ namespace Microsoft.AspNetCore.Mvc.Internal actionDescriptors.Add(CreateActionDescriptor(requiredValue)); } - var actionDescriptorCollectionProvider = new Mock(); - actionDescriptorCollectionProvider + var actionDescriptorCollectionProviderMock = new Mock(); + actionDescriptorCollectionProviderMock .Setup(m => m.ActionDescriptors) .Returns(new ActionDescriptorCollection(actionDescriptors, version: 0)); - return actionDescriptorCollectionProvider.Object; + return actionDescriptorCollectionProviderMock.Object; } private ActionDescriptor CreateActionDescriptor(string controller, string action, string area = null) diff --git a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesDispatchingTest.cs b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesDispatchingTest.cs new file mode 100644 index 0000000000..f725d9ed92 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesDispatchingTest.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Mvc.FunctionalTests +{ + public class RequestServicesDispatchingTest : RequestServicesTestBase + { + public RequestServicesDispatchingTest(MvcTestFixture fixture) + : base(fixture) + { + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTest.cs b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTest.cs index 29fd69112c..f62521bd6c 100644 --- a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTest.cs @@ -1,93 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Net; -using System.Net.Http; -using System.Threading.Tasks; -using Xunit; - namespace Microsoft.AspNetCore.Mvc.FunctionalTests { - // Each of these tests makes two requests, because we want each test to verify that the data is - // PER-REQUEST and does not linger around to impact the next request. - public class RequestServicesTest : IClassFixture> + public class RequestServicesTest : RequestServicesTestBase { public RequestServicesTest(MvcTestFixture fixture) + : base(fixture) { - Client = fixture.CreateDefaultClient(); - } - - public HttpClient Client { get; } - - [Theory] - [InlineData("http://localhost/RequestScopedService/FromFilter")] - [InlineData("http://localhost/RequestScopedService/FromView")] - [InlineData("http://localhost/RequestScopedService/FromViewComponent")] - [InlineData("http://localhost/RequestScopedService/FromActionArgument")] - public async Task RequestServices(string url) - { - for (var i = 0; i < 2; i++) - { - // Arrange - var requestId = Guid.NewGuid().ToString(); - var request = new HttpRequestMessage(HttpMethod.Get, url); - request.Headers.TryAddWithoutValidation("RequestId", requestId); - - // Act - var response = await Client.SendAsync(request); - - // Assert - response.EnsureSuccessStatusCode(); - var body = (await response.Content.ReadAsStringAsync()).Trim(); - Assert.Equal(requestId, body); - } - } - - [Fact] - public async Task RequestServices_TagHelper() - { - // Arrange - var url = "http://localhost/RequestScopedService/FromTagHelper"; - - // Act & Assert - for (var i = 0; i < 2; i++) - { - var requestId = Guid.NewGuid().ToString(); - var request = new HttpRequestMessage(HttpMethod.Get, url); - request.Headers.TryAddWithoutValidation("RequestId", requestId); - - var response = await Client.SendAsync(request); - - var body = (await response.Content.ReadAsStringAsync()).Trim(); - - var expected = "" + requestId + ""; - Assert.Equal(expected, body); - } - } - - [Fact] - public async Task RequestServices_ActionConstraint() - { - // Arrange - var url = "http://localhost/RequestScopedService/FromActionConstraint"; - - // Act & Assert - var requestId1 = "b40f6ec1-8a6b-41c1-b3fe-928f581ebaf5"; - var request1 = new HttpRequestMessage(HttpMethod.Get, url); - request1.Headers.TryAddWithoutValidation("RequestId", requestId1); - - var response1 = await Client.SendAsync(request1); - - var body1 = (await response1.Content.ReadAsStringAsync()).Trim(); - Assert.Equal(requestId1, body1); - - var requestId2 = Guid.NewGuid().ToString(); - var request2 = new HttpRequestMessage(HttpMethod.Get, url); - request2.Headers.TryAddWithoutValidation("RequestId", requestId2); - - var response2 = await Client.SendAsync(request2); - Assert.Equal(HttpStatusCode.NotFound, response2.StatusCode); } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTestBase.cs b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTestBase.cs new file mode 100644 index 0000000000..425bda5ea5 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/RequestServicesTestBase.cs @@ -0,0 +1,99 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Net; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Hosting; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.FunctionalTests +{ + // Each of these tests makes two requests, because we want each test to verify that the data is + // PER-REQUEST and does not linger around to impact the next request. + public abstract class RequestServicesTestBase : IClassFixture> where TStartup : class + { + protected RequestServicesTestBase(MvcTestFixture fixture) + { + var factory = fixture.Factories.FirstOrDefault() ?? fixture.WithWebHostBuilder(ConfigureWebHostBuilder); + Client = factory.CreateDefaultClient(); + } + + private static void ConfigureWebHostBuilder(IWebHostBuilder builder) => + builder.UseStartup(); + + public HttpClient Client { get; } + + [Theory] + [InlineData("http://localhost/RequestScopedService/FromFilter")] + [InlineData("http://localhost/RequestScopedService/FromView")] + [InlineData("http://localhost/RequestScopedService/FromViewComponent")] + [InlineData("http://localhost/RequestScopedService/FromActionArgument")] + public async Task RequestServices(string url) + { + for (var i = 0; i < 2; i++) + { + // Arrange + var requestId = Guid.NewGuid().ToString(); + var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.TryAddWithoutValidation("RequestId", requestId); + + // Act + var response = await Client.SendAsync(request); + + // Assert + response.EnsureSuccessStatusCode(); + var body = (await response.Content.ReadAsStringAsync()).Trim(); + Assert.Equal(requestId, body); + } + } + + [Fact] + public async Task RequestServices_TagHelper() + { + // Arrange + var url = "http://localhost/RequestScopedService/FromTagHelper"; + + // Act & Assert + for (var i = 0; i < 2; i++) + { + var requestId = Guid.NewGuid().ToString(); + var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.TryAddWithoutValidation("RequestId", requestId); + + var response = await Client.SendAsync(request); + + var body = (await response.Content.ReadAsStringAsync()).Trim(); + + var expected = "" + requestId + ""; + Assert.Equal(expected, body); + } + } + + [Fact] + public async Task RequestServices_Constraint() + { + // Arrange + var url = "http://localhost/RequestScopedService/FromConstraint"; + + // Act & Assert + var requestId1 = "b40f6ec1-8a6b-41c1-b3fe-928f581ebaf5"; + var request1 = new HttpRequestMessage(HttpMethod.Get, url); + request1.Headers.TryAddWithoutValidation("RequestId", requestId1); + + var response1 = await Client.SendAsync(request1); + + var body1 = (await response1.Content.ReadAsStringAsync()).Trim(); + Assert.Equal(requestId1, body1); + + var requestId2 = Guid.NewGuid().ToString(); + var request2 = new HttpRequestMessage(HttpMethod.Get, url); + request2.Headers.TryAddWithoutValidation("RequestId", requestId2); + + var response2 = await Client.SendAsync(request2); + Assert.Equal(HttpStatusCode.NotFound, response2.StatusCode); + } + } +} \ No newline at end of file diff --git a/test/WebSites/BasicWebSite/Controllers/RequestScopedServiceController.cs b/test/WebSites/BasicWebSite/Controllers/RequestScopedServiceController.cs index e7a89377b3..a7bd552217 100644 --- a/test/WebSites/BasicWebSite/Controllers/RequestScopedServiceController.cs +++ b/test/WebSites/BasicWebSite/Controllers/RequestScopedServiceController.cs @@ -10,8 +10,8 @@ namespace BasicWebSite { // This only matches a specific requestId value [HttpGet] - [RequestScopedActionConstraint("b40f6ec1-8a6b-41c1-b3fe-928f581ebaf5")] - public string FromActionConstraint() + [RequestScopedConstraint("b40f6ec1-8a6b-41c1-b3fe-928f581ebaf5")] + public string FromConstraint() { return "b40f6ec1-8a6b-41c1-b3fe-928f581ebaf5"; } diff --git a/test/WebSites/BasicWebSite/RequestScopedActionConstraint.cs b/test/WebSites/BasicWebSite/RequestScopedActionConstraint.cs index 3a5be47cf4..0b351b389a 100644 --- a/test/WebSites/BasicWebSite/RequestScopedActionConstraint.cs +++ b/test/WebSites/BasicWebSite/RequestScopedActionConstraint.cs @@ -4,12 +4,13 @@ using System; using System.Collections.Concurrent; using Microsoft.AspNetCore.Mvc.ActionConstraints; +using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.Extensions.DependencyInjection; namespace BasicWebSite { // Only matches when the requestId is the same as the one passed in the constructor. - public class RequestScopedActionConstraintAttribute : Attribute, IActionConstraintFactory + public class RequestScopedConstraintAttribute : Attribute, IActionConstraintFactory, IEndpointConstraintFactory { private readonly string _requestId; private readonly Func CreateFactory = @@ -19,18 +20,28 @@ namespace BasicWebSite public bool IsReusable => false; - public RequestScopedActionConstraintAttribute(string requestId) + public RequestScopedConstraintAttribute(string requestId) { _requestId = requestId; } - public IActionConstraint CreateInstance(IServiceProvider services) + IActionConstraint IActionConstraintFactory.CreateInstance(IServiceProvider services) { - var constraintType = typeof(Constraint); - return (Constraint)ActivatorUtilities.CreateInstance(services, typeof(Constraint),new[] { _requestId }); + return CreateInstanceCore(services); } - private class Constraint : IActionConstraint + IEndpointConstraint IEndpointConstraintFactory.CreateInstance(IServiceProvider services) + { + return CreateInstanceCore(services); + } + + private Constraint CreateInstanceCore(IServiceProvider services) + { + var constraintType = typeof(Constraint); + return (Constraint)ActivatorUtilities.CreateInstance(services, typeof(Constraint), new[] { _requestId }); + } + + private class Constraint : IActionConstraint, IEndpointConstraint { private readonly RequestIdService _requestIdService; private readonly string _requestId; @@ -43,7 +54,17 @@ namespace BasicWebSite public int Order { get; private set; } - public bool Accept(ActionConstraintContext context) + bool IActionConstraint.Accept(ActionConstraintContext context) + { + return AcceptCore(); + } + + bool IEndpointConstraint.Accept(EndpointConstraintContext context) + { + return AcceptCore(); + } + + private bool AcceptCore() { return _requestId == _requestIdService.RequestId; } diff --git a/test/WebSites/BasicWebSite/StartupWithDispatching.cs b/test/WebSites/BasicWebSite/StartupWithDispatching.cs index 4c5ed0dc1e..eb122b7290 100644 --- a/test/WebSites/BasicWebSite/StartupWithDispatching.cs +++ b/test/WebSites/BasicWebSite/StartupWithDispatching.cs @@ -19,10 +19,16 @@ namespace BasicWebSite .AddXmlDataContractSerializerFormatters(); services.ConfigureBaseWebSiteAuthPolicies(); + + services.AddHttpContextAccessor(); + services.AddScoped(); } public void Configure(IApplicationBuilder app) { + // Initializes the RequestId service for each request + app.UseMiddleware(); + app.UseDispatcher(); app.UseMvcWithEndpoint(routes =>