AuthorizeHelper will no-op if endpoint routing is used (#10471)

This commit is contained in:
Brennan 2019-06-02 22:35:45 -07:00 committed by GitHub
parent 25672336f9
commit f0df10f211
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 761 additions and 598 deletions

View File

@ -9,10 +9,12 @@ namespace Microsoft.AspNetCore.Analyzers.TestFiles.CompilationFeatureDetectorTes
{ {
public void Configure(IApplicationBuilder app) public void Configure(IApplicationBuilder app)
{ {
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes => app.UseSignalR(routes =>
{ {
}); });
#pragma warning restore CS0618 // Type or member is obsolete
} }
} }
} }

View File

@ -1,18 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.IdentityModel.Tokens;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{ {
@ -33,11 +25,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public void Configure(IApplicationBuilder app) public void Configure(IApplicationBuilder app)
{ {
app.UseRouting();
app.UseAuthentication(); app.UseAuthentication();
app.UseSignalR(routes => app.UseEndpoints(endpoints =>
{ {
routes.MapHub<VersionHub>("/version"); endpoints.MapHub<VersionHub>("/version");
}); });
} }
} }

View File

@ -12,6 +12,7 @@ namespace Microsoft.AspNetCore.Builder
} }
public static partial class ConnectionsAppBuilderExtensions public static partial class ConnectionsAppBuilderExtensions
{ {
[System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseConnections(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder> configure) { throw null; } public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseConnections(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder> configure) { throw null; }
} }
} }
@ -28,6 +29,7 @@ namespace Microsoft.AspNetCore.Http.Connections
public ConnectionOptionsSetup() { } public ConnectionOptionsSetup() { }
public void Configure(Microsoft.AspNetCore.Http.Connections.ConnectionOptions options) { } public void Configure(Microsoft.AspNetCore.Http.Connections.ConnectionOptions options) { }
} }
[System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public partial class ConnectionsRouteBuilder public partial class ConnectionsRouteBuilder
{ {
internal ConnectionsRouteBuilder() { } internal ConnectionsRouteBuilder() { }

View File

@ -3,8 +3,6 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Linq;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.Http.Connections.Internal;
@ -48,13 +46,6 @@ namespace Microsoft.AspNetCore.Builder
public static IEndpointConventionBuilder MapConnectionHandler<TConnectionHandler>(this IEndpointRouteBuilder endpoints, string pattern, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler public static IEndpointConventionBuilder MapConnectionHandler<TConnectionHandler>(this IEndpointRouteBuilder endpoints, string pattern, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler
{ {
var options = new HttpConnectionDispatcherOptions(); var options = new HttpConnectionDispatcherOptions();
// REVIEW: WE should consider removing this and instead just relying on the
// AuthorizationMiddleware
var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true);
foreach (var attribute in attributes.OfType<AuthorizeAttribute>())
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options); configureOptions?.Invoke(options);
var conventionBuilder = endpoints.MapConnections(pattern, options, b => var conventionBuilder = endpoints.MapConnections(pattern, options, b =>
@ -62,6 +53,7 @@ namespace Microsoft.AspNetCore.Builder
b.UseConnectionHandler<TConnectionHandler>(); b.UseConnectionHandler<TConnectionHandler>();
}); });
var attributes = typeof(TConnectionHandler).GetCustomAttributes(inherit: true);
conventionBuilder.Add(e => conventionBuilder.Add(e =>
{ {
// Add all attributes on the ConnectionHandler has metadata (this will allow for things like) // Add all attributes on the ConnectionHandler has metadata (this will allow for things like)
@ -93,7 +85,7 @@ namespace Microsoft.AspNetCore.Builder
var connectionDelegate = connectionBuilder.Build(); var connectionDelegate = connectionBuilder.Build();
// REVIEW: Consider expanding the internals of the dispatcher as endpoint routes instead of // REVIEW: Consider expanding the internals of the dispatcher as endpoint routes instead of
// using if statemants we can let the matcher handle // using if statements we can let the matcher handle
var conventionBuilders = new List<IEndpointConventionBuilder>(); var conventionBuilders = new List<IEndpointConventionBuilder>();

View File

@ -3,9 +3,6 @@
using System; using System;
using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Routing;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Builder namespace Microsoft.AspNetCore.Builder
{ {
@ -16,10 +13,15 @@ namespace Microsoft.AspNetCore.Builder
{ {
/// <summary> /// <summary>
/// Adds support for ASP.NET Core Connection Handlers to the <see cref="IApplicationBuilder"/> request execution pipeline. /// Adds support for ASP.NET Core Connection Handlers to the <see cref="IApplicationBuilder"/> request execution pipeline.
/// <para>
/// This method is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapConnections or MapConnectionHandler&#60;TConnectionHandler&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary> /// </summary>
/// <param name="app">The <see cref="IApplicationBuilder"/>.</param> /// <param name="app">The <see cref="IApplicationBuilder"/>.</param>
/// <param name="configure">A callback to configure connection routes.</param> /// <param name="configure">A callback to configure connection routes.</param>
/// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns> /// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns>
[Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapConnections or MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static IApplicationBuilder UseConnections(this IApplicationBuilder app, Action<ConnectionsRouteBuilder> configure) public static IApplicationBuilder UseConnections(this IApplicationBuilder app, Action<ConnectionsRouteBuilder> configure)
{ {
if (configure == null) if (configure == null)
@ -27,14 +29,13 @@ namespace Microsoft.AspNetCore.Builder
throw new ArgumentNullException(nameof(configure)); throw new ArgumentNullException(nameof(configure));
} }
var dispatcher = app.ApplicationServices.GetRequiredService<HttpConnectionDispatcher>();
var routes = new RouteBuilder(app);
configure(new ConnectionsRouteBuilder(routes, dispatcher));
app.UseWebSockets(); app.UseWebSockets();
app.UseRouter(routes.Build()); app.UseRouting();
app.UseAuthorization();
app.UseEndpoints(endpoints =>
{
configure(new ConnectionsRouteBuilder(endpoints));
});
return app; return app;
} }
} }

View File

@ -2,26 +2,27 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Reflection; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Http.Connections namespace Microsoft.AspNetCore.Http.Connections
{ {
/// <summary> /// <summary>
/// Maps routes to ASP.NET Core Connection Handlers. /// Maps routes to ASP.NET Core Connection Handlers.
/// <para>
/// This class is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapConnection and MapConnectionHandler&#60;TConnectionHandler&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary> /// </summary>
[Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapConnection and MapConnectionHandler<TConnectionHandler> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public class ConnectionsRouteBuilder public class ConnectionsRouteBuilder
{ {
private readonly HttpConnectionDispatcher _dispatcher; private readonly IEndpointRouteBuilder _endpoints;
private readonly RouteBuilder _routes;
internal ConnectionsRouteBuilder(RouteBuilder routes, HttpConnectionDispatcher dispatcher) internal ConnectionsRouteBuilder(IEndpointRouteBuilder endpoints)
{ {
_routes = routes; _endpoints = endpoints;
_dispatcher = dispatcher;
} }
/// <summary> /// <summary>
@ -38,24 +39,16 @@ namespace Microsoft.AspNetCore.Http.Connections
/// <param name="path">The request path.</param> /// <param name="path">The request path.</param>
/// <param name="options">Options used to configure the connection.</param> /// <param name="options">Options used to configure the connection.</param>
/// <param name="configure">A callback to configure the connection.</param> /// <param name="configure">A callback to configure the connection.</param>
public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action<IConnectionBuilder> configure) public void MapConnections(PathString path, HttpConnectionDispatcherOptions options, Action<IConnectionBuilder> configure) =>
{ _endpoints.MapConnections(path, options, configure);
var connectionBuilder = new ConnectionBuilder(_routes.ServiceProvider);
configure(connectionBuilder);
var socket = connectionBuilder.Build();
_routes.MapRoute(path, c => _dispatcher.ExecuteAsync(c, options, socket));
_routes.MapRoute(path + "/negotiate", c => _dispatcher.ExecuteNegotiateAsync(c, options));
}
/// <summary> /// <summary>
/// Maps incoming requests with the specified path to the provided connection pipeline. /// Maps incoming requests with the specified path to the provided connection pipeline.
/// </summary> /// </summary>
/// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam> /// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam>
/// <param name="path">The request path.</param> /// <param name="path">The request path.</param>
public void MapConnectionHandler<TConnectionHandler>(PathString path) where TConnectionHandler : ConnectionHandler public void MapConnectionHandler<TConnectionHandler>(PathString path) where TConnectionHandler : ConnectionHandler =>
{
MapConnectionHandler<TConnectionHandler>(path, configureOptions: null); MapConnectionHandler<TConnectionHandler>(path, configureOptions: null);
}
/// <summary> /// <summary>
/// Maps incoming requests with the specified path to the provided connection pipeline. /// Maps incoming requests with the specified path to the provided connection pipeline.
@ -63,20 +56,7 @@ namespace Microsoft.AspNetCore.Http.Connections
/// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam> /// <typeparam name="TConnectionHandler">The <see cref="ConnectionHandler"/> type.</typeparam>
/// <param name="path">The request path.</param> /// <param name="path">The request path.</param>
/// <param name="configureOptions">A callback to configure dispatcher options.</param> /// <param name="configureOptions">A callback to configure dispatcher options.</param>
public void MapConnectionHandler<TConnectionHandler>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler public void MapConnectionHandler<TConnectionHandler>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where TConnectionHandler : ConnectionHandler =>
{ _endpoints.MapConnectionHandler<TConnectionHandler>(path, configureOptions);
var authorizeAttributes = typeof(TConnectionHandler).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpConnectionDispatcherOptions();
foreach (var attribute in authorizeAttributes)
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options);
MapConnections(path, options, builder =>
{
builder.UseConnectionHandler<TConnectionHandler>();
});
}
} }
} }

View File

@ -1,69 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Authorization.Policy;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Http.Connections.Internal
{
internal static class AuthorizeHelper
{
public static async Task<bool> AuthorizeAsync(HttpContext context, IList<IAuthorizeData> policies)
{
if (policies.Count == 0)
{
return true;
}
var policyProvider = context.RequestServices.GetRequiredService<IAuthorizationPolicyProvider>();
var authorizePolicy = await AuthorizationPolicy.CombineAsync(policyProvider, policies);
var policyEvaluator = context.RequestServices.GetRequiredService<IPolicyEvaluator>();
// This will set context.User if required
var authenticateResult = await policyEvaluator.AuthenticateAsync(authorizePolicy, context);
var authorizeResult = await policyEvaluator.AuthorizeAsync(authorizePolicy, authenticateResult, context, resource: null);
if (authorizeResult.Succeeded)
{
return true;
}
else if (authorizeResult.Challenged)
{
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ChallengeAsync(scheme);
}
}
else
{
await context.ChallengeAsync();
}
return false;
}
else if (authorizeResult.Forbidden)
{
if (authorizePolicy.AuthenticationSchemes.Count > 0)
{
foreach (var scheme in authorizePolicy.AuthenticationSchemes)
{
await context.ForbidAsync(scheme);
}
}
else
{
await context.ForbidAsync();
}
return false;
}
return false;
}
}
}

View File

@ -61,11 +61,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
var logScope = new ConnectionLogScope(GetConnectionId(context)); var logScope = new ConnectionLogScope(GetConnectionId(context));
using (_logger.BeginScope(logScope)) using (_logger.BeginScope(logScope))
{ {
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData))
{
return;
}
if (HttpMethods.IsPost(context.Request.Method)) if (HttpMethods.IsPost(context.Request.Method))
{ {
// POST /{path} // POST /{path}
@ -95,11 +90,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
var logScope = new ConnectionLogScope(connectionId: string.Empty); var logScope = new ConnectionLogScope(connectionId: string.Empty);
using (_logger.BeginScope(logScope)) using (_logger.BeginScope(logScope))
{ {
if (!await AuthorizeHelper.AuthorizeAsync(context, options.AuthorizationData))
{
return;
}
if (HttpMethods.IsPost(context.Request.Method)) if (HttpMethods.IsPost(context.Request.Method))
{ {
// POST /{path}/negotiate // POST /{path}/negotiate

View File

@ -14,6 +14,7 @@
<Compile Include="$(SignalRSharedSourceRoot)WebSocketExtensions.cs" Link="WebSocketExtensions.cs" /> <Compile Include="$(SignalRSharedSourceRoot)WebSocketExtensions.cs" Link="WebSocketExtensions.cs" />
<Compile Include="$(SignalRSharedSourceRoot)StreamExtensions.cs" Link="StreamExtensions.cs" /> <Compile Include="$(SignalRSharedSourceRoot)StreamExtensions.cs" Link="StreamExtensions.cs" />
<Compile Include="$(SignalRSharedSourceRoot)DuplexPipe.cs" Link="DuplexPipe.cs" /> <Compile Include="$(SignalRSharedSourceRoot)DuplexPipe.cs" Link="DuplexPipe.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TaskCache.cs" Link="Internal\TaskCache.cs" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@ -1347,358 +1347,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
} }
} }
[Fact]
public async Task UnauthorizedConnectionFailsToStartEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
}
[Fact]
public async Task AuthenticatedUserWithoutPermissionCausesForbidden()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy => policy.RequireClaim(ClaimTypes.NameIdentifier));
});
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
services.AddLogging();
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
context.User = new ClaimsPrincipal(new ClaimsIdentity("authenticated"));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status403Forbidden, context.Response.StatusCode);
}
}
[Fact]
public async Task AuthorizedConnectionCanConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AllPoliciesRequiredForAuthorizedEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
});
o.AddPolicy("secondPolicy", policy =>
{
policy.RequireClaim(ClaimTypes.StreetAddress);
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
options.AuthorizationData.Add(new AuthorizeAttribute("secondPolicy"));
// partially "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would get stuck if EndPoint was running
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
// reset HttpContext
context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
// fully "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[]
{
new Claim(ClaimTypes.NameIdentifier, "name"),
new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW")
}));
// First poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AuthorizedConnectionWithAcceptedSchemesCanConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(TestAuthenticationHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// Initial poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
}
[Fact]
public async Task AuthorizedConnectionWithRejectedSchemesFailsToConnectToEndPoint()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
services.AddAuthorization(o =>
{
o.AddPolicy("test", policy =>
{
policy.RequireClaim(ClaimTypes.NameIdentifier);
policy.AddAuthenticationSchemes("Default");
});
});
services.AddLogging();
services.AddAuthenticationCore(o =>
{
o.DefaultScheme = "Default";
o.AddScheme("Default", a => a.HandlerType = typeof(RejectHandler));
});
var sp = services.BuildServiceProvider();
context.Request.Path = "/foo";
context.Request.Method = "GET";
context.RequestServices = sp;
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
var builder = new ConnectionBuilder(sp);
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.AuthorizationData.Add(new AuthorizeAttribute("test"));
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// would block if EndPoint was executed
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status401Unauthorized, context.Response.StatusCode);
}
}
[Fact] [Fact]
public async Task SetsInherentKeepAliveFeatureOnFirstLongPollingRequest() public async Task SetsInherentKeepAliveFeatureOnFirstLongPollingRequest()
{ {
@ -2119,50 +1767,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
} }
} }
private class RejectHandler : TestAuthenticationHandler
{
protected override bool ShouldAccept => false;
}
private class TestAuthenticationHandler : IAuthenticationHandler
{
private HttpContext HttpContext;
private AuthenticationScheme _scheme;
protected virtual bool ShouldAccept => true;
public Task<AuthenticateResult> AuthenticateAsync()
{
if (ShouldAccept)
{
return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(HttpContext.User, _scheme.Name)));
}
else
{
return Task.FromResult(AuthenticateResult.NoResult());
}
}
public Task ChallengeAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status401Unauthorized;
return Task.CompletedTask;
}
public Task ForbidAsync(AuthenticationProperties properties)
{
HttpContext.Response.StatusCode = StatusCodes.Status403Forbidden;
return Task.CompletedTask;
}
public Task InitializeAsync(AuthenticationScheme scheme, HttpContext context)
{
HttpContext = context;
_scheme = scheme;
return Task.CompletedTask;
}
}
private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory) private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory)
{ {
var manager = CreateConnectionManager(loggerFactory); var manager = CreateConnectionManager(loggerFactory);

View File

@ -4,11 +4,14 @@
using System; using System;
using System.Linq; using System.Linq;
using System.Net.WebSockets; using System.Net.WebSockets;
using System.Reflection;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Cors;
using Microsoft.AspNetCore.Cors.Infrastructure;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Hosting.Server.Features; using Microsoft.AspNetCore.Hosting.Server.Features;
using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing;
@ -34,39 +37,113 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public void MapConnectionHandlerFindsAuthAttributeOnEndPoint() public void MapConnectionHandlerFindsAuthAttributeOnEndPoint()
{ {
var authCount = 0; var authCount = 0;
using (var builder = BuildWebHost<AuthConnectionHandler>("/auth", using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count)) options => authCount += options.AuthorizationData.Count))
{ {
builder.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
} }
[Fact] [Fact]
public void MapConnectionHandlerFindsAuthAttributeOnInheritedEndPoint() public void MapConnectionHandlerFindsAuthAttributeOnInheritedEndPoint()
{ {
var authCount = 0; var authCount = 0;
using (var builder = BuildWebHost<InheritedAuthConnectionHandler>("/auth", using (var host = BuildWebHost<InheritedAuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count)) options => authCount += options.AuthorizationData.Count))
{ {
builder.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
} }
[Fact] [Fact]
public void MapConnectionHandlerFindsAuthAttributesOnDoubleAuthEndPoint() public void MapConnectionHandlerFindsAuthAttributesOnDoubleAuthEndPoint()
{ {
var authCount = 0; var authCount = 0;
using (var builder = BuildWebHost<DoubleAuthConnectionHandler>("/auth", using (var host = BuildWebHost<DoubleAuthConnectionHandler>("/auth",
options => authCount += options.AuthorizationData.Count)) options => authCount += options.AuthorizationData.Count))
{ {
builder.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
} }
Assert.Equal(2, authCount); Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerFindsAttributesFromEndPointAndOptions()
{
var authCount = 0;
using (var host = BuildWebHost<AuthConnectionHandler>("/auth",
options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
}))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/auth/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/auth", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -82,12 +159,50 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()); endpoint =>
Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()); {
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Single(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>());
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
}
[Fact]
public void MapConnectionHandlerEndPointRoutingFindsAttributesFromOptions()
{
var authCount = 0;
using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapConnectionHandler<AuthConnectionHandler>("/path", options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -106,9 +221,27 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()?.Policy); endpoint =>
Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()?.Policy); {
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
});
} }
} }
@ -126,9 +259,45 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<NegotiateMetadata>()); endpoint =>
Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata<NegotiateMetadata>()); {
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
}
}
[Fact]
public void MapConnectionHandlerEndPointRoutingAppliesCorsMetadata()
{
void ConfigureRoutes(IEndpointRouteBuilder endpoints)
{
endpoints.MapConnectionHandler<CorsConnectionHandler>("/path");
}
using (var host = BuildWebHostWithEndPointRouting(ConfigureRoutes))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.NotNull(endpoint.Metadata.GetMetadata<IEnableCorsAttribute>());
});
} }
} }
@ -177,6 +346,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
} }
} }
[EnableCors]
private class CorsConnectionHandler : ConnectionHandler
{
public override Task OnConnectedAsync(ConnectionContext connection)
{
throw new NotImplementedException();
}
}
private class InheritedAuthConnectionHandler : AuthConnectionHandler private class InheritedAuthConnectionHandler : AuthConnectionHandler
{ {
public override Task OnConnectedAsync(ConnectionContext connection) public override Task OnConnectedAsync(ConnectionContext connection)
@ -227,10 +405,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}) })
.Configure(app => .Configure(app =>
{ {
#pragma warning disable CS0618 // Type or member is obsolete
app.UseConnections(routes => app.UseConnections(routes =>
{ {
routes.MapConnectionHandler<TConnectionHandler>(path, configureOptions); routes.MapConnectionHandler<TConnectionHandler>(path, configureOptions);
}); });
#pragma warning restore CS0618 // Type or member is obsolete
}) })
.ConfigureLogging(factory => .ConfigureLogging(factory =>
{ {

View File

@ -15,6 +15,7 @@
<ItemGroup> <ItemGroup>
<Reference Include="Microsoft.AspNetCore.Authentication.Core" /> <Reference Include="Microsoft.AspNetCore.Authentication.Core" />
<Reference Include="Microsoft.AspNetCore.Cors" />
<Reference Include="Microsoft.AspNetCore.Http.Connections" /> <Reference Include="Microsoft.AspNetCore.Http.Connections" />
<Reference Include="Microsoft.AspNetCore.Http" /> <Reference Include="Microsoft.AspNetCore.Http" />
<Reference Include="Newtonsoft.Json" /> <Reference Include="Newtonsoft.Json" />

View File

@ -3,11 +3,11 @@
using System.Threading.Tasks; using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal namespace Microsoft.AspNetCore.Internal
{ {
internal static class TaskCache internal static class TaskCache
{ {
public static readonly Task<bool> True = Task.FromResult(true); public static readonly Task<bool> True = Task.FromResult(true);
public static readonly Task<bool> False = Task.FromResult(false); public static readonly Task<bool> False = Task.FromResult(false);
} }
} }

View File

@ -11,6 +11,7 @@ using System.Threading;
using System.Threading.Channels; using System.Threading.Channels;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal; using Microsoft.Extensions.Internal;

View File

@ -12,6 +12,7 @@
<Compile Include="$(SharedSourceRoot)ClosedGenericMatcher\*.cs" /> <Compile Include="$(SharedSourceRoot)ClosedGenericMatcher\*.cs" />
<Compile Include="$(SharedSourceRoot)ObjectMethodExecutor\*.cs" /> <Compile Include="$(SharedSourceRoot)ObjectMethodExecutor\*.cs" />
<Compile Include="$(SignalRSharedSourceRoot)AsyncEnumerableAdapters.cs" Link="Internal\AsyncEnumerableAdapters.cs" /> <Compile Include="$(SignalRSharedSourceRoot)AsyncEnumerableAdapters.cs" Link="Internal\AsyncEnumerableAdapters.cs" />
<Compile Include="$(SignalRSharedSourceRoot)TaskCache.cs" Link="Internal\TaskCache.cs" />
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>

View File

@ -10,6 +10,7 @@ namespace Microsoft.AspNetCore.Builder
} }
public static partial class SignalRAppBuilderExtensions public static partial class SignalRAppBuilderExtensions
{ {
[System.ObsoleteAttribute("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseSignalR(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.SignalR.HubRouteBuilder> configure) { throw null; } public static Microsoft.AspNetCore.Builder.IApplicationBuilder UseSignalR(this Microsoft.AspNetCore.Builder.IApplicationBuilder app, System.Action<Microsoft.AspNetCore.SignalR.HubRouteBuilder> configure) { throw null; }
} }
} }
@ -25,6 +26,7 @@ namespace Microsoft.AspNetCore.SignalR
internal HubEndpointConventionBuilder() { } internal HubEndpointConventionBuilder() { }
public void Add(System.Action<Microsoft.AspNetCore.Builder.EndpointBuilder> convention) { } public void Add(System.Action<Microsoft.AspNetCore.Builder.EndpointBuilder> convention) { }
} }
[System.ObsoleteAttribute("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public partial class HubRouteBuilder public partial class HubRouteBuilder
{ {
public HubRouteBuilder(Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder routes) { } public HubRouteBuilder(Microsoft.AspNetCore.Http.Connections.ConnectionsRouteBuilder routes) { }

View File

@ -2,8 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System; using System;
using System.Linq;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.Routing;
using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR;
@ -44,14 +42,6 @@ namespace Microsoft.AspNetCore.Builder
} }
var options = new HttpConnectionDispatcherOptions(); var options = new HttpConnectionDispatcherOptions();
// REVIEW: WE should consider removing this and instead just relying on the
// AuthorizationMiddleware
var attributes = typeof(THub).GetCustomAttributes(inherit: true);
foreach (var attribute in attributes.OfType<AuthorizeAttribute>())
{
options.AuthorizationData.Add(attribute);
}
configureOptions?.Invoke(options); configureOptions?.Invoke(options);
var conventionBuilder = endpoints.MapConnections(pattern, options, b => var conventionBuilder = endpoints.MapConnections(pattern, options, b =>
@ -59,9 +49,10 @@ namespace Microsoft.AspNetCore.Builder
b.UseHub<THub>(); b.UseHub<THub>();
}); });
var attributes = typeof(THub).GetCustomAttributes(inherit: true);
conventionBuilder.Add(e => conventionBuilder.Add(e =>
{ {
// Add all attributes on the Hub has metadata (this will allow for things like) // Add all attributes on the Hub as metadata (this will allow for things like)
// auth attributes and cors attributes to work seamlessly // auth attributes and cors attributes to work seamlessly
foreach (var item in attributes) foreach (var item in attributes)
{ {

View File

@ -4,17 +4,25 @@
using System; using System;
using System.Reflection; using System.Reflection;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.SignalR namespace Microsoft.AspNetCore.SignalR
{ {
/// <summary> /// <summary>
/// Maps incoming requests to <see cref="Hub"/> types. /// Maps incoming requests to <see cref="Hub"/> types.
/// <para>
/// This class is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapHub&#60;THub&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary> /// </summary>
[Obsolete("This class is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public class HubRouteBuilder public class HubRouteBuilder
{ {
private readonly ConnectionsRouteBuilder _routes; private readonly ConnectionsRouteBuilder _routes;
private readonly IEndpointRouteBuilder _endpoints;
/// <summary> /// <summary>
/// Initializes a new instance of the <see cref="HubRouteBuilder"/> class. /// Initializes a new instance of the <see cref="HubRouteBuilder"/> class.
@ -25,6 +33,11 @@ namespace Microsoft.AspNetCore.SignalR
_routes = routes; _routes = routes;
} }
internal HubRouteBuilder(IEndpointRouteBuilder endpoints)
{
_endpoints = endpoints;
}
/// <summary> /// <summary>
/// Maps incoming requests with the specified path to the specified <see cref="Hub"/> type. /// Maps incoming requests with the specified path to the specified <see cref="Hub"/> type.
/// </summary> /// </summary>
@ -43,6 +56,14 @@ namespace Microsoft.AspNetCore.SignalR
/// <param name="configureOptions">A callback to configure dispatcher options.</param> /// <param name="configureOptions">A callback to configure dispatcher options.</param>
public void MapHub<THub>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where THub : Hub public void MapHub<THub>(PathString path, Action<HttpConnectionDispatcherOptions> configureOptions) where THub : Hub
{ {
// This will be null if someone is manually using the HubRouteBuilder(ConnectionsRouteBuilder routes) constructor
// SignalR itself will only use the IEndpointRouteBuilder overload
if (_endpoints != null)
{
_endpoints.MapHub<THub>(path, configureOptions);
return;
}
// find auth attributes // find auth attributes
var authorizeAttributes = typeof(THub).GetCustomAttributes<AuthorizeAttribute>(inherit: true); var authorizeAttributes = typeof(THub).GetCustomAttributes<AuthorizeAttribute>(inherit: true);
var options = new HttpConnectionDispatcherOptions(); var options = new HttpConnectionDispatcherOptions();

View File

@ -14,10 +14,15 @@ namespace Microsoft.AspNetCore.Builder
{ {
/// <summary> /// <summary>
/// Adds SignalR to the <see cref="IApplicationBuilder"/> request execution pipeline. /// Adds SignalR to the <see cref="IApplicationBuilder"/> request execution pipeline.
/// <para>
/// This method is obsolete and will be removed in a future version.
/// The recommended alternative is to use MapHub&#60;THub&#62; inside Microsoft.AspNetCore.Builder.UseEndpoints(...).
/// </para>
/// </summary> /// </summary>
/// <param name="app">The <see cref="IApplicationBuilder"/>.</param> /// <param name="app">The <see cref="IApplicationBuilder"/>.</param>
/// <param name="configure">A callback to configure hub routes.</param> /// <param name="configure">A callback to configure hub routes.</param>
/// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns> /// <returns>The same instance of the <see cref="IApplicationBuilder"/> for chaining.</returns>
[Obsolete("This method is obsolete and will be removed in a future version. The recommended alternative is to use MapHub<THub> inside Microsoft.AspNetCore.Builder.UseEndpoints(...).")]
public static IApplicationBuilder UseSignalR(this IApplicationBuilder app, Action<HubRouteBuilder> configure) public static IApplicationBuilder UseSignalR(this IApplicationBuilder app, Action<HubRouteBuilder> configure)
{ {
var marker = app.ApplicationServices.GetService<SignalRMarkerService>(); var marker = app.ApplicationServices.GetService<SignalRMarkerService>();
@ -27,9 +32,12 @@ namespace Microsoft.AspNetCore.Builder
"'IServiceCollection.AddSignalR' inside the call to 'ConfigureServices(...)' in the application startup code."); "'IServiceCollection.AddSignalR' inside the call to 'ConfigureServices(...)' in the application startup code.");
} }
app.UseConnections(routes => app.UseWebSockets();
app.UseRouting();
app.UseAuthorization();
app.UseEndpoints(endpoints =>
{ {
configure(new HubRouteBuilder(routes)); configure(new HubRouteBuilder(endpoints));
}); });
return app; return app;

View File

@ -0,0 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Authorization;
namespace Microsoft.AspNetCore.SignalR.Tests
{
[Authorize]
class AuthHub : Hub
{
}
}

View File

@ -446,6 +446,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedConnectionCanConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/auth";
var connection = new HttpConnection(new HttpConnectionOptions()
{
AccessTokenProvider = () => Task.FromResult(token),
Url = new Uri(url),
Transports = HttpTransportType.ServerSentEvents
}, LoggerFactory);
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync(TransferFormat.Text).OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[ConditionalFact] [ConditionalFact]
[WebSocketsSupportedCondition] [WebSocketsSupportedCondition]
public async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated_WebSocket() public async Task ServerClosesConnectionWithErrorIfHubCannotBeCreated_WebSocket()
@ -532,6 +578,178 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UnauthorizedHubConnectionDoesNotConnectWithEndpoints()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
var url = server.Url + "/authHubEndpoints";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling)
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
Assert.True(false);
}
catch (Exception ex)
{
Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task UnauthorizedHubConnectionDoesNotConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
var url = server.Url + "/authHub";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling)
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
Assert.True(false);
}
catch (Exception ex)
{
Assert.Equal("Response status code does not indicate success: 401 (Unauthorized).", ex.Message);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedHubConnectionCanConnectWithEndpoints()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/authHubEndpoints";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling, o =>
{
o.AccessTokenProvider = () => Task.FromResult(token);
})
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
[Fact]
[LogLevel(LogLevel.Trace)]
public async Task AuthorizedHubConnectionCanConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartServer<Startup>(out var server, ExpectedErrors))
{
var logger = LoggerFactory.CreateLogger<EndToEndTests>();
string token;
using (var client = new HttpClient())
{
client.BaseAddress = new Uri(server.Url);
var response = await client.GetAsync("generatetoken?user=bob");
token = await response.Content.ReadAsStringAsync();
}
var url = server.Url + "/authHub";
var connection = new HubConnectionBuilder()
.WithLoggerFactory(LoggerFactory)
.WithUrl(url, HttpTransportType.LongPolling, o =>
{
o.AccessTokenProvider = () => Task.FromResult(token);
})
.Build();
try
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();
logger.LogInformation("Connected to {url}", url);
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
}
}
// Serves a fake transport that lets us verify fallback behavior // Serves a fake transport that lets us verify fallback behavior
private class TestTransportFactory : ITransportFactory private class TestTransportFactory : ITransportFactory
{ {

View File

@ -1,7 +1,4 @@
using System; using System;
using System.Collections;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting;
@ -42,10 +39,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var ex = Assert.Throws<InvalidOperationException>(() => var ex = Assert.Throws<InvalidOperationException>(() =>
{ {
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes => app.UseSignalR(routes =>
{ {
routes.MapHub<AuthHub>("/overloads"); routes.MapHub<AuthHub>("/overloads");
}); });
#pragma warning restore CS0618 // Type or member is obsolete
}); });
Assert.Equal("Unable to find the required services. Please add all the required services by calling " + Assert.Equal("Unable to find the required services. Please add all the required services by calling " +
@ -109,9 +108,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}))) })))
{ {
host.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -124,9 +137,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}))) })))
{ {
host.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -139,9 +166,23 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}))) })))
{ {
host.Start(); host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
} }
Assert.Equal(2, authCount); Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -157,12 +198,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()); endpoint =>
Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()); {
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(1, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
} }
Assert.Equal(1, authCount); Assert.Equal(0, authCount);
}
[Fact]
public void MapHubEndPointRoutingFindsAttributesOnHubAndFromOptions()
{
var authCount = 0;
HttpConnectionDispatcherOptions configuredOptions = null;
using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapHub<AuthHub>("/path", options =>
{
authCount += options.AuthorizationData.Count;
options.AuthorizationData.Add(new AuthorizeAttribute());
configuredOptions = options;
})))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(2, endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>().Count);
});
}
Assert.Equal(0, authCount);
} }
[Fact] [Fact]
@ -181,9 +262,27 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata<IAuthorizeData>()?.Policy); endpoint =>
Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata<IAuthorizeData>()?.Policy); {
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Collection(endpoint.Metadata.GetOrderedMetadata<IAuthorizeData>(),
auth => { },
auth =>
{
Assert.Equal("Foo", auth?.Policy);
});
});
} }
} }
@ -202,11 +301,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var dataSource = host.Services.GetRequiredService<EndpointDataSource>(); var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /) // We register 2 endpoints (/negotiate and /)
Assert.Equal(2, dataSource.Endpoints.Count); Assert.Collection(dataSource.Endpoints,
Assert.Equal(typeof(AuthHub), dataSource.Endpoints[0].Metadata.GetMetadata<HubMetadata>()?.HubType); endpoint =>
Assert.Equal(typeof(AuthHub), dataSource.Endpoints[1].Metadata.GetMetadata<HubMetadata>()?.HubType); {
Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata<NegotiateMetadata>()); Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata<NegotiateMetadata>()); Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
}
}
[Fact]
public void MapHubAppliesHubMetadata()
{
#pragma warning disable CS0618 // Type or member is obsolete
void ConfigureRoutes(HubRouteBuilder routes)
#pragma warning restore CS0618 // Type or member is obsolete
{
// This "Foo" policy should override the default auth attribute
routes.MapHub<AuthHub>("/path");
}
using (var host = BuildWebHost(ConfigureRoutes))
{
host.Start();
var dataSource = host.Services.GetRequiredService<EndpointDataSource>();
// We register 2 endpoints (/negotiate and /)
Assert.Collection(dataSource.Endpoints,
endpoint =>
{
Assert.Equal("/path/negotiate", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.NotNull(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
},
endpoint =>
{
Assert.Equal("/path", endpoint.DisplayName);
Assert.Equal(typeof(AuthHub), endpoint.Metadata.GetMetadata<HubMetadata>()?.HubType);
Assert.Null(endpoint.Metadata.GetMetadata<NegotiateMetadata>());
});
} }
} }
@ -252,6 +392,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
.Build(); .Build();
} }
#pragma warning disable CS0618 // Type or member is obsolete
private IWebHost BuildWebHost(Action<HubRouteBuilder> configure) private IWebHost BuildWebHost(Action<HubRouteBuilder> configure)
{ {
return new WebHostBuilder() return new WebHostBuilder()
@ -267,5 +408,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
.UseUrls("http://127.0.0.1:0") .UseUrls("http://127.0.0.1:0")
.Build(); .Build();
} }
#pragma warning restore CS0618 // Type or member is obsolete
} }
} }

View File

@ -13,8 +13,8 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<Reference Include="Microsoft.AspNetCore.Authentication.Cookies" />
<Reference Include="Microsoft.AspNetCore.Authentication" /> <Reference Include="Microsoft.AspNetCore.Authentication" />
<Reference Include="Microsoft.AspNetCore.Authentication.JwtBearer" />
<Reference Include="Microsoft.AspNetCore.Http.Abstractions" /> <Reference Include="Microsoft.AspNetCore.Http.Abstractions" />
<Reference Include="Microsoft.AspNetCore.SignalR.Client" /> <Reference Include="Microsoft.AspNetCore.SignalR.Client" />
<Reference Include="Microsoft.AspNetCore.SignalR.Specification.Tests" /> <Reference Include="Microsoft.AspNetCore.SignalR.Specification.Tests" />

View File

@ -1,14 +1,24 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Authentication.Cookies; using System;
using System.IdentityModel.Tokens.Jwt;
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication.JwtBearer;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.IdentityModel.Tokens;
namespace Microsoft.AspNetCore.SignalR.Tests namespace Microsoft.AspNetCore.SignalR.Tests
{ {
public class Startup public class Startup
{ {
private readonly SymmetricSecurityKey SecurityKey = new SymmetricSecurityKey(Guid.NewGuid().ToByteArray());
private readonly JwtSecurityTokenHandler JwtTokenHandler = new JwtSecurityTokenHandler();
public void ConfigureServices(IServiceCollection services) public void ConfigureServices(IServiceCollection services)
{ {
services.AddConnections(); services.AddConnections();
@ -19,11 +29,40 @@ namespace Microsoft.AspNetCore.SignalR.Tests
services.AddAuthentication(options => services.AddAuthentication(options =>
{ {
options.DefaultAuthenticateScheme = CookieAuthenticationDefaults.AuthenticationScheme; options.DefaultAuthenticateScheme = JwtBearerDefaults.AuthenticationScheme;
options.DefaultChallengeScheme = CookieAuthenticationDefaults.AuthenticationScheme; options.DefaultChallengeScheme = JwtBearerDefaults.AuthenticationScheme;
}).AddCookie(); }).AddJwtBearer(options =>
{
options.TokenValidationParameters =
new TokenValidationParameters
{
LifetimeValidator = (before, expires, token, parameters) => expires > DateTime.UtcNow,
ValidateAudience = false,
ValidateIssuer = false,
ValidateActor = false,
ValidateLifetime = true,
IssuerSigningKey = SecurityKey
};
options.Events = new JwtBearerEvents
{
OnMessageReceived = context =>
{
var accessToken = context.Request.Query["access_token"];
if (!string.IsNullOrEmpty(accessToken) &&
(context.HttpContext.WebSockets.IsWebSocketRequest || context.Request.Headers["Accept"] == "text/event-stream"))
{
context.Token = context.Request.Query["access_token"];
}
return Task.CompletedTask;
}
};
});
services.AddAuthorization(); services.AddAuthorization();
services.AddSingleton<IAuthorizationHandler, TestAuthHandler>();
} }
public void Configure(IApplicationBuilder app) public void Configure(IApplicationBuilder app)
@ -32,15 +71,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests
app.UseAuthentication(); app.UseAuthentication();
app.UseAuthorization(); app.UseAuthorization();
// Legacy routing, runs different code path for mapping hubs
#pragma warning disable CS0618 // Type or member is obsolete
app.UseSignalR(routes =>
{
routes.MapHub<AuthHub>("/authHub");
});
#pragma warning restore CS0618 // Type or member is obsolete
app.UseEndpoints(endpoints => app.UseEndpoints(endpoints =>
{ {
endpoints.MapHub<UncreatableHub>("/uncreatable"); endpoints.MapHub<UncreatableHub>("/uncreatable");
endpoints.MapHub<AuthHub>("/authHubEndpoints");
endpoints.MapConnectionHandler<EchoConnectionHandler>("/echo"); endpoints.MapConnectionHandler<EchoConnectionHandler>("/echo");
endpoints.MapConnectionHandler<WriteThenCloseConnectionHandler>("/echoAndClose"); endpoints.MapConnectionHandler<WriteThenCloseConnectionHandler>("/echoAndClose");
endpoints.MapConnectionHandler<HttpHeaderConnectionHandler>("/httpheader"); endpoints.MapConnectionHandler<HttpHeaderConnectionHandler>("/httpheader");
endpoints.MapConnectionHandler<AuthConnectionHandler>("/auth"); endpoints.MapConnectionHandler<AuthConnectionHandler>("/auth");
endpoints.MapGet("/generatetoken", context =>
{
return context.Response.WriteAsync(GenerateToken(context));
});
}); });
} }
private string GenerateToken(HttpContext httpContext)
{
var claims = new[] { new Claim(ClaimTypes.NameIdentifier, httpContext.Request.Query["user"]) };
var credentials = new SigningCredentials(SecurityKey, SecurityAlgorithms.HmacSha256);
var token = new JwtSecurityToken("SignalRTestServer", "SignalRTests", claims, expires: DateTime.UtcNow.AddMinutes(1), signingCredentials: credentials);
return JwtTokenHandler.WriteToken(token);
}
} }
} }

View File

@ -0,0 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Security.Claims;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public class TestAuthHandler : IAuthorizationHandler
{
public Task HandleAsync(AuthorizationHandlerContext context)
{
foreach (var req in context.Requirements)
{
context.Succeed(req);
}
var hasClaim = context.User.HasClaim(o => o.Type == ClaimTypes.NameIdentifier && !string.IsNullOrEmpty(o.Value));
if (!hasClaim)
{
context.Fail();
}
return Task.CompletedTask;
}
}
}

View File

@ -1,9 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Tests namespace Microsoft.AspNetCore.SignalR.Tests
{ {
public class UncreatableHub: Hub public class UncreatableHub : Hub
{ {
public UncreatableHub(object obj) public UncreatableHub(object obj)
{ {