diff --git a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
index 64ac863d88..eb316c5cea 100644
--- a/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
+++ b/src/Microsoft.AspNetCore.SignalR/Microsoft.AspNetCore.SignalR.csproj
@@ -10,7 +10,7 @@
aspnetcore;signalr
false
-
+
diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs
index 568ab565fd..11574c5306 100644
--- a/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs
+++ b/src/Microsoft.AspNetCore.SignalR/SignalRSocketBuilderExtensions.cs
@@ -1,6 +1,6 @@
-using System;
-using System.Collections.Generic;
-using System.Text;
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
@@ -11,10 +11,7 @@ namespace Microsoft.AspNetCore.SignalR
public static ISocketBuilder UseHub(this ISocketBuilder socketBuilder) where THub : Hub
{
var endpoint = socketBuilder.ApplicationServices.GetRequiredService>();
- return socketBuilder.Run(connection =>
- {
- return endpoint.OnConnectedAsync(connection);
- });
+ return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection));
}
}
}
diff --git a/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs b/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs
index f96b7a983b..d1176d9adb 100644
--- a/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs
+++ b/src/Microsoft.AspNetCore.Sockets.Http/SocketRouteBuilder.cs
@@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
+using System.Reflection;
+using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Routing;
namespace Microsoft.AspNetCore.Sockets
@@ -30,7 +32,20 @@ namespace Microsoft.AspNetCore.Sockets
public void MapEndPoint(string path) where TEndPoint : EndPoint
{
- MapSocket(path, builder =>
+ MapEndPoint(path, socketOptions: null);
+ }
+
+ public void MapEndPoint(string path, Action socketOptions) where TEndPoint : EndPoint
+ {
+ var authorizeAttributes = typeof(TEndPoint).GetCustomAttributes(inherit: true);
+ var options = new HttpSocketOptions();
+ foreach (var attribute in authorizeAttributes)
+ {
+ options.AuthorizationData.Add(attribute);
+ }
+ socketOptions?.Invoke(options);
+
+ MapSocket(path, options, builder =>
{
builder.UseEndPoint();
});
diff --git a/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs b/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs
index 29be9b8a8b..f0d6d3387b 100644
--- a/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs
+++ b/src/Microsoft.AspNetCore.Sockets/SocketBuilderExtensions.cs
@@ -1,9 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-using System;
-using System.Collections.Generic;
-using System.Text;
using Microsoft.Extensions.DependencyInjection;
namespace Microsoft.AspNetCore.Sockets
@@ -12,12 +9,9 @@ namespace Microsoft.AspNetCore.Sockets
{
public static ISocketBuilder UseEndPoint(this ISocketBuilder socketBuilder) where TEndPoint : EndPoint
{
+ var endpoint = socketBuilder.ApplicationServices.GetRequiredService();
// This is a terminal middleware, so there's no need to use the 'next' parameter
- return socketBuilder.Run(connection =>
- {
- var endpoint = socketBuilder.ApplicationServices.GetRequiredService();
- return endpoint.OnConnectedAsync(connection);
- });
+ return socketBuilder.Run(connection => endpoint.OnConnectedAsync(connection));
}
}
}
diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs
new file mode 100644
index 0000000000..1ab74f0d23
--- /dev/null
+++ b/test/Microsoft.AspNetCore.Sockets.Tests/MapEndPointTests.cs
@@ -0,0 +1,116 @@
+// Copyright (c) .NET Foundation. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Threading.Tasks;
+using Microsoft.AspNetCore.Authorization;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.AspNetCore.Hosting;
+using Microsoft.Extensions.DependencyInjection;
+using Xunit;
+
+namespace Microsoft.AspNetCore.Sockets.Tests
+{
+ public class MapEndPointTests
+ {
+ [Fact]
+ public void MapEndPointFindsAuthAttributeOnEndPoint()
+ {
+ var authCount = 0;
+ var builder = new WebHostBuilder()
+ .UseKestrel()
+ .ConfigureServices(services =>
+ {
+ services.AddSockets();
+ services.AddEndPoint();
+ })
+ .Configure(app =>
+ {
+ app.UseSockets(routes =>
+ {
+ routes.MapEndPoint("auth", httpSocketOptions =>
+ {
+ authCount += httpSocketOptions.AuthorizationData.Count;
+ });
+ });
+ })
+ .Build();
+
+ Assert.Equal(1, authCount);
+ }
+
+ [Fact]
+ public void MapEndPointFindsAuthAttributeOnInheritedEndPoint()
+ {
+ var authCount = 0;
+ var builder = new WebHostBuilder()
+ .UseKestrel()
+ .ConfigureServices(services =>
+ {
+ services.AddSockets();
+ services.AddEndPoint();
+ })
+ .Configure(app =>
+ {
+ app.UseSockets(routes =>
+ {
+ routes.MapEndPoint("auth", httpSocketOptions =>
+ {
+ authCount += httpSocketOptions.AuthorizationData.Count;
+ });
+ });
+ })
+ .Build();
+
+ Assert.Equal(1, authCount);
+ }
+
+ [Fact]
+ public void MapEndPointFindsAuthAttributesOnDoubleAuthEndPoint()
+ {
+ var authCount = 0;
+ var builder = new WebHostBuilder()
+ .UseKestrel()
+ .ConfigureServices(services =>
+ {
+ services.AddSockets();
+ services.AddEndPoint();
+ })
+ .Configure(app =>
+ {
+ app.UseSockets(routes =>
+ {
+ routes.MapEndPoint("auth", httpSocketOptions =>
+ {
+ authCount += httpSocketOptions.AuthorizationData.Count;
+ });
+ });
+ })
+ .Build();
+
+ Assert.Equal(2, authCount);
+ }
+
+ private class InheritedAuthEndPoint : AuthEndPoint
+ {
+ public override Task OnConnectedAsync(ConnectionContext connection)
+ {
+ throw new NotImplementedException();
+ }
+ }
+
+ [Authorize]
+ private class DoubleAuthEndPoint : AuthEndPoint
+ {
+ }
+
+ [Authorize]
+ private class AuthEndPoint : EndPoint
+ {
+ public override Task OnConnectedAsync(ConnectionContext connection)
+ {
+ throw new NotImplementedException();
+ }
+ }
+ }
+}