[Fixes #1012] Make it possible to override services when using UseStartup.

* Add IStartupConfigureServicesFilter to wrap ConfigureServices.
* Add IStartupConfigureContainerFilter<TContainerBuilder> to wrap
  ConfigureContainer.
* Make StartupLoader build a thunk for configure services that
  resolves all instances of IStartupConfigureServicesFilter and
  IStartupConfigureContainerFilter<TContainerBuilder> and wraps
  invocations to ConfigureServices and ConfigureContainer respectively.
* Refactor building the ConfigureServices callback into a private
  builder class due to the increased complexity in the process.
This commit is contained in:
Javier Calvarro Nelson 2017-08-14 13:35:06 -07:00
parent 1ea0647ae2
commit 964b671288
9 changed files with 615 additions and 29 deletions

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
namespace Microsoft.AspNetCore.Hosting
{
public interface IStartupConfigureContainerFilter<TContainerBuilder>
{
Action<TContainerBuilder> ConfigureContainer(Action<TContainerBuilder> container);
}
}

View File

@ -0,0 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Hosting
{
public interface IStartupConfigureServicesFilter
{
Action<IServiceCollection> ConfigureServices(Action<IServiceCollection> next);
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Reflection;
namespace Microsoft.AspNetCore.Hosting.Internal
@ -16,6 +15,8 @@ namespace Microsoft.AspNetCore.Hosting.Internal
public MethodInfo MethodInfo { get; }
public Func<Action<object>, Action<object>> ConfigureContainerFilters { get; set; }
public Action<object> Build(object instance) => container => Invoke(instance, container);
public Type GetContainerType()
@ -30,6 +31,13 @@ namespace Microsoft.AspNetCore.Hosting.Internal
}
private void Invoke(object instance, object container)
{
ConfigureContainerFilters(StartupConfigureContainer)(container);
void StartupConfigureContainer(object containerBuilder) => InvokeCore(instance, containerBuilder);
}
private void InvokeCore(object instance, object container)
{
if (MethodInfo == null)
{

View File

@ -17,9 +17,18 @@ namespace Microsoft.AspNetCore.Hosting.Internal
public MethodInfo MethodInfo { get; }
public Func<Func<IServiceCollection, IServiceProvider>, Func<IServiceCollection, IServiceProvider>> StartupServiceFilters { get; set; }
public Func<IServiceCollection, IServiceProvider> Build(object instance) => services => Invoke(instance, services);
private IServiceProvider Invoke(object instance, IServiceCollection services)
{
return StartupServiceFilters(Startup)(services);
IServiceProvider Startup(IServiceCollection serviceCollection) => InvokeCore(instance, serviceCollection);
}
private IServiceProvider InvokeCore(object instance, IServiceCollection services)
{
if (MethodInfo == null)
{

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Linq;
using System.Reflection;
@ -11,9 +12,35 @@ namespace Microsoft.AspNetCore.Hosting.Internal
{
public class StartupLoader
{
// Creates an <see cref="StartupMethods"/> instance with the actions to run for configuring the application services and the
// request pipeline of the application.
// When using convention based startup, the process for initializing the services is as follows:
// The host looks for a method with the signature <see cref="IServiceProvider"/> ConfigureServices(<see cref="IServiceCollection"/> services).
// If it can't find one, it looks for a method with the signature <see cref="void"/> ConfigureServices(<see cref="IServiceCollection"/> services).
// When the configure services method is void returning, the host builds a services configuration function that runs all the <see cref="IStartupConfigureServicesFilter"/>
// instances registered on the host, along with the ConfigureServices method following a decorator pattern.
// Additionally to the ConfigureServices method, the Startup class can define a <see cref="void"/> ConfigureContainer&lt;TContainerBuilder&gt;(TContainerBuilder builder)
// method that further configures services into the container. If the ConfigureContainer method is defined, the services configuration function
// creates a TContainerBuilder <see cref="IServiceProviderFactory{TContainerBuilder}"/> and runs all the <see cref="IStartupConfigureContainerFilter{TContainerBuilder}"/>
// instances registered on the host, along with the ConfigureContainer method following a decorator pattern.
// For example:
// StartupFilter1
// StartupFilter2
// ConfigureServices
// StartupFilter2
// StartupFilter1
// ConfigureContainerFilter1
// ConfigureContainerFilter2
// ConfigureContainer
// ConfigureContainerFilter2
// ConfigureContainerFilter1
//
// If the Startup class ConfigureServices returns an <see cref="IServiceProvider"/> and there is at least an <see cref="IStartupConfigureServicesFilter"/> registered we
// throw as the filters can't be applied.
public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider, Type startupType, string environmentName)
{
var configureMethod = FindConfigureDelegate(startupType, environmentName);
var servicesMethod = FindConfigureServicesDelegate(startupType, environmentName);
var configureContainerMethod = FindConfigureContainerDelegate(startupType, environmentName);
@ -23,44 +50,170 @@ namespace Microsoft.AspNetCore.Hosting.Internal
instance = ActivatorUtilities.GetServiceOrCreateInstance(hostingServiceProvider, startupType);
}
var configureServicesCallback = servicesMethod.Build(instance);
var configureContainerCallback = configureContainerMethod.Build(instance);
// The type of the TContainerBuilder. If there is no ConfigureContainer method we can just use object as it's not
// going to be used for anything.
var type = configureContainerMethod.MethodInfo != null ? configureContainerMethod.GetContainerType() : typeof(object);
Func<IServiceCollection, IServiceProvider> configureServices = services =>
var builder = (ConfigureServicesDelegateBuilder) Activator.CreateInstance(
typeof(ConfigureServicesDelegateBuilder<>).MakeGenericType(type),
hostingServiceProvider,
servicesMethod,
configureContainerMethod,
instance);
return new StartupMethods(instance, configureMethod.Build(instance), builder.Build());
}
private abstract class ConfigureServicesDelegateBuilder
{
public abstract Func<IServiceCollection, IServiceProvider> Build();
}
private class ConfigureServicesDelegateBuilder<TContainerBuilder> : ConfigureServicesDelegateBuilder
{
public ConfigureServicesDelegateBuilder(
IServiceProvider hostingServiceProvider,
ConfigureServicesBuilder configureServicesBuilder,
ConfigureContainerBuilder configureContainerBuilder,
object instance)
{
// Call ConfigureServices, if that returned an IServiceProvider, we're done
IServiceProvider applicationServiceProvider = configureServicesCallback.Invoke(services);
HostingServiceProvider = hostingServiceProvider;
ConfigureServicesBuilder = configureServicesBuilder;
ConfigureContainerBuilder = configureContainerBuilder;
Instance = instance;
}
if (applicationServiceProvider != null)
public IServiceProvider HostingServiceProvider { get; }
public ConfigureServicesBuilder ConfigureServicesBuilder { get; }
public ConfigureContainerBuilder ConfigureContainerBuilder { get; }
public object Instance { get; }
public override Func<IServiceCollection, IServiceProvider> Build()
{
ConfigureServicesBuilder.StartupServiceFilters = BuildStartupServicesFilterPipeline;
var configureServicesCallback = ConfigureServicesBuilder.Build(Instance);
ConfigureContainerBuilder.ConfigureContainerFilters = ConfigureContainerPipeline;
var configureContainerCallback = ConfigureContainerBuilder.Build(Instance);
return ConfigureServices(configureServicesCallback, configureContainerCallback);
Action<object> ConfigureContainerPipeline(Action<object> action)
{
return applicationServiceProvider;
}
return Target;
// If there's a ConfigureContainer method
if (configureContainerMethod.MethodInfo != null)
// The ConfigureContainer pipeline needs an Action<TContainerBuilder> as source, so we just adapt the
// signature with this function.
void Source(TContainerBuilder containerBuilder) =>
action(containerBuilder);
// The ConfigureContainerBuilder.ConfigureContainerFilters expects an Action<object> as value, but our pipeline
// produces an Action<TContainerBuilder> given a source, so we wrap it on an Action<object> that internally casts
// the object containerBuilder to TContainerBuilder to match the expected signature of our ConfigureContainer pipeline.
void Target(object containerBuilder) =>
BuildStartupConfigureContainerFiltersPipeline(Source)((TContainerBuilder)containerBuilder);
}
}
Func<IServiceCollection, IServiceProvider> ConfigureServices(
Func<IServiceCollection, IServiceProvider> configureServicesCallback,
Action<object> configureContainerCallback)
{
return ConfigureServicesWithContainerConfiguration;
IServiceProvider ConfigureServicesWithContainerConfiguration(IServiceCollection services)
{
// We have a ConfigureContainer method, get the IServiceProviderFactory<TContainerBuilder>
var serviceProviderFactoryType = typeof(IServiceProviderFactory<>).MakeGenericType(configureContainerMethod.GetContainerType());
var serviceProviderFactory = hostingServiceProvider.GetRequiredService(serviceProviderFactoryType);
// var builder = serviceProviderFactory.CreateBuilder(services);
var builder = serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateBuilder)).Invoke(serviceProviderFactory, new object[] { services });
configureContainerCallback.Invoke(builder);
// applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(builder);
applicationServiceProvider = (IServiceProvider)serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateServiceProvider)).Invoke(serviceProviderFactory, new object[] { builder });
// Call ConfigureServices, if that returned an IServiceProvider, we're done
IServiceProvider applicationServiceProvider = configureServicesCallback.Invoke(services);
if (applicationServiceProvider != null)
{
return applicationServiceProvider;
}
// If there's a ConfigureContainer method
if (ConfigureContainerBuilder.MethodInfo != null)
{
var serviceProviderFactory = HostingServiceProvider.GetRequiredService<IServiceProviderFactory<TContainerBuilder>>();
var builder = serviceProviderFactory.CreateBuilder(services);
configureContainerCallback(builder);
applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(builder);
}
else
{
// Get the default factory
var serviceProviderFactory = HostingServiceProvider.GetRequiredService<IServiceProviderFactory<IServiceCollection>>();
// Don't bother calling CreateBuilder since it just returns the default service collection
applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(services);
}
return applicationServiceProvider ?? services.BuildServiceProvider();
}
else
}
private Func<IServiceCollection, IServiceProvider> BuildStartupServicesFilterPipeline(Func<IServiceCollection, IServiceProvider> startup)
{
return RunPipeline;
IServiceProvider RunPipeline(IServiceCollection services)
{
// Get the default factory
var serviceProviderFactory = hostingServiceProvider.GetRequiredService<IServiceProviderFactory<IServiceCollection>>();
var filters = HostingServiceProvider.GetRequiredService<IEnumerable<IStartupConfigureServicesFilter>>().Reverse().ToArray();
// Don't bother calling CreateBuilder since it just returns the default service collection
applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(services);
// If there are no filters just run startup (makes IServiceProvider ConfigureServices(IServiceCollection services) work.
if (filters.Length == 0)
{
return startup(services);
}
Action<IServiceCollection> pipeline = InvokeStartup;
for (int i = 0; i < filters.Length; i++)
{
pipeline = filters[i].ConfigureServices(pipeline);
}
pipeline(services);
// We return null so that the host here builds the container (same result as void ConfigureServices(IServiceCollection services);
return null;
void InvokeStartup(IServiceCollection serviceCollection)
{
var result = startup(serviceCollection);
if (filters.Length > 0 && result != null)
{
// public IServiceProvider ConfigureServices(IServiceCollection serviceCollection) is not compatible with IStartupServicesFilter;
var message = $"A ConfigureServices method that returns an {nameof(IServiceProvider)} is " +
$"not compatible with the use of one or more {nameof(IStartupConfigureServicesFilter)}. " +
$"Use a void returning ConfigureServices method instead or a ConfigureContainer method.";
throw new InvalidOperationException(message);
};
}
}
}
return applicationServiceProvider ?? services.BuildServiceProvider();
};
private Action<TContainerBuilder> BuildStartupConfigureContainerFiltersPipeline(Action<TContainerBuilder> configureContainer)
{
return RunPipeline;
return new StartupMethods(instance, configureMethod.Build(instance), configureServices);
void RunPipeline(TContainerBuilder containerBuilder)
{
var filters = HostingServiceProvider
.GetRequiredService<IEnumerable<IStartupConfigureContainerFilter<TContainerBuilder>>>()
.Reverse()
.ToArray();
Action<TContainerBuilder> pipeline = InvokeConfigureContainer;
for (int i = 0; i < filters.Length; i++)
{
pipeline = filters[i].ConfigureContainer(pipeline);
}
pipeline(containerBuilder);
void InvokeConfigureContainer(TContainerBuilder builder) => configureContainer(builder);
}
}
}
public static Type FindStartupType(string startupAssemblyName, string environmentName)

View File

@ -0,0 +1,94 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Hosting;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.TestHost
{
public static class WebHostBuilderExtensions
{
public static IWebHostBuilder ConfigureTestServices(this IWebHostBuilder webHostBuilder, Action<IServiceCollection> servicesConfiguration)
{
if (webHostBuilder == null)
{
throw new ArgumentNullException(nameof(webHostBuilder));
}
if (servicesConfiguration == null)
{
throw new ArgumentNullException(nameof(servicesConfiguration));
}
webHostBuilder.ConfigureServices(
s => s.AddSingleton<IStartupConfigureServicesFilter>(
new ConfigureTestServicesStartupConfigureServicesFilter(servicesConfiguration)));
return webHostBuilder;
}
public static IWebHostBuilder ConfigureTestContainer<TContainer>(this IWebHostBuilder webHostBuilder, Action<TContainer> servicesConfiguration)
{
if (webHostBuilder == null)
{
throw new ArgumentNullException(nameof(webHostBuilder));
}
if (servicesConfiguration == null)
{
throw new ArgumentNullException(nameof(servicesConfiguration));
}
webHostBuilder.ConfigureServices(
s => s.AddSingleton<IStartupConfigureContainerFilter<TContainer>>(
new ConfigureTestServicesStartupConfigureContainerFilter<TContainer>(servicesConfiguration)));
return webHostBuilder;
}
private class ConfigureTestServicesStartupConfigureServicesFilter : IStartupConfigureServicesFilter
{
private readonly Action<IServiceCollection> _servicesConfiguration;
public ConfigureTestServicesStartupConfigureServicesFilter(Action<IServiceCollection> servicesConfiguration)
{
if (servicesConfiguration == null)
{
throw new ArgumentNullException(nameof(servicesConfiguration));
}
_servicesConfiguration = servicesConfiguration;
}
public Action<IServiceCollection> ConfigureServices(Action<IServiceCollection> next) =>
serviceCollection =>
{
next(serviceCollection);
_servicesConfiguration(serviceCollection);
};
}
private class ConfigureTestServicesStartupConfigureContainerFilter<TContainer> : IStartupConfigureContainerFilter<TContainer>
{
private readonly Action<TContainer> _servicesConfiguration;
public ConfigureTestServicesStartupConfigureContainerFilter(Action<TContainer> containerConfiguration)
{
if (containerConfiguration == null)
{
throw new ArgumentNullException(nameof(containerConfiguration));
}
_servicesConfiguration = containerConfiguration;
}
public Action<TContainer> ConfigureContainer(Action<TContainer> next) =>
containerBuilder =>
{
next(containerBuilder);
_servicesConfiguration(containerBuilder);
};
}
}
}

View File

@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Hosting.Tests.Internal
public bool FancyMethodCalled { get; private set; }
public IServiceCollection Services => _services;
public string Environment { get; set; }
public object GetService(Type serviceType)

View File

@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Hosting.Fakes;
using Microsoft.AspNetCore.Hosting.Internal;
using Microsoft.AspNetCore.Hosting.Tests.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options;
using Xunit;
@ -17,6 +18,261 @@ namespace Microsoft.AspNetCore.Hosting.Tests
{
public class StartupManagerTests
{
[Fact]
public void ConventionalStartupClass_StartupServiceFilters_WrapsConfigureServicesMethod()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(1, overrideAfterService: true));
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(2, overrideAfterService: true));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(VoidReturningStartupServicesFiltersStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var applicationServices = startup.ConfigureServicesDelegate(serviceCollection);
var before = applicationServices.GetRequiredService<ServiceBefore>();
var after = applicationServices.GetRequiredService<ServiceAfter>();
Assert.Equal("StartupServicesFilter Before 1", before.Message);
Assert.Equal("StartupServicesFilter After 1", after.Message);
}
[Fact]
public void ConventionalStartupClass_StartupServiceFilters_MultipleStartupServiceFiltersRun()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(1, overrideAfterService: false));
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(2, overrideAfterService: true));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(VoidReturningStartupServicesFiltersStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var applicationServices = startup.ConfigureServicesDelegate(serviceCollection);
var before = applicationServices.GetRequiredService<ServiceBefore>();
var after = applicationServices.GetRequiredService<ServiceAfter>();
Assert.Equal("StartupServicesFilter Before 1", before.Message);
Assert.Equal("StartupServicesFilter After 2", after.Message);
}
[Fact]
public void ConventionalStartupClass_StartupServicesFilters_ThrowsIfStartupBuildsTheContainerAsync()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(1, overrideAfterService: false));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(IServiceProviderReturningStartupServicesFiltersStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var expectedMessage = $"A ConfigureServices method that returns an {nameof(IServiceProvider)} is " +
$"not compatible with the use of one or more {nameof(IStartupConfigureServicesFilter)}. " +
$"Use a void returning ConfigureServices method instead or a ConfigureContainer method.";
var exception = Assert.Throws<InvalidOperationException>(() => startup.ConfigureServicesDelegate(serviceCollection));
Assert.Equal(expectedMessage, exception.Message);
}
[Fact]
public void ConventionalStartupClass_ConfigureContainerFilters_WrapInRegistrationOrder()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
serviceCollection.AddSingleton<IStartupConfigureContainerFilter<MyContainer>>(new TestConfigureContainerFilter(1, overrideAfterService: true));
serviceCollection.AddSingleton<IStartupConfigureContainerFilter<MyContainer>>(new TestConfigureContainerFilter(2, overrideAfterService: true));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(ConfigureContainerStartupServicesFiltersStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var applicationServices = startup.ConfigureServicesDelegate(serviceCollection);
var before = applicationServices.GetRequiredService<ServiceBefore>();
var after = applicationServices.GetRequiredService<ServiceAfter>();
Assert.Equal("ConfigureContainerFilter Before 1", before.Message);
Assert.Equal("ConfigureContainerFilter After 1", after.Message);
}
[Fact]
public void ConventionalStartupClass_ConfigureContainerFilters_RunsAllFilters()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
serviceCollection.AddSingleton<IStartupConfigureContainerFilter<MyContainer>>(new TestConfigureContainerFilter(1, overrideAfterService: false));
serviceCollection.AddSingleton<IStartupConfigureContainerFilter<MyContainer>>(new TestConfigureContainerFilter(2, overrideAfterService: true));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(ConfigureContainerStartupServicesFiltersStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var applicationServices = startup.ConfigureServicesDelegate(serviceCollection);
var before = applicationServices.GetRequiredService<ServiceBefore>();
var after = applicationServices.GetRequiredService<ServiceAfter>();
Assert.Equal("ConfigureContainerFilter Before 1", before.Message);
Assert.Equal("ConfigureContainerFilter After 2", after.Message);
}
[Fact]
public void ConventionalStartupClass_ConfigureContainerFilters_RunAfterConfigureServicesFilters()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
serviceCollection.AddSingleton<IStartupConfigureServicesFilter>(new TestStartupServicesFilter(1, overrideAfterService: false));
serviceCollection.AddSingleton<IStartupConfigureContainerFilter<MyContainer>>(new TestConfigureContainerFilter(2, overrideAfterService: true));
var services = serviceCollection.BuildServiceProvider();
var type = typeof(ConfigureServicesAndConfigureContainerStartup);
var startup = StartupLoader.LoadMethods(services, type, "");
var applicationServices = startup.ConfigureServicesDelegate(serviceCollection);
var before = applicationServices.GetRequiredService<ServiceBefore>();
var after = applicationServices.GetRequiredService<ServiceAfter>();
Assert.Equal("StartupServicesFilter Before 1", before.Message);
Assert.Equal("ConfigureContainerFilter After 2", after.Message);
}
public class ConfigureContainerStartupServicesFiltersStartup
{
public void ConfigureContainer(MyContainer services)
{
services.Services.TryAddSingleton(new ServiceBefore { Message = "Configure container" });
services.Services.TryAddSingleton(new ServiceAfter { Message = "Configure container" });
}
public void Configure(IApplicationBuilder builder)
{
}
}
public class ConfigureServicesAndConfigureContainerStartup
{
public void ConfigureServices(IServiceCollection services)
{
services.TryAddSingleton(new ServiceBefore { Message = "Configure services" });
services.TryAddSingleton(new ServiceAfter { Message = "Configure services" });
}
public void ConfigureContainer(MyContainer services)
{
services.Services.TryAddSingleton(new ServiceBefore { Message = "Configure container" });
services.Services.TryAddSingleton(new ServiceAfter { Message = "Configure container" });
}
public void Configure(IApplicationBuilder builder)
{
}
}
public class TestConfigureContainerFilter : IStartupConfigureContainerFilter<MyContainer>
{
public TestConfigureContainerFilter(object additionalData, bool overrideAfterService)
{
AdditionalData = additionalData;
OverrideAfterService = overrideAfterService;
}
public object AdditionalData { get; }
public bool OverrideAfterService { get; }
public Action<MyContainer> ConfigureContainer(Action<MyContainer> next)
{
return services =>
{
services.Services.TryAddSingleton(new ServiceBefore { Message = $"ConfigureContainerFilter Before {AdditionalData}" });
next(services);
// Ensures we can always override.
if (OverrideAfterService)
{
services.Services.AddSingleton(new ServiceAfter { Message = $"ConfigureContainerFilter After {AdditionalData}" });
}
else
{
services.Services.TryAddSingleton(new ServiceAfter { Message = $"ConfigureContainerFilter After {AdditionalData}" });
}
};
}
}
public class IServiceProviderReturningStartupServicesFiltersStartup
{
public IServiceProvider ConfigureServices(IServiceCollection services)
{
services.TryAddSingleton(new ServiceBefore { Message = "Configure services" });
services.TryAddSingleton(new ServiceAfter { Message = "Configure services" });
return services.BuildServiceProvider();
}
public void Configure(IApplicationBuilder builder)
{
}
}
public class TestStartupServicesFilter : IStartupConfigureServicesFilter
{
public TestStartupServicesFilter(object additionalData, bool overrideAfterService)
{
AdditionalData = additionalData;
OverrideAfterService = overrideAfterService;
}
public object AdditionalData { get; }
public bool OverrideAfterService { get; }
public Action<IServiceCollection> ConfigureServices(Action<IServiceCollection> next)
{
return services =>
{
services.TryAddSingleton(new ServiceBefore { Message = $"StartupServicesFilter Before {AdditionalData}" });
next(services);
// Ensures we can always override.
if (OverrideAfterService)
{
services.AddSingleton(new ServiceAfter { Message = $"StartupServicesFilter After {AdditionalData}" });
}
else
{
services.TryAddSingleton(new ServiceAfter { Message = $"StartupServicesFilter After {AdditionalData}" });
}
};
}
}
public class VoidReturningStartupServicesFiltersStartup
{
public void ConfigureServices(IServiceCollection services)
{
services.TryAddSingleton(new ServiceBefore { Message = "Configure services" });
services.TryAddSingleton(new ServiceAfter { Message = "Configure services" });
}
public void Configure(IApplicationBuilder builder)
{
}
}
public class ServiceBefore
{
public string Message { get; set; }
}
public class ServiceAfter
{
public string Message { get; set; }
}
[Fact]
public void StartupClassMayHaveHostingServicesInjected()
{

View File

@ -30,7 +30,6 @@ namespace Microsoft.AspNetCore.TestHost
new TestServer(new WebHostBuilder().Configure(app => { }));
}
[Fact]
public void DoesNotCaptureStartupErrorsByDefault()
{
@ -43,6 +42,46 @@ namespace Microsoft.AspNetCore.TestHost
Assert.Throws<InvalidOperationException>(() => new TestServer(builder));
}
[Fact]
public async Task ServicesCanBeOverridenForTestingAsync()
{
var builder = new WebHostBuilder()
.ConfigureServices(s => s.AddSingleton<IServiceProviderFactory<ThirdPartyContainer>,ThirdPartyContainerServiceProviderFactory>())
.UseStartup<ThirdPartyContainerStartup>()
.ConfigureTestServices(services => services.AddSingleton(new SimpleService { Message = "OverridesConfigureServices" }))
.ConfigureTestContainer<ThirdPartyContainer>(container => container.Services.AddSingleton(new TestService { Message = "OverridesConfigureContainer" }));
var host = new TestServer(builder);
var response = await host.CreateClient().GetStringAsync("/");
Assert.Equal("OverridesConfigureServices, OverridesConfigureContainer", response);
}
public class ThirdPartyContainerStartup
{
public void ConfigureServices(IServiceCollection services) =>
services.AddSingleton(new SimpleService { Message = "ConfigureServices" });
public void ConfigureContainer(ThirdPartyContainer container) =>
container.Services.AddSingleton(new TestService { Message = "ConfigureContainer" });
public void Configure(IApplicationBuilder app) =>
app.Use((ctx, next) => ctx.Response.WriteAsync(
$"{ctx.RequestServices.GetRequiredService<SimpleService>().Message}, {ctx.RequestServices.GetRequiredService<TestService>().Message}"));
}
public class ThirdPartyContainer
{
public IServiceCollection Services { get; set; }
}
public class ThirdPartyContainerServiceProviderFactory : IServiceProviderFactory<ThirdPartyContainer>
{
public ThirdPartyContainer CreateBuilder(IServiceCollection services) => new ThirdPartyContainer { Services = services };
public IServiceProvider CreateServiceProvider(ThirdPartyContainer containerBuilder) => containerBuilder.Services.BuildServiceProvider();
}
[Fact]
public void CaptureStartupErrorsSettingPreserved()
@ -153,7 +192,7 @@ namespace Microsoft.AspNetCore.TestHost
Assert.Throws<ArgumentNullException>(() => new TestServer(builder, null));
}
public class TestService { }
public class TestService { public string Message { get; set; } }
public class TestRequestServiceMiddleware
{