From 964b671288d95ee659d2119f3c3d9831c70b81db Mon Sep 17 00:00:00 2001 From: Javier Calvarro Nelson Date: Mon, 14 Aug 2017 13:35:06 -0700 Subject: [PATCH] [Fixes #1012] Make it possible to override services when using UseStartup. * Add IStartupConfigureServicesFilter to wrap ConfigureServices. * Add IStartupConfigureContainerFilter to wrap ConfigureContainer. * Make StartupLoader build a thunk for configure services that resolves all instances of IStartupConfigureServicesFilter and IStartupConfigureContainerFilter and wraps invocations to ConfigureServices and ConfigureContainer respectively. * Refactor building the ConfigureServices callback into a private builder class due to the increased complexity in the process. --- .../IStartupConfigureContainerFilter.cs | 12 + .../IStartupConfigureServicesFilter.cs | 13 + .../Internal/ConfigureContainerBuilder.cs | 10 +- .../Internal/ConfigureServicesBuilder.cs | 9 + .../Internal/StartupLoader.cs | 205 ++++++++++++-- .../WebHostBuilderExtensions.cs | 94 +++++++ .../Internal/MyContainer.cs | 2 + .../StartupManagerTests.cs | 256 ++++++++++++++++++ .../TestServerTests.cs | 43 ++- 9 files changed, 615 insertions(+), 29 deletions(-) create mode 100644 src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureContainerFilter.cs create mode 100644 src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureServicesFilter.cs create mode 100644 src/Microsoft.AspNetCore.TestHost/WebHostBuilderExtensions.cs diff --git a/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureContainerFilter.cs b/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureContainerFilter.cs new file mode 100644 index 0000000000..443a066efa --- /dev/null +++ b/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureContainerFilter.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Hosting +{ + public interface IStartupConfigureContainerFilter + { + Action ConfigureContainer(Action container); + } +} diff --git a/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureServicesFilter.cs b/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureServicesFilter.cs new file mode 100644 index 0000000000..a7e50f7572 --- /dev/null +++ b/src/Microsoft.AspNetCore.Hosting.Abstractions/IStartupConfigureServicesFilter.cs @@ -0,0 +1,13 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Hosting +{ + public interface IStartupConfigureServicesFilter + { + Action ConfigureServices(Action next); + } +} diff --git a/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureContainerBuilder.cs b/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureContainerBuilder.cs index 8eb9603e3e..ed8d0fd06e 100644 --- a/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureContainerBuilder.cs +++ b/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureContainerBuilder.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Linq; using System.Reflection; namespace Microsoft.AspNetCore.Hosting.Internal @@ -16,6 +15,8 @@ namespace Microsoft.AspNetCore.Hosting.Internal public MethodInfo MethodInfo { get; } + public Func, Action> ConfigureContainerFilters { get; set; } + public Action Build(object instance) => container => Invoke(instance, container); public Type GetContainerType() @@ -30,6 +31,13 @@ namespace Microsoft.AspNetCore.Hosting.Internal } private void Invoke(object instance, object container) + { + ConfigureContainerFilters(StartupConfigureContainer)(container); + + void StartupConfigureContainer(object containerBuilder) => InvokeCore(instance, containerBuilder); + } + + private void InvokeCore(object instance, object container) { if (MethodInfo == null) { diff --git a/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureServicesBuilder.cs b/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureServicesBuilder.cs index d8f92fa953..4206d0d62a 100644 --- a/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureServicesBuilder.cs +++ b/src/Microsoft.AspNetCore.Hosting/Internal/ConfigureServicesBuilder.cs @@ -17,9 +17,18 @@ namespace Microsoft.AspNetCore.Hosting.Internal public MethodInfo MethodInfo { get; } + public Func, Func> StartupServiceFilters { get; set; } + public Func Build(object instance) => services => Invoke(instance, services); private IServiceProvider Invoke(object instance, IServiceCollection services) + { + return StartupServiceFilters(Startup)(services); + + IServiceProvider Startup(IServiceCollection serviceCollection) => InvokeCore(instance, serviceCollection); + } + + private IServiceProvider InvokeCore(object instance, IServiceCollection services) { if (MethodInfo == null) { diff --git a/src/Microsoft.AspNetCore.Hosting/Internal/StartupLoader.cs b/src/Microsoft.AspNetCore.Hosting/Internal/StartupLoader.cs index e373ed0d64..879f382f5c 100644 --- a/src/Microsoft.AspNetCore.Hosting/Internal/StartupLoader.cs +++ b/src/Microsoft.AspNetCore.Hosting/Internal/StartupLoader.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Globalization; using System.Linq; using System.Reflection; @@ -11,9 +12,35 @@ namespace Microsoft.AspNetCore.Hosting.Internal { public class StartupLoader { + // Creates an instance with the actions to run for configuring the application services and the + // request pipeline of the application. + // When using convention based startup, the process for initializing the services is as follows: + // The host looks for a method with the signature ConfigureServices( services). + // If it can't find one, it looks for a method with the signature ConfigureServices( services). + // When the configure services method is void returning, the host builds a services configuration function that runs all the + // instances registered on the host, along with the ConfigureServices method following a decorator pattern. + // Additionally to the ConfigureServices method, the Startup class can define a ConfigureContainer<TContainerBuilder>(TContainerBuilder builder) + // method that further configures services into the container. If the ConfigureContainer method is defined, the services configuration function + // creates a TContainerBuilder and runs all the + // instances registered on the host, along with the ConfigureContainer method following a decorator pattern. + // For example: + // StartupFilter1 + // StartupFilter2 + // ConfigureServices + // StartupFilter2 + // StartupFilter1 + // ConfigureContainerFilter1 + // ConfigureContainerFilter2 + // ConfigureContainer + // ConfigureContainerFilter2 + // ConfigureContainerFilter1 + // + // If the Startup class ConfigureServices returns an and there is at least an registered we + // throw as the filters can't be applied. public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider, Type startupType, string environmentName) { var configureMethod = FindConfigureDelegate(startupType, environmentName); + var servicesMethod = FindConfigureServicesDelegate(startupType, environmentName); var configureContainerMethod = FindConfigureContainerDelegate(startupType, environmentName); @@ -23,44 +50,170 @@ namespace Microsoft.AspNetCore.Hosting.Internal instance = ActivatorUtilities.GetServiceOrCreateInstance(hostingServiceProvider, startupType); } - var configureServicesCallback = servicesMethod.Build(instance); - var configureContainerCallback = configureContainerMethod.Build(instance); + // The type of the TContainerBuilder. If there is no ConfigureContainer method we can just use object as it's not + // going to be used for anything. + var type = configureContainerMethod.MethodInfo != null ? configureContainerMethod.GetContainerType() : typeof(object); - Func configureServices = services => + var builder = (ConfigureServicesDelegateBuilder) Activator.CreateInstance( + typeof(ConfigureServicesDelegateBuilder<>).MakeGenericType(type), + hostingServiceProvider, + servicesMethod, + configureContainerMethod, + instance); + + return new StartupMethods(instance, configureMethod.Build(instance), builder.Build()); + } + + private abstract class ConfigureServicesDelegateBuilder + { + public abstract Func Build(); + } + + private class ConfigureServicesDelegateBuilder : ConfigureServicesDelegateBuilder + { + public ConfigureServicesDelegateBuilder( + IServiceProvider hostingServiceProvider, + ConfigureServicesBuilder configureServicesBuilder, + ConfigureContainerBuilder configureContainerBuilder, + object instance) { - // Call ConfigureServices, if that returned an IServiceProvider, we're done - IServiceProvider applicationServiceProvider = configureServicesCallback.Invoke(services); + HostingServiceProvider = hostingServiceProvider; + ConfigureServicesBuilder = configureServicesBuilder; + ConfigureContainerBuilder = configureContainerBuilder; + Instance = instance; + } - if (applicationServiceProvider != null) + public IServiceProvider HostingServiceProvider { get; } + public ConfigureServicesBuilder ConfigureServicesBuilder { get; } + public ConfigureContainerBuilder ConfigureContainerBuilder { get; } + public object Instance { get; } + + public override Func Build() + { + ConfigureServicesBuilder.StartupServiceFilters = BuildStartupServicesFilterPipeline; + var configureServicesCallback = ConfigureServicesBuilder.Build(Instance); + + ConfigureContainerBuilder.ConfigureContainerFilters = ConfigureContainerPipeline; + var configureContainerCallback = ConfigureContainerBuilder.Build(Instance); + + return ConfigureServices(configureServicesCallback, configureContainerCallback); + + Action ConfigureContainerPipeline(Action action) { - return applicationServiceProvider; - } + return Target; - // If there's a ConfigureContainer method - if (configureContainerMethod.MethodInfo != null) + // The ConfigureContainer pipeline needs an Action as source, so we just adapt the + // signature with this function. + void Source(TContainerBuilder containerBuilder) => + action(containerBuilder); + + // The ConfigureContainerBuilder.ConfigureContainerFilters expects an Action as value, but our pipeline + // produces an Action given a source, so we wrap it on an Action that internally casts + // the object containerBuilder to TContainerBuilder to match the expected signature of our ConfigureContainer pipeline. + void Target(object containerBuilder) => + BuildStartupConfigureContainerFiltersPipeline(Source)((TContainerBuilder)containerBuilder); + } + } + + Func ConfigureServices( + Func configureServicesCallback, + Action configureContainerCallback) + { + return ConfigureServicesWithContainerConfiguration; + + IServiceProvider ConfigureServicesWithContainerConfiguration(IServiceCollection services) { - // We have a ConfigureContainer method, get the IServiceProviderFactory - var serviceProviderFactoryType = typeof(IServiceProviderFactory<>).MakeGenericType(configureContainerMethod.GetContainerType()); - var serviceProviderFactory = hostingServiceProvider.GetRequiredService(serviceProviderFactoryType); - // var builder = serviceProviderFactory.CreateBuilder(services); - var builder = serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateBuilder)).Invoke(serviceProviderFactory, new object[] { services }); - configureContainerCallback.Invoke(builder); - // applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(builder); - applicationServiceProvider = (IServiceProvider)serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateServiceProvider)).Invoke(serviceProviderFactory, new object[] { builder }); + // Call ConfigureServices, if that returned an IServiceProvider, we're done + IServiceProvider applicationServiceProvider = configureServicesCallback.Invoke(services); + + if (applicationServiceProvider != null) + { + return applicationServiceProvider; + } + + // If there's a ConfigureContainer method + if (ConfigureContainerBuilder.MethodInfo != null) + { + var serviceProviderFactory = HostingServiceProvider.GetRequiredService>(); + var builder = serviceProviderFactory.CreateBuilder(services); + configureContainerCallback(builder); + applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(builder); + } + else + { + // Get the default factory + var serviceProviderFactory = HostingServiceProvider.GetRequiredService>(); + + // Don't bother calling CreateBuilder since it just returns the default service collection + applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(services); + } + + return applicationServiceProvider ?? services.BuildServiceProvider(); } - else + } + + private Func BuildStartupServicesFilterPipeline(Func startup) + { + return RunPipeline; + + IServiceProvider RunPipeline(IServiceCollection services) { - // Get the default factory - var serviceProviderFactory = hostingServiceProvider.GetRequiredService>(); + var filters = HostingServiceProvider.GetRequiredService>().Reverse().ToArray(); - // Don't bother calling CreateBuilder since it just returns the default service collection - applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(services); + // If there are no filters just run startup (makes IServiceProvider ConfigureServices(IServiceCollection services) work. + if (filters.Length == 0) + { + return startup(services); + } + + Action pipeline = InvokeStartup; + for (int i = 0; i < filters.Length; i++) + { + pipeline = filters[i].ConfigureServices(pipeline); + } + + pipeline(services); + + // We return null so that the host here builds the container (same result as void ConfigureServices(IServiceCollection services); + return null; + + void InvokeStartup(IServiceCollection serviceCollection) + { + var result = startup(serviceCollection); + if (filters.Length > 0 && result != null) + { + // public IServiceProvider ConfigureServices(IServiceCollection serviceCollection) is not compatible with IStartupServicesFilter; + var message = $"A ConfigureServices method that returns an {nameof(IServiceProvider)} is " + + $"not compatible with the use of one or more {nameof(IStartupConfigureServicesFilter)}. " + + $"Use a void returning ConfigureServices method instead or a ConfigureContainer method."; + throw new InvalidOperationException(message); + }; + } } + } - return applicationServiceProvider ?? services.BuildServiceProvider(); - }; + private Action BuildStartupConfigureContainerFiltersPipeline(Action configureContainer) + { + return RunPipeline; - return new StartupMethods(instance, configureMethod.Build(instance), configureServices); + void RunPipeline(TContainerBuilder containerBuilder) + { + var filters = HostingServiceProvider + .GetRequiredService>>() + .Reverse() + .ToArray(); + + Action pipeline = InvokeConfigureContainer; + for (int i = 0; i < filters.Length; i++) + { + pipeline = filters[i].ConfigureContainer(pipeline); + } + + pipeline(containerBuilder); + + void InvokeConfigureContainer(TContainerBuilder builder) => configureContainer(builder); + } + } } public static Type FindStartupType(string startupAssemblyName, string environmentName) diff --git a/src/Microsoft.AspNetCore.TestHost/WebHostBuilderExtensions.cs b/src/Microsoft.AspNetCore.TestHost/WebHostBuilderExtensions.cs new file mode 100644 index 0000000000..55b156b08b --- /dev/null +++ b/src/Microsoft.AspNetCore.TestHost/WebHostBuilderExtensions.cs @@ -0,0 +1,94 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.TestHost +{ + public static class WebHostBuilderExtensions + { + public static IWebHostBuilder ConfigureTestServices(this IWebHostBuilder webHostBuilder, Action servicesConfiguration) + { + if (webHostBuilder == null) + { + throw new ArgumentNullException(nameof(webHostBuilder)); + } + + if (servicesConfiguration == null) + { + throw new ArgumentNullException(nameof(servicesConfiguration)); + } + + webHostBuilder.ConfigureServices( + s => s.AddSingleton( + new ConfigureTestServicesStartupConfigureServicesFilter(servicesConfiguration))); + + return webHostBuilder; + } + + public static IWebHostBuilder ConfigureTestContainer(this IWebHostBuilder webHostBuilder, Action servicesConfiguration) + { + if (webHostBuilder == null) + { + throw new ArgumentNullException(nameof(webHostBuilder)); + } + + if (servicesConfiguration == null) + { + throw new ArgumentNullException(nameof(servicesConfiguration)); + } + + webHostBuilder.ConfigureServices( + s => s.AddSingleton>( + new ConfigureTestServicesStartupConfigureContainerFilter(servicesConfiguration))); + + return webHostBuilder; + } + + private class ConfigureTestServicesStartupConfigureServicesFilter : IStartupConfigureServicesFilter + { + private readonly Action _servicesConfiguration; + + public ConfigureTestServicesStartupConfigureServicesFilter(Action servicesConfiguration) + { + if (servicesConfiguration == null) + { + throw new ArgumentNullException(nameof(servicesConfiguration)); + } + + _servicesConfiguration = servicesConfiguration; + } + + public Action ConfigureServices(Action next) => + serviceCollection => + { + next(serviceCollection); + _servicesConfiguration(serviceCollection); + }; + } + + private class ConfigureTestServicesStartupConfigureContainerFilter : IStartupConfigureContainerFilter + { + private readonly Action _servicesConfiguration; + + public ConfigureTestServicesStartupConfigureContainerFilter(Action containerConfiguration) + { + if (containerConfiguration == null) + { + throw new ArgumentNullException(nameof(containerConfiguration)); + } + + _servicesConfiguration = containerConfiguration; + } + + public Action ConfigureContainer(Action next) => + containerBuilder => + { + next(containerBuilder); + _servicesConfiguration(containerBuilder); + }; + } + } +} diff --git a/test/Microsoft.AspNetCore.Hosting.Tests/Internal/MyContainer.cs b/test/Microsoft.AspNetCore.Hosting.Tests/Internal/MyContainer.cs index 97a5d63431..5fbffaad1b 100644 --- a/test/Microsoft.AspNetCore.Hosting.Tests/Internal/MyContainer.cs +++ b/test/Microsoft.AspNetCore.Hosting.Tests/Internal/MyContainer.cs @@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Hosting.Tests.Internal public bool FancyMethodCalled { get; private set; } + public IServiceCollection Services => _services; + public string Environment { get; set; } public object GetService(Type serviceType) diff --git a/test/Microsoft.AspNetCore.Hosting.Tests/StartupManagerTests.cs b/test/Microsoft.AspNetCore.Hosting.Tests/StartupManagerTests.cs index acc23c3bc2..75d5ccb98c 100644 --- a/test/Microsoft.AspNetCore.Hosting.Tests/StartupManagerTests.cs +++ b/test/Microsoft.AspNetCore.Hosting.Tests/StartupManagerTests.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Hosting.Fakes; using Microsoft.AspNetCore.Hosting.Internal; using Microsoft.AspNetCore.Hosting.Tests.Internal; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; using Xunit; @@ -17,6 +18,261 @@ namespace Microsoft.AspNetCore.Hosting.Tests { public class StartupManagerTests { + [Fact] + public void ConventionalStartupClass_StartupServiceFilters_WrapsConfigureServicesMethod() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, DefaultServiceProviderFactory>(); + serviceCollection.AddSingleton(new TestStartupServicesFilter(1, overrideAfterService: true)); + serviceCollection.AddSingleton(new TestStartupServicesFilter(2, overrideAfterService: true)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(VoidReturningStartupServicesFiltersStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var applicationServices = startup.ConfigureServicesDelegate(serviceCollection); + var before = applicationServices.GetRequiredService(); + var after = applicationServices.GetRequiredService(); + + Assert.Equal("StartupServicesFilter Before 1", before.Message); + Assert.Equal("StartupServicesFilter After 1", after.Message); + } + + [Fact] + public void ConventionalStartupClass_StartupServiceFilters_MultipleStartupServiceFiltersRun() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, DefaultServiceProviderFactory>(); + serviceCollection.AddSingleton(new TestStartupServicesFilter(1, overrideAfterService: false)); + serviceCollection.AddSingleton(new TestStartupServicesFilter(2, overrideAfterService: true)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(VoidReturningStartupServicesFiltersStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var applicationServices = startup.ConfigureServicesDelegate(serviceCollection); + var before = applicationServices.GetRequiredService(); + var after = applicationServices.GetRequiredService(); + + Assert.Equal("StartupServicesFilter Before 1", before.Message); + Assert.Equal("StartupServicesFilter After 2", after.Message); + } + + [Fact] + public void ConventionalStartupClass_StartupServicesFilters_ThrowsIfStartupBuildsTheContainerAsync() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, DefaultServiceProviderFactory>(); + serviceCollection.AddSingleton(new TestStartupServicesFilter(1, overrideAfterService: false)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(IServiceProviderReturningStartupServicesFiltersStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var expectedMessage = $"A ConfigureServices method that returns an {nameof(IServiceProvider)} is " + + $"not compatible with the use of one or more {nameof(IStartupConfigureServicesFilter)}. " + + $"Use a void returning ConfigureServices method instead or a ConfigureContainer method."; + + var exception = Assert.Throws(() => startup.ConfigureServicesDelegate(serviceCollection)); + + Assert.Equal(expectedMessage, exception.Message); + } + + [Fact] + public void ConventionalStartupClass_ConfigureContainerFilters_WrapInRegistrationOrder() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, MyContainerFactory>(); + serviceCollection.AddSingleton>(new TestConfigureContainerFilter(1, overrideAfterService: true)); + serviceCollection.AddSingleton>(new TestConfigureContainerFilter(2, overrideAfterService: true)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(ConfigureContainerStartupServicesFiltersStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var applicationServices = startup.ConfigureServicesDelegate(serviceCollection); + var before = applicationServices.GetRequiredService(); + var after = applicationServices.GetRequiredService(); + + Assert.Equal("ConfigureContainerFilter Before 1", before.Message); + Assert.Equal("ConfigureContainerFilter After 1", after.Message); + } + + [Fact] + public void ConventionalStartupClass_ConfigureContainerFilters_RunsAllFilters() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, MyContainerFactory>(); + serviceCollection.AddSingleton>(new TestConfigureContainerFilter(1, overrideAfterService: false)); + serviceCollection.AddSingleton>(new TestConfigureContainerFilter(2, overrideAfterService: true)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(ConfigureContainerStartupServicesFiltersStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var applicationServices = startup.ConfigureServicesDelegate(serviceCollection); + var before = applicationServices.GetRequiredService(); + var after = applicationServices.GetRequiredService(); + + Assert.Equal("ConfigureContainerFilter Before 1", before.Message); + Assert.Equal("ConfigureContainerFilter After 2", after.Message); + } + + [Fact] + public void ConventionalStartupClass_ConfigureContainerFilters_RunAfterConfigureServicesFilters() + { + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton, MyContainerFactory>(); + serviceCollection.AddSingleton(new TestStartupServicesFilter(1, overrideAfterService: false)); + serviceCollection.AddSingleton>(new TestConfigureContainerFilter(2, overrideAfterService: true)); + var services = serviceCollection.BuildServiceProvider(); + + var type = typeof(ConfigureServicesAndConfigureContainerStartup); + var startup = StartupLoader.LoadMethods(services, type, ""); + + var applicationServices = startup.ConfigureServicesDelegate(serviceCollection); + var before = applicationServices.GetRequiredService(); + var after = applicationServices.GetRequiredService(); + + Assert.Equal("StartupServicesFilter Before 1", before.Message); + Assert.Equal("ConfigureContainerFilter After 2", after.Message); + } + + public class ConfigureContainerStartupServicesFiltersStartup + { + public void ConfigureContainer(MyContainer services) + { + services.Services.TryAddSingleton(new ServiceBefore { Message = "Configure container" }); + services.Services.TryAddSingleton(new ServiceAfter { Message = "Configure container" }); + } + + public void Configure(IApplicationBuilder builder) + { + } + } + + public class ConfigureServicesAndConfigureContainerStartup + { + public void ConfigureServices(IServiceCollection services) + { + services.TryAddSingleton(new ServiceBefore { Message = "Configure services" }); + services.TryAddSingleton(new ServiceAfter { Message = "Configure services" }); + } + + public void ConfigureContainer(MyContainer services) + { + services.Services.TryAddSingleton(new ServiceBefore { Message = "Configure container" }); + services.Services.TryAddSingleton(new ServiceAfter { Message = "Configure container" }); + } + + public void Configure(IApplicationBuilder builder) + { + } + } + + public class TestConfigureContainerFilter : IStartupConfigureContainerFilter + { + public TestConfigureContainerFilter(object additionalData, bool overrideAfterService) + { + AdditionalData = additionalData; + OverrideAfterService = overrideAfterService; + } + + public object AdditionalData { get; } + public bool OverrideAfterService { get; } + + public Action ConfigureContainer(Action next) + { + return services => + { + services.Services.TryAddSingleton(new ServiceBefore { Message = $"ConfigureContainerFilter Before {AdditionalData}" }); + + next(services); + + // Ensures we can always override. + if (OverrideAfterService) + { + services.Services.AddSingleton(new ServiceAfter { Message = $"ConfigureContainerFilter After {AdditionalData}" }); + } + else + { + services.Services.TryAddSingleton(new ServiceAfter { Message = $"ConfigureContainerFilter After {AdditionalData}" }); + } + }; + } + } + + public class IServiceProviderReturningStartupServicesFiltersStartup + { + public IServiceProvider ConfigureServices(IServiceCollection services) + { + services.TryAddSingleton(new ServiceBefore { Message = "Configure services" }); + services.TryAddSingleton(new ServiceAfter { Message = "Configure services" }); + + return services.BuildServiceProvider(); + } + + public void Configure(IApplicationBuilder builder) + { + } + } + + public class TestStartupServicesFilter : IStartupConfigureServicesFilter + { + public TestStartupServicesFilter(object additionalData, bool overrideAfterService) + { + AdditionalData = additionalData; + OverrideAfterService = overrideAfterService; + } + + public object AdditionalData { get; } + public bool OverrideAfterService { get; } + + public Action ConfigureServices(Action next) + { + return services => + { + services.TryAddSingleton(new ServiceBefore { Message = $"StartupServicesFilter Before {AdditionalData}" }); + + next(services); + + // Ensures we can always override. + if (OverrideAfterService) + { + services.AddSingleton(new ServiceAfter { Message = $"StartupServicesFilter After {AdditionalData}" }); + } + else + { + services.TryAddSingleton(new ServiceAfter { Message = $"StartupServicesFilter After {AdditionalData}" }); + } + }; + } + } + + public class VoidReturningStartupServicesFiltersStartup + { + public void ConfigureServices(IServiceCollection services) + { + services.TryAddSingleton(new ServiceBefore { Message = "Configure services" }); + services.TryAddSingleton(new ServiceAfter { Message = "Configure services" }); + } + + public void Configure(IApplicationBuilder builder) + { + } + } + + + public class ServiceBefore + { + public string Message { get; set; } + } + + public class ServiceAfter + { + public string Message { get; set; } + } + [Fact] public void StartupClassMayHaveHostingServicesInjected() { diff --git a/test/Microsoft.AspNetCore.TestHost.Tests/TestServerTests.cs b/test/Microsoft.AspNetCore.TestHost.Tests/TestServerTests.cs index c89af26114..6ad118f7fe 100644 --- a/test/Microsoft.AspNetCore.TestHost.Tests/TestServerTests.cs +++ b/test/Microsoft.AspNetCore.TestHost.Tests/TestServerTests.cs @@ -30,7 +30,6 @@ namespace Microsoft.AspNetCore.TestHost new TestServer(new WebHostBuilder().Configure(app => { })); } - [Fact] public void DoesNotCaptureStartupErrorsByDefault() { @@ -43,6 +42,46 @@ namespace Microsoft.AspNetCore.TestHost Assert.Throws(() => new TestServer(builder)); } + [Fact] + public async Task ServicesCanBeOverridenForTestingAsync() + { + var builder = new WebHostBuilder() + .ConfigureServices(s => s.AddSingleton,ThirdPartyContainerServiceProviderFactory>()) + .UseStartup() + .ConfigureTestServices(services => services.AddSingleton(new SimpleService { Message = "OverridesConfigureServices" })) + .ConfigureTestContainer(container => container.Services.AddSingleton(new TestService { Message = "OverridesConfigureContainer" })); + + var host = new TestServer(builder); + + var response = await host.CreateClient().GetStringAsync("/"); + + Assert.Equal("OverridesConfigureServices, OverridesConfigureContainer", response); + } + + public class ThirdPartyContainerStartup + { + public void ConfigureServices(IServiceCollection services) => + services.AddSingleton(new SimpleService { Message = "ConfigureServices" }); + + public void ConfigureContainer(ThirdPartyContainer container) => + container.Services.AddSingleton(new TestService { Message = "ConfigureContainer" }); + + public void Configure(IApplicationBuilder app) => + app.Use((ctx, next) => ctx.Response.WriteAsync( + $"{ctx.RequestServices.GetRequiredService().Message}, {ctx.RequestServices.GetRequiredService().Message}")); + } + + public class ThirdPartyContainer + { + public IServiceCollection Services { get; set; } + } + + public class ThirdPartyContainerServiceProviderFactory : IServiceProviderFactory + { + public ThirdPartyContainer CreateBuilder(IServiceCollection services) => new ThirdPartyContainer { Services = services }; + + public IServiceProvider CreateServiceProvider(ThirdPartyContainer containerBuilder) => containerBuilder.Services.BuildServiceProvider(); + } [Fact] public void CaptureStartupErrorsSettingPreserved() @@ -153,7 +192,7 @@ namespace Microsoft.AspNetCore.TestHost Assert.Throws(() => new TestServer(builder, null)); } - public class TestService { } + public class TestService { public string Message { get; set; } } public class TestRequestServiceMiddleware {