diff --git a/src/Microsoft.AspNetCore.Owin/OwinExtensions.cs b/src/Microsoft.AspNetCore.Owin/OwinExtensions.cs index 7d6d160113..a06530cd9d 100644 --- a/src/Microsoft.AspNetCore.Owin/OwinExtensions.cs +++ b/src/Microsoft.AspNetCore.Owin/OwinExtensions.cs @@ -12,20 +12,25 @@ using Microsoft.AspNetCore.Owin; namespace Microsoft.AspNetCore.Builder { + using AddMiddleware = Action, Task>, + Func, Task> + >>; using AppFunc = Func, Task>; using CreateMiddleware = Func< Func, Task>, Func, Task> >; - using AddMiddleware = Action, Task>, - Func, Task> - >>; public static class OwinExtensions { public static AddMiddleware UseOwin(this IApplicationBuilder builder) { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + AddMiddleware add = middleware => { Func middleware1 = next1 => @@ -61,6 +66,15 @@ namespace Microsoft.AspNetCore.Builder public static IApplicationBuilder UseOwin(this IApplicationBuilder builder, Action pipeline) { + if (builder == null) + { + throw new ArgumentNullException(nameof(builder)); + } + if (pipeline == null) + { + throw new ArgumentNullException(nameof(pipeline)); + } + pipeline(builder.UseOwin()); return builder; } @@ -72,6 +86,18 @@ namespace Microsoft.AspNetCore.Builder public static IApplicationBuilder UseBuilder(this AddMiddleware app, IServiceProvider serviceProvider) { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + + // Do not set ApplicationBuilder.ApplicationServices to null. May fail later due to missing services but + // at least that results in a more useful Exception than a NRE. + if (serviceProvider == null) + { + serviceProvider = new EmptyProvider(); + } + // Adapt WebSockets by default. app(OwinWebSocketAcceptAdapter.AdaptWebSockets); var builder = new ApplicationBuilder(serviceProvider: serviceProvider); @@ -125,9 +151,26 @@ namespace Microsoft.AspNetCore.Builder public static AddMiddleware UseBuilder(this AddMiddleware app, Action pipeline, IServiceProvider serviceProvider) { + if (app == null) + { + throw new ArgumentNullException(nameof(app)); + } + if (pipeline == null) + { + throw new ArgumentNullException(nameof(pipeline)); + } + var builder = app.UseBuilder(serviceProvider); pipeline(builder); return app; } + + private class EmptyProvider : IServiceProvider + { + public object GetService(Type serviceType) + { + return null; + } + } } } diff --git a/test/Microsoft.AspNetCore.Owin.Tests/OwinExtensionTests.cs b/test/Microsoft.AspNetCore.Owin.Tests/OwinExtensionTests.cs index d894565968..66c62334bb 100644 --- a/test/Microsoft.AspNetCore.Owin.Tests/OwinExtensionTests.cs +++ b/test/Microsoft.AspNetCore.Owin.Tests/OwinExtensionTests.cs @@ -6,27 +6,30 @@ using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Builder.Internal; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Internal; using Microsoft.Extensions.DependencyInjection; using Xunit; namespace Microsoft.AspNetCore.Owin { + using AddMiddleware = Action, Task>, + Func, Task> + >>; using AppFunc = Func, Task>; using CreateMiddleware = Func< Func, Task>, Func, Task> >; - using AddMiddleware = Action, Task>, - Func, Task> - >>; public class OwinExtensionTests { static AppFunc notFound = env => new Task(() => { env["owin.ResponseStatusCode"] = 404; }); [Fact] - public void OwinConfigureServiceProviderAddsServices() + public async Task OwinConfigureServiceProviderAddsServices() { var list = new List(); AddMiddleware build = list.Add; @@ -36,21 +39,61 @@ namespace Microsoft.AspNetCore.Owin var builder = build.UseBuilder(applicationBuilder => { serviceProvider = applicationBuilder.ApplicationServices; - applicationBuilder.Run(async context => + applicationBuilder.Run(context => { fakeService = context.RequestServices.GetService(); + return Task.FromResult(0); }); - }, new ServiceCollection().AddSingleton(new FakeService()).BuildServiceProvider()); + }, + new ServiceCollection().AddSingleton(new FakeService()).BuildServiceProvider()); list.Reverse(); - list.Aggregate(notFound, (next, middleware) => middleware(next)).Invoke(new Dictionary()); + await list + .Aggregate(notFound, (next, middleware) => middleware(next)) + .Invoke(new Dictionary()); - Assert.NotNull(fakeService); + Assert.NotNull(serviceProvider); Assert.NotNull(serviceProvider.GetService()); + Assert.NotNull(fakeService); } [Fact] - public void OwinDefaultNoServices() + public async Task OwinDefaultNoServices() + { + var list = new List(); + AddMiddleware build = list.Add; + IServiceProvider expectedServiceProvider = new ServiceCollection().BuildServiceProvider(); + IServiceProvider serviceProvider = null; + FakeService fakeService = null; + bool builderExecuted = false; + bool applicationExecuted = false; + + var builder = build.UseBuilder(applicationBuilder => + { + builderExecuted = true; + serviceProvider = applicationBuilder.ApplicationServices; + applicationBuilder.Run(context => + { + applicationExecuted = true; + fakeService = context.RequestServices.GetService(); + return Task.FromResult(0); + }); + }, + expectedServiceProvider); + + list.Reverse(); + await list + .Aggregate(notFound, (next, middleware) => middleware(next)) + .Invoke(new Dictionary()); + + Assert.True(builderExecuted); + Assert.Equal(expectedServiceProvider, serviceProvider); + Assert.True(applicationExecuted); + Assert.Null(fakeService); + } + + [Fact] + public async Task OwinDefaultNullServiceProvider() { var list = new List(); AddMiddleware build = list.Add; @@ -63,25 +106,60 @@ namespace Microsoft.AspNetCore.Owin { builderExecuted = true; serviceProvider = applicationBuilder.ApplicationServices; - applicationBuilder.Run(async context => + applicationBuilder.Run(context => { applicationExecuted = true; fakeService = context.RequestServices.GetService(); + return Task.FromResult(0); }); }); list.Reverse(); - list.Aggregate(notFound, (next, middleware) => middleware(next)).Invoke(new Dictionary()); + await list + .Aggregate(notFound, (next, middleware) => middleware(next)) + .Invoke(new Dictionary()); Assert.True(builderExecuted); - Assert.Null(fakeService); + Assert.NotNull(serviceProvider); Assert.True(applicationExecuted); - Assert.Null(serviceProvider); + Assert.Null(fakeService); + } + + [Fact] + public async Task UseOwin() + { + var serviceProvider = new ServiceCollection().BuildServiceProvider(); + var builder = new ApplicationBuilder(serviceProvider); + IDictionary environment = null; + var context = new DefaultHttpContext(); + + builder.UseOwin(addToPipeline => + { + addToPipeline(next => + { + Assert.NotNull(next); + return async env => + { + environment = env; + await next(env); + }; + }); + }); + await builder.Build().Invoke(context); + + // Dictionary contains context but does not contain "websocket.Accept" or "websocket.AcceptAlt" keys. + Assert.NotNull(environment); + var value = Assert.Single( + environment, + kvp => string.Equals(typeof(HttpContext).FullName, kvp.Key, StringComparison.Ordinal)) + .Value; + Assert.Equal(context, value); + Assert.False(environment.ContainsKey("websocket.Accept")); + Assert.False(environment.ContainsKey("websocket.AcceptAlt")); } private class FakeService { - } } }