From 73e4d55d7bef3d3fbd90f7309bd0034a3d78cf4f Mon Sep 17 00:00:00 2001 From: James Newton-King Date: Sat, 14 Jul 2018 18:20:42 +1200 Subject: [PATCH] Check dispatcher services registered (#610) --- .../DispatcherApplicationBuilderExtensions.cs | 20 ++++ .../DispatcherServiceCollectionExtensions.cs | 4 + .../RoutingServiceCollectionExtensions.cs | 1 - .../Internal/DispatcherMarkerService.cs | 15 +++ .../Internal/RoutingMarkerService.cs | 2 +- .../DispatcherApplicationBuilderExtensions.cs | 103 ++++++++++++++++++ 6 files changed, 143 insertions(+), 2 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Routing/Internal/DispatcherMarkerService.cs create mode 100644 test/Microsoft.AspNetCore.Routing.Tests/DispatcherApplicationBuilderExtensions.cs diff --git a/src/Microsoft.AspNetCore.Routing/Builder/DispatcherApplicationBuilderExtensions.cs b/src/Microsoft.AspNetCore.Routing/Builder/DispatcherApplicationBuilderExtensions.cs index c3cc3df56c..df0528799e 100644 --- a/src/Microsoft.AspNetCore.Routing/Builder/DispatcherApplicationBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Routing/Builder/DispatcherApplicationBuilderExtensions.cs @@ -1,7 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Routing.Internal; +using Microsoft.Extensions.DependencyInjection; namespace Microsoft.AspNetCore.Builder { @@ -9,12 +12,29 @@ namespace Microsoft.AspNetCore.Builder { public static IApplicationBuilder UseDispatcher(this IApplicationBuilder builder) { + VerifyDispatcherIsRegistered(builder); + return builder.UseMiddleware(); } public static IApplicationBuilder UseEndpoint(this IApplicationBuilder builder) { + VerifyDispatcherIsRegistered(builder); + return builder.UseMiddleware(); } + + private static void VerifyDispatcherIsRegistered(IApplicationBuilder app) + { + // Verify if AddDispatcher was done before calling UseDispatcher/UseEndpoint + // We use the DispatcherMarkerService to make sure if all the services were added. + if (app.ApplicationServices.GetService(typeof(DispatcherMarkerService)) == null) + { + throw new InvalidOperationException(Resources.FormatUnableToFindServices( + nameof(IServiceCollection), + nameof(DispatcherServiceCollectionExtensions.AddDispatcher), + "ConfigureServices(...)")); + } + } } } diff --git a/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs index 565fe50cdf..afd9a4a4b0 100644 --- a/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Routing/DependencyInjection/DispatcherServiceCollectionExtensions.cs @@ -6,6 +6,7 @@ using System.Reflection; using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing.EndpointConstraints; using Microsoft.AspNetCore.Routing.EndpointFinders; +using Microsoft.AspNetCore.Routing.Internal; using Microsoft.AspNetCore.Routing.Matchers; using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; @@ -34,6 +35,7 @@ namespace Microsoft.Extensions.DependencyInjection // // Default matcher implementation // + services.TryAddSingleton(); services.TryAddSingleton(); // Link generation related services @@ -50,6 +52,8 @@ namespace Microsoft.Extensions.DependencyInjection services.TryAddEnumerable( ServiceDescriptor.Transient()); + services.TryAddSingleton(typeof(DispatcherMarkerService)); + return services; } diff --git a/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs index 18ededc423..675ef797ac 100644 --- a/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Routing/DependencyInjection/RoutingServiceCollectionExtensions.cs @@ -30,7 +30,6 @@ namespace Microsoft.Extensions.DependencyInjection throw new ArgumentNullException(nameof(services)); } - services.TryAddSingleton(); services.TryAddTransient(); services.TryAddSingleton>(s => { diff --git a/src/Microsoft.AspNetCore.Routing/Internal/DispatcherMarkerService.cs b/src/Microsoft.AspNetCore.Routing/Internal/DispatcherMarkerService.cs new file mode 100644 index 0000000000..f4e742b654 --- /dev/null +++ b/src/Microsoft.AspNetCore.Routing/Internal/DispatcherMarkerService.cs @@ -0,0 +1,15 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Routing.Internal +{ + /// + /// A marker class used to determine if all the dispatcher services were added + /// to the before dispatcher is configured. + /// + internal class DispatcherMarkerService + { + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Routing/Internal/RoutingMarkerService.cs b/src/Microsoft.AspNetCore.Routing/Internal/RoutingMarkerService.cs index b180294316..c7bed9df18 100644 --- a/src/Microsoft.AspNetCore.Routing/Internal/RoutingMarkerService.cs +++ b/src/Microsoft.AspNetCore.Routing/Internal/RoutingMarkerService.cs @@ -9,7 +9,7 @@ namespace Microsoft.AspNetCore.Routing.Internal /// A marker class used to determine if all the routing services were added /// to the before routing is configured. /// - public class RoutingMarkerService + internal class RoutingMarkerService { } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Routing.Tests/DispatcherApplicationBuilderExtensions.cs b/test/Microsoft.AspNetCore.Routing.Tests/DispatcherApplicationBuilderExtensions.cs new file mode 100644 index 0000000000..6398227db1 --- /dev/null +++ b/test/Microsoft.AspNetCore.Routing.Tests/DispatcherApplicationBuilderExtensions.cs @@ -0,0 +1,103 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Builder +{ + public class DispatcherApplicationBuilderExtensions + { + [Fact] + public void UseDispatcher_ServicesNotRegistered_Throws() + { + // Arrange + var app = new ApplicationBuilder(Mock.Of()); + + // Act + var ex = Assert.Throws(() => app.UseDispatcher()); + + // Assert + Assert.Equal( + "Unable to find the required services. " + + "Please add all the required services by calling 'IServiceCollection.AddDispatcher' " + + "inside the call to 'ConfigureServices(...)' in the application startup code.", + ex.Message); + } + + [Fact] + public void UseEndpoint_ServicesNotRegistered_Throws() + { + // Arrange + var app = new ApplicationBuilder(Mock.Of()); + + // Act + var ex = Assert.Throws(() => app.UseEndpoint()); + + // Assert + Assert.Equal( + "Unable to find the required services. " + + "Please add all the required services by calling 'IServiceCollection.AddDispatcher' " + + "inside the call to 'ConfigureServices(...)' in the application startup code.", + ex.Message); + } + + [Fact] + public async Task UseDispatcher_ServicesRegistered_SetsFeature() + { + // Arrange + var services = CreateServices(); + + var app = new ApplicationBuilder(services); + + app.UseDispatcher(); + + var appFunc = app.Build(); + var httpContext = new DefaultHttpContext(); + + // Act + await appFunc(httpContext); + + // Assert + Assert.NotNull(httpContext.Features.Get()); + } + + [Fact] + public async Task UseEndpoint_ServicesRegistered_SetsFeature() + { + // Arrange + var services = CreateServices(); + + var app = new ApplicationBuilder(services); + + app.UseDispatcher(); + app.UseEndpoint(); + + var appFunc = app.Build(); + var httpContext = new DefaultHttpContext(); + + // Act + await appFunc(httpContext); + + // Assert + Assert.NotNull(httpContext.Features.Get()); + } + + private IServiceProvider CreateServices() + { + var services = new ServiceCollection(); + + services.AddLogging(); + services.AddOptions(); + services.AddDispatcher(); + + return services.BuildServiceProvider(); + } + } +}