Nuke RequestServicesContainer (inline instead)

This commit is contained in:
Hao Kung 2015-06-23 13:44:06 -07:00
parent 0013d44167
commit ee8baab1ed
4 changed files with 82 additions and 119 deletions

View File

@ -1,114 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNet.Http;
using Microsoft.Framework.DependencyInjection;
namespace Microsoft.AspNet.Hosting.Internal
{
public class RequestServicesContainer : IDisposable
{
public RequestServicesContainer(
HttpContext context,
IServiceScopeFactory scopeFactory,
IServiceProvider appServiceProvider)
{
if (scopeFactory == null)
{
throw new ArgumentNullException(nameof(scopeFactory));
}
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
Context = context;
PriorAppServices = context.ApplicationServices;
PriorRequestServices = context.RequestServices;
// Begin the scope
Scope = scopeFactory.CreateScope();
Context.ApplicationServices = appServiceProvider;
Context.RequestServices = Scope.ServiceProvider;
}
private HttpContext Context { get; set; }
private IServiceProvider PriorAppServices { get; set; }
private IServiceProvider PriorRequestServices { get; set; }
private IServiceScope Scope { get; set; }
// CONSIDER: this could be an extension method on HttpContext instead
public static RequestServicesContainer EnsureRequestServices(HttpContext httpContext, IServiceProvider services)
{
// All done if we already have a request services
if (httpContext.RequestServices != null)
{
return null;
}
var serviceProvider = httpContext.ApplicationServices ?? services;
if (serviceProvider == null)
{
throw new InvalidOperationException("TODO: services and httpContext.ApplicationServices are both null!");
}
// Matches constructor of RequestContainer
var rootServiceProvider = serviceProvider.GetRequiredService<IServiceProvider>();
var rootServiceScopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
// Pre Scope setup
var priorApplicationServices = serviceProvider;
var priorRequestServices = serviceProvider;
var appServiceProvider = rootServiceProvider;
var appServiceScopeFactory = rootServiceScopeFactory;
if (priorApplicationServices != null &&
priorApplicationServices != appServiceProvider)
{
appServiceProvider = priorApplicationServices;
appServiceScopeFactory = priorApplicationServices.GetRequiredService<IServiceScopeFactory>();
}
// Creates the scope and does the service swaps
return new RequestServicesContainer(httpContext, appServiceScopeFactory, appServiceProvider);
}
#region IDisposable Support
private bool disposedValue = false; // To detect redundant calls
protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing)
{
Context.RequestServices = PriorRequestServices;
Context.ApplicationServices = PriorAppServices;
}
if (Scope != null)
{
Scope.Dispose();
Scope = null;
}
Context = null;
PriorAppServices = null;
PriorRequestServices = null;
disposedValue = true;
}
}
// This code added to correctly implement the disposable pattern.
public void Dispose()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
Dispose(true);
}
#endregion
}
}

View File

@ -5,6 +5,8 @@ using System;
using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Http;
using Microsoft.Framework.DependencyInjection;
using Microsoft.Framework.Internal;
namespace Microsoft.AspNet.Hosting.Internal
{
@ -13,18 +15,41 @@ namespace Microsoft.AspNet.Hosting.Internal
private readonly RequestDelegate _next;
private readonly IServiceProvider _services;
public RequestServicesContainerMiddleware(RequestDelegate next, IServiceProvider services)
public RequestServicesContainerMiddleware([NotNull] RequestDelegate next, [NotNull] IServiceProvider services)
{
_services = services;
_next = next;
}
public async Task Invoke(HttpContext httpContext)
public async Task Invoke([NotNull] HttpContext httpContext)
{
using (var container = RequestServicesContainer.EnsureRequestServices(httpContext, _services))
// All done if there request services is set
if (httpContext.RequestServices != null)
{
await _next.Invoke(httpContext);
return;
}
var priorApplicationServices = httpContext.ApplicationServices;
var serviceProvider = priorApplicationServices ?? _services;
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
try
{
// Creates the scope and temporarily swap services
using (var scope = scopeFactory.CreateScope())
{
httpContext.ApplicationServices = serviceProvider;
httpContext.RequestServices = scope.ServiceProvider;
await _next.Invoke(httpContext);
}
}
finally
{
httpContext.RequestServices = null;
httpContext.ApplicationServices = priorApplicationServices;
}
}
}
}
}

View File

@ -8,6 +8,7 @@ using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.FeatureModel;
using Microsoft.AspNet.Hosting.Builder;
using Microsoft.AspNet.Hosting.Fakes;
using Microsoft.AspNet.Hosting.Internal;
using Microsoft.AspNet.Hosting.Server;
using Microsoft.AspNet.Hosting.Startup;

View File

@ -8,6 +8,7 @@ using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Hosting;
using Microsoft.AspNet.Hosting.Startup;
using Microsoft.AspNet.Http;
using Microsoft.Framework.Configuration;
using Microsoft.Framework.DependencyInjection;
@ -55,10 +56,60 @@ namespace Microsoft.AspNet.TestHost
Assert.Equal("RequestServices:True", result);
}
public class TestService { }
public class TestRequestServiceMiddleware
{
private RequestDelegate _next;
public TestRequestServiceMiddleware(RequestDelegate next)
{
_next = next;
}
public Task Invoke(HttpContext httpContext)
{
var services = new ServiceCollection();
services.AddTransient<TestService>();
httpContext.RequestServices = services.BuildServiceProvider();
return _next.Invoke(httpContext);
}
}
public class RequestServicesFilter : IStartupFilter
{
public Action<IApplicationBuilder> Configure(IApplicationBuilder app, Action<IApplicationBuilder> next)
{
return builder =>
{
app.UseMiddleware<TestRequestServiceMiddleware>();
next(builder);
};
}
}
[Fact]
public async Task ExistingRequestServicesWillNotBeReplaced()
{
var server = TestServer.Create(app =>
{
app.Run(context =>
{
var service = context.RequestServices.GetService<TestService>();
return context.Response.WriteAsync("Found:" + (service != null));
});
},
services => services.AddTransient<IStartupFilter, RequestServicesFilter>());
string result = await server.CreateClient().GetStringAsync("/path");
Assert.Equal("Found:True", result);
}
[Fact]
public async Task CanAccessLogger()
{
TestServer server = TestServer.Create(app =>
var server = TestServer.Create(app =>
{
app.Run(context =>
{