From a8c0970cde72242c40534c6c56a99f1d2ee6827a Mon Sep 17 00:00:00 2001 From: David Fowler Date: Fri, 13 Apr 2018 09:45:38 -0700 Subject: [PATCH] Check for non-null RequestServices (#1378) --- .../RequestServicesContainerMiddleware.cs | 10 +- ...RequestServicesContainerMiddlewareTests.cs | 122 ++++++++++++++++++ 2 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 test/Microsoft.AspNetCore.Hosting.Tests/RequestServicesContainerMiddlewareTests.cs diff --git a/src/Microsoft.AspNetCore.Hosting/Internal/RequestServicesContainerMiddleware.cs b/src/Microsoft.AspNetCore.Hosting/Internal/RequestServicesContainerMiddleware.cs index 22b034ee34..9fcb01aa12 100644 --- a/src/Microsoft.AspNetCore.Hosting/Internal/RequestServicesContainerMiddleware.cs +++ b/src/Microsoft.AspNetCore.Hosting/Internal/RequestServicesContainerMiddleware.cs @@ -34,12 +34,16 @@ namespace Microsoft.AspNetCore.Hosting.Internal { Debug.Assert(httpContext != null); - // local cache for virtual disptach result var features = httpContext.Features; + var servicesFeature = features.Get(); - var servicesFeature = new RequestServicesFeature(httpContext, _scopeFactory); + // All done if RequestServices is set + if (servicesFeature?.RequestServices != null) + { + return _next.Invoke(httpContext); + } - features.Set(servicesFeature); + features.Set(new RequestServicesFeature(httpContext, _scopeFactory)); return _next.Invoke(httpContext); } } diff --git a/test/Microsoft.AspNetCore.Hosting.Tests/RequestServicesContainerMiddlewareTests.cs b/test/Microsoft.AspNetCore.Hosting.Tests/RequestServicesContainerMiddlewareTests.cs new file mode 100644 index 0000000000..c153af4dc1 --- /dev/null +++ b/test/Microsoft.AspNetCore.Hosting.Tests/RequestServicesContainerMiddlewareTests.cs @@ -0,0 +1,122 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.IO; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Hosting.Internal; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Hosting.Tests +{ + public class RequestServicesContainerMiddlewareTests + { + [Fact] + public async Task RequestServicesAreSet() + { + var serviceProvider = new ServiceCollection() + .BuildServiceProvider(); + + var scopeFactory = serviceProvider.GetRequiredService(); + + var middleware = new RequestServicesContainerMiddleware( + ctx => Task.CompletedTask, + scopeFactory); + + var context = new DefaultHttpContext(); + await middleware.Invoke(context); + + Assert.NotNull(context.RequestServices); + } + + [Fact] + public async Task RequestServicesAreNotOverwrittenIfAlreadySet() + { + var serviceProvider = new ServiceCollection() + .BuildServiceProvider(); + + var scopeFactory = serviceProvider.GetRequiredService(); + + var middleware = new RequestServicesContainerMiddleware( + ctx => Task.CompletedTask, + scopeFactory); + + var context = new DefaultHttpContext(); + context.RequestServices = serviceProvider; + await middleware.Invoke(context); + + Assert.Same(serviceProvider, context.RequestServices); + } + + [Fact] + public async Task RequestServicesAreDisposedOnCompleted() + { + var serviceProvider = new ServiceCollection() + .AddTransient() + .BuildServiceProvider(); + + var scopeFactory = serviceProvider.GetRequiredService(); + DisposableThing instance = null; + + var middleware = new RequestServicesContainerMiddleware( + ctx => + { + instance = ctx.RequestServices.GetRequiredService(); + return Task.CompletedTask; + }, + scopeFactory); + + var context = new DefaultHttpContext(); + var responseFeature = new TestHttpResponseFeature(); + context.Features.Set(responseFeature); + + await middleware.Invoke(context); + + Assert.NotNull(context.RequestServices); + Assert.Single(responseFeature.CompletedCallbacks); + + var callback = responseFeature.CompletedCallbacks[0]; + await callback.callback(callback.state); + + Assert.Null(context.RequestServices); + Assert.True(instance.Disposed); + } + + private class DisposableThing : IDisposable + { + public bool Disposed { get; set; } + public void Dispose() + { + Disposed = true; + } + } + + private class TestHttpResponseFeature : IHttpResponseFeature + { + public List<(Func callback, object state)> CompletedCallbacks = new List<(Func callback, object state)>(); + + public int StatusCode { get; set; } + public string ReasonPhrase { get; set; } + public IHeaderDictionary Headers { get; set; } = new HeaderDictionary(); + public Stream Body { get; set; } + + public bool HasStarted => false; + + public void OnCompleted(Func callback, object state) + { + CompletedCallbacks.Add((callback, state)); + } + + public void OnStarting(Func callback, object state) + { + } + } + } +} \ No newline at end of file