diff --git a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainer.cs b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainer.cs deleted file mode 100644 index a4ab07d546..0000000000 --- a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainer.cs +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using Microsoft.AspNet.Http; -using Microsoft.Framework.DependencyInjection; - -namespace Microsoft.AspNet.Hosting.Internal -{ - public class RequestServicesContainer : IDisposable - { - public RequestServicesContainer( - HttpContext context, - IServiceScopeFactory scopeFactory, - IServiceProvider appServiceProvider) - { - if (scopeFactory == null) - { - throw new ArgumentNullException(nameof(scopeFactory)); - } - if (context == null) - { - throw new ArgumentNullException(nameof(context)); - } - - Context = context; - PriorAppServices = context.ApplicationServices; - PriorRequestServices = context.RequestServices; - - // Begin the scope - Scope = scopeFactory.CreateScope(); - - Context.ApplicationServices = appServiceProvider; - Context.RequestServices = Scope.ServiceProvider; - } - - private HttpContext Context { get; set; } - private IServiceProvider PriorAppServices { get; set; } - private IServiceProvider PriorRequestServices { get; set; } - private IServiceScope Scope { get; set; } - - // CONSIDER: this could be an extension method on HttpContext instead - public static RequestServicesContainer EnsureRequestServices(HttpContext httpContext, IServiceProvider services) - { - // All done if we already have a request services - if (httpContext.RequestServices != null) - { - return null; - } - - var serviceProvider = httpContext.ApplicationServices ?? services; - if (serviceProvider == null) - { - throw new InvalidOperationException("TODO: services and httpContext.ApplicationServices are both null!"); - } - - // Matches constructor of RequestContainer - var rootServiceProvider = serviceProvider.GetRequiredService(); - var rootServiceScopeFactory = serviceProvider.GetRequiredService(); - - // Pre Scope setup - var priorApplicationServices = serviceProvider; - var priorRequestServices = serviceProvider; - - var appServiceProvider = rootServiceProvider; - var appServiceScopeFactory = rootServiceScopeFactory; - - if (priorApplicationServices != null && - priorApplicationServices != appServiceProvider) - { - appServiceProvider = priorApplicationServices; - appServiceScopeFactory = priorApplicationServices.GetRequiredService(); - } - - // Creates the scope and does the service swaps - return new RequestServicesContainer(httpContext, appServiceScopeFactory, appServiceProvider); - } - -#region IDisposable Support - private bool disposedValue = false; // To detect redundant calls - - protected virtual void Dispose(bool disposing) - { - if (!disposedValue) - { - if (disposing) - { - Context.RequestServices = PriorRequestServices; - Context.ApplicationServices = PriorAppServices; - } - - if (Scope != null) - { - Scope.Dispose(); - Scope = null; - } - - Context = null; - PriorAppServices = null; - PriorRequestServices = null; - - disposedValue = true; - } - } - - // This code added to correctly implement the disposable pattern. - public void Dispose() - { - // Do not change this code. Put cleanup code in Dispose(bool disposing) above. - Dispose(true); - } -#endregion - } -} \ No newline at end of file diff --git a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs index 0a0d169dcb..9c740ba7f1 100644 --- a/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs +++ b/src/Microsoft.AspNet.Hosting/Internal/RequestServicesContainerMiddleware.cs @@ -5,6 +5,8 @@ using System; using System.Threading.Tasks; using Microsoft.AspNet.Builder; using Microsoft.AspNet.Http; +using Microsoft.Framework.DependencyInjection; +using Microsoft.Framework.Internal; namespace Microsoft.AspNet.Hosting.Internal { @@ -13,18 +15,41 @@ namespace Microsoft.AspNet.Hosting.Internal private readonly RequestDelegate _next; private readonly IServiceProvider _services; - public RequestServicesContainerMiddleware(RequestDelegate next, IServiceProvider services) + public RequestServicesContainerMiddleware([NotNull] RequestDelegate next, [NotNull] IServiceProvider services) { _services = services; _next = next; } - public async Task Invoke(HttpContext httpContext) + public async Task Invoke([NotNull] HttpContext httpContext) { - using (var container = RequestServicesContainer.EnsureRequestServices(httpContext, _services)) + // All done if there request services is set + if (httpContext.RequestServices != null) { await _next.Invoke(httpContext); + return; + } + + var priorApplicationServices = httpContext.ApplicationServices; + var serviceProvider = priorApplicationServices ?? _services; + var scopeFactory = serviceProvider.GetRequiredService(); + + try + { + // Creates the scope and temporarily swap services + using (var scope = scopeFactory.CreateScope()) + { + httpContext.ApplicationServices = serviceProvider; + httpContext.RequestServices = scope.ServiceProvider; + + await _next.Invoke(httpContext); + } + } + finally + { + httpContext.RequestServices = null; + httpContext.ApplicationServices = priorApplicationServices; } } } -} +} \ No newline at end of file diff --git a/test/Microsoft.AspNet.Hosting.Tests/HostingEngineTests.cs b/test/Microsoft.AspNet.Hosting.Tests/HostingEngineTests.cs index de7782c4dc..b9f0d6a442 100644 --- a/test/Microsoft.AspNet.Hosting.Tests/HostingEngineTests.cs +++ b/test/Microsoft.AspNet.Hosting.Tests/HostingEngineTests.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using Microsoft.AspNet.Builder; using Microsoft.AspNet.FeatureModel; using Microsoft.AspNet.Hosting.Builder; +using Microsoft.AspNet.Hosting.Fakes; using Microsoft.AspNet.Hosting.Internal; using Microsoft.AspNet.Hosting.Server; using Microsoft.AspNet.Hosting.Startup; diff --git a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs index 23ae59567b..d1dbcb93e1 100644 --- a/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs +++ b/test/Microsoft.AspNet.TestHost.Tests/TestServerTests.cs @@ -8,6 +8,7 @@ using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNet.Builder; using Microsoft.AspNet.Hosting; +using Microsoft.AspNet.Hosting.Startup; using Microsoft.AspNet.Http; using Microsoft.Framework.Configuration; using Microsoft.Framework.DependencyInjection; @@ -55,10 +56,60 @@ namespace Microsoft.AspNet.TestHost Assert.Equal("RequestServices:True", result); } + public class TestService { } + + public class TestRequestServiceMiddleware + { + private RequestDelegate _next; + + public TestRequestServiceMiddleware(RequestDelegate next) + { + _next = next; + } + + public Task Invoke(HttpContext httpContext) + { + var services = new ServiceCollection(); + services.AddTransient(); + httpContext.RequestServices = services.BuildServiceProvider(); + + return _next.Invoke(httpContext); + } + } + + public class RequestServicesFilter : IStartupFilter + { + public Action Configure(IApplicationBuilder app, Action next) + { + return builder => + { + app.UseMiddleware(); + next(builder); + }; + } + } + + [Fact] + public async Task ExistingRequestServicesWillNotBeReplaced() + { + var server = TestServer.Create(app => + { + app.Run(context => + { + var service = context.RequestServices.GetService(); + return context.Response.WriteAsync("Found:" + (service != null)); + }); + }, + services => services.AddTransient()); + string result = await server.CreateClient().GetStringAsync("/path"); + Assert.Equal("Found:True", result); + } + + [Fact] public async Task CanAccessLogger() { - TestServer server = TestServer.Create(app => + var server = TestServer.Create(app => { app.Run(context => {