Use policy names in EndPointOptions (#340)

This commit is contained in:
BrennanConroy 2017-04-11 12:35:31 -07:00 committed by GitHub
parent be88d2918e
commit 9993fd96da
5 changed files with 325 additions and 48 deletions

View File

@ -44,7 +44,6 @@
<environment names="Development">
<script src="~/lib/jquery/dist/jquery.js"></script>
<script src="~/lib/bootstrap/dist/js/bootstrap.js"></script>
<script src="~/js/site.js" asp-append-version="true"></script>
</environment>
<environment names="Staging,Production">
<script src="https://ajax.aspnetcdn.com/ajax/jquery/jquery-2.2.0.min.js"

View File

@ -1,14 +1,13 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Authorization;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class EndPointOptions<TEndPoint> where TEndPoint : EndPoint
{
public AuthorizationPolicy Policy { get; set; }
public IList<string> AuthorizationPolicyNames { get; } = new List<string>();
public TransportType Transports { get; set; } = TransportType.All;

View File

@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Sockets
{
var options = context.RequestServices.GetRequiredService<IOptions<EndPointOptions<TEndPoint>>>().Value;
// TODO: Authorize attribute on EndPoint
if (!await AuthorizeHelper.AuthorizeAsync(context, options.Policy))
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationPolicyNames))
{
return;
}

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
@ -12,53 +13,62 @@ namespace Microsoft.AspNetCore.Sockets.Internal
{
public static class AuthorizeHelper
{
public static async Task<bool> AuthorizeAsync(HttpContext context, AuthorizationPolicy policy)
public static async Task<bool> AuthorizeAsync(HttpContext context, IList<string> policies)
{
if (policy != null)
if (policies.Count == 0)
{
if (policy.AuthenticationSchemes != null && policy.AuthenticationSchemes.Count > 0)
{
ClaimsPrincipal newPrincipal = null;
foreach (var scheme in policy.AuthenticationSchemes)
{
var result = await context.Authentication.AuthenticateAsync(scheme);
if (result != null)
{
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
}
}
if (newPrincipal == null)
{
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
}
context.User = newPrincipal;
}
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
if (await authService.AuthorizeAsync(context.User, context, policy))
{
return true;
}
// Challenge
if (policy.AuthenticationSchemes != null && policy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in policy.AuthenticationSchemes)
{
await context.Authentication.ChallengeAsync(scheme, properties: null);
}
}
else
{
await context.Authentication.ChallengeAsync(properties: null);
}
return false;
return true;
}
return true;
var policyProvider = context.RequestServices.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizeData = new List<IAuthorizeData>();
foreach (var policy in policies)
{
authorizeData.Add(new AuthorizeAttribute(policy));
}
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, authorizeData);
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
{
ClaimsPrincipal newPrincipal = null;
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
var result = await context.Authentication.AuthenticateAsync(scheme);
if (result != null)
{
newPrincipal = SecurityHelper.MergeUserPrincipal(newPrincipal, result);
}
}
if (newPrincipal == null)
{
newPrincipal = new ClaimsPrincipal(new ClaimsIdentity());
}
context.User = newPrincipal;
}
var authService = context.RequestServices.GetRequiredService<IAuthorizationService>();
if (await authService.AuthorizeAsync(context.User, context, authorizePolicy))
{
return true;
}
// Challenge
if (authorizePolicy.AuthenticationSchemes != null && authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.Authentication.ChallengeAsync(scheme, properties: null);
}
}
else
{
await context.Authentication.ChallengeAsync(properties: null);
}
return false;
}
}
}

View File

@ -4,10 +4,12 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Security.Claims;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features.Authentication;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
@ -452,6 +454,273 @@ namespace Microsoft.AspNetCore.Sockets.Tests
Assert.Equal(MessageType.Close, messages[3].Type);
}
[Fact]
public async Task UnauthorizedConnectionFailsToStartEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
[Fact]
public async Task AuthorizedConnectionCanConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
}
[Fact]
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
options.AuthorizationPolicyNames.Add("secondPolicy");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
options.AddPolicy("secondPolicy", policy => policy.RequireClaim(ClaimTypes.StreetAddress));
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// partialy "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would hang if EndPoint was running
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// fully "authorize" user
context.User.AddIdentity(new ClaimsIdentity(new[] { new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
}
[Fact]
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var endPointTask = dispatcher.ExecuteAsync<BlockingEndPoint>("", context);
await state.Connection.Transport.Output.WriteAsync(new Message(Encoding.UTF8.GetBytes("Hello, World"), MessageType.Text)).OrTimeout();
await endPointTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("T12:T:Hello, World;", GetContentAsString(context.Response.Body));
}
[Fact]
public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint()
{
var manager = CreateConnectionManager();
var state = manager.CreateConnection();
var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory());
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddEndPoint<BlockingEndPoint>(options =>
{
options.AuthorizationPolicyNames.Add("test");
});
services.AddAuthorization(options =>
{
options.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
context.RequestServices = services.BuildServiceProvider();
context.Request.Path = "/poll";
var values = new Dictionary<string, StringValues>();
values["id"] = state.Connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var authFeature = new HttpAuthenticationFeature();
authFeature.Handler = new TestAuthenticationHandler(context, acceptScheme: false);
context.Features.Set<IHttpAuthenticationFeature>(authFeature);
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync<BlockingEndPoint>("", context).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
private class TestAuthenticationHandler : IAuthenticationHandler
{
private readonly HttpContext HttpContext;
private readonly bool _acceptScheme;
public TestAuthenticationHandler(HttpContext context, bool acceptScheme = true)
{
HttpContext = context;
_acceptScheme = acceptScheme;
}
public Task AuthenticateAsync(AuthenticateContext context)
{
if (_acceptScheme)
{
context.Authenticated(HttpContext.User, context.Properties, context.Description);
}
else
{
context.NotAuthenticated();
}
return Task.CompletedTask;
}
public Task ChallengeAsync(ChallengeContext context)
{
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
context.Accept();
return Task.CompletedTask;
}
public void GetDescriptions(DescribeSchemesContext context)
{
throw new NotImplementedException();
}
public Task SignInAsync(SignInContext context)
{
throw new NotImplementedException();
}
public Task SignOutAsync(SignOutContext context)
{
throw new NotImplementedException();
}
}
private static async Task CheckTransportSupported(TransportType supportedTransports, TransportType transportType, int status)
{
var path = "";