Switch to IServiceProvidersFeature for RequestServices

This commit is contained in:
Hao Kung 2015-09-23 13:33:07 -07:00
parent 285da613e4
commit 49520a2a73
3 changed files with 182 additions and 18 deletions

View File

@ -0,0 +1,70 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNet.Http.Features.Internal;
using Microsoft.Framework.DependencyInjection;
namespace Microsoft.AspNet.Hosting.Internal
{
public class RequestServicesFeature : IServiceProvidersFeature, IDisposable
{
private IServiceProvider _appServices;
private IServiceProvider _requestServices;
private IServiceScope _scope;
private bool _requestServicesSet;
public RequestServicesFeature(IServiceProvider applicationServices)
{
if (applicationServices == null)
{
throw new ArgumentNullException(nameof(applicationServices));
}
ApplicationServices = applicationServices;
}
public IServiceProvider ApplicationServices
{
get
{
return _appServices;
}
set
{
if (value == null)
{
throw new ArgumentNullException(nameof(value));
}
_appServices = value;
}
}
public IServiceProvider RequestServices
{
get
{
if (!_requestServicesSet)
{
_scope = ApplicationServices.GetRequiredService<IServiceScopeFactory>().CreateScope();
_requestServices = _scope.ServiceProvider;
_requestServicesSet = true;
}
return _requestServices;
}
set
{
_requestServicesSet = true;
RequestServices = value;
}
}
public void Dispose()
{
_scope?.Dispose();
_scope = null;
_requestServices = null;
}
}
}

View File

@ -5,6 +5,8 @@ using System;
using System.Threading.Tasks;
using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Http;
using Microsoft.AspNet.Http.Features;
using Microsoft.AspNet.Http.Features.Internal;
using Microsoft.Framework.DependencyInjection;
namespace Microsoft.AspNet.Hosting.Internal
@ -20,7 +22,6 @@ namespace Microsoft.AspNet.Hosting.Internal
{
throw new ArgumentNullException(nameof(next));
}
if (services == null)
{
throw new ArgumentNullException(nameof(services));
@ -37,32 +38,26 @@ namespace Microsoft.AspNet.Hosting.Internal
throw new ArgumentNullException(nameof(httpContext));
}
// All done if there request services is set
if (httpContext.RequestServices != null)
var existingFeature = httpContext.Features.Get<IServiceProvidersFeature>();
// All done if request services is set
if (existingFeature?.RequestServices != null)
{
await _next.Invoke(httpContext);
return;
}
var priorApplicationServices = httpContext.ApplicationServices;
var serviceProvider = priorApplicationServices ?? _services;
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
try
using (var feature = new RequestServicesFeature(_services))
{
// Creates the scope and temporarily swap services
using (var scope = scopeFactory.CreateScope())
try
{
httpContext.ApplicationServices = serviceProvider;
httpContext.RequestServices = scope.ServiceProvider;
httpContext.Features.Set<IServiceProvidersFeature>(feature);
await _next.Invoke(httpContext);
}
}
finally
{
httpContext.RequestServices = null;
httpContext.ApplicationServices = priorApplicationServices;
finally
{
httpContext.Features.Set(existingFeature);
}
}
}
}

View File

@ -11,6 +11,8 @@ using Microsoft.AspNet.Builder;
using Microsoft.AspNet.Hosting;
using Microsoft.AspNet.Hosting.Startup;
using Microsoft.AspNet.Http;
using Microsoft.AspNet.Http.Features;
using Microsoft.AspNet.Http.Features.Internal;
using Microsoft.Framework.Configuration;
using Microsoft.Framework.DependencyInjection;
using Microsoft.Framework.Logging;
@ -133,6 +135,103 @@ namespace Microsoft.AspNet.TestHost
Assert.Equal("Found:True", result);
}
[Fact]
public async Task SettingApplicationServicesOnFeatureToNullThrows()
{
var server = TestServer.Create(app =>
{
app.Run(context =>
{
var feature = context.Features.Get<IServiceProvidersFeature>();
Assert.Throws<ArgumentNullException>(() => feature.ApplicationServices = null);
return context.Response.WriteAsync("Success");
});
});
string result = await server.CreateClient().GetStringAsync("/path");
Assert.Equal("Success", result);
}
public class ReplaceServiceProvidersFeatureFilter : IStartupFilter, IServiceProvidersFeature
{
public ReplaceServiceProvidersFeatureFilter(IServiceProvider appServices, IServiceProvider requestServices)
{
ApplicationServices = appServices;
RequestServices = requestServices;
}
public IServiceProvider ApplicationServices { get; set; }
public IServiceProvider RequestServices { get; set; }
public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
{
return app =>
{
app.Use(async (context, nxt) =>
{
context.Features.Set<IServiceProvidersFeature>(this);
await nxt();
});
next(app);
};
}
}
[Fact]
public async Task ExistingServiceProviderFeatureWillNotBeReplaced()
{
var appServices = new ServiceCollection().BuildServiceProvider();
var server = TestServer.Create(app =>
{
app.Run(context =>
{
Assert.Equal(appServices, context.ApplicationServices);
Assert.Equal(appServices, context.RequestServices);
return context.Response.WriteAsync("Success");
});
},
services => services.AddInstance<IStartupFilter>(new ReplaceServiceProvidersFeatureFilter(appServices, appServices)));
var result = await server.CreateClient().GetStringAsync("/path");
Assert.Equal("Success", result);
}
public class NullServiceProvidersFeatureFilter : IStartupFilter, IServiceProvidersFeature
{
public IServiceProvider ApplicationServices { get; set; }
public IServiceProvider RequestServices { get; set; }
public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)
{
return app =>
{
app.Use(async (context, nxt) =>
{
context.Features.Set<IServiceProvidersFeature>(this);
await nxt();
});
next(app);
};
}
}
[Fact]
public async Task WillReplaceServiceProviderFeatureWithNullRequestServices()
{
var server = TestServer.Create(app =>
{
app.Run(context =>
{
Assert.NotNull(context.ApplicationServices);
Assert.NotNull(context.RequestServices);
return context.Response.WriteAsync("Success");
});
},
services => services.AddTransient<IStartupFilter, NullServiceProvidersFeatureFilter>());
var result = await server.CreateClient().GetStringAsync("/path");
Assert.Equal("Success", result);
}
public class EnsureApplicationServicesFilter : IStartupFilter
{
public Action<IApplicationBuilder> Configure(Action<IApplicationBuilder> next)