diff --git a/src/Microsoft.AspNet.Hosting/ConfigureHostingEnvironment.cs b/src/Microsoft.AspNet.Hosting/ConfigureHostingEnvironment.cs index a714c2cfef..892eaab372 100644 --- a/src/Microsoft.AspNet.Hosting/ConfigureHostingEnvironment.cs +++ b/src/Microsoft.AspNet.Hosting/ConfigureHostingEnvironment.cs @@ -17,7 +17,7 @@ namespace Microsoft.AspNet.Hosting public void Configure(IHostingEnvironment hostingEnv) { - hostingEnv.EnvironmentName = _config.Get(EnvironmentKey) ?? hostingEnv.EnvironmentName; + hostingEnv.EnvironmentName = _config?.Get(EnvironmentKey) ?? hostingEnv.EnvironmentName; } } } \ No newline at end of file diff --git a/src/Microsoft.AspNet.Hosting/HostingServices.cs b/src/Microsoft.AspNet.Hosting/HostingServices.cs index 1a5ac71846..108e7ae881 100644 --- a/src/Microsoft.AspNet.Hosting/HostingServices.cs +++ b/src/Microsoft.AspNet.Hosting/HostingServices.cs @@ -14,44 +14,66 @@ namespace Microsoft.AspNet.Hosting { public static class HostingServices { - private static IServiceCollection Import(IServiceProvider fallbackProvider) + private static IServiceCollection Import(IServiceProvider fallbackProvider, Action configureHostServices) { var services = new ServiceCollection(); + if (configureHostServices != null) + { + configureHostServices(services); + } var manifest = fallbackProvider.GetRequiredService(); foreach (var service in manifest.Services) { services.AddTransient(service, sp => fallbackProvider.GetService(service)); } + services.AddSingleton(sp => new HostingManifest(services)); return services; } - public static IServiceCollection Create(IConfiguration configuration = null) + public static IServiceCollection Create() { - return Create(CallContextServiceLocator.Locator.ServiceProvider, configuration); + return Create(CallContextServiceLocator.Locator.ServiceProvider, configureHostServices: null, configuration: null); } - public static IServiceCollection Create(IServiceProvider fallbackServices, IConfiguration configuration = null) + public static IServiceCollection Create(IServiceProvider fallbackServices) { - configuration = configuration ?? new Configuration(); - var services = Import(fallbackServices); + return Create(fallbackServices, configureHostServices: null, configuration: null); + } + + public static IServiceCollection Create(IServiceProvider fallbackServices, Action configureHostServices) + { + return Create(fallbackServices, configureHostServices, configuration: null); + } + + public static IServiceCollection Create(Action configureHostServices, IConfiguration configuration) + { + return Create(CallContextServiceLocator.Locator.ServiceProvider, configureHostServices, configuration); + } + + public static IServiceCollection Create(IServiceProvider fallbackServices, IConfiguration configuration) + { + return Create(CallContextServiceLocator.Locator.ServiceProvider, configureHostServices: null, configuration: configuration); + } + + public static IServiceCollection Create(IServiceProvider fallbackServices, Action configureHostServices, IConfiguration configuration) + { + var services = Import(fallbackServices, configureHostServices); services.AddHosting(configuration); - services.AddSingleton(sp => new HostingManifest(fallbackServices)); return services; } // Manifest exposes the fallback manifest in addition to ITypeActivator, IHostingEnvironment, and ILoggerFactory private class HostingManifest : IServiceManifest { - public HostingManifest(IServiceProvider fallback) + public HostingManifest(IServiceCollection hostServices) { - var manifest = fallback.GetRequiredService(); Services = new Type[] { typeof(ITypeActivator), typeof(IHostingEnvironment), typeof(ILoggerFactory), typeof(IHttpContextAccessor), typeof(IApplicationLifetime) - }.Concat(manifest.Services).Distinct(); + }.Concat(hostServices.Select(s => s.ServiceType)).Distinct(); } public IEnumerable Services { get; private set; } diff --git a/src/Microsoft.AspNet.Hosting/HostingServicesCollectionExtensions.cs b/src/Microsoft.AspNet.Hosting/HostingServicesCollectionExtensions.cs index 4ee17b7f65..8510ab43a5 100644 --- a/src/Microsoft.AspNet.Hosting/HostingServicesCollectionExtensions.cs +++ b/src/Microsoft.AspNet.Hosting/HostingServicesCollectionExtensions.cs @@ -11,7 +11,12 @@ namespace Microsoft.Framework.DependencyInjection { public static class HostingServicesExtensions { - public static IServiceCollection AddHosting(this IServiceCollection services, IConfiguration configuration = null) + public static IServiceCollection AddHosting(this IServiceCollection services) + { + return services.AddHosting(configuration: null); + } + + public static IServiceCollection AddHosting(this IServiceCollection services, IConfiguration configuration) { services.TryAdd(ServiceDescriptor.Transient()); services.TryAdd(ServiceDescriptor.Transient()); diff --git a/src/Microsoft.AspNet.TestHost/TestServer.cs b/src/Microsoft.AspNet.TestHost/TestServer.cs index b7350e7c13..2a6517a4a2 100644 --- a/src/Microsoft.AspNet.TestHost/TestServer.cs +++ b/src/Microsoft.AspNet.TestHost/TestServer.cs @@ -47,12 +47,22 @@ namespace Microsoft.AspNet.TestHost public static TestServer Create(Action app) { - return Create(CallContextServiceLocator.Locator.ServiceProvider, app); + return Create(CallContextServiceLocator.Locator.ServiceProvider, app, configureHostServices: null); + } + + public static TestServer Create(Action app, Action configureHostServices) + { + return Create(CallContextServiceLocator.Locator.ServiceProvider, app, configureHostServices); } public static TestServer Create(IServiceProvider serviceProvider, Action app) { - var appServices = HostingServices.Create(serviceProvider).BuildServiceProvider(); + return Create(serviceProvider, app, configureHostServices: null); + } + + public static TestServer Create(IServiceProvider serviceProvider, Action app, Action configureHostServices) + { + var appServices = HostingServices.Create(serviceProvider, configureHostServices).BuildServiceProvider(); var config = new Configuration(); return new TestServer(config, appServices, app); } diff --git a/test/Microsoft.AspNet.Hosting.Tests/HostingServicesFacts.cs b/test/Microsoft.AspNet.Hosting.Tests/HostingServicesFacts.cs index 1ccae6b90c..ebb9b1a892 100644 --- a/test/Microsoft.AspNet.Hosting.Tests/HostingServicesFacts.cs +++ b/test/Microsoft.AspNet.Hosting.Tests/HostingServicesFacts.cs @@ -52,6 +52,73 @@ namespace Microsoft.AspNet.Hosting.Tests Assert.Null(provider.GetService()); // Make sure we don't leak non manifest services } + [Fact] + public void CreateCanAddAdditionalServices() + { + // Arrange + var fallbackServices = new ServiceCollection(); + fallbackServices.AddTransient(); + fallbackServices.AddTransient(); // Don't register in manifest + + fallbackServices.AddInstance(new ServiceManifest( + new Type[] { + typeof(IFakeService), + })); + + var instance = new FakeService(); + var factoryInstance = new FakeFactoryService(instance); + + var services = HostingServices.Create(fallbackServices.BuildServiceProvider(), + additionalHostServices => + { + additionalHostServices.AddSingleton(); + additionalHostServices.AddInstance(instance); + additionalHostServices.AddSingleton(serviceProvider => factoryInstance); + }); + + // Act + var provider = services.BuildServiceProvider(); + var singleton = provider.GetRequiredService(); + var transient = provider.GetRequiredService(); + var factory = provider.GetRequiredService(); + var manifest = provider.GetRequiredService(); + + // Assert + Assert.Same(singleton, provider.GetRequiredService()); + Assert.NotSame(transient, provider.GetRequiredService()); + Assert.Same(instance, provider.GetRequiredService()); + Assert.Same(factoryInstance, factory); + Assert.Same(factory.FakeService, instance); + Assert.Null(provider.GetService()); + Assert.Null(provider.GetService()); // Make sure we don't leak non manifest services + Assert.Contains(typeof(IFakeSingletonService), manifest.Services); + Assert.Contains(typeof(IFakeServiceInstance), manifest.Services); + Assert.Contains(typeof(IFactoryService), manifest.Services); + } + + [Fact] + public void CreateAdditionalServicesDoNotOverrideFallback() + { + // Arrange + var fallbackServices = new ServiceCollection(); + fallbackServices.AddTransient(); + + fallbackServices.AddInstance(new ServiceManifest( + new Type[] { + typeof(IFakeService), + })); + + var services = HostingServices.Create(fallbackServices.BuildServiceProvider(), + additionalHostServices => additionalHostServices.AddSingleton()); + + // Act + var provider = services.BuildServiceProvider(); + var stillTransient = provider.GetRequiredService(); + + // Assert + Assert.NotSame(stillTransient, provider.GetRequiredService()); + } + [Fact] public void CanHideImportedServices() { diff --git a/test/Microsoft.AspNet.Hosting.Tests/Microsoft.AspNet.Hosting.Tests.kproj b/test/Microsoft.AspNet.Hosting.Tests/Microsoft.AspNet.Hosting.Tests.kproj index 766519fb73..59c5545a1f 100644 --- a/test/Microsoft.AspNet.Hosting.Tests/Microsoft.AspNet.Hosting.Tests.kproj +++ b/test/Microsoft.AspNet.Hosting.Tests/Microsoft.AspNet.Hosting.Tests.kproj @@ -12,7 +12,7 @@ 2.0 - 23533 + 18007 \ No newline at end of file diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs index c9ef3bdf96..899d3be45a 100644 --- a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs +++ b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs @@ -10,6 +10,7 @@ using Microsoft.AspNet.Builder; using Microsoft.AspNet.Hosting; using Microsoft.AspNet.Http; using Microsoft.Framework.DependencyInjection; +using Microsoft.Framework.Logging; using Xunit; namespace Microsoft.AspNet.TestHost @@ -36,10 +37,25 @@ namespace Microsoft.AspNet.TestHost Assert.Throws(() => TestServer.Create(services, new Startup().Configuration)); } + [Fact] + public async Task CanAccessLogger() + { + TestServer server = TestServer.Create(app => + { + app.Run(context => + { + var logger = app.ApplicationServices.GetRequiredService>(); + return context.Response.WriteAsync("FoundLogger:" + (logger != null)); + }); + }); + + string result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("FoundLogger:True", result); + } + [Fact] public async Task CanAccessHttpContext() { - var services = new ServiceCollection().BuildServiceProvider(); TestServer server = TestServer.Create(app => { app.Run(context => @@ -53,6 +69,35 @@ namespace Microsoft.AspNet.TestHost Assert.Equal("HasContext:True", result); } + public class ContextHolder + { + public ContextHolder(IHttpContextAccessor accessor) + { + Accessor = accessor; + } + + public IHttpContextAccessor Accessor { get; set; } + } + + [Fact] + public async Task CanAddNewHostServices() + { + TestServer server = TestServer.Create(app => + { + var a = app.ApplicationServices.GetRequiredService(); + + app.Run(context => + { + var b = app.ApplicationServices.GetRequiredService(); + var accessor = app.ApplicationServices.GetRequiredService(); + return context.Response.WriteAsync("HasContext:" + (accessor.Accessor.HttpContext != null)); + }); + }, newHostServices => newHostServices.AddSingleton()); + + string result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("HasContext:True", result); + } + [Fact] public async Task CreateInvokesApp() {