Use policy names in EndPointOptions (#340)
This commit is contained in:
parent
be88d2918e
commit
9993fd96da
|
|
@ -44,7 +44,6 @@
|
|||
<environment names="Development">
|
||||
<script src="~/lib/jquery/dist/jquery.js"></script>
|
||||
<script src="~/lib/bootstrap/dist/js/bootstrap.js"></script>
|
||||
<script src="~/js/site.js" asp-append-version="true"></script>
|
||||
</environment>
|
||||
<environment names="Staging,Production">
|
||||
<script src="https://ajax.aspnetcdn.com/ajax/jquery/jquery-2.2.0.min.js"
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
|
||||
{
|
||||
public AuthorizationPolicy Policy { get; set; }
|
||||
public IList<string> AuthorizationPolicyNames { get; } = new List<string>();
|
||||
|
||||
public TransportType Transports { get; set; } = TransportType.All;
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
|
||||
// TODO: Authorize attribute on EndPoint
|
||||
if (!await AuthorizeHelper.AuthorizeAsync(context, options.Policy))
|
||||
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames))
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System.Collections.Generic;
|
||||
using System.Security.Claims;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Authorization;
|
||||
|
|
@ -12,53 +13,62 @@ namespace Microsoft.AspNetCore.Sockets.Internal
|
|||
{
|
||||
public static class AuthorizeHelper
|
||||
{
|
||||
public static async Task<bool> AuthorizeAsync(HttpContext context, AuthorizationPolicy policy)
|
||||
public static async Task<bool> AuthorizeAsync(HttpContext context, IList<string> policies)
|
||||
{
|
||||
if (policy != null)
|
||||
if (policies.Count == 0)
|
||||
{
|
||||
if (policy.AuthenticationSchemes != null && policy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
ClaimsPrincipal newPrincipal = null;
|
||||
foreach (var scheme in policy.AuthenticationSchemes)
|
||||
{
|
||||
var result = await context.Authentication.AuthenticateAsync(scheme);
|
||||
if (result != null)
|
||||
{
|
||||
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
|
||||
}
|
||||
}
|
||||
|
||||
if (newPrincipal == null)
|
||||
{
|
||||
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
|
||||
}
|
||||
|
||||
context.User = newPrincipal;
|
||||
}
|
||||
|
||||
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
|
||||
if (await authService.AuthorizeAsync(context.User, context, policy))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Challenge
|
||||
if (policy.AuthenticationSchemes != null && policy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
foreach (var scheme in policy.AuthenticationSchemes)
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(scheme, properties: null);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(properties: null);
|
||||
}
|
||||
|
||||
return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
return true;
|
||||
var policyProvider = context.RequestServices.GetRequiredService<IAuthorizationPolicyProvider>();
|
||||
var authorizeData = new List<IAuthorizeData>();
|
||||
|
||||
foreach (var policy in policies)
|
||||
{
|
||||
authorizeData.Add(new AuthorizeAttribute(policy));
|
||||
}
|
||||
|
||||
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData);
|
||||
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
ClaimsPrincipal newPrincipal = null;
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
{
|
||||
var result = await context.Authentication.AuthenticateAsync(scheme);
|
||||
if (result != null)
|
||||
{
|
||||
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
|
||||
}
|
||||
}
|
||||
|
||||
if (newPrincipal == null)
|
||||
{
|
||||
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
|
||||
}
|
||||
|
||||
context.User = newPrincipal;
|
||||
}
|
||||
|
||||
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
|
||||
if (await authService.AuthorizeAsync(context.User, context, authorizePolicy))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// Challenge
|
||||
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
|
||||
{
|
||||
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(scheme, properties: null);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
await context.Authentication.ChallengeAsync(properties: null);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Security.Claims;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Features.Authentication;
|
||||
using Microsoft.AspNetCore.Http.Internal;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
|
|
@ -452,6 +454,273 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
Assert.Equal(MessageType.Close, messages[3].Type);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task UnauthorizedConnectionFailsToStartEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<BlockingEndPoint>(options =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/poll";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionCanConnectToEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<BlockingEndPoint>(options =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/poll";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<BlockingEndPoint>(options =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
options.AuthorizationPolicyNames.Add("secondPolicy");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
|
||||
options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/poll";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
// partialy "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// would hang if EndPoint was running
|
||||
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
|
||||
// fully "authorize" user
|
||||
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
||||
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<BlockingEndPoint>(options =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/poll";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
|
||||
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint()
|
||||
{
|
||||
var manager = CreateConnectionManager();
|
||||
var state = manager.CreateConnection();
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
|
||||
var context = new DefaultHttpContext();
|
||||
var services = new ServiceCollection();
|
||||
services.AddOptions();
|
||||
services.AddEndPoint<BlockingEndPoint>(options =>
|
||||
{
|
||||
options.AuthorizationPolicyNames.Add("test");
|
||||
});
|
||||
services.AddAuthorization(options =>
|
||||
{
|
||||
options.AddPolicy("test", policy =>
|
||||
{
|
||||
policy.RequireClaim(ClaimTypes.NameIdentifier);
|
||||
policy.AddAuthenticationSchemes("Default");
|
||||
});
|
||||
});
|
||||
services.AddLogging();
|
||||
|
||||
context.RequestServices = services.BuildServiceProvider();
|
||||
context.Request.Path = "/poll";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = state.Connection.ConnectionId;
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Response.Body = new MemoryStream();
|
||||
var authFeature = new HttpAuthenticationFeature();
|
||||
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
|
||||
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
|
||||
|
||||
// "authorize" user
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
|
||||
|
||||
// would block if EndPoint was executed
|
||||
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
|
||||
|
||||
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
|
||||
}
|
||||
|
||||
private class TestAuthenticationHandler : IAuthenticationHandler
|
||||
{
|
||||
private readonly HttpContext HttpContext;
|
||||
private readonly bool _acceptScheme;
|
||||
|
||||
public TestAuthenticationHandler(HttpContext context, bool acceptScheme = true)
|
||||
{
|
||||
HttpContext = context;
|
||||
_acceptScheme = acceptScheme;
|
||||
}
|
||||
|
||||
public Task AuthenticateAsync(AuthenticateContext context)
|
||||
{
|
||||
if (_acceptScheme)
|
||||
{
|
||||
context.Authenticated(HttpContext.User, context.Properties, context.Description);
|
||||
}
|
||||
else
|
||||
{
|
||||
context.NotAuthenticated();
|
||||
}
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public Task ChallengeAsync(ChallengeContext context)
|
||||
{
|
||||
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
|
||||
context.Accept();
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public void GetDescriptions(DescribeSchemesContext context)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task SignInAsync(SignInContext context)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public Task SignOutAsync(SignOutContext context)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
}
|
||||
|
||||
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
|
||||
{
|
||||
var path = "";
|
||||
|
|
|
|||
Loading…
Reference in New Issue