diff --git a/src/Microsoft.AspNet.Http.Extensions/UseMiddlewareExtensions.cs b/src/Microsoft.AspNet.Http.Extensions/UseMiddlewareExtensions.cs index 1e61139153..8a9de346ab 100644 --- a/src/Microsoft.AspNet.Http.Extensions/UseMiddlewareExtensions.cs +++ b/src/Microsoft.AspNet.Http.Extensions/UseMiddlewareExtensions.cs @@ -5,6 +5,8 @@ using System; using System.Linq; using System.Reflection; using Microsoft.Framework.DependencyInjection; +using System.Threading.Tasks; +using Microsoft.AspNet.Http; namespace Microsoft.AspNet.Builder { @@ -17,12 +19,41 @@ namespace Microsoft.AspNet.Builder public static IApplicationBuilder UseMiddleware(this IApplicationBuilder builder, Type middleware, params object[] args) { + var applicationServices = builder.ApplicationServices; return builder.Use(next => { - var typeActivator = builder.ApplicationServices.GetRequiredService(); + var typeActivator = applicationServices.GetRequiredService(); var instance = typeActivator.CreateInstance(builder.ApplicationServices, middleware, new[] { next }.Concat(args).ToArray()); var methodinfo = middleware.GetMethod("Invoke", BindingFlags.Instance | BindingFlags.Public); - return (RequestDelegate)methodinfo.CreateDelegate(typeof(RequestDelegate), instance); + var parameters = methodinfo.GetParameters(); + if (parameters[0].ParameterType != typeof(HttpContext)) + { + throw new Exception("TODO: Middleware Invoke method must take first argument of HttpContext"); + } + if (parameters.Length == 1) + { + return (RequestDelegate)methodinfo.CreateDelegate(typeof(RequestDelegate), instance); + } + return context => + { + var serviceProvider = context.RequestServices ?? context.ApplicationServices ?? applicationServices; + if (serviceProvider == null) + { + throw new Exception("TODO: IServiceProvider is not available"); + } + var arguments = new object[parameters.Length]; + arguments[0] = context; + for(var index = 1; index != parameters.Length; ++index) + { + var serviceType = parameters[index].ParameterType; + arguments[index] = serviceProvider.GetService(serviceType); + if (arguments[index] == null) + { + throw new Exception(string.Format("TODO: No service for type '{0}' has been registered.", serviceType)); + } + } + return (Task)methodinfo.Invoke(instance, arguments); + }; }); } } diff --git a/src/Microsoft.AspNet.Http/Extensions/RunExtensions.cs b/src/Microsoft.AspNet.Http/Extensions/RunExtensions.cs index 694dd3ec0f..4b74158aae 100644 --- a/src/Microsoft.AspNet.Http/Extensions/RunExtensions.cs +++ b/src/Microsoft.AspNet.Http/Extensions/RunExtensions.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNet.Http; +using System.Threading.Tasks; namespace Microsoft.AspNet.Builder { @@ -12,5 +13,25 @@ namespace Microsoft.AspNet.Builder { app.Use(_ => handler); } + + public static void Run(this IApplicationBuilder app, Func handler) + { + app.Use((ctx, _, s1) => handler(ctx, s1)); + } + + public static void Run(this IApplicationBuilder app, Func handler) + { + app.Use((ctx, _, s1, s2) => handler(ctx, s1, s2)); + } + + public static void Run(this IApplicationBuilder app, Func handler) + { + app.Use((ctx, _, s1, s2, s3) => handler(ctx, s1, s2, s3)); + } + + public static void Run(this IApplicationBuilder app, Func handler) + { + app.Use((ctx, _, s1, s2, s3, s4) => handler(ctx, s1, s2, s3, s4)); + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNet.Http/Extensions/UseExtensions.cs b/src/Microsoft.AspNet.Http/Extensions/UseExtensions.cs index 038b14e03e..4be750be47 100644 --- a/src/Microsoft.AspNet.Http/Extensions/UseExtensions.cs +++ b/src/Microsoft.AspNet.Http/Extensions/UseExtensions.cs @@ -26,5 +26,125 @@ namespace Microsoft.AspNet.Builder }; }); } + + /// + /// Use middleware defined in-line + /// + /// Per-request service required by middleware + /// + /// A function that handles the request or calls the given next function. + /// + public static IApplicationBuilder Use(this IApplicationBuilder app, Func, TService1, Task> middleware) + { + var applicationServices = app.ApplicationServices; + return app.Use(next => context => + { + var serviceProvider = context.RequestServices ?? context.ApplicationServices ?? applicationServices; + if (serviceProvider == null) + { + throw new Exception("TODO: IServiceProvider is not available"); + } + return middleware( + context, + () => next(context), + GetRequiredService(serviceProvider)); + }); + } + + /// + /// Use middleware defined in-line + /// + /// Per-request service required by middleware + /// Per-request service required by middleware + /// + /// A function that handles the request or calls the given next function. + /// + public static IApplicationBuilder Use(this IApplicationBuilder app, Func, TService1, TService2, Task> middleware) + { + var applicationServices = app.ApplicationServices; + return app.Use(next => context => + { + var serviceProvider = context.RequestServices ?? context.ApplicationServices ?? applicationServices; + if (serviceProvider == null) + { + throw new Exception("TODO: IServiceProvider is not available"); + } + return middleware( + context, + () => next(context), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider)); + }); + } + + /// + /// Use middleware defined in-line + /// + /// Per-request service required by middleware + /// Per-request service required by middleware + /// Per-request service required by middleware + /// + /// A function that handles the request or calls the given next function. + /// + public static IApplicationBuilder Use(this IApplicationBuilder app, Func, TService1, TService2, TService3, Task> middleware) + { + var applicationServices = app.ApplicationServices; + return app.Use(next => context => + { + var serviceProvider = context.RequestServices ?? context.ApplicationServices ?? applicationServices; + if (serviceProvider == null) + { + throw new Exception("TODO: IServiceProvider is not available"); + } + return middleware( + context, + () => next(context), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider)); + }); + } + + /// + /// Use middleware defined in-line + /// + /// Per-request service required by middleware + /// Per-request service required by middleware + /// Per-request service required by middleware + /// Per-request service required by middleware + /// + /// A function that handles the request or calls the given next function. + /// + public static IApplicationBuilder Use(this IApplicationBuilder app, Func, TService1, TService2, TService3, TService4, Task> middleware) + { + var applicationServices = app.ApplicationServices; + return app.Use(next => context => + { + var serviceProvider = context.RequestServices ?? context.ApplicationServices ?? applicationServices; + if (serviceProvider == null) + { + throw new Exception("TODO: IServiceProvider is not available"); + } + return middleware( + context, + () => next(context), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider), + GetRequiredService(serviceProvider)); + }); + } + + private static TService GetRequiredService(IServiceProvider serviceProvider) + { + var service = (TService)serviceProvider.GetService(typeof(TService)); + + if (service == null) + { + throw new Exception(string.Format("TODO: No service for type '{0}' has been registered.", typeof(TService))); + } + + return service; + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNet.Http.Extensions.Tests/UseWithServicesTests.cs b/test/Microsoft.AspNet.Http.Extensions.Tests/UseWithServicesTests.cs new file mode 100644 index 0000000000..9dc54ab09e --- /dev/null +++ b/test/Microsoft.AspNet.Http.Extensions.Tests/UseWithServicesTests.cs @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Xunit; +using Microsoft.AspNet.Builder; +using Microsoft.Framework.DependencyInjection; +using Microsoft.Framework.DependencyInjection.Fallback; +using Microsoft.AspNet.PipelineCore; +using System.Collections.Generic; +using System.Threading.Tasks; + +namespace Microsoft.AspNet.Http.Extensions.Tests +{ + public class UseWithServicesTests + { + [Fact] + public async Task CallingUseThatAlsoTakesServices() + { + var builder = new ApplicationBuilder(new ServiceCollection() + .AddScoped() + .BuildServiceProvider()); + + ITestService theService = null; + builder.Use(async (ctx, next, testService) => + { + theService = testService; + await next(); + }); + + var app = builder.Build(); + await app(new DefaultHttpContext()); + + Assert.IsType(theService); + } + + [Fact] + public async Task ServicesArePerRequest() + { + var services = new ServiceCollection() + .AddScoped() + .AddTransient() + .BuildServiceProvider(); + var builder = new ApplicationBuilder(services); + + builder.Use(async (ctx, next) => + { + var serviceScopeFactory = services.GetRequiredService(); + using (var serviceScope = serviceScopeFactory.CreateScope()) + { + var priorApplicationServices = ctx.ApplicationServices; + var priorRequestServices = ctx.ApplicationServices; + ctx.ApplicationServices = services; + ctx.RequestServices = serviceScope.ServiceProvider; + try + { + await next(); + } + finally + { + ctx.ApplicationServices = priorApplicationServices; + ctx.RequestServices = priorRequestServices; + } + } + }); + + var testServicesA = new List(); + builder.Use(async (HttpContext ctx, Func next, ITestService testService) => + { + testServicesA.Add(testService); + await next(); + }); + + var testServicesB = new List(); + builder.Use(async (ctx, next, testService) => + { + testServicesB.Add(testService); + await next(); + }); + + var app = builder.Build(); + await app(new DefaultHttpContext()); + await app(new DefaultHttpContext()); + + Assert.Equal(2, testServicesA.Count); + Assert.IsType(testServicesA[0]); + Assert.IsType(testServicesA[1]); + + Assert.Equal(2, testServicesB.Count); + Assert.IsType(testServicesB[0]); + Assert.IsType(testServicesB[1]); + + Assert.Same(testServicesA[0], testServicesB[0]); + Assert.Same(testServicesA[1], testServicesB[1]); + + Assert.NotSame(testServicesA[0], testServicesA[1]); + Assert.NotSame(testServicesB[0], testServicesB[1]); + } + + [Fact] + public async Task InvokeMethodWillAllowPerRequestServices() + { + var services = new ServiceCollection() + .AddScoped() + .AddTransient() + .BuildServiceProvider(); + var builder = new ApplicationBuilder(services); + builder.UseMiddleware(); + var app = builder.Build(); + + var ctx1 = new DefaultHttpContext(); + await app(ctx1); + + var testService = ctx1.Items[typeof(ITestService)]; + Assert.IsType(testService); + } + } + + public interface ITestService + { + } + + public class TestService : ITestService + { + } + + public class TestMiddleware + { + RequestDelegate _next; + + public TestMiddleware(RequestDelegate next) + { + _next = next; + } + + public async Task Invoke(HttpContext context, ITestService testService) + { + context.Items[typeof(ITestService)] = testService; + } + } +} \ No newline at end of file