From dc29e98032d4a73e7532f5dc4dfe668c113b7fcd Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Wed, 28 Jun 2017 14:52:52 -0700 Subject: [PATCH] Support Authorize attribute on EndPoints (#606) --- .../Microsoft.AspNetCore.SignalR.csproj | 2 +- .../SignalRSocketBuilderExtensions.cs | 11 +- .../SocketRouteBuilder.cs | 17 ++- .../SocketBuilderExtensions.cs | 10 +- .../MapEndPointTests.cs | 116 ++++++++++++++++++ 5 files changed, 139 insertions(+), 17 deletions(-) create mode 100644 test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs diff --git a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj index 64ac863d88..eb316c5cea 100644 --- a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj +++ b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj @@ -10,7 +10,7 @@ aspnetcore;signalr false - + diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs index 568ab565fd..11574c5306 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs @@ -1,6 +1,6 @@ -using System; -using System.Collections.Generic; -using System.Text; +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; @@ -11,10 +11,7 @@ namespace Microsoft.AspNetCore.SignalR public static ISocketBuilder UseHub(this ISocketBuilder socketBuilder) where THub : Hub { var endpoint = socketBuilder.ApplicationServices.GetRequiredService>(); - return socketBuilder.Run(connection => - { - return endpoint.OnConnectedAsync(connection); - }); + return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection)); } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs index f96b7a983b..d1176d9adb 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Reflection; +using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Routing; namespace Microsoft.AspNetCore.Sockets @@ -30,7 +32,20 @@ namespace Microsoft.AspNetCore.Sockets public void MapEndPoint(string path) where TEndPoint : EndPoint { - MapSocket(path, builder => + MapEndPoint(path, socketOptions: null); + } + + public void MapEndPoint(string path, Action socketOptions) where TEndPoint : EndPoint + { + var authorizeAttributes = typeof(TEndPoint).GetCustomAttributes(inherit: true); + var options = new HttpSocketOptions(); + foreach (var attribute in authorizeAttributes) + { + options.AuthorizationData.Add(attribute); + } + socketOptions?.Invoke(options); + + MapSocket(path, options, builder => { builder.UseEndPoint(); }); diff --git a/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs index 29be9b8a8b..f0d6d3387b 100644 --- a/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.Text; using Microsoft.Extensions.DependencyInjection; namespace Microsoft.AspNetCore.Sockets @@ -12,12 +9,9 @@ namespace Microsoft.AspNetCore.Sockets { public static ISocketBuilder UseEndPoint(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint { + var endpoint = socketBuilder.ApplicationServices.GetRequiredService(); // This is a terminal middleware, so there's no need to use the 'next' parameter - return socketBuilder.Run(connection => - { - var endpoint = socketBuilder.ApplicationServices.GetRequiredService(); - return endpoint.OnConnectedAsync(connection); - }); + return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection)); } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs new file mode 100644 index 0000000000..1ab74f0d23 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs @@ -0,0 +1,116 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.Extensions.DependencyInjection; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Tests +{ + public class MapEndPointTests + { + [Fact] + public void MapEndPointFindsAuthAttributeOnEndPoint() + { + var authCount = 0; + var builder = new WebHostBuilder() + .UseKestrel() + .ConfigureServices(services => + { + services.AddSockets(); + services.AddEndPoint(); + }) + .Configure(app => + { + app.UseSockets(routes => + { + routes.MapEndPoint("auth", httpSocketOptions => + { + authCount += httpSocketOptions.AuthorizationData.Count; + }); + }); + }) + .Build(); + + Assert.Equal(1, authCount); + } + + [Fact] + public void MapEndPointFindsAuthAttributeOnInheritedEndPoint() + { + var authCount = 0; + var builder = new WebHostBuilder() + .UseKestrel() + .ConfigureServices(services => + { + services.AddSockets(); + services.AddEndPoint(); + }) + .Configure(app => + { + app.UseSockets(routes => + { + routes.MapEndPoint("auth", httpSocketOptions => + { + authCount += httpSocketOptions.AuthorizationData.Count; + }); + }); + }) + .Build(); + + Assert.Equal(1, authCount); + } + + [Fact] + public void MapEndPointFindsAuthAttributesOnDoubleAuthEndPoint() + { + var authCount = 0; + var builder = new WebHostBuilder() + .UseKestrel() + .ConfigureServices(services => + { + services.AddSockets(); + services.AddEndPoint(); + }) + .Configure(app => + { + app.UseSockets(routes => + { + routes.MapEndPoint("auth", httpSocketOptions => + { + authCount += httpSocketOptions.AuthorizationData.Count; + }); + }); + }) + .Build(); + + Assert.Equal(2, authCount); + } + + private class InheritedAuthEndPoint : AuthEndPoint + { + public override Task OnConnectedAsync(ConnectionContext connection) + { + throw new NotImplementedException(); + } + } + + [Authorize] + private class DoubleAuthEndPoint : AuthEndPoint + { + } + + [Authorize] + private class AuthEndPoint : EndPoint + { + public override Task OnConnectedAsync(ConnectionContext connection) + { + throw new NotImplementedException(); + } + } + } +}