React to Auth + switch to PolicyEvaluator
This commit is contained in:
parent
3f36fa5986
commit
42739b064f
|
|
@ -2,12 +2,12 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Security.Claims;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Authentication;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using Microsoft.AspNetCore.Authorization.Policy;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets.Internal
|
||||
{
|
||||
|
|
@ -29,45 +29,47 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
}
|
||||
|
||||
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData);
|
||||
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
ClaimsPrincipal newPrincipal = null;
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
{
|
||||
var result = await context.Authentication.AuthenticateAsync(scheme);
|
||||
if (result != null)
|
||||
{
|
||||
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
|
||||
}
|
||||
}
|
||||
|
||||
if (newPrincipal == null)
|
||||
{
|
||||
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
|
||||
}
|
||||
var policyEvaluator = context.RequestServices.GetRequiredService<IPolicyEvaluator>();
|
||||
|
||||
context.User = newPrincipal;
|
||||
}
|
||||
// This will set context.User if required
|
||||
var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context);
|
||||
|
||||
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
|
||||
if (await authService.AuthorizeAsync(context.User, context, authorizePolicy))
|
||||
var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context);
|
||||
if (authorizeResult.Succeeded)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Challenge
|
||||
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
else if (authorizeResult.Challenged)
|
||||
{
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
if (authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(scheme, properties: null);
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
{
|
||||
await context.ChallengeAsync(scheme);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
await context.ChallengeAsync();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
else
|
||||
else if (authorizeResult.Forbidden)
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(properties: null);
|
||||
if (authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
{
|
||||
await context.ForbidAsync(scheme);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
await context.ForbidAsync();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@
|
|||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets\Microsoft.AspNetCore.Sockets.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
|
||||
<ProjectReference Include="..\Microsoft.AspNetCore.WebSockets.Internal\Microsoft.AspNetCore.WebSockets.Internal.csproj" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(AspNetCoreVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" Version="$(AspNetCoreVersion)" PrivateAssets="All" />
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ namespace Microsoft.Extensions.DependencyInjection
|
|||
public static IServiceCollection AddSockets(this IServiceCollection services)
|
||||
{
|
||||
services.AddRouting();
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.TryAddSingleton<HttpConnectionDispatcher>();
|
||||
return services.AddSocketsCore();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ using System.Text;
|
|||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Features.Authentication;
|
||||
using Microsoft.AspNetCore.Authentication;
|
||||
using Microsoft.AspNetCore.Http.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
|
|
@ -638,10 +638,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
});
|
||||
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
|
||||
services.AddLogging();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
|
|
@ -651,9 +653,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
|
|
@ -668,7 +667,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionCanConnectToEndPoint()
|
||||
public async Task AuthenticatedUserWithoutPermissionCausesForbidden()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
|
|
@ -677,13 +676,12 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
o.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
});
|
||||
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
});
|
||||
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
|
||||
services.AddLogging();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
|
|
@ -693,10 +691,50 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated"));
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionCanConnectToEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
o.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
|
|
@ -716,61 +754,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<TestEndPoint>();
|
||||
services.AddAuthorization(o =>
|
||||
{
|
||||
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
o.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
|
||||
});
|
||||
services.AddLogging();
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
context.RequestServices = sp;
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
var app = builder.Build();
|
||||
var options = new HttpSocketOptions();
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
options.AuthorizationPolicyNames.Add("secondPolicy");
|
||||
|
||||
// partialy "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
|
||||
// fully "authorize" user
|
||||
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync(context, options, app);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
||||
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
|
||||
}
|
||||
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
|
||||
{
|
||||
|
|
@ -789,7 +773,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.AddLogging();
|
||||
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler)));
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
|
|
@ -799,9 +785,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
|
|
@ -839,7 +822,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddAuthorizationPolicyEvaluator();
|
||||
services.AddLogging();
|
||||
services.AddAuthenticationCore(o => o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler)));
|
||||
var sp = services.BuildServiceProvider();
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
|
|
@ -849,9 +834,6 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
var builder = new SocketBuilder(sp);
|
||||
builder.UseEndPoint<TestEndPoint>();
|
||||
|
|
@ -868,48 +850,55 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
}
|
||||
|
||||
private class RejectHandler : TestAuthenticationHandler
|
||||
{
|
||||
protected override bool ShouldAccept => false;
|
||||
}
|
||||
|
||||
private class TestAuthenticationHandler : IAuthenticationHandler
|
||||
{
|
||||
private readonly HttpContext HttpContext;
|
||||
private readonly bool _acceptScheme;
|
||||
private HttpContext HttpContext;
|
||||
private AuthenticationScheme _scheme;
|
||||
|
||||
public TestAuthenticationHandler(HttpContext context, bool acceptScheme = true)
|
||||
{
|
||||
HttpContext = context;
|
||||
_acceptScheme = acceptScheme;
|
||||
}
|
||||
protected virtual bool ShouldAccept { get => true; }
|
||||
|
||||
public Task AuthenticateAsync(AuthenticateContext context)
|
||||
public Task<AuthenticateResult> AuthenticateAsync()
|
||||
{
|
||||
if (_acceptScheme)
|
||||
if (ShouldAccept)
|
||||
{
|
||||
context.Authenticated(HttpContext.User, context.Properties, context.Description);
|
||||
return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name)));
|
||||
}
|
||||
else
|
||||
{
|
||||
context.NotAuthenticated();
|
||||
return Task.FromResult(AuthenticateResult.None());
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task ChallengeAsync(ChallengeContext context)
|
||||
public Task ChallengeAsync(AuthenticationProperties properties)
|
||||
{
|
||||
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
||||
context.Accept();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public void GetDescriptions(DescribeSchemesContext context)
|
||||
public Task ForbidAsync(AuthenticationProperties properties)
|
||||
{
|
||||
HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context)
|
||||
{
|
||||
HttpContext = context;
|
||||
_scheme = scheme;
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task SignInAsync(ClaimsPrincipal user, AuthenticationProperties properties)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task SignInAsync(SignInContext context)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task SignOutAsync(SignOutContext context)
|
||||
public Task SignOutAsync(AuthenticationProperties properties)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue