From 42739b064f042522cf85c17cc9bdb420435ada70 Mon Sep 17 00:00:00 2001 From: Hao Kung Date: Thu, 25 May 2017 18:22:51 -0700 Subject: [PATCH] React to Auth + switch to PolicyEvaluator --- .../Internal/AuthorizeHelper.cs | 60 +++--- .../Microsoft.AspNetCore.Sockets.Http.csproj | 2 +- .../SocketsDependencyInjectionExtensions.cs | 1 + .../HttpConnectionDispatcherTests.cs | 177 ++++++++---------- 4 files changed, 116 insertions(+), 124 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/AuthorizeHelper.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/AuthorizeHelper.cs index de9101dbdb..65dae51187 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/AuthorizeHelper.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/AuthorizeHelper.cs @@ -2,12 +2,12 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; -using System.Security.Claims; using System.Threading.Tasks; +using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Authorization.Policy; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.Sockets.Internal { @@ -29,45 +29,47 @@ namespace Microsoft.AspNetCore.Sockets.Internal } var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData); - if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0) - { - ClaimsPrincipal newPrincipal = null; - foreach (var scheme in authorizePolicy.AuthenticationSchemes) - { - var result = await context.Authentication.AuthenticateAsync(scheme); - if (result != null) - { - newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result); - } - } - if (newPrincipal == null) - { - newPrincipal = new ClaimsPrincipal(new ClaimsIdentity()); - } + var policyEvaluator = context.RequestServices.GetRequiredService(); - context.User = newPrincipal; - } + // This will set context.User if required + var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context); - var authService = context.RequestServices.GetRequiredService(); - if (await authService.AuthorizeAsync(context.User, context, authorizePolicy)) + var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context); + if (authorizeResult.Succeeded) { return true; } - - // Challenge - if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0) + else if (authorizeResult.Challenged) { - foreach (var scheme in authorizePolicy.AuthenticationSchemes) + if (authorizePolicy.AuthenticationSchemes.Count > 0) { - await context.Authentication.ChallengeAsync(scheme, properties: null); + foreach (var scheme in authorizePolicy.AuthenticationSchemes) + { + await context.ChallengeAsync(scheme); + } } + else + { + await context.ChallengeAsync(); + } + return false; } - else + else if (authorizeResult.Forbidden) { - await context.Authentication.ChallengeAsync(properties: null); + if (authorizePolicy.AuthenticationSchemes.Count > 0) + { + foreach (var scheme in authorizePolicy.AuthenticationSchemes) + { + await context.ForbidAsync(scheme); + } + } + else + { + await context.ForbidAsync(); + } + return false; } - return false; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index de6c20c504..5d28929d51 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -15,7 +15,7 @@ - + diff --git a/src/Microsoft.AspNetCore.Sockets.Http/SocketsDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Http/SocketsDependencyInjectionExtensions.cs index e7d451b980..8b912dcbcb 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/SocketsDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/SocketsDependencyInjectionExtensions.cs @@ -11,6 +11,7 @@ namespace Microsoft.Extensions.DependencyInjection public static IServiceCollection AddSockets(this IServiceCollection services) { services.AddRouting(); + services.AddAuthorizationPolicyEvaluator(); services.TryAddSingleton(); return services.AddSocketsCore(); } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 9e635bf156..65a82071a3 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -9,7 +9,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Features.Authentication; +using Microsoft.AspNetCore.Authentication; using Microsoft.AspNetCore.Http.Internal; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; @@ -638,10 +638,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); + services.AddAuthorizationPolicyEvaluator(); services.AddAuthorization(o => { o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); + services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler))); services.AddLogging(); var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; @@ -651,9 +653,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; - var authFeature = new HttpAuthenticationFeature(); - authFeature.Handler = new TestAuthenticationHandler(context); - context.Features.Set(authFeature); var builder = new SocketBuilder(sp); builder.UseEndPoint(); @@ -668,7 +667,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests } [Fact] - public async Task AuthorizedConnectionCanConnectToEndPoint() + public async Task AuthenticatedUserWithoutPermissionCausesForbidden() { var manager = CreateConnectionManager(); var state = manager.CreateConnection(); @@ -677,13 +676,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests var services = new ServiceCollection(); services.AddOptions(); services.AddEndPoint(); + services.AddAuthorizationPolicyEvaluator(); services.AddAuthorization(o => { - o.AddPolicy("test", policy => - { - policy.RequireClaim(ClaimTypes.NameIdentifier); - }); + o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); }); + services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler))); services.AddLogging(); var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; @@ -693,10 +691,50 @@ namespace Microsoft.AspNetCore.Sockets.Tests values["id"] = state.Connection.ConnectionId; var qs = new QueryCollection(values); context.Request.Query = qs; + + var builder = new SocketBuilder(sp); + builder.UseEndPoint(); + var app = builder.Build(); + var options = new HttpSocketOptions(); + options.AuthorizationPolicyNames.Add("test"); + + context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated")); + + // would hang if EndPoint was running + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode); + } + + [Fact] + public async Task AuthorizedConnectionCanConnectToEndPoint() + { + var manager = CreateConnectionManager(); + var state = manager.CreateConnection(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddOptions(); + services.AddEndPoint(); + services.AddAuthorizationPolicyEvaluator(); + services.AddAuthorization(o => + { + o.AddPolicy("test", policy => + { + policy.RequireClaim(ClaimTypes.NameIdentifier); + }); + }); + services.AddLogging(); + services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler))); + var sp = services.BuildServiceProvider(); + context.Request.Path = "/foo"; + context.Request.Method = "GET"; + context.RequestServices = sp; + var values = new Dictionary(); + values["id"] = state.Connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; context.Response.Body = new MemoryStream(); - var authFeature = new HttpAuthenticationFeature(); - authFeature.Handler = new TestAuthenticationHandler(context); - context.Features.Set(authFeature); var builder = new SocketBuilder(sp); builder.UseEndPoint(); @@ -716,61 +754,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body)); } - [Fact] - public async Task AllPoliciesRequiredForAuthorizedEndPoint() - { - var manager = CreateConnectionManager(); - var state = manager.CreateConnection(); - var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddOptions(); - services.AddEndPoint(); - services.AddAuthorization(o => - { - o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier)); - o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress)); - }); - services.AddLogging(); - var sp = services.BuildServiceProvider(); - context.Request.Path = "/foo"; - context.Request.Method = "GET"; - context.RequestServices = sp; - var values = new Dictionary(); - values["id"] = state.Connection.ConnectionId; - var qs = new QueryCollection(values); - context.Request.Query = qs; - context.Response.Body = new MemoryStream(); - var authFeature = new HttpAuthenticationFeature(); - authFeature.Handler = new TestAuthenticationHandler(context); - context.Features.Set(authFeature); - - var builder = new SocketBuilder(sp); - builder.UseEndPoint(); - var app = builder.Build(); - var options = new HttpSocketOptions(); - options.AuthorizationPolicyNames.Add("test"); - options.AuthorizationPolicyNames.Add("secondPolicy"); - - // partialy "authorize" user - context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); - - // would hang if EndPoint was running - await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - - Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); - - // fully "authorize" user - context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") })); - - var endPointTask = dispatcher.ExecuteAsync(context, options, app); - await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout(); - - await endPointTask.OrTimeout(); - - Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body)); - } - + [Fact] public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint() { @@ -789,7 +773,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests policy.AddAuthenticationSchemes("Default"); }); }); + services.AddAuthorizationPolicyEvaluator(); services.AddLogging(); + services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler))); var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; @@ -799,9 +785,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); - var authFeature = new HttpAuthenticationFeature(); - authFeature.Handler = new TestAuthenticationHandler(context); - context.Features.Set(authFeature); var builder = new SocketBuilder(sp); builder.UseEndPoint(); @@ -839,7 +822,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests policy.AddAuthenticationSchemes("Default"); }); }); + services.AddAuthorizationPolicyEvaluator(); services.AddLogging(); + services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler))); var sp = services.BuildServiceProvider(); context.Request.Path = "/foo"; context.Request.Method = "GET"; @@ -849,9 +834,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); - var authFeature = new HttpAuthenticationFeature(); - authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false); - context.Features.Set(authFeature); var builder = new SocketBuilder(sp); builder.UseEndPoint(); @@ -868,48 +850,55 @@ namespace Microsoft.AspNetCore.Sockets.Tests Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode); } + private class RejectHandler : TestAuthenticationHandler + { + protected override bool ShouldAccept => false; + } + private class TestAuthenticationHandler : IAuthenticationHandler { - private readonly HttpContext HttpContext; - private readonly bool _acceptScheme; + private HttpContext HttpContext; + private AuthenticationScheme _scheme; - public TestAuthenticationHandler(HttpContext context, bool acceptScheme = true) - { - HttpContext = context; - _acceptScheme = acceptScheme; - } + protected virtual bool ShouldAccept { get => true; } - public Task AuthenticateAsync(AuthenticateContext context) + public Task AuthenticateAsync() { - if (_acceptScheme) + if (ShouldAccept) { - context.Authenticated(HttpContext.User, context.Properties, context.Description); + return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name))); } else { - context.NotAuthenticated(); + return Task.FromResult(AuthenticateResult.None()); } - return Task.CompletedTask; } - public Task ChallengeAsync(ChallengeContext context) + public Task ChallengeAsync(AuthenticationProperties properties) { HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized; - context.Accept(); return Task.CompletedTask; } - public void GetDescriptions(DescribeSchemesContext context) + public Task ForbidAsync(AuthenticationProperties properties) + { + HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden; + return Task.CompletedTask; + } + + public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context) + { + HttpContext = context; + _scheme = scheme; + return Task.CompletedTask; + } + + public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties) { throw new NotImplementedException(); } - public Task SignInAsync(SignInContext context) - { - throw new NotImplementedException(); - } - - public Task SignOutAsync(SignOutContext context) + public Task SignOutAsync(AuthenticationProperties properties) { throw new NotImplementedException(); }