diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs b/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs index 03451d62a1..2dcbcd8fe4 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Builder/MvcApplicationBuilderExtensions.cs @@ -85,6 +85,9 @@ namespace Microsoft.AspNetCore.Builder "ConfigureServices(...)")); } + var middlewarePipelineBuilder = app.ApplicationServices.GetRequiredService(); + middlewarePipelineBuilder.ApplicationBuilder = app.New(); + var routes = new RouteBuilder(app) { DefaultHandler = app.ApplicationServices.GetRequiredService(), diff --git a/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs b/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs index 80f72cf352..22e6a5deb9 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/DependencyInjection/MvcCoreServiceCollectionExtensions.cs @@ -222,6 +222,13 @@ namespace Microsoft.Extensions.DependencyInjection // services.TryAddSingleton(); // Only one per app services.TryAddTransient(); // Many per app + + // + // Middleware pipeline filter related + // + services.TryAddSingleton(); + // This maintains a cache of middleware pipelines, so it needs to be a singleton + services.TryAddSingleton(); } private static void ConfigureDefaultServices(IServiceCollection services) diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Filters/MiddlewareFilterAttribute.cs b/src/Microsoft.AspNetCore.Mvc.Core/Filters/MiddlewareFilterAttribute.cs new file mode 100644 index 0000000000..f88fd1a4c0 --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Filters/MiddlewareFilterAttribute.cs @@ -0,0 +1,54 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.Internal; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Mvc +{ + /// + /// Executes a middleware pipeline provided the by the . + /// The middleware pipeline will be treated as an async resource filter. + /// + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + public class MiddlewareFilterAttribute : Attribute, IFilterFactory, IOrderedFilter + { + /// + /// Instantiates a new instance of . + /// + /// A type which configures a middleware pipeline. + public MiddlewareFilterAttribute(Type configurationType) + { + if (configurationType == null) + { + throw new ArgumentNullException(nameof(configurationType)); + } + + ConfigurationType = configurationType; + } + + public Type ConfigurationType { get; } + + /// + public int Order { get; set; } + + /// + public bool IsReusable { get; } = true; + + /// + public IFilterMetadata CreateInstance(IServiceProvider serviceProvider) + { + if (serviceProvider == null) + { + throw new ArgumentNullException(nameof(serviceProvider)); + } + + var middlewarePipelineService = serviceProvider.GetRequiredService(); + var pipeline = middlewarePipelineService.GetPipeline(ConfigurationType); + + return new MiddlewareFilter(pipeline); + } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/IMiddlewareFilterFeature.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/IMiddlewareFilterFeature.cs new file mode 100644 index 0000000000..43af488b8a --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/IMiddlewareFilterFeature.cs @@ -0,0 +1,19 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + /// + /// A feature in which is used to capture the + /// currently executing context of a resource filter. This feature is used in the final middleware + /// of a middleware filter's pipeline to keep the request flow through the rest of the MVC layers. + /// + public interface IMiddlewareFilterFeature + { + ResourceExecutingContext ResourceExecutingContext { get; } + + ResourceExecutionDelegate ResourceExecutionDelegate { get; } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilter.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilter.cs new file mode 100644 index 0000000000..d346acc27e --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilter.cs @@ -0,0 +1,46 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + /// + /// A filter which executes a user configured middleware pipeline. + /// + public class MiddlewareFilter : IAsyncResourceFilter + { + private readonly RequestDelegate _middlewarePipeline; + + public MiddlewareFilter(RequestDelegate middlewarePipeline) + { + if (middlewarePipeline == null) + { + throw new ArgumentNullException(nameof(middlewarePipeline)); + } + + _middlewarePipeline = middlewarePipeline; + } + + public Task OnResourceExecutionAsync(ResourceExecutingContext context, ResourceExecutionDelegate next) + { + var httpContext = context.HttpContext; + + // Capture the current context into the feature. This will later be used in the end middleware to continue + // the execution flow to later MVC layers. + // Example: + // this filter -> user-middleware1 -> user-middleware2 -> the-end-middleware -> resouce filters or model binding + var feature = new MiddlewareFilterFeature() + { + ResourceExecutionDelegate = next, + ResourceExecutingContext = context + }; + httpContext.Features.Set(feature); + + return _middlewarePipeline(httpContext); + } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterBuilder.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterBuilder.cs new file mode 100644 index 0000000000..04a620620a --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterBuilder.cs @@ -0,0 +1,88 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Core; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + /// + /// Builds a middleware pipeline after receiving the pipeline from a pipeline provider + /// + public class MiddlewareFilterBuilder + { + // 'GetOrAdd' call on the dictionary is not thread safe and we might end up creating the pipeline more + // once. To prevent this Lazy<> is used. In the worst case multiple Lazy<> objects are created for multiple + // threads but only one of the objects succeeds in creating a pipeline. + private readonly ConcurrentDictionary> _pipelinesCache + = new ConcurrentDictionary>(); + private readonly MiddlewareFilterConfigurationProvider _configurationProvider; + + public IApplicationBuilder ApplicationBuilder { get; set; } + + public MiddlewareFilterBuilder(MiddlewareFilterConfigurationProvider configurationProvider) + { + _configurationProvider = configurationProvider; + } + + public RequestDelegate GetPipeline(Type configurationType) + { + // Build the pipeline only once. This is similar to how middlewares registered in Startup are constructed. + + var requestDelegate = _pipelinesCache.GetOrAdd( + configurationType, + key => new Lazy(() => BuildPipeline(key))); + + return requestDelegate.Value; + } + + private RequestDelegate BuildPipeline(Type middlewarePipelineProviderType) + { + if (ApplicationBuilder == null) + { + throw new InvalidOperationException( + Resources.FormatMiddlewareFilterBuilder_NullApplicationBuilder(nameof(ApplicationBuilder))); + } + + var nestedAppBuilder = ApplicationBuilder.New(); + + // Get the 'Configure' method from the user provided type. + var configureDelegate = _configurationProvider.CreateConfigureDelegate(middlewarePipelineProviderType); + configureDelegate(nestedAppBuilder); + + // The middleware resource filter, after receiving the request executes the user configured middleware + // pipeline. Since we want execution of the request to continue to later MVC layers (resource filters + // or model binding), add a middleware at the end of the user provided pipeline which make sure to continue + // this flow. + // Example: + // middleware filter -> user-middleware1 -> user-middleware2 -> end-middleware -> resouce filters or model binding + nestedAppBuilder.Run(async (httpContext) => + { + var feature = httpContext.Features.Get(); + if (feature == null) + { + throw new InvalidOperationException( + Resources.FormatMiddlewareFilterBuilder_NoMiddlewareFeature(nameof(IMiddlewareFilterFeature))); + } + + var resourceExecutionDelegate = feature.ResourceExecutionDelegate; + + var resourceExecutedContext = await resourceExecutionDelegate(); + + // Ideally we want the experience of a middleware pipeline to behave the same as if it was registered, + // in Startup. In this scenario an exception thrown in a middelware later in the pipeline gets propagated + // back to earlier middleware. + // So check if a later resource filter threw an exception and propagate that back to the middleware pipeline. + if (!resourceExecutedContext.ExceptionHandled && resourceExecutedContext.Exception != null) + { + throw resourceExecutedContext.Exception; + } + }); + + return nestedAppBuilder.Build(); + } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterConfigurationProvider.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterConfigurationProvider.cs new file mode 100644 index 0000000000..49d8e52d95 --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterConfigurationProvider.cs @@ -0,0 +1,117 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Linq; +using System.Reflection; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Mvc.Core; +using Microsoft.Extensions.DependencyInjection; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + /// + /// Calls into user provided 'Configure' methods for configuring a middleware pipeline. The semantics of finding + /// the 'Configure' methods is similar to the application Startup class. + /// + public class MiddlewareFilterConfigurationProvider + { + public Action CreateConfigureDelegate(Type configurationType) + { + if (configurationType == null) + { + throw new ArgumentNullException(nameof(configurationType)); + } + + var instance = Activator.CreateInstance(configurationType); + var configureDelegateBuilder = GetConfigureDelegateBuilder(configurationType); + return configureDelegateBuilder.Build(instance); + } + + private static ConfigureBuilder GetConfigureDelegateBuilder(Type startupType) + { + var configureMethod = FindMethod(startupType, typeof(void)); + return new ConfigureBuilder(configureMethod); + } + + private static MethodInfo FindMethod(Type startupType, Type returnType = null) + { + var methodName = "Configure"; + + var methods = startupType.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static); + var selectedMethods = methods.Where(method => method.Name.Equals(methodName)).ToList(); + if (selectedMethods.Count > 1) + { + throw new InvalidOperationException( + Resources.FormatMiddewareFilter_ConfigureMethodOverload(methodName)); + } + + var methodInfo = selectedMethods.FirstOrDefault(); + if (methodInfo == null) + { + throw new InvalidOperationException( + Resources.FormatMiddewareFilter_NoConfigureMethod( + methodName, + startupType.FullName)); + } + + if (returnType != null && methodInfo.ReturnType != returnType) + { + throw new InvalidOperationException( + Resources.FormatMiddlewareFilter_InvalidConfigureReturnType( + methodInfo.Name, + startupType.FullName, + returnType.Name)); + } + return methodInfo; + } + + private class ConfigureBuilder + { + public ConfigureBuilder(MethodInfo configure) + { + MethodInfo = configure; + } + + public MethodInfo MethodInfo { get; } + + public Action Build(object instance) + { + return (applicationBuilder) => Invoke(instance, applicationBuilder); + } + + private void Invoke(object instance, IApplicationBuilder builder) + { + var serviceProvider = builder.ApplicationServices; + var parameterInfos = MethodInfo.GetParameters(); + var parameters = new object[parameterInfos.Length]; + for (var index = 0; index < parameterInfos.Length; index++) + { + var parameterInfo = parameterInfos[index]; + if (parameterInfo.ParameterType == typeof(IApplicationBuilder)) + { + parameters[index] = builder; + } + else + { + try + { + parameters[index] = serviceProvider.GetRequiredService(parameterInfo.ParameterType); + } + catch (Exception ex) + { + throw new InvalidOperationException( + Resources.FormatMiddlewareFilter_ServiceResolutionFail( + parameterInfo.ParameterType.FullName, + parameterInfo.Name, + MethodInfo.Name, + MethodInfo.DeclaringType.FullName), + ex); + } + } + } + MethodInfo.Invoke(instance, parameters); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterFeature.cs b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterFeature.cs new file mode 100644 index 0000000000..d33378d6f7 --- /dev/null +++ b/src/Microsoft.AspNetCore.Mvc.Core/Internal/MiddlewareFilterFeature.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Mvc.Filters; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + public class MiddlewareFilterFeature : IMiddlewareFilterFeature + { + public ResourceExecutingContext ResourceExecutingContext { get; set; } + + public ResourceExecutionDelegate ResourceExecutionDelegate { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Properties/Resources.Designer.cs b/src/Microsoft.AspNetCore.Mvc.Core/Properties/Resources.Designer.cs index 418c37e1bf..b5cd370b82 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Properties/Resources.Designer.cs +++ b/src/Microsoft.AspNetCore.Mvc.Core/Properties/Resources.Designer.cs @@ -1210,6 +1210,102 @@ namespace Microsoft.AspNetCore.Mvc.Core return string.Format(CultureInfo.CurrentCulture, GetString("OutputFormattersAreRequired"), p0, p1, p2); } + /// + /// Having multiple overloads of method '{0}' is not supported. + /// + internal static string MiddewareFilter_ConfigureMethodOverload + { + get { return GetString("MiddewareFilter_ConfigureMethodOverload"); } + } + + /// + /// Having multiple overloads of method '{0}' is not supported. + /// + internal static string FormatMiddewareFilter_ConfigureMethodOverload(object p0) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddewareFilter_ConfigureMethodOverload"), p0); + } + + /// + /// A public method named '{0}' could not be found in the '{1}' type. + /// + internal static string MiddewareFilter_NoConfigureMethod + { + get { return GetString("MiddewareFilter_NoConfigureMethod"); } + } + + /// + /// A public method named '{0}' could not be found in the '{1}' type. + /// + internal static string FormatMiddewareFilter_NoConfigureMethod(object p0, object p1) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddewareFilter_NoConfigureMethod"), p0, p1); + } + + /// + /// Could not find '{0}' in the feature list. + /// + internal static string MiddlewareFilterBuilder_NoMiddlewareFeature + { + get { return GetString("MiddlewareFilterBuilder_NoMiddlewareFeature"); } + } + + /// + /// Could not find '{0}' in the feature list. + /// + internal static string FormatMiddlewareFilterBuilder_NoMiddlewareFeature(object p0) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddlewareFilterBuilder_NoMiddlewareFeature"), p0); + } + + /// + /// '{0}' property cannot be null. + /// + internal static string MiddlewareFilterBuilder_NullApplicationBuilder + { + get { return GetString("MiddlewareFilterBuilder_NullApplicationBuilder"); } + } + + /// + /// '{0}' property cannot be null. + /// + internal static string FormatMiddlewareFilterBuilder_NullApplicationBuilder(object p0) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddlewareFilterBuilder_NullApplicationBuilder"), p0); + } + + /// + /// The '{0}' method in the type '{1}' must have a return type of '{2}'. + /// + internal static string MiddlewareFilter_InvalidConfigureReturnType + { + get { return GetString("MiddlewareFilter_InvalidConfigureReturnType"); } + } + + /// + /// The '{0}' method in the type '{1}' must have a return type of '{2}'. + /// + internal static string FormatMiddlewareFilter_InvalidConfigureReturnType(object p0, object p1, object p2) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddlewareFilter_InvalidConfigureReturnType"), p0, p1, p2); + } + + /// + /// Could not resolve a service of type '{0}' for the parameter '{1}' of method '{2}' on type '{3}'. + /// + internal static string MiddlewareFilter_ServiceResolutionFail + { + get { return GetString("MiddlewareFilter_ServiceResolutionFail"); } + } + + /// + /// Could not resolve a service of type '{0}' for the parameter '{1}' of method '{2}' on type '{3}'. + /// + internal static string FormatMiddlewareFilter_ServiceResolutionFail(object p0, object p1, object p2, object p3) + { + return string.Format(CultureInfo.CurrentCulture, GetString("MiddlewareFilter_ServiceResolutionFail"), p0, p1, p2, p3); + } + private static string GetString(string name, params string[] formatterNames) { var value = _resourceManager.GetString(name); diff --git a/src/Microsoft.AspNetCore.Mvc.Core/Resources.resx b/src/Microsoft.AspNetCore.Mvc.Core/Resources.resx index 732bba5387..8bb5bc4735 100644 --- a/src/Microsoft.AspNetCore.Mvc.Core/Resources.resx +++ b/src/Microsoft.AspNetCore.Mvc.Core/Resources.resx @@ -352,4 +352,22 @@ '{0}.{1}' must not be empty. At least one '{2}' is required to format a response. + + Multiple overloads of method '{0}' are not supported. + + + A public method named '{0}' could not be found in the '{1}' type. + + + Could not find '{0}' in the feature list. + + + The '{0}' property cannot be null. + + + The '{0}' method in the type '{1}' must have a return type of '{2}'. + + + Could not resolve a service of type '{0}' for the parameter '{1}' of method '{2}' on type '{3}'. + \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Builder/MvcApplicationBuilderExtensionsTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Builder/MvcApplicationBuilderExtensionsTest.cs index c4016e9ef2..8470b7ea25 100644 --- a/test/Microsoft.AspNetCore.Mvc.Core.Test/Builder/MvcApplicationBuilderExtensionsTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Builder/MvcApplicationBuilderExtensionsTest.cs @@ -1,7 +1,7 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; using Moq; using Xunit; diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Filters/MiddlewareFilterAttributeTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Filters/MiddlewareFilterAttributeTest.cs new file mode 100644 index 0000000000..bedfd14f4a --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Filters/MiddlewareFilterAttributeTest.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + public class MiddlewareFilterAttributeTest + { + [Fact] + public void CreatesMiddlewareFilter_WithConfiguredPipeline() + { + // Arrange + var middlewareFilterAttribute = new MiddlewareFilterAttribute(typeof(Pipeline1)); + var services = new ServiceCollection(); + services.AddSingleton(new MiddlewareFilterBuilder(new MiddlewareFilterConfigurationProvider())); + var serviceProvider = services.BuildServiceProvider(); + var filterBuilderService = serviceProvider.GetRequiredService(); + filterBuilderService.ApplicationBuilder = new ApplicationBuilder(serviceProvider); + var configureCallCount = 0; + Pipeline1.ConfigurePipeline = (ab) => + { + configureCallCount++; + ab.Use((httpCtxt, next) => + { + return next(); + }); + }; + + // Act + var filter = middlewareFilterAttribute.CreateInstance(serviceProvider); + + // Assert + var middlewareFilter = Assert.IsType(filter); + Assert.NotNull(middlewareFilter); + Assert.Equal(1, configureCallCount); + } + + private class Pipeline1 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + + private class Pipeline2 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterBuilderTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterBuilderTest.cs new file mode 100644 index 0000000000..9238a6e437 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterBuilderTest.cs @@ -0,0 +1,157 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + public class MiddlewareFilterBuilderTest + { + [Fact] + public void GetPipeline_CallsInto_Configure() + { + // Arrange + var services = new ServiceCollection(); + var appBuilder = new ApplicationBuilder(services.BuildServiceProvider()); + var pipelineBuilderService = new MiddlewareFilterBuilder(new MiddlewareFilterConfigurationProvider()); + pipelineBuilderService.ApplicationBuilder = appBuilder; + var configureCount = 0; + Pipeline1.ConfigurePipeline = (ab) => + { + configureCount++; + }; + + // Act + var pipeline = pipelineBuilderService.GetPipeline(typeof(Pipeline1)); + + // Assert + Assert.NotNull(pipeline); + Assert.Equal(1, configureCount); + } + + [Fact] + public void GetPipeline_CallsIntoConfigure_OnlyOnce_ForTheSamePipelineType() + { + // Arrange + var services = new ServiceCollection(); + var appBuilder = new ApplicationBuilder(services.BuildServiceProvider()); + var pipelineBuilderService = new MiddlewareFilterBuilder(new MiddlewareFilterConfigurationProvider()); + pipelineBuilderService.ApplicationBuilder = appBuilder; + var configureCount = 0; + Pipeline1.ConfigurePipeline = (ab) => + { + configureCount++; + }; + + // Act + var pipeline1 = pipelineBuilderService.GetPipeline(typeof(Pipeline1)); + + // Assert + Assert.NotNull(pipeline1); + Assert.Equal(1, configureCount); + + // Act + var pipeline2 = pipelineBuilderService.GetPipeline(typeof(Pipeline1)); + + // Assert + Assert.NotNull(pipeline2); + Assert.Same(pipeline1, pipeline2); + Assert.Equal(1, configureCount); + } + + [Fact] + public async Task EndMiddleware_ThrowsException_WhenMiddleFeature_NotAvailable() + { + // Arrange + var services = new ServiceCollection(); + var appBuilder = new ApplicationBuilder(services.BuildServiceProvider()); + var pipelineBuilderService = new MiddlewareFilterBuilder(new MiddlewareFilterConfigurationProvider()); + pipelineBuilderService.ApplicationBuilder = appBuilder; + Pipeline1.ConfigurePipeline = (ab) => + { + ab.Use((httpContext, next) => + { + return next(); + }); + }; + + // Act + var pipeline = pipelineBuilderService.GetPipeline(typeof(Pipeline1)); + + // Assert + Assert.NotNull(pipeline); + var exception = await Assert.ThrowsAsync(() => pipeline(new DefaultHttpContext())); + Assert.Equal( + "Could not find 'IMiddlewareFilterFeature' in the feature list.", + exception.Message); + } + + [Fact] + public async Task EndMiddleware_PropagatesBackException_ToEarlierMiddleware() + { + // Arrange + var services = new ServiceCollection(); + var appBuilder = new ApplicationBuilder(services.BuildServiceProvider()); + var pipelineBuilderService = new MiddlewareFilterBuilder(new MiddlewareFilterConfigurationProvider()); + pipelineBuilderService.ApplicationBuilder = appBuilder; + Pipeline1.ConfigurePipeline = (ab) => + { + ab.Use((httpCtxt, next) => + { + return next(); + }); + }; + var middlewareFilterFeature = new MiddlewareFilterFeature(); + middlewareFilterFeature.ResourceExecutionDelegate = () => + { + var context = new ResourceExecutedContext( + new ActionContext(new DefaultHttpContext(), new RouteData(), new ActionDescriptor(), new ModelStateDictionary()), + new List()); + context.Exception = new InvalidOperationException("Error!!!"); + return Task.FromResult(context); + }; + var features = new FeatureCollection(); + features.Set(middlewareFilterFeature); + var httpContext = new DefaultHttpContext(features); + + // Act + var pipeline = pipelineBuilderService.GetPipeline(typeof(Pipeline1)); + + // Assert + Assert.NotNull(pipeline); + var exception = await Assert.ThrowsAsync(() => pipeline(httpContext)); + Assert.Equal("Error!!!", exception.Message); + } + private class Pipeline1 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + + private class Pipeline2 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterConfigurationProviderTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterConfigurationProviderTest.cs new file mode 100644 index 0000000000..d5e2d77287 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterConfigurationProviderTest.cs @@ -0,0 +1,178 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Castle.Core.Logging; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + public class MiddlewareFilterConfigurationProviderTest + { + [Fact] + public void ValidConfigure_DoesNotThrow() + { + // Arrange + var provider = new MiddlewareFilterConfigurationProvider(); + + // Act + var configureDelegate = provider.CreateConfigureDelegate(typeof(ValidConfigure_WithNoEnvironment)); + + // Assert + Assert.NotNull(configureDelegate); + } + + [Fact] + public void ValidConfigure_AndAdditionalServices_DoesNotThrow() + { + // Arrange + var loggerFactory = Mock.Of(); + var services = new ServiceCollection(); + services.AddSingleton(loggerFactory); + services.AddSingleton(Mock.Of()); + var applicationBuilder = GetApplicationBuilder(services); + var provider = new MiddlewareFilterConfigurationProvider(); + + // Act + var configureDelegate = provider.CreateConfigureDelegate(typeof(ValidConfigure_WithNoEnvironment_AdditionalServices)); + + // Assert + Assert.NotNull(configureDelegate); + } + + [Fact] + public void InvalidType_NoConfigure_Throws() + { + // Arrange + var type = typeof(InvalidType_NoConfigure); + var provider = new MiddlewareFilterConfigurationProvider(); + var expected = $"A public method named 'Configure' could not be found in the '{type.FullName}' type."; + + // Act & Assert + var exception = Assert.Throws(() => + { + provider.CreateConfigureDelegate(type); + }); + Assert.Equal(expected, exception.Message); + } + + [Fact] + public void InvalidType_NoPublicConfigure_Throws() + { + // Arrange + var type = typeof(InvalidType_NoPublic_Configure); + var provider = new MiddlewareFilterConfigurationProvider(); + var expected = $"A public method named 'Configure' could not be found in the '{type.FullName}' type."; + + // Act & Assert + var exception = Assert.Throws(() => + { + provider.CreateConfigureDelegate(type); + }); + Assert.Equal(expected, exception.Message); + } + + private IApplicationBuilder GetApplicationBuilder(ServiceCollection services = null) + { + if (services == null) + { + services = new ServiceCollection(); + } + var serviceProvider = services.BuildServiceProvider(); + + var applicationBuilder = new Mock(); + applicationBuilder + .SetupGet(a => a.ApplicationServices) + .Returns(serviceProvider); + + return applicationBuilder.Object; + } + + private class ValidConfigure_WithNoEnvironment + { + public void Configure(IApplicationBuilder appBuilder) { } + } + + private class ValidConfigure_WithNoEnvironment_AdditionalServices + { + public void Configure( + IApplicationBuilder appBuilder, + IHostingEnvironment hostingEnvironment, + ILoggerFactory loggerFactory) + { + if (hostingEnvironment == null) + { + throw new ArgumentNullException(nameof(hostingEnvironment)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + } + } + + private class ValidConfigure_WithEnvironment + { + public void ConfigureProduction(IApplicationBuilder appBuilder) { } + } + + private class ValidConfigure_WithEnvironment_AdditionalServices + { + public void ConfigureProduction( + IApplicationBuilder appBuilder, + IHostingEnvironment hostingEnvironment, + ILoggerFactory loggerFactory) + { + if (hostingEnvironment == null) + { + throw new ArgumentNullException(nameof(hostingEnvironment)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + } + } + + private class MultipleConfigureWithEnvironments + { + public void ConfigureDevelopment(IApplicationBuilder appBuilder) + { + + } + + public void ConfigureProduction(IApplicationBuilder appBuilder) + { + + } + } + + private class InvalidConfigure_NoParameters + { + public void Configure() + { + + } + } + + private class InvalidType_NoConfigure + { + public void Foo(IApplicationBuilder appBuilder) + { + + } + } + + private class InvalidType_NoPublic_Configure + { + private void Configure(IApplicationBuilder appBuilder) + { + + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterTest.cs b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterTest.cs new file mode 100644 index 0000000000..68b63d1153 --- /dev/null +++ b/test/Microsoft.AspNetCore.Mvc.Core.Test/Internal/MiddlewareFilterTest.cs @@ -0,0 +1,516 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Diagnostics; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc.Abstractions; +using Microsoft.AspNetCore.Mvc.Controllers; +using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; +using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; +using Microsoft.Extensions.Options; +using Moq; +using Xunit; + +namespace Microsoft.AspNetCore.Mvc.Internal +{ + public class MiddlewareFilterTest + { + private readonly TestController _controller = new TestController(); + + [Fact] + public async Task MiddlewareFilter_SetsMiddlewareFilterFeature_OnExecution() + { + // Arrange + RequestDelegate requestDelegate = (context) => Task.FromResult(true); + var middlwareFilter = new MiddlewareFilter(requestDelegate); + var httpContext = new DefaultHttpContext(); + var resourceExecutingContext = GetResourceExecutingContext(httpContext); + var resourceExecutionDelegate = GetResourceExecutionDelegate(httpContext); + + // Act + await middlwareFilter.OnResourceExecutionAsync(resourceExecutingContext, resourceExecutionDelegate); + + // Assert + var feature = resourceExecutingContext.HttpContext.Features.Get(); + Assert.NotNull(feature); + Assert.Same(resourceExecutingContext, feature.ResourceExecutingContext); + Assert.Same(resourceExecutionDelegate, feature.ResourceExecutionDelegate); + } + + [Fact] + public async Task OnMiddlewareShortCircuit_DoesNotExecute_RestOfFilterPipeline() + { + // Arrange + var expectedHeader = "h1"; + Pipeline1.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use((httpContext, next) => + { + httpContext.Response.Headers.Add(expectedHeader, ""); + return Task.FromResult(true); // short circuit the request + }); + }; + var resourceFilter1 = new TestResourceFilter(TestResourceFilterAction.Passthrough); + var middlewareResourceFilter = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline1))); + var exceptionThrowingResourceFilter = new TestResourceFilter(TestResourceFilterAction.ThrowException); + + var invoker = CreateInvoker( + new IFilterMetadata[] + { + resourceFilter1, + middlewareResourceFilter, + exceptionThrowingResourceFilter, + }, + actionThrows: true); // The action won't run + + // Act + await invoker.InvokeAsync(); + + // Assert + var resourceExecutedContext = resourceFilter1.ResourceExecutedContext; + Assert.True(resourceExecutedContext.HttpContext.Response.Headers.ContainsKey(expectedHeader)); + Assert.True(resourceExecutedContext.Canceled); + Assert.False(invoker.ControllerFactory.CreateCalled); + } + + // Example: Middleware filters are applied at Global, Controller & Action level + [Fact] + public async Task Multiple_MiddlewareFilters_ConcatsTheMiddlewarePipelines() + { + // Arrange + var expectedHeader = "h1"; + var expectedHeaderValue = "pipeline1-pipeline2"; + Pipeline1.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use((httpContext, next) => + { + httpContext.Response.Headers["h1"] = "pipeline1"; + return next(); + }); + }; + Pipeline2.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use((httpContext, next) => + { + httpContext.Response.Headers["h1"] = httpContext.Response.Headers["h1"] + "-pipeline2"; + return Task.FromResult(true); // short circuits the request + }); + }; + var resourceFilter1 = new TestResourceFilter(TestResourceFilterAction.Passthrough); + var middlewareResourceFilter1 = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline1))); + var middlewareResourceFilter2 = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline2))); + var exceptionThrowingResourceFilter = new TestResourceFilter(TestResourceFilterAction.ThrowException); + + var invoker = CreateInvoker( + new IFilterMetadata[] + { + resourceFilter1, // This filter will pass through + middlewareResourceFilter1, // This filter will pass through + middlewareResourceFilter2, // This filter will short circuit + exceptionThrowingResourceFilter, // This shouldn't run + }, + actionThrows: true); // The action won't run + + // Act + await invoker.InvokeAsync(); + + // Assert + var resourceExecutedContext = resourceFilter1.ResourceExecutedContext; + var response = resourceExecutedContext.HttpContext.Response; + Assert.True(response.Headers.ContainsKey(expectedHeader)); + Assert.Equal(expectedHeaderValue, response.Headers[expectedHeader]); + Assert.True(resourceExecutedContext.Canceled); + Assert.False(invoker.ControllerFactory.CreateCalled); + } + + [Fact] + public async Task UnhandledException_InMiddleware_PropagatesBackToInvoker() + { + // Arrange + var expectedMessage = "Error!!!"; + Pipeline1.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use((httpContext, next) => + { + throw new InvalidOperationException(expectedMessage); + }); + }; + var resourceFilter1 = new TestResourceFilter(TestResourceFilterAction.Passthrough); + var middlewareResourceFilter = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline1))); + var exceptionThrowingResourceFilter = new TestResourceFilter(TestResourceFilterAction.ThrowException); + + var invoker = CreateInvoker( + new IFilterMetadata[] + { + resourceFilter1, + middlewareResourceFilter, + exceptionThrowingResourceFilter, // This shouldn't run + }, + actionThrows: true); // The action won't run + + // Act + var exception = await Assert.ThrowsAsync(async () => await invoker.InvokeAsync()); + + // Assert + Assert.Equal(expectedMessage, exception.Message); + } + + [Fact] + public async Task ExceptionThrownInMiddleware_CanBeHandled_ByEarlierMiddleware() + { + // Arrange + var expectedMessage = "Error!!!"; + Pipeline1.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use(async (httpContext, next) => + { + try + { + await next(); + } + catch + { + httpContext.Response.StatusCode = 500; + httpContext.Response.Headers.Add("Error", "Error!!!!"); + } + }); + }; + Pipeline2.ConfigurePipeline = (appBuilder) => + { + appBuilder.Use((httpContext, next) => + { + throw new InvalidOperationException(expectedMessage); + }); + }; + var resourceFilter1 = new TestResourceFilter(TestResourceFilterAction.Passthrough); + var middlewareResourceFilter1 = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline1))); + var middlewareResourceFilter2 = new MiddlewareFilter(GetMiddlewarePipeline(typeof(Pipeline2))); + var exceptionThrowingResourceFilter = new TestResourceFilter(TestResourceFilterAction.ThrowException); + + var invoker = CreateInvoker( + new IFilterMetadata[] + { + resourceFilter1, + middlewareResourceFilter1, + middlewareResourceFilter2, + exceptionThrowingResourceFilter, // This shouldn't run + }, + actionThrows: true); // The action won't run + + // Act + var exception = await Assert.ThrowsAsync(async () => await invoker.InvokeAsync()); + + // Assert + var resourceExecutedContext = resourceFilter1.ResourceExecutedContext; + var response = resourceExecutedContext.HttpContext.Response; + Assert.Equal(500, response.StatusCode); + Assert.True(response.Headers.ContainsKey("Error")); + Assert.False(invoker.ControllerFactory.CreateCalled); + } + + private ResourceExecutingContext GetResourceExecutingContext(HttpContext httpContext) + { + return new ResourceExecutingContext( + new ActionContext(httpContext, new RouteData(), new ActionDescriptor(), new ModelStateDictionary()), + new List(), + new List()); + } + + private ResourceExecutionDelegate GetResourceExecutionDelegate(HttpContext httpContext) + { + return new ResourceExecutionDelegate( + () => Task.FromResult(new ResourceExecutedContext(new ActionContext(), new List()))); + } + + private TestControllerActionInvoker CreateInvoker( + IFilterMetadata[] filters, + bool actionThrows = false) + { + var actionDescriptor = new ControllerActionDescriptor() + { + FilterDescriptors = new List(), + Parameters = new List(), + }; + + if (actionThrows) + { + actionDescriptor.MethodInfo = typeof(ControllerActionInvokerTest).GetMethod( + nameof(ControllerActionInvokerTest.ThrowingActionMethod)); + } + else + { + actionDescriptor.MethodInfo = typeof(ControllerActionInvokerTest).GetMethod( + nameof(ControllerActionInvokerTest.ActionMethod)); + } + actionDescriptor.ControllerTypeInfo = typeof(ControllerActionInvokerTest).GetTypeInfo(); + + return CreateInvoker(filters, actionDescriptor, _controller); + } + + private TestControllerActionInvoker CreateInvoker( + IFilterMetadata[] filters, + ControllerActionDescriptor actionDescriptor, + object controller) + { + var httpContext = GetHttpContext(); + httpContext.Response.Body = new MemoryStream(); + + var options = new MvcOptions(); + var optionsAccessor = new Mock>(); + optionsAccessor + .SetupGet(o => o.Value) + .Returns(options); + + var actionContext = new ActionContext(httpContext, new RouteData(), actionDescriptor); + + var filterProvider = new Mock(MockBehavior.Strict); + filterProvider + .Setup(fp => fp.OnProvidersExecuting(It.IsAny())) + .Callback(context => + { + foreach (var filterMetadata in filters) + { + context.Results.Add(new FilterItem(new FilterDescriptor(filterMetadata, FilterScope.Action)) + { + Filter = filterMetadata, + }); + } + }); + + filterProvider + .Setup(fp => fp.OnProvidersExecuted(It.IsAny())) + .Verifiable(); + + filterProvider + .SetupGet(fp => fp.Order) + .Returns(-1000); + + var diagnosticSource = new DiagnosticListener("Microsoft.AspNetCore"); + diagnosticSource.SubscribeWithAdapter(new TestDiagnosticListener()); + + var invoker = new TestControllerActionInvoker( + new[] { filterProvider.Object }, + new MockControllerFactory(controller ?? this), + new TestControllerArgumentBinder(actionParameters: null), + new NullLoggerFactory().CreateLogger(), + diagnosticSource, + actionContext, + new List(), + maxAllowedErrorsInModelState: 200); + return invoker; + } + + private class Pipeline1 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + + private class Pipeline2 + { + public static Action ConfigurePipeline { get; set; } + + public void Configure(IApplicationBuilder appBuilder) + { + ConfigurePipeline(appBuilder); + } + } + + private static HttpContext GetHttpContext() + { + var services = CreateServices(); + + var httpContext = new DefaultHttpContext(); + httpContext.RequestServices = services.BuildServiceProvider(); + + return httpContext; + } + + private RequestDelegate GetMiddlewarePipeline(Type middlewarePipelineProviderType) + { + var applicationServices = new ServiceCollection(); + var applicationBuilder = new ApplicationBuilder(applicationServices.BuildServiceProvider()); + var middlewareFilterBuilderService = new MiddlewareFilterBuilder( + new MiddlewareFilterConfigurationProvider()); + middlewareFilterBuilderService.ApplicationBuilder = applicationBuilder; + return middlewareFilterBuilderService.GetPipeline(middlewarePipelineProviderType); + } + + private static IServiceCollection CreateServices() + { + var services = new ServiceCollection(); + + services.AddSingleton(NullLoggerFactory.Instance); + return services; + } + + private class MockControllerFactory : IControllerFactory + { + private object _controller; + + public MockControllerFactory(object controller) + { + _controller = controller; + } + + public bool CreateCalled { get; private set; } + + public bool ReleaseCalled { get; private set; } + + public ControllerContext ControllerContext { get; private set; } + + public object CreateController(ControllerContext context) + { + ControllerContext = context; + CreateCalled = true; + return _controller; + } + + public void ReleaseController(ControllerContext context, object controller) + { + Assert.NotNull(controller); + Assert.Same(_controller, controller); + ReleaseCalled = true; + } + + public void Verify() + { + if (CreateCalled && !ReleaseCalled) + { + Assert.False(true, "ReleaseController should have been called."); + } + } + } + + private static ControllerActionInvokerCache CreateFilterCache(IFilterProvider[] filterProviders = null) + { + var services = new ServiceCollection().BuildServiceProvider(); + var descriptorProvider = new ActionDescriptorCollectionProvider(services); + return new ControllerActionInvokerCache( + descriptorProvider, + filterProviders.AsEnumerable() ?? new List()); + } + + private class TestControllerActionInvoker : ControllerActionInvoker + { + public TestControllerActionInvoker( + IFilterProvider[] filterProviders, + MockControllerFactory controllerFactory, + IControllerArgumentBinder argumentBinder, + ILogger logger, + DiagnosticSource diagnosticSource, + ActionContext actionContext, + IReadOnlyList valueProviderFactories, + int maxAllowedErrorsInModelState) + : base( + CreateFilterCache(filterProviders), + controllerFactory, + argumentBinder, + logger, + diagnosticSource, + actionContext, + valueProviderFactories, + maxAllowedErrorsInModelState) + { + ControllerFactory = controllerFactory; + } + + public MockControllerFactory ControllerFactory { get; } + + public async override Task InvokeAsync() + { + await base.InvokeAsync(); + + // Make sure that the controller was disposed in every test that creates ones. + ControllerFactory.Verify(); + } + } + + private class TestControllerArgumentBinder : IControllerArgumentBinder + { + private readonly IDictionary _actionParameters; + public TestControllerArgumentBinder(IDictionary actionParameters) + { + _actionParameters = actionParameters; + } + + public Task BindArgumentsAsync( + ControllerContext controllerContext, + object controller, + IDictionary arguments) + { + foreach (var entry in _actionParameters) + { + arguments.Add(entry.Key, entry.Value); + } + + return TaskCache.CompletedTask; + } + } + + private sealed class TestController + { + } + + + private enum TestResourceFilterAction + { + ShortCircuit, + ThrowException, + Passthrough + } + + private class TestResourceFilter : IAsyncResourceFilter + { + private readonly TestResourceFilterAction _action; + public TestResourceFilter(TestResourceFilterAction action) + { + _action = action; + } + + public ResourceExecutedContext ResourceExecutedContext { get; private set; } + + public async Task OnResourceExecutionAsync(ResourceExecutingContext context, ResourceExecutionDelegate next) + { + if (_action == TestResourceFilterAction.ThrowException) + { + throw new NotImplementedException("This filter should not have been run!"); + + } + else if (_action == TestResourceFilterAction.Passthrough) + { + ResourceExecutedContext = await next(); + } + else + { + context.Result = new TestActionResult(); + } + } + } + + public class TestActionResult : IActionResult + { + public Task ExecuteResultAsync(ActionContext context) + { + context.HttpContext.Response.StatusCode = 200; + return context.HttpContext.Response.WriteAsync("Shortcircuited"); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/FiltersTest.cs b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/FiltersTest.cs index ecb7c631fb..bbc00f9c4c 100644 --- a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/FiltersTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/FiltersTest.cs @@ -536,5 +536,23 @@ namespace Microsoft.AspNetCore.Mvc.FunctionalTests Assert.Equal("text/plain", response.Content.Headers.ContentType.MediaType); Assert.Equal("Data:10", await response.Content.ReadAsStringAsync()); } + + [Theory] + [InlineData("en-US", "en-US")] + [InlineData("fr", "fr")] + [InlineData("ab-cd", "en-US")] + public async Task MiddlewareFilter_LocalizationMiddlewareRegistration_UsesRouteDataToFindCulture( + string culture, + string expected) + { + // Arrange & Act + var response = await Client.GetAsync($"http://localhost/{culture}/MiddlewareFilterTest/CultureFromRouteData"); + + // Assert + Assert.Equal(HttpStatusCode.OK, response.StatusCode); + Assert.Equal( + $"CurrentCulture:{expected},CurrentUICulture:{expected}", + await response.Content.ReadAsStringAsync()); + } } } \ No newline at end of file diff --git a/test/WebSites/FiltersWebSite/Controllers/MiddlewareFilterTestController.cs b/test/WebSites/FiltersWebSite/Controllers/MiddlewareFilterTestController.cs new file mode 100644 index 0000000000..0b00c23b1c --- /dev/null +++ b/test/WebSites/FiltersWebSite/Controllers/MiddlewareFilterTestController.cs @@ -0,0 +1,25 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Globalization; +using System.Linq; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace FiltersWebSite +{ + public class MiddlewareFilterTestController : Controller + { + [Route("{culture}/[controller]/[action]")] + [MiddlewareFilter(typeof(LocalizationPipeline))] + public IActionResult CultureFromRouteData() + { + return Content($"CurrentCulture:{CultureInfo.CurrentCulture.Name},CurrentUICulture:{CultureInfo.CurrentUICulture.Name}"); + } + } +} diff --git a/test/WebSites/FiltersWebSite/Filters/TestResourceFilter.cs b/test/WebSites/FiltersWebSite/Filters/TestResourceFilter.cs new file mode 100644 index 0000000000..6d7f8e0715 --- /dev/null +++ b/test/WebSites/FiltersWebSite/Filters/TestResourceFilter.cs @@ -0,0 +1,95 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Mvc; +using Microsoft.AspNetCore.Mvc.Filters; + +namespace FiltersWebSite +{ + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + public class TestSyncResourceFilter : Attribute, IResourceFilter, IOrderedFilter + { + public enum Action + { + PassThrough, + ThrowException, + Shortcircuit + } + + public readonly string ExceptionMessage = $"Error!! in {nameof(TestSyncResourceFilter)}"; + + public readonly string ShortcircuitMessage = $"Shortcircuited by {nameof(TestSyncResourceFilter)}"; + + public Action FilterAction { get; set; } + + public int Order { get; set; } + + public void OnResourceExecuted(ResourceExecutedContext context) + { + } + + public void OnResourceExecuting(ResourceExecutingContext context) + { + if (FilterAction == Action.PassThrough) + { + return; + } + else if (FilterAction == Action.ThrowException) + { + throw new InvalidOperationException(ExceptionMessage); + } + else + { + context.Result = new ContentResult() + { + Content = ShortcircuitMessage, + StatusCode = 400, + ContentType = "text/abcd" + }; + } + } + } + + [AttributeUsage(AttributeTargets.Class | AttributeTargets.Method, AllowMultiple = true, Inherited = true)] + public class TestAsyncResourceFilter : Attribute, IAsyncResourceFilter, IOrderedFilter + { + public enum Action + { + PassThrough, + ThrowException, + Shortcircuit + } + + public readonly string ExceptionMessage = $"Error!! in {nameof(TestAsyncResourceFilter)}"; + + public readonly string ShortcircuitMessage = $"Shortcircuited by {nameof(TestAsyncResourceFilter)}"; + + public Action FilterAction { get; set; } + + public int Order { get; set; } + + public Task OnResourceExecutionAsync(ResourceExecutingContext context, ResourceExecutionDelegate next) + { + if (FilterAction == Action.PassThrough) + { + return next(); + } + else if (FilterAction == Action.ThrowException) + { + throw new InvalidOperationException(ExceptionMessage); + } + else + { + context.Result = new ContentResult() + { + Content = ShortcircuitMessage, + StatusCode = 400, + ContentType = "text/abcd" + }; + return Task.FromResult(true); + } + } + } +} diff --git a/test/WebSites/FiltersWebSite/LocalizationPipeline.cs b/test/WebSites/FiltersWebSite/LocalizationPipeline.cs new file mode 100644 index 0000000000..d61e355c66 --- /dev/null +++ b/test/WebSites/FiltersWebSite/LocalizationPipeline.cs @@ -0,0 +1,32 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Globalization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Localization; +using Microsoft.AspNetCore.Mvc.Localization; + +namespace FiltersWebSite +{ + public class LocalizationPipeline + { + public void Configure(IApplicationBuilder applicationBuilder) + { + var supportedCultures = new[] + { + new CultureInfo("en-US"), + new CultureInfo("fr") + }; + + var options = new RequestLocalizationOptions() + { + DefaultRequestCulture = new RequestCulture(culture: "en-US", uiCulture: "en-US"), + SupportedCultures = supportedCultures, + SupportedUICultures = supportedCultures + }; + options.RequestCultureProviders = new[] { new RouteDataRequestCultureProvider() { Options = options } }; + + applicationBuilder.UseRequestLocalization(options); + } + } +} diff --git a/test/WebSites/FiltersWebSite/RouteDataRequestCultureProvider.cs b/test/WebSites/FiltersWebSite/RouteDataRequestCultureProvider.cs new file mode 100644 index 0000000000..0f005d5498 --- /dev/null +++ b/test/WebSites/FiltersWebSite/RouteDataRequestCultureProvider.cs @@ -0,0 +1,72 @@ +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing; +using Microsoft.AspNetCore.Localization; + +namespace FiltersWebSite +{ + /// + /// Determines the culture information for a request via values in the route data. + /// + public class RouteDataRequestCultureProvider : RequestCultureProvider + { + /// + /// The key that contains the culture name. + /// Defaults to "culture". + /// + public string RouteDataStringKey { get; set; } = "culture"; + + /// + /// The key that contains the UI culture name. If not specified or no value is found, + /// will be used. + /// Defaults to "ui-culture". + /// + public string UIRouteDataStringKey { get; set; } = "ui-culture"; + + /// + public override Task DetermineProviderCultureResult(HttpContext httpContext) + { + if (httpContext == null) + { + throw new ArgumentNullException(nameof(httpContext)); + } + + string culture = null; + string uiCulture = null; + + if (!string.IsNullOrWhiteSpace(RouteDataStringKey)) + { + culture = httpContext.GetRouteValue(RouteDataStringKey) as string; + } + + if (!string.IsNullOrWhiteSpace(UIRouteDataStringKey)) + { + uiCulture = httpContext.GetRouteValue(UIRouteDataStringKey) as string; + } + + if (culture == null && uiCulture == null) + { + // No values specified for either so no match + return Task.FromResult((ProviderCultureResult)null); + } + + if (culture != null && uiCulture == null) + { + // Value for culture but not for UI culture so default to culture value for both + uiCulture = culture; + } + + if (culture == null && uiCulture != null) + { + // Value for UI culture but not for culture so default to UI culture value for both + culture = uiCulture; + } + + var providerResultCulture = new ProviderCultureResult(culture, uiCulture); + + return Task.FromResult(providerResultCulture); + } + } + +}