Check for non-null RequestServices (#1378)

This commit is contained in:
David Fowler 2018-04-13 09:45:38 -07:00 committed by GitHub
parent 5fd1f9e0e5
commit a8c0970cde
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 129 additions and 3 deletions

View File

@ -34,12 +34,16 @@ namespace Microsoft.AspNetCore.Hosting.Internal
{
Debug.Assert(httpContext != null);
// local cache for virtual disptach result
var features = httpContext.Features;
var servicesFeature = features.Get<IServiceProvidersFeature>();
var servicesFeature = new RequestServicesFeature(httpContext, _scopeFactory);
// All done if RequestServices is set
if (servicesFeature?.RequestServices != null)
{
return _next.Invoke(httpContext);
}
features.Set<IServiceProvidersFeature>(servicesFeature);
features.Set<IServiceProvidersFeature>(new RequestServicesFeature(httpContext, _scopeFactory));
return _next.Invoke(httpContext);
}
}

View File

@ -0,0 +1,122 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Builder.Internal;
using Microsoft.AspNetCore.Hosting.Internal;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.DependencyInjection;
using Xunit;
namespace Microsoft.AspNetCore.Hosting.Tests
{
public class RequestServicesContainerMiddlewareTests
{
[Fact]
public async Task RequestServicesAreSet()
{
var serviceProvider = new ServiceCollection()
.BuildServiceProvider();
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
var middleware = new RequestServicesContainerMiddleware(
ctx => Task.CompletedTask,
scopeFactory);
var context = new DefaultHttpContext();
await middleware.Invoke(context);
Assert.NotNull(context.RequestServices);
}
[Fact]
public async Task RequestServicesAreNotOverwrittenIfAlreadySet()
{
var serviceProvider = new ServiceCollection()
.BuildServiceProvider();
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
var middleware = new RequestServicesContainerMiddleware(
ctx => Task.CompletedTask,
scopeFactory);
var context = new DefaultHttpContext();
context.RequestServices = serviceProvider;
await middleware.Invoke(context);
Assert.Same(serviceProvider, context.RequestServices);
}
[Fact]
public async Task RequestServicesAreDisposedOnCompleted()
{
var serviceProvider = new ServiceCollection()
.AddTransient<DisposableThing>()
.BuildServiceProvider();
var scopeFactory = serviceProvider.GetRequiredService<IServiceScopeFactory>();
DisposableThing instance = null;
var middleware = new RequestServicesContainerMiddleware(
ctx =>
{
instance = ctx.RequestServices.GetRequiredService<DisposableThing>();
return Task.CompletedTask;
},
scopeFactory);
var context = new DefaultHttpContext();
var responseFeature = new TestHttpResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
await middleware.Invoke(context);
Assert.NotNull(context.RequestServices);
Assert.Single(responseFeature.CompletedCallbacks);
var callback = responseFeature.CompletedCallbacks[0];
await callback.callback(callback.state);
Assert.Null(context.RequestServices);
Assert.True(instance.Disposed);
}
private class DisposableThing : IDisposable
{
public bool Disposed { get; set; }
public void Dispose()
{
Disposed = true;
}
}
private class TestHttpResponseFeature : IHttpResponseFeature
{
public List<(Func<object, Task> callback, object state)> CompletedCallbacks = new List<(Func<object, Task> callback, object state)>();
public int StatusCode { get; set; }
public string ReasonPhrase { get; set; }
public IHeaderDictionary Headers { get; set; } = new HeaderDictionary();
public Stream Body { get; set; }
public bool HasStarted => false;
public void OnCompleted(Func<object, Task> callback, object state)
{
CompletedCallbacks.Add((callback, state));
}
public void OnStarting(Func<object, Task> callback, object state)
{
}
}
}
}