Check for non-null RequestServices (#1378)
This commit is contained in:
parent
5fd1f9e0e5
commit
a8c0970cde
|
|
@ -34,12 +34,16 @@ namespace Microsoft.AspNetCore.Hosting.Internal
|
|||
{
|
||||
Debug.Assert(httpContext != null);
|
||||
|
||||
// local cache for virtual disptach result
|
||||
var features = httpContext.Features;
|
||||
var servicesFeature = features.Get<IServiceProvidersFeature>();
|
||||
|
||||
var servicesFeature = new RequestServicesFeature(httpContext, _scopeFactory);
|
||||
// All done if RequestServices is set
|
||||
if (servicesFeature?.RequestServices != null)
|
||||
{
|
||||
return _next.Invoke(httpContext);
|
||||
}
|
||||
|
||||
features.Set<IServiceProvidersFeature>(servicesFeature);
|
||||
features.Set<IServiceProvidersFeature>(new RequestServicesFeature(httpContext, _scopeFactory));
|
||||
return _next.Invoke(httpContext);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,122 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Reflection;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Builder.Internal;
|
||||
using Microsoft.AspNetCore.Hosting.Internal;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Hosting.Tests
|
||||
{
|
||||
public class RequestServicesContainerMiddlewareTests
|
||||
{
|
||||
[Fact]
|
||||
public async Task RequestServicesAreSet()
|
||||
{
|
||||
var serviceProvider = new ServiceCollection()
|
||||
.BuildServiceProvider();
|
||||
|
||||
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
|
||||
|
||||
var middleware = new RequestServicesContainerMiddleware(
|
||||
ctx => Task.CompletedTask,
|
||||
scopeFactory);
|
||||
|
||||
var context = new DefaultHttpContext();
|
||||
await middleware.Invoke(context);
|
||||
|
||||
Assert.NotNull(context.RequestServices);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RequestServicesAreNotOverwrittenIfAlreadySet()
|
||||
{
|
||||
var serviceProvider = new ServiceCollection()
|
||||
.BuildServiceProvider();
|
||||
|
||||
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
|
||||
|
||||
var middleware = new RequestServicesContainerMiddleware(
|
||||
ctx => Task.CompletedTask,
|
||||
scopeFactory);
|
||||
|
||||
var context = new DefaultHttpContext();
|
||||
context.RequestServices = serviceProvider;
|
||||
await middleware.Invoke(context);
|
||||
|
||||
Assert.Same(serviceProvider, context.RequestServices);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task RequestServicesAreDisposedOnCompleted()
|
||||
{
|
||||
var serviceProvider = new ServiceCollection()
|
||||
.AddTransient<DisposableThing>()
|
||||
.BuildServiceProvider();
|
||||
|
||||
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
|
||||
DisposableThing instance = null;
|
||||
|
||||
var middleware = new RequestServicesContainerMiddleware(
|
||||
ctx =>
|
||||
{
|
||||
instance = ctx.RequestServices.GetRequiredService<DisposableThing>();
|
||||
return Task.CompletedTask;
|
||||
},
|
||||
scopeFactory);
|
||||
|
||||
var context = new DefaultHttpContext();
|
||||
var responseFeature = new TestHttpResponseFeature();
|
||||
context.Features.Set<IHttpResponseFeature>(responseFeature);
|
||||
|
||||
await middleware.Invoke(context);
|
||||
|
||||
Assert.NotNull(context.RequestServices);
|
||||
Assert.Single(responseFeature.CompletedCallbacks);
|
||||
|
||||
var callback = responseFeature.CompletedCallbacks[0];
|
||||
await callback.callback(callback.state);
|
||||
|
||||
Assert.Null(context.RequestServices);
|
||||
Assert.True(instance.Disposed);
|
||||
}
|
||||
|
||||
private class DisposableThing : IDisposable
|
||||
{
|
||||
public bool Disposed { get; set; }
|
||||
public void Dispose()
|
||||
{
|
||||
Disposed = true;
|
||||
}
|
||||
}
|
||||
|
||||
private class TestHttpResponseFeature : IHttpResponseFeature
|
||||
{
|
||||
public List<(Func<object, Task> callback, object state)> CompletedCallbacks = new List<(Func<object, Task> callback, object state)>();
|
||||
|
||||
public int StatusCode { get; set; }
|
||||
public string ReasonPhrase { get; set; }
|
||||
public IHeaderDictionary Headers { get; set; } = new HeaderDictionary();
|
||||
public Stream Body { get; set; }
|
||||
|
||||
public bool HasStarted => false;
|
||||
|
||||
public void OnCompleted(Func<object, Task> callback, object state)
|
||||
{
|
||||
CompletedCallbacks.Add((callback, state));
|
||||
}
|
||||
|
||||
public void OnStarting(Func<object, Task> callback, object state)
|
||||
{
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue