diff --git a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs index 9b9db9128c..cbf3c89743 100644 --- a/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs +++ b/src/SignalR/common/Http.Connections/ref/Microsoft.AspNetCore.Http.Connections.netcoreapp3.0.cs @@ -55,6 +55,10 @@ namespace Microsoft.AspNetCore.Http.Connections public LongPollingOptions() { } public System.TimeSpan PollTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } + public partial class NegotiateMetadata + { + public NegotiateMetadata() { } + } public partial class WebSocketOptions { public WebSocketOptions() { } diff --git a/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs b/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs index 1fd4f7ca23..bc111a2ed4 100644 --- a/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs +++ b/src/SignalR/common/Http.Connections/src/ConnectionEndpointRouteBuilderExtensions.cs @@ -105,6 +105,8 @@ namespace Microsoft.AspNetCore.Builder var negotiateBuilder = endpoints.Map(pattern + "/negotiate", negotiateHandler); conventionBuilders.Add(negotiateBuilder); + // Add the negotiate metadata so this endpoint can be identified + negotiateBuilder.WithMetadata(new NegotiateMetadata()); // build the execute handler part of the protocol app = endpoints.CreateApplicationBuilder(); diff --git a/src/SignalR/common/Http.Connections/src/NegotiateMetadata.cs b/src/SignalR/common/Http.Connections/src/NegotiateMetadata.cs new file mode 100644 index 0000000000..047b1779a3 --- /dev/null +++ b/src/SignalR/common/Http.Connections/src/NegotiateMetadata.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Text; + +namespace Microsoft.AspNetCore.Http.Connections +{ + /// + /// Metadata to identify the /negotiate endpoint for HTTP connections + /// + public class NegotiateMetadata + { + } +} diff --git a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs index d62c4f953d..4541f5e3f1 100644 --- a/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs +++ b/src/SignalR/common/Http.Connections/test/MapConnectionHandlerTests.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Hosting.Server.Features; +using Microsoft.AspNetCore.Routing; using Microsoft.AspNetCore.SignalR.Tests; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.DependencyInjection; @@ -68,6 +69,69 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(2, authCount); } + [Fact] + public void MapConnectionHandlerEndPointRoutingFindsAttributesOnHub() + { + var authCount = 0; + using (var host = BuildWebHostWithEndPointRouting(routes => routes.MapConnectionHandler("/path", options => + { + authCount += options.AuthorizationData.Count; + }))) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Equal(2, dataSource.Endpoints.Count); + Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); + Assert.NotNull(dataSource.Endpoints[1].Metadata.GetMetadata()); + } + + Assert.Equal(1, authCount); + } + + [Fact] + public void MapConnectionHandlerEndPointRoutingAppliesAttributesBeforeConventions() + { + void ConfigureRoutes(IEndpointRouteBuilder endpoints) + { + // This "Foo" policy should override the default auth attribute + endpoints.MapConnectionHandler("/path") + .RequireAuthorization(new AuthorizeAttribute("Foo")); + } + + using (var host = BuildWebHostWithEndPointRouting(ConfigureRoutes)) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Equal(2, dataSource.Endpoints.Count); + Assert.Equal("Foo", dataSource.Endpoints[0].Metadata.GetMetadata()?.Policy); + Assert.Equal("Foo", dataSource.Endpoints[1].Metadata.GetMetadata()?.Policy); + } + } + + [Fact] + public void MapConnectionHandlerEndPointRoutingAppliesNegotiateMetadata() + { + void ConfigureRoutes(IEndpointRouteBuilder endpoints) + { + endpoints.MapConnectionHandler("/path"); + } + + using (var host = BuildWebHostWithEndPointRouting(ConfigureRoutes)) + { + host.Start(); + + var dataSource = host.Services.GetRequiredService(); + // We register 2 endpoints (/negotiate and /) + Assert.Equal(2, dataSource.Endpoints.Count); + Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); + Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata()); + } + } + [ConditionalFact] [WebSocketsSupportedCondition] public async Task MapConnectionHandlerWithWebSocketSubProtocolSetsProtocol() @@ -135,6 +199,23 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + private IWebHost BuildWebHostWithEndPointRouting(Action configure) + { + return new WebHostBuilder() + .UseKestrel() + .ConfigureServices(services => + { + services.AddConnections(); + }) + .Configure(app => + { + app.UseRouting(); + app.UseEndpoints(endpoints => configure(endpoints)); + }) + .UseUrls("http://127.0.0.1:0") + .Build(); + } + private IWebHost BuildWebHost(string path, Action configureOptions) where TConnectionHandler : ConnectionHandler { return new WebHostBuilder() diff --git a/src/SignalR/server/SignalR/test/MapSignalRTests.cs b/src/SignalR/server/SignalR/test/MapSignalRTests.cs index 84e85957c5..40ea60382f 100644 --- a/src/SignalR/server/SignalR/test/MapSignalRTests.cs +++ b/src/SignalR/server/SignalR/test/MapSignalRTests.cs @@ -5,6 +5,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Builder; using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Routing; using Microsoft.Extensions.DependencyInjection; using Xunit; @@ -204,6 +205,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(2, dataSource.Endpoints.Count); Assert.Equal(typeof(AuthHub), dataSource.Endpoints[0].Metadata.GetMetadata()?.HubType); Assert.Equal(typeof(AuthHub), dataSource.Endpoints[1].Metadata.GetMetadata()?.HubType); + Assert.NotNull(dataSource.Endpoints[0].Metadata.GetMetadata()); + Assert.Null(dataSource.Endpoints[1].Metadata.GetMetadata()); } }