AuthorizeHelper will no-op if endpoint routing is used (#10471)

This commit is contained in:
Brennan 2019-06-02 22:35:45 -07:00 committed by GitHub
parent 25672336f9
commit f0df10f211
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 761 additions and 598 deletions

View File

@ -9,10 +9,12 @@ namespace Microsoft.AspNetCore.Analyzers.TestFiles.CompilationFeatureDetectorTes
{
public void Configure(IApplicationBuilder app)
{
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes =>
{
});
#pragma warning restore CS0618 // Type or member is obsolete
}
}
}

View File

@ -1,18 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.IdentityModel.Tokens;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
@ -33,11 +25,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public void Configure(IApplicationBuilder app)
{
app.UseRouting();
app.UseAuthentication();
app.UseSignalR(routes =>
app.UseEndpoints(endpoints =>
{
routes.MapHub<VersionHub>("/version");
endpoints.MapHub<VersionHub>("/version");
});
}
}

View File

@ -12,6 +12,7 @@ namespace Microsoft.AspNetCore.Builder
}
public static partial class ConnectionsAppBuilderExtensions
{
[System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseConnections(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder> configure) { throw null; }
}
}
@ -28,6 +29,7 @@ namespace Microsoft.AspNetCore.Http.Connections
public ConnectionOptionsSetup() { }
public void Configure(Microsoft.AspNetCore.Http.Connections.ConnectionOptions options) { }
}
[System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public partial class ConnectionsRouteBuilder
{
internal ConnectionsRouteBuilder() { }

View File

@ -3,8 +3,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
@ -48,13 +46,6 @@ namespace Microsoft.AspNetCore.Builder
public static IEndpointConventionBuilder MapConnectionHandler<TConnectionHandler>(this IEndpointRouteBuilder endpoints, string pattern, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler
{
var options = new HttpConnectionDispatcherOptions();
// REVIEW: WE should consider removing this and instead just relying on the
// AuthorizationMiddleware
var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true);
foreach (var attribute in attributes.OfType<AuthorizeAttribute>())
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options);
var conventionBuilder = endpoints.MapConnections(pattern, options, b =>
@ -62,6 +53,7 @@ namespace Microsoft.AspNetCore.Builder
b.UseConnectionHandler<TConnectionHandler>();
});
var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true);
conventionBuilder.Add(e =>
{
// Add all attributes on the ConnectionHandler has metadata (this will allow for things like)
@ -93,7 +85,7 @@ namespace Microsoft.AspNetCore.Builder
var connectionDelegate = connectionBuilder.Build();
// REVIEW: Consider expanding the internals of the dispatcher as endpoint routes instead of
// using if statemants we can let the matcher handle
// using if statements we can let the matcher handle
var conventionBuilders = new List<IEndpointConventionBuilder>();

View File

@ -3,9 +3,6 @@
using System;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Builder
{
@ -16,10 +13,15 @@ namespace Microsoft.AspNetCore.Builder
{
/// <summary>
/// Adds support for ASP.NET Core Connection Handlers to the <see cref="IApplicationBuilder"/> request execution pipeline.
/// <para>
/// This method is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapConnections or MapConnectionHandler&#60;TConnectionHandler&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary>
/// <param name="app">The <see cref="IApplicationBuilder"/>.</param>
/// <param name="configure">A callback to configure connection routes.</param>
/// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns>
[Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static IApplicationBuilder UseConnections(this IApplicationBuilder app, Action<ConnectionsRouteBuilder> configure)
{
if (configure == null)
@ -27,14 +29,13 @@ namespace Microsoft.AspNetCore.Builder
throw new ArgumentNullException(nameof(configure));
}
var dispatcher = app.ApplicationServices.GetRequiredService<HttpConnectionDispatcher>();
var routes = new RouteBuilder(app);
configure(new ConnectionsRouteBuilder(routes, dispatcher));
app.UseWebSockets();
app.UseRouter(routes.Build());
app.UseRouting();
app.UseAuthorization();
app.UseEndpoints(endpoints =>
{
configure(new ConnectionsRouteBuilder(endpoints));
});
return app;
}
}

View File

@ -2,26 +2,27 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Reflection;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Http.Connections
{
/// <summary>
/// Maps routes to ASP.NET Core Connection Handlers.
/// <para>
/// This class is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapConnection and MapConnectionHandler&#60;TConnectionHandler&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary>
[Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public class ConnectionsRouteBuilder
{
private readonly HttpConnectionDispatcher _dispatcher;
private readonly RouteBuilder _routes;
private readonly IEndpointRouteBuilder _endpoints;
internal ConnectionsRouteBuilder(RouteBuilder routes, HttpConnectionDispatcher dispatcher)
internal ConnectionsRouteBuilder(IEndpointRouteBuilder endpoints)
{
_routes = routes;
_dispatcher = dispatcher;
_endpoints = endpoints;
}
/// <summary>
@ -38,24 +39,16 @@ namespace Microsoft.AspNetCore.Http.Connections
/// <param name="path">The request path.</param>
/// <param name="options">Options used to configure the connection.</param>
/// <param name="configure">A callback to configure the connection.</param>
public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action<IConnectionBuilder> configure)
{
var connectionBuilder = new ConnectionBuilder(_routes.ServiceProvider);
configure(connectionBuilder);
var socket = connectionBuilder.Build();
_routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket));
_routes.MapRoute(path + "/negotiate", c => _dispatcher.ExecuteNegotiateAsync(c, options));
}
public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action<IConnectionBuilder> configure) =>
_endpoints.MapConnections(path, options, configure);
/// <summary>
/// Maps incoming requests with the specified path to the provided connection pipeline.
/// </summary>
/// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam>
/// <param name="path">The request path.</param>
public void MapConnectionHandler<TConnectionHandler>(PathString path) where TConnectionHandler : ConnectionHandler
{
public void MapConnectionHandler<TConnectionHandler>(PathString path) where TConnectionHandler : ConnectionHandler =>
MapConnectionHandler<TConnectionHandler>(path, configureOptions: null);
}
/// <summary>
/// Maps incoming requests with the specified path to the provided connection pipeline.
@ -63,20 +56,7 @@ namespace Microsoft.AspNetCore.Http.Connections
/// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam>
/// <param name="path">The request path.</param>
/// <param name="configureOptions">A callback to configure dispatcher options.</param>
public void MapConnectionHandler<TConnectionHandler>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler
{
var authorizeAttributes = typeof(TConnectionHandler).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpConnectionDispatcherOptions();
foreach (var attribute in authorizeAttributes)
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options);
MapConnections(path, options, builder =>
{
builder.UseConnectionHandler<TConnectionHandler>();
});
}
public void MapConnectionHandler<TConnectionHandler>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler =>
_endpoints.MapConnectionHandler<TConnectionHandler>(path, configureOptions);
}
}

View File

@ -1,69 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Authorization.Policy;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Http.Connections.Internal
{
internal static class AuthorizeHelper
{
public static async Task<bool> AuthorizeAsync(HttpContext context, IList<IAuthorizeData> policies)
{
if (policies.Count == 0)
{
return true;
}
var policyProvider = context.RequestServices.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
var policyEvaluator = context.RequestServices.GetRequiredService<IPolicyEvaluator>();
// This will set context.User if required
var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context);
var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context, resource: null);
if (authorizeResult.Succeeded)
{
return true;
}
else if (authorizeResult.Challenged)
{
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ChallengeAsync(scheme);
}
}
else
{
await context.ChallengeAsync();
}
return false;
}
else if (authorizeResult.Forbidden)
{
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ForbidAsync(scheme);
}
}
else
{
await context.ForbidAsync();
}
return false;
}
return false;
}
}
}

View File

@ -61,11 +61,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
var logScope = new ConnectionLogScope(GetConnectionId(context));
using (_logger.BeginScope(logScope))
{
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData))
{
return;
}
if (HttpMethods.IsPost(context.Request.Method))
{
// POST /{path}
@ -95,11 +90,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
var logScope = new ConnectionLogScope(connectionId: string.Empty);
using (_logger.BeginScope(logScope))
{
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData))
{
return;
}
if (HttpMethods.IsPost(context.Request.Method))
{
// POST /{path}/negotiate

View File

@ -14,6 +14,7 @@
<Compile Include="$(SignalRSharedSourceRoot)WebSocketExtensions.cs" Link="WebSocketExtensions.cs" />
<Compile Include="$(SignalRSharedSourceRoot)StreamExtensions.cs" Link="StreamExtensions.cs" />
<Compile Include="$(SignalRSharedSourceRoot)DuplexPipe.cs" Link="DuplexPipe.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TaskCache.cs" Link="Internal\TaskCache.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -1347,358 +1347,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
[Fact]
public async Task UnauthorizedConnectionFailsToStartEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
}
[Fact]
public async Task AuthenticatedUserWithoutPermissionCausesForbidden()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated"));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
}
}
[Fact]
public async Task AuthorizedConnectionCanConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
o.AddPolicy("secondPolicy", policy =>
{
policy.RequireClaim(ClaimTypes.StreetAddress);
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
options.AuthorizationData.Add(new AuthorizeAttribute("secondPolicy"));
// partially "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// reset HttpContext
context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
// fully "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[]
{
new Claim(ClaimTypes.NameIdentifier, "name"),
new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW")
}));
// First poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// Initial poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
}
[Fact]
public async Task SetsInherentKeepAliveFeatureOnFirstLongPollingRequest()
{
@ -2119,50 +1767,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
private class RejectHandler : TestAuthenticationHandler
{
protected override bool ShouldAccept => false;
}
private class TestAuthenticationHandler : IAuthenticationHandler
{
private HttpContext HttpContext;
private AuthenticationScheme _scheme;
protected virtual bool ShouldAccept => true;
public Task<AuthenticateResult> AuthenticateAsync()
{
if (ShouldAccept)
{
return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name)));
}
else
{
return Task.FromResult(AuthenticateResult.NoResult());
}
}
public Task ChallengeAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
return Task.CompletedTask;
}
public Task ForbidAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
return Task.CompletedTask;
}
public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context)
{
HttpContext = context;
_scheme = scheme;
return Task.CompletedTask;
}
}
private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory)
{
var manager = CreateConnectionManager(loggerFactory);

View File

@ -4,11 +4,14 @@
using System;
using System.Linq;
using System.Net.WebSockets;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Cors;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Routing;
@ -34,39 +37,113 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public void MapConnectionHandlerFindsAuthAttributeOnEndPoint()
{
var authCount = 0;
using (var builder = BuildWebHost<AuthConnectionHandler>("/auth",
using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count))
{
builder.Start();
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerFindsAuthAttributeOnInheritedEndPoint()
{
var authCount = 0;
using (var builder = BuildWebHost<InheritedAuthConnectionHandler>("/auth",
using (var host = BuildWebHost<InheritedAuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count))
{
builder.Start();
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerFindsAuthAttributesOnDoubleAuthEndPoint()
{
var authCount = 0;
using (var builder = BuildWebHost<DoubleAuthConnectionHandler>("/auth",
using (var host = BuildWebHost<DoubleAuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count))
{
builder.Start();
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(2, authCount);
Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerFindsAttributesFromEndPointAndOptions()
{
var authCount = 0;
using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
}))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
}
[Fact]
@ -82,12 +159,50 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>());
Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>());
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerEndPointRoutingFindsAttributesFromOptions()
{
var authCount = 0;
using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapConnectionHandler<AuthConnectionHandler>("/path", options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
}
[Fact]
@ -106,9 +221,27 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()?.Policy);
Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()?.Policy);
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
});
}
}
@ -126,9 +259,45 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<NegotiateMetadata>());
Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata<NegotiateMetadata>());
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
}
}
[Fact]
public void MapConnectionHandlerEndPointRoutingAppliesCorsMetadata()
{
void ConfigureRoutes(IEndpointRouteBuilder endpoints)
{
endpoints.MapConnectionHandler<CorsConnectionHandler>("/path");
}
using (var host = BuildWebHostWithEndPointRouting(ConfigureRoutes))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
});
}
}
@ -177,6 +346,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
[EnableCors]
private class CorsConnectionHandler : ConnectionHandler
{
public override Task OnConnectedAsync(ConnectionContext connection)
{
throw new NotImplementedException();
}
}
private class InheritedAuthConnectionHandler : AuthConnectionHandler
{
public override Task OnConnectedAsync(ConnectionContext connection)
@ -227,10 +405,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
})
.Configure(app =>
{
#pragma warning disable CS0618 // Type or member is obsolete
app.UseConnections(routes =>
{
routes.MapConnectionHandler<TConnectionHandler>(path, configureOptions);
});
#pragma warning restore CS0618 // Type or member is obsolete
})
.ConfigureLogging(factory =>
{

View File

@ -15,6 +15,7 @@
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Authentication.Core" />
<Reference Include="Microsoft.AspNetCore.Cors" />
<Reference Include="Microsoft.AspNetCore.Http.Connections" />
<Reference Include="Microsoft.AspNetCore.Http" />
<Reference Include="Newtonsoft.Json" />

View File

@ -3,11 +3,11 @@
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
namespace Microsoft.AspNetCore.Internal
{
internal static class TaskCache
{
public static readonly Task<bool> True = Task.FromResult(true);
public static readonly Task<bool> False = Task.FromResult(false);
}
}
}

View File

@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;

View File

@ -12,6 +12,7 @@
<Compile Include="$(SharedSourceRoot)ClosedGenericMatcher\*.cs" />
<Compile Include="$(SharedSourceRoot)ObjectMethodExecutor\*.cs" />
<Compile Include="$(SignalRSharedSourceRoot)AsyncEnumerableAdapters.cs" Link="Internal\AsyncEnumerableAdapters.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TaskCache.cs" Link="Internal\TaskCache.cs" />
</ItemGroup>
<ItemGroup>

View File

@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Builder
}
public static partial class SignalRAppBuilderExtensions
{
[System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseSignalR(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.SignalR.HubRouteBuilder> configure) { throw null; }
}
}
@ -25,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR
internal HubEndpointConventionBuilder() { }
public void Add(System.Action<Microsoft.AspNetCore.Builder.EndpointBuilder> convention) { }
}
[System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public partial class HubRouteBuilder
{
public HubRouteBuilder(Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder routes) { }

View File

@ -2,8 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.SignalR;
@ -44,14 +42,6 @@ namespace Microsoft.AspNetCore.Builder
}
var options = new HttpConnectionDispatcherOptions();
// REVIEW: WE should consider removing this and instead just relying on the
// AuthorizationMiddleware
var attributes = typeof(THub).GetCustomAttributes(inherit: true);
foreach (var attribute in attributes.OfType<AuthorizeAttribute>())
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options);
var conventionBuilder = endpoints.MapConnections(pattern, options, b =>
@ -59,9 +49,10 @@ namespace Microsoft.AspNetCore.Builder
b.UseHub<THub>();
});
var attributes = typeof(THub).GetCustomAttributes(inherit: true);
conventionBuilder.Add(e =>
{
// Add all attributes on the Hub has metadata (this will allow for things like)
// Add all attributes on the Hub as metadata (this will allow for things like)
// auth attributes and cors attributes to work seamlessly
foreach (var item in attributes)
{

View File

@ -4,17 +4,25 @@
using System;
using System.Reflection;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.SignalR
{
/// <summary>
/// Maps incoming requests to <see cref="Hub"/> types.
/// <para>
/// This class is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapHub&#60;THub&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary>
[Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public class HubRouteBuilder
{
private readonly ConnectionsRouteBuilder _routes;
private readonly IEndpointRouteBuilder _endpoints;
/// <summary>
/// Initializes a new instance of the <see cref="HubRouteBuilder"/> class.
@ -25,6 +33,11 @@ namespace Microsoft.AspNetCore.SignalR
_routes = routes;
}
internal HubRouteBuilder(IEndpointRouteBuilder endpoints)
{
_endpoints = endpoints;
}
/// <summary>
/// Maps incoming requests with the specified path to the specified <see cref="Hub"/> type.
/// </summary>
@ -43,6 +56,14 @@ namespace Microsoft.AspNetCore.SignalR
/// <param name="configureOptions">A callback to configure dispatcher options.</param>
public void MapHub<THub>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where THub : Hub
{
// This will be null if someone is manually using the HubRouteBuilder(ConnectionsRouteBuilder routes) constructor
// SignalR itself will only use the IEndpointRouteBuilder overload
if (_endpoints != null)
{
_endpoints.MapHub<THub>(path, configureOptions);
return;
}
// find auth attributes
var authorizeAttributes = typeof(THub).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpConnectionDispatcherOptions();

View File

@ -14,10 +14,15 @@ namespace Microsoft.AspNetCore.Builder
{
/// <summary>
/// Adds SignalR to the <see cref="IApplicationBuilder"/> request execution pipeline.
/// <para>
/// This method is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapHub&#60;THub&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary>
/// <param name="app">The <see cref="IApplicationBuilder"/>.</param>
/// <param name="configure">A callback to configure hub routes.</param>
/// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns>
[Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static IApplicationBuilder UseSignalR(this IApplicationBuilder app, Action<HubRouteBuilder> configure)
{
var marker = app.ApplicationServices.GetService<SignalRMarkerService>();
@ -27,9 +32,12 @@ namespace Microsoft.AspNetCore.Builder
"'IServiceCollection.AddSignalR' inside the call to 'ConfigureServices(...)' in the application startup code.");
}
app.UseConnections(routes =>
app.UseWebSockets();
app.UseRouting();
app.UseAuthorization();
app.UseEndpoints(endpoints =>
{
configure(new HubRouteBuilder(routes));
configure(new HubRouteBuilder(endpoints));
});
return app;

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Authorization;
namespace Microsoft.AspNetCore.SignalR.Tests
{
[Authorize]
class AuthHub : Hub
{
}
}

View File

@ -446,6 +446,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedConnectionCanConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/auth";
var connection = new HttpConnection(new HttpConnectionOptions()
{
AccessTokenProvider = () => Task.FromResult(token),
Url = new Uri(url),
Transports = HttpTransportType.ServerSentEvents
}, LoggerFactory);
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync(TransferFormat.Text).OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[ConditionalFact]
[WebSocketsSupportedCondition]
public async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated_WebSocket()
@ -532,6 +578,178 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UnauthorizedHubConnectionDoesNotConnectWithEndpoints()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
var url = server.Url + "/authHubEndpoints";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling)
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
Assert.True(false);
}
catch (Exception ex)
{
Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UnauthorizedHubConnectionDoesNotConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
var url = server.Url + "/authHub";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling)
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
Assert.True(false);
}
catch (Exception ex)
{
Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedHubConnectionCanConnectWithEndpoints()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/authHubEndpoints";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling, o =>
{
o.AccessTokenProvider = () => Task.FromResult(token);
})
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedHubConnectionCanConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/authHub";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling, o =>
{
o.AccessTokenProvider = () => Task.FromResult(token);
})
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
// Serves a fake transport that lets us verify fallback behavior
private class TestTransportFactory : ITransportFactory
{

View File

@ -1,7 +1,4 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
@ -42,10 +39,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var ex = Assert.Throws<InvalidOperationException>(() =>
{
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes =>
{
routes.MapHub<AuthHub>("/overloads");
});
#pragma warning restore CS0618 // Type or member is obsolete
});
Assert.Equal("Unable to find the required services. Please add all the required services by calling " +
@ -109,9 +108,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
@ -124,9 +137,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
@ -139,9 +166,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(2, authCount);
Assert.Equal(0, authCount);
}
[Fact]
@ -157,12 +198,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>());
Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>());
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(1, authCount);
Assert.Equal(0, authCount);
}
[Fact]
public void MapHubEndPointRoutingFindsAttributesOnHubAndFromOptions()
{
var authCount = 0;
HttpConnectionDispatcherOptions configuredOptions = null;
using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapHub<AuthHub>("/path", options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
configuredOptions = options;
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
}
[Fact]
@ -181,9 +262,27 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()?.Policy);
Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()?.Policy);
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
});
}
}
@ -202,11 +301,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count);
Assert.Equal(typeof(AuthHub), dataSource.Endpoints[0].Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.Equal(typeof(AuthHub), dataSource.Endpoints[1].Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<NegotiateMetadata>());
Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata<NegotiateMetadata>());
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
}
}
[Fact]
public void MapHubAppliesHubMetadata()
{
#pragma warning disable CS0618 // Type or member is obsolete
void ConfigureRoutes(HubRouteBuilder routes)
#pragma warning restore CS0618 // Type or member is obsolete
{
// This "Foo" policy should override the default auth attribute
routes.MapHub<AuthHub>("/path");
}
using (var host = BuildWebHost(ConfigureRoutes))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
}
}
@ -252,6 +392,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
.Build();
}
#pragma warning disable CS0618 // Type or member is obsolete
private IWebHost BuildWebHost(Action<HubRouteBuilder> configure)
{
return new WebHostBuilder()
@ -267,5 +408,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
.UseUrls("http://127.0.0.1:0")
.Build();
}
#pragma warning restore CS0618 // Type or member is obsolete
}
}

View File

@ -13,8 +13,8 @@
</ItemGroup>
<ItemGroup>
<Reference Include="Microsoft.AspNetCore.Authentication.Cookies" />
<Reference Include="Microsoft.AspNetCore.Authentication" />
<Reference Include="Microsoft.AspNetCore.Authentication.JwtBearer" />
<Reference Include="Microsoft.AspNetCore.Http.Abstractions" />
<Reference Include="Microsoft.AspNetCore.SignalR.Client" />
<Reference Include="Microsoft.AspNetCore.SignalR.Specification.Tests" />

View File

@ -1,14 +1,24 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Authentication.Cookies;
using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.IdentityModel.Tokens;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class Startup
{
private readonly SymmetricSecurityKey SecurityKey = new SymmetricSecurityKey(Guid.NewGuid().ToByteArray());
private readonly JwtSecurityTokenHandler JwtTokenHandler = new JwtSecurityTokenHandler();
public void ConfigureServices(IServiceCollection services)
{
services.AddConnections();
@ -19,11 +29,40 @@ namespace Microsoft.AspNetCore.SignalR.Tests
services.AddAuthentication(options =>
{
options.DefaultAuthenticateScheme = CookieAuthenticationDefaults.AuthenticationScheme;
options.DefaultChallengeScheme = CookieAuthenticationDefaults.AuthenticationScheme;
}).AddCookie();
options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme;
options.DefaultChallengeScheme = JwtBearerDefaults.AuthenticationScheme;
}).AddJwtBearer(options =>
{
options.TokenValidationParameters =
new TokenValidationParameters
{
LifetimeValidator = (before, expires, token, parameters) => expires > DateTime.UtcNow,
ValidateAudience = false,
ValidateIssuer = false,
ValidateActor = false,
ValidateLifetime = true,
IssuerSigningKey = SecurityKey
};
options.Events = new JwtBearerEvents
{
OnMessageReceived = context =>
{
var accessToken = context.Request.Query["access_token"];
if (!string.IsNullOrEmpty(accessToken) &&
(context.HttpContext.WebSockets.IsWebSocketRequest || context.Request.Headers["Accept"] == "text/event-stream"))
{
context.Token = context.Request.Query["access_token"];
}
return Task.CompletedTask;
}
};
});
services.AddAuthorization();
services.AddSingleton<IAuthorizationHandler, TestAuthHandler>();
}
public void Configure(IApplicationBuilder app)
@ -32,15 +71,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests
app.UseAuthentication();
app.UseAuthorization();
// Legacy routing, runs different code path for mapping hubs
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes =>
{
routes.MapHub<AuthHub>("/authHub");
});
#pragma warning restore CS0618 // Type or member is obsolete
app.UseEndpoints(endpoints =>
{
endpoints.MapHub<UncreatableHub>("/uncreatable");
endpoints.MapHub<AuthHub>("/authHubEndpoints");
endpoints.MapConnectionHandler<EchoConnectionHandler>("/echo");
endpoints.MapConnectionHandler<WriteThenCloseConnectionHandler>("/echoAndClose");
endpoints.MapConnectionHandler<HttpHeaderConnectionHandler>("/httpheader");
endpoints.MapConnectionHandler<AuthConnectionHandler>("/auth");
endpoints.MapGet("/generatetoken", context =>
{
return context.Response.WriteAsync(GenerateToken(context));
});
});
}
private string GenerateToken(HttpContext httpContext)
{
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, httpContext.Request.Query["user"]) };
var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.UtcNow.AddMinutes(1), signingCredentials: credentials);
return JwtTokenHandler.WriteToken(token);
}
}
}

View File

@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class TestAuthHandler : IAuthorizationHandler
{
public Task HandleAsync(AuthorizationHandlerContext context)
{
foreach (var req in context.Requirements)
{
context.Succeed(req);
}
var hasClaim = context.User.HasClaim(o => o.Type == ClaimTypes.NameIdentifier && !string.IsNullOrEmpty(o.Value));
if (!hasClaim)
{
context.Fail();
}
return Task.CompletedTask;
}
}
}

View File

@ -1,9 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class UncreatableHub: Hub
public class UncreatableHub : Hub
{
public UncreatableHub(object obj)
{