Added support for Startup.ConfigureContainer

- Startup.ConfigureContainer allows users to configure a 3rd party DI
container in a first class way in the Startup class. 3rd party containers
plug in via IServiceProviderFactory<TContainerBuilder> configured in
IWebHostBuilder.ConfigureServices.
- Added tests
This commit is contained in:
David Fowler 2016-08-07 02:46:51 -07:00
parent e3b5686d96
commit 0a7cf6b5a0
10 changed files with 417 additions and 37 deletions

View File

@ -0,0 +1,44 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Reflection;
namespace Microsoft.AspNetCore.Hosting.Internal
{
public class ConfigureContainerBuilder
{
public ConfigureContainerBuilder(MethodInfo configureContainerMethod)
{
MethodInfo = configureContainerMethod;
}
public MethodInfo MethodInfo { get; }
public Action<object> Build(object instance) => container => Invoke(instance, container);
public Type GetContainerType()
{
var parameters = MethodInfo.GetParameters();
if (parameters.Length != 1)
{
// REVIEW: This might be a breaking change
throw new InvalidOperationException($"The {MethodInfo.Name} method must take only one parameter.");
}
return parameters[0].ParameterType;
}
private void Invoke(object instance, object container)
{
if (MethodInfo == null)
{
return;
}
var arguments = new object[1] { container };
MethodInfo.Invoke(instance, arguments);
}
}
}

View File

@ -12,19 +12,6 @@ namespace Microsoft.AspNetCore.Hosting.Internal
{
public ConfigureServicesBuilder(MethodInfo configureServices)
{
if (configureServices == null)
{
throw new ArgumentNullException(nameof(configureServices));
}
// Only support IServiceCollection parameters
var parameters = configureServices.GetParameters();
if (parameters.Length > 1 ||
parameters.Any(p => p.ParameterType != typeof(IServiceCollection)))
{
throw new InvalidOperationException("The ConfigureServices method must either be parameterless or take only one parameter of type IServiceCollection.");
}
MethodInfo = configureServices;
}
@ -32,22 +19,29 @@ namespace Microsoft.AspNetCore.Hosting.Internal
public Func<IServiceCollection, IServiceProvider> Build(object instance) => services => Invoke(instance, services);
private IServiceProvider Invoke(object instance, IServiceCollection exportServices)
private IServiceProvider Invoke(object instance, IServiceCollection services)
{
if (exportServices == null)
if (MethodInfo == null)
{
throw new ArgumentNullException(nameof(exportServices));
return null;
}
var parameters = new object[MethodInfo.GetParameters().Length];
// Only support IServiceCollection parameters
var parameters = MethodInfo.GetParameters();
if (parameters.Length > 1 ||
parameters.Any(p => p.ParameterType != typeof(IServiceCollection)))
{
throw new InvalidOperationException("The ConfigureServices method must either be parameterless or take only one parameter of type IServiceCollection.");
}
var arguments = new object[MethodInfo.GetParameters().Length];
// Ctor ensures we have at most one IServiceCollection parameter
if (parameters.Length > 0)
{
parameters[0] = exportServices;
arguments[0] = services;
}
return MethodInfo.Invoke(instance, parameters) as IServiceProvider ?? exportServices.BuildServiceProvider();
return MethodInfo.Invoke(instance, arguments) as IServiceProvider;
}
}
}

View File

@ -11,18 +11,56 @@ namespace Microsoft.AspNetCore.Hosting.Internal
{
public class StartupLoader
{
public static StartupMethods LoadMethods(IServiceProvider services, Type startupType, string environmentName)
public static StartupMethods LoadMethods(IServiceProvider hostingServiceProvider, Type startupType, string environmentName)
{
var configureMethod = FindConfigureDelegate(startupType, environmentName);
var servicesMethod = FindConfigureServicesDelegate(startupType, environmentName);
var configureContainerMethod = FindConfigureContainerDelegate(startupType, environmentName);
object instance = null;
if (!configureMethod.MethodInfo.IsStatic || (servicesMethod != null && !servicesMethod.MethodInfo.IsStatic))
{
instance = ActivatorUtilities.GetServiceOrCreateInstance(services, startupType);
instance = ActivatorUtilities.GetServiceOrCreateInstance(hostingServiceProvider, startupType);
}
return new StartupMethods(configureMethod.Build(instance), servicesMethod?.Build(instance));
var configureServicesCallback = servicesMethod.Build(instance);
var configureContainerCallback = configureContainerMethod.Build(instance);
Func<IServiceCollection, IServiceProvider> configureServices = services =>
{
// Call ConfigureServices, if that returned an IServiceProvider, we're done
IServiceProvider applicationServiceProvider = configureServicesCallback.Invoke(services);
if (applicationServiceProvider != null)
{
return applicationServiceProvider;
}
// If there's a ConfigureContainer method
if (configureContainerMethod.MethodInfo != null)
{
// We have a ConfigureContainer method, get the IServiceProviderFactory<TContainerBuilder>
var serviceProviderFactoryType = typeof(IServiceProviderFactory<>).MakeGenericType(configureContainerMethod.GetContainerType());
var serviceProviderFactory = hostingServiceProvider.GetRequiredService(serviceProviderFactoryType);
// var builder = serviceProviderFactory.CreateBuilder(services);
var builder = serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateBuilder)).Invoke(serviceProviderFactory, new object[] { services });
configureContainerCallback.Invoke(builder);
// applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(builder);
applicationServiceProvider = (IServiceProvider)serviceProviderFactoryType.GetMethod(nameof(DefaultServiceProviderFactory.CreateServiceProvider)).Invoke(serviceProviderFactory, new object[] { builder });
}
else
{
// Get the default factory
var serviceProviderFactory = hostingServiceProvider.GetRequiredService<IServiceProviderFactory<IServiceCollection>>();
// Don't bother calling CreateBuilder since it just returns the default service collection
applicationServiceProvider = serviceProviderFactory.CreateServiceProvider(services);
}
return applicationServiceProvider ?? services.BuildServiceProvider();
};
return new StartupMethods(configureMethod.Build(instance), configureServices);
}
public static Type FindStartupType(string startupAssemblyName, string environmentName)
@ -30,8 +68,8 @@ namespace Microsoft.AspNetCore.Hosting.Internal
if (string.IsNullOrEmpty(startupAssemblyName))
{
throw new ArgumentException(
string.Format("A startup method, startup type or startup assembly is required. If specifying an assembly, '{0}' cannot be null or empty.",
nameof(startupAssemblyName)),
string.Format("A startup method, startup type or startup assembly is required. If specifying an assembly, '{0}' cannot be null or empty.",
nameof(startupAssemblyName)),
nameof(startupAssemblyName));
}
@ -83,11 +121,17 @@ namespace Microsoft.AspNetCore.Hosting.Internal
return new ConfigureBuilder(configureMethod);
}
private static ConfigureContainerBuilder FindConfigureContainerDelegate(Type startupType, string environmentName)
{
var configureMethod = FindMethod(startupType, "Configure{0}Container", environmentName, typeof(void), required: false);
return new ConfigureContainerBuilder(configureMethod);
}
private static ConfigureServicesBuilder FindConfigureServicesDelegate(Type startupType, string environmentName)
{
var servicesMethod = FindMethod(startupType, "Configure{0}Services", environmentName, typeof(IServiceProvider), required: false)
?? FindMethod(startupType, "Configure{0}Services", environmentName, typeof(void), required: false);
return servicesMethod == null ? null : new ConfigureServicesBuilder(servicesMethod);
return new ConfigureServicesBuilder(servicesMethod);
}
private static MethodInfo FindMethod(Type startupType, string methodName, string environmentName, Type returnType = null, bool required = true)

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.DependencyInjection;
@ -9,17 +10,13 @@ namespace Microsoft.AspNetCore.Hosting.Internal
{
public class StartupMethods
{
internal static Func<IServiceCollection, IServiceProvider> DefaultBuildServiceProvider = s => s.BuildServiceProvider();
public StartupMethods(Action<IApplicationBuilder> configure)
: this(configure, configureServices: null)
{
}
public StartupMethods(Action<IApplicationBuilder> configure, Func<IServiceCollection, IServiceProvider> configureServices)
{
Debug.Assert(configure != null);
Debug.Assert(configureServices != null);
ConfigureDelegate = configure;
ConfigureServicesDelegate = configureServices ?? DefaultBuildServiceProvider;
ConfigureServicesDelegate = configureServices;
}
public Func<IServiceCollection, IServiceProvider> ConfigureServicesDelegate { get; }

View File

@ -16,4 +16,31 @@ namespace Microsoft.AspNetCore.Hosting
return services.BuildServiceProvider();
}
}
public abstract class StartupBase<TContainerBuilder> : IStartup
{
private readonly IServiceProviderFactory<TContainerBuilder> _factory;
public StartupBase(IServiceProviderFactory<TContainerBuilder> factory)
{
_factory = factory;
}
public abstract void Configure(IApplicationBuilder app);
public virtual void ConfigureServices(IServiceCollection services)
{
}
IServiceProvider IStartup.ConfigureServices(IServiceCollection services)
{
ConfigureServices(services);
var builder = _factory.CreateBuilder(services);
ConfigureContainer(builder);
return _factory.CreateServiceProvider(builder);
}
public virtual void ConfigureContainer(TContainerBuilder containerBuilder) { }
}
}

View File

@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Hosting.Internal;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Extensions.PlatformAbstractions;
@ -200,6 +201,7 @@ namespace Microsoft.AspNetCore.Hosting
// Conjure up a RequestServices
services.AddTransient<IStartupFilter, AutoRequestServicesStartupFilter>();
services.AddTransient<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
// Ensure object pooling is available everywhere.
services.AddSingleton<ObjectPoolProvider, DefaultObjectPoolProvider>();

View File

@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Hosting.Tests.Internal
{
public class MyBadContainerFactory : IServiceProviderFactory<MyContainer>
{
public MyContainer CreateBuilder(IServiceCollection services)
{
var container = new MyContainer();
container.Populate(services);
return container;
}
public IServiceProvider CreateServiceProvider(MyContainer containerBuilder)
{
containerBuilder.Build();
return null;
}
}
}

View File

@ -0,0 +1,38 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Hosting.Tests.Internal
{
public class MyContainer : IServiceProvider
{
private IServiceProvider _inner;
private IServiceCollection _services;
public bool FancyMethodCalled { get; private set; }
public string Environment { get; set; }
public object GetService(Type serviceType)
{
return _inner.GetService(serviceType);
}
public void Populate(IServiceCollection services)
{
_services = services;
}
public void Build()
{
_inner = _services.BuildServiceProvider();
}
public void MyFancyContainerMethod()
{
FancyMethodCalled = true;
}
}
}

View File

@ -0,0 +1,24 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Hosting.Tests.Internal
{
public class MyContainerFactory : IServiceProviderFactory<MyContainer>
{
public MyContainer CreateBuilder(IServiceCollection services)
{
var container = new MyContainer();
container.Populate(services);
return container;
}
public IServiceProvider CreateServiceProvider(MyContainer containerBuilder)
{
containerBuilder.Build();
return containerBuilder;
}
}
}

View File

@ -8,6 +8,7 @@ using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Builder.Internal;
using Microsoft.AspNetCore.Hosting.Fakes;
using Microsoft.AspNetCore.Hosting.Internal;
using Microsoft.AspNetCore.Hosting.Tests.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Xunit;
@ -22,6 +23,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupClassMayHaveHostingServicesInjected()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IFakeStartupCallback>(this);
var services = serviceCollection.BuildServiceProvider();
@ -46,8 +48,9 @@ namespace Microsoft.AspNetCore.Hosting.Tests
[InlineData("BaseClass")]
public void StartupClassAddsConfigureServicesToApplicationServices(string environment)
{
var services = new ServiceCollection().BuildServiceProvider();
var services = new ServiceCollection()
.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>()
.BuildServiceProvider();
var type = StartupLoader.FindStartupType("Microsoft.AspNetCore.Hosting.Tests", environment);
var startup = StartupLoader.LoadMethods(services, type, environment);
@ -65,6 +68,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupWithNoConfigureThrows()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IFakeStartupCallback>(this);
var services = serviceCollection.BuildServiceProvider();
var type = StartupLoader.FindStartupType("Microsoft.AspNetCore.Hosting.Tests", "Boom");
@ -76,6 +80,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupWithTwoConfiguresThrows()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IFakeStartupCallback>(this);
var services = serviceCollection.BuildServiceProvider();
@ -84,11 +89,12 @@ namespace Microsoft.AspNetCore.Hosting.Tests
var ex = Assert.Throws<InvalidOperationException>(() => StartupLoader.LoadMethods(services, type, "TwoConfigures"));
Assert.Equal("Having multiple overloads of method 'Configure' is not supported.", ex.Message);
}
[Fact]
public void StartupWithPrivateConfiguresThrows()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IFakeStartupCallback>(this);
var services = serviceCollection.BuildServiceProvider();
@ -103,6 +109,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupWithTwoConfigureServicesThrows()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
serviceCollection.AddSingleton<IFakeStartupCallback>(this);
var services = serviceCollection.BuildServiceProvider();
@ -116,6 +123,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupClassCanHandleConfigureServicesThatReturnsNull()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
var services = serviceCollection.BuildServiceProvider();
var type = StartupLoader.FindStartupType("Microsoft.AspNetCore.Hosting.Tests", "WithNullConfigureServices");
@ -132,6 +140,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupClassWithConfigureServicesShouldMakeServiceAvailableInConfigure()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
var services = serviceCollection.BuildServiceProvider();
var type = StartupLoader.FindStartupType("Microsoft.AspNetCore.Hosting.Tests", "WithConfigureServices");
@ -149,6 +158,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupLoaderCanLoadByType()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
var services = serviceCollection.BuildServiceProvider();
var hostingEnv = new HostingEnvironment();
@ -166,6 +176,7 @@ namespace Microsoft.AspNetCore.Hosting.Tests
public void StartupLoaderCanLoadByTypeWithEnvironment()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<IServiceCollection>, DefaultServiceProviderFactory>();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(TestStartup), "No");
@ -177,6 +188,181 @@ namespace Microsoft.AspNetCore.Hosting.Tests
Assert.IsAssignableFrom(typeof(InvalidOperationException), ex.InnerException);
}
[Fact]
public void CustomProviderFactoryCallsConfigureContainer()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(MyContainerStartup), "Development");
var app = new ApplicationBuilder(services);
app.ApplicationServices = startup.ConfigureServicesDelegate(serviceCollection);
Assert.IsType(typeof(MyContainer), app.ApplicationServices);
Assert.True(((MyContainer)app.ApplicationServices).FancyMethodCalled);
}
[Fact]
public void CustomServiceProviderFactoryStartupBaseClassCallsConfigureContainer()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(MyContainerStartupBaseClass), "Development");
var app = new ApplicationBuilder(services);
app.ApplicationServices = startup.ConfigureServicesDelegate(serviceCollection);
Assert.IsType(typeof(MyContainer), app.ApplicationServices);
Assert.True(((MyContainer)app.ApplicationServices).FancyMethodCalled);
}
[Fact]
public void CustomServiceProviderFactoryEnvironmentBasedConfigureContainer()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(MyContainerStartupEnvironmentBased), "Production");
var app = new ApplicationBuilder(services);
app.ApplicationServices = startup.ConfigureServicesDelegate(serviceCollection);
Assert.IsType(typeof(MyContainer), app.ApplicationServices);
Assert.Equal(((MyContainer)app.ApplicationServices).Environment, "Production");
}
[Fact]
public void CustomServiceProviderFactoryThrowsIfNotRegisteredWithConfigureContainerMethod()
{
var serviceCollection = new ServiceCollection();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(MyContainerStartup), "Development");
Assert.Throws<InvalidOperationException>(() => startup.ConfigureServicesDelegate(serviceCollection));
}
[Fact]
public void CustomServiceProviderFactoryThrowsIfNotRegisteredWithConfigureContainerMethodStartupBase()
{
var serviceCollection = new ServiceCollection();
var services = serviceCollection.BuildServiceProvider();
Assert.Throws<InvalidOperationException>(() => StartupLoader.LoadMethods(services, typeof(MyContainerStartupBaseClass), "Development"));
}
[Fact]
public void CustomServiceProviderFactoryFailsWithOverloadsInStartup()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyContainerFactory>();
var services = serviceCollection.BuildServiceProvider();
Assert.Throws<InvalidOperationException>(() => StartupLoader.LoadMethods(services, typeof(MyContainerStartupWithOverloads), "Development"));
}
[Fact]
public void BadServiceProviderFactoryFailsThatReturnsNullServiceProviderOverriddenByDefault()
{
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<IServiceProviderFactory<MyContainer>, MyBadContainerFactory>();
var services = serviceCollection.BuildServiceProvider();
var startup = StartupLoader.LoadMethods(services, typeof(MyContainerStartup), "Development");
var app = new ApplicationBuilder(services);
app.ApplicationServices = startup.ConfigureServicesDelegate(serviceCollection);
Assert.NotNull(app.ApplicationServices);
Assert.IsNotType(typeof(MyContainer), app.ApplicationServices);
}
public class MyContainerStartupWithOverloads
{
public void ConfigureServices(IServiceCollection services)
{
}
public void ConfigureContainer(MyContainer container)
{
container.MyFancyContainerMethod();
}
public void ConfigureContainer(IServiceCollection services)
{
}
public void Configure(IApplicationBuilder app)
{
}
}
public class MyContainerStartupEnvironmentBased
{
public void ConfigureServices(IServiceCollection services)
{
}
public void ConfigureDevelopmentContainer(MyContainer container)
{
container.Environment = "Development";
}
public void ConfigureProductionContainer(MyContainer container)
{
container.Environment = "Production";
}
public void Configure(IApplicationBuilder app)
{
}
}
public class MyContainerStartup
{
public void ConfigureServices(IServiceCollection services)
{
}
public void ConfigureContainer(MyContainer container)
{
container.MyFancyContainerMethod();
}
public void Configure(IApplicationBuilder app)
{
}
}
public class MyContainerStartupBaseClass : StartupBase<MyContainer>
{
public MyContainerStartupBaseClass(IServiceProviderFactory<MyContainer> factory) : base(factory)
{
}
public override void Configure(IApplicationBuilder app)
{
}
public override void ConfigureContainer(MyContainer containerBuilder)
{
containerBuilder.MyFancyContainerMethod();
}
}
public class SimpleService
{
public SimpleService()