React to Auth + switch to PolicyEvaluator

This commit is contained in:
Hao Kung 2017-05-25 18:22:51 -07:00
parent 3f36fa5986
commit 42739b064f
4 changed files with 116 additions and 124 deletions

View File

@ -2,12 +2,12 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Authorization.Policy;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Sockets.Internal
{
@ -29,45 +29,47 @@ namespace Microsoft.AspNetCore.Sockets.Internal
}
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData);
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
{
ClaimsPrincipal newPrincipal = null;
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
var result = await context.Authentication.AuthenticateAsync(scheme);
if (result != null)
{
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
}
}
if (newPrincipal == null)
{
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
}
var policyEvaluator = context.RequestServices.GetRequiredService<IPolicyEvaluator>();
context.User = newPrincipal;
}
// This will set context.User if required
var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context);
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
if (await authService.AuthorizeAsync(context.User, context, authorizePolicy))
var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context);
if (authorizeResult.Succeeded)
{
return true;
}
// Challenge
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
else if (authorizeResult.Challenged)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
await context.Authentication.ChallengeAsync(scheme, properties: null);
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ChallengeAsync(scheme);
}
}
else
{
await context.ChallengeAsync();
}
return false;
}
else
else if (authorizeResult.Forbidden)
{
await context.Authentication.ChallengeAsync(properties: null);
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ForbidAsync(scheme);
}
}
else
{
await context.ForbidAsync();
}
return false;
}
return false;
}
}

View File

@ -15,7 +15,7 @@
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />

View File

@ -11,6 +11,7 @@ namespace Microsoft.Extensions.DependencyInjection
public static IServiceCollection AddSockets(this IServiceCollection services)
{
services.AddRouting();
services.AddAuthorizationPolicyEvaluator();
services.TryAddSingleton<HttpConnectionDispatcher>();
return services.AddSocketsCore();
}

View File

@ -9,7 +9,7 @@ using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features.Authentication;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
@ -638,10 +638,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
services.AddAuthorizationPolicyEvaluator();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
@ -651,9 +653,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
@ -668,7 +667,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
[Fact]
public async Task AuthorizedConnectionCanConnectToEndPoint()
public async Task AuthenticatedUserWithoutPermissionCausesForbidden()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
@ -677,13 +676,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
services.AddAuthorizationPolicyEvaluator();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
@ -693,10 +691,50 @@ namespace Microsoft.AspNetCore.Sockets.Tests
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated"));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
}
[Fact]
public async Task AuthorizedConnectionCanConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
services.AddAuthorizationPolicyEvaluator();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
});
services.AddLogging();
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
@ -716,61 +754,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
}
[Fact]
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<TestEndPoint>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
});
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
var app = builder.Build();
var options = new HttpSocketOptions();
options.AuthorizationPolicyNames.Add("test");
options.AuthorizationPolicyNames.Add("secondPolicy");
// partialy "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// fully "authorize" user
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
}
[Fact]
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
{
@ -789,7 +773,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
policy.AddAuthenticationSchemes("Default");
});
});
services.AddAuthorizationPolicyEvaluator();
services.AddLogging();
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
@ -799,9 +785,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
@ -839,7 +822,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
policy.AddAuthenticationSchemes("Default");
});
});
services.AddAuthorizationPolicyEvaluator();
services.AddLogging();
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler)));
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
@ -849,9 +834,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
var builder = new SocketBuilder(sp);
builder.UseEndPoint<TestEndPoint>();
@ -868,48 +850,55 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
private class RejectHandler : TestAuthenticationHandler
{
protected override bool ShouldAccept => false;
}
private class TestAuthenticationHandler : IAuthenticationHandler
{
private readonly HttpContext HttpContext;
private readonly bool _acceptScheme;
private HttpContext HttpContext;
private AuthenticationScheme _scheme;
public TestAuthenticationHandler(HttpContext context, bool acceptScheme = true)
{
HttpContext = context;
_acceptScheme = acceptScheme;
}
protected virtual bool ShouldAccept { get => true; }
public Task AuthenticateAsync(AuthenticateContext context)
public Task<AuthenticateResult> AuthenticateAsync()
{
if (_acceptScheme)
if (ShouldAccept)
{
context.Authenticated(HttpContext.User, context.Properties, context.Description);
return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name)));
}
else
{
context.NotAuthenticated();
return Task.FromResult(AuthenticateResult.None());
}
return Task.CompletedTask;
}
public Task ChallengeAsync(ChallengeContext context)
public Task ChallengeAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
context.Accept();
return Task.CompletedTask;
}
public void GetDescriptions(DescribeSchemesContext context)
public Task ForbidAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
return Task.CompletedTask;
}
public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context)
{
HttpContext = context;
_scheme = scheme;
return Task.CompletedTask;
}
public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties)
{
throw new NotImplementedException();
}
public Task SignInAsync(SignInContext context)
{
throw new NotImplementedException();
}
public Task SignOutAsync(SignOutContext context)
public Task SignOutAsync(AuthenticationProperties properties)
{
throw new NotImplementedException();
}