diff --git a/src/Microsoft.AspNetCore.SignalR/GetHttpContextExtensions.cs b/src/Microsoft.AspNetCore.SignalR/GetHttpContextExtensions.cs new file mode 100644 index 0000000000..f2a076b7f9 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/GetHttpContextExtensions.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Http.Connections.Features; + +namespace Microsoft.AspNetCore.SignalR +{ + public static class GetHttpContextExtensions + { + public static HttpContext GetHttpContext(this HubCallerContext connection) + { + if (connection == null) + { + throw new ArgumentNullException(nameof(connection)); + } + return connection.Features.Get()?.HttpContext; + } + + public static HttpContext GetHttpContext(this HubConnectionContext connection) + { + if (connection == null) + { + throw new ArgumentNullException(nameof(connection)); + } + return connection.Features.Get()?.HttpContext; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs deleted file mode 100644 index ab1754b972..0000000000 --- a/src/Microsoft.AspNetCore.SignalR/HubCallerContextExtensions.cs +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using Microsoft.AspNetCore.Http; -using Microsoft.AspNetCore.Http.Connections.Features; - -namespace Microsoft.AspNetCore.SignalR -{ - public static class HubCallerContextExtensions - { - public static HttpContext GetHttpContext(this HubCallerContext connection) - { - return connection.Features.Get()?.HttpContext; - } - } -} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HeaderUserIdProvider.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HeaderUserIdProvider.cs new file mode 100644 index 0000000000..aec978c6df --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HeaderUserIdProvider.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests +{ + internal class HeaderUserIdProvider : IUserIdProvider + { + public static readonly string HeaderName = "Super-Insecure-UserName"; + + public string GetUserId(HubConnectionContext connection) + { + // Super-insecure user id provider :). Don't use this for anything real! + return connection.GetHttpContext()?.Request?.Headers?[HeaderName]; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 85cff50b24..ff6b2cbf04 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -864,15 +864,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Theory] - [MemberData(nameof(TransportTypes))] - public async Task CheckHttpConnectionFeatures(HttpTransportType transportType) + [Fact] + public async Task CheckHttpConnectionFeatures() { - using (StartVerifableLog(out var loggerFactory, $"{nameof(CheckHttpConnectionFeatures)}_{transportType}")) + using (StartVerifableLog(out var loggerFactory)) { var hubConnection = new HubConnectionBuilder() .WithLoggerFactory(loggerFactory) - .WithUrl(ServerFixture.Url + "/default", transportType) + .WithUrl(ServerFixture.Url + "/default") .Build(); try { @@ -901,6 +900,37 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task UserIdProviderCanAccessHttpContext() + { + using (StartVerifableLog(out var loggerFactory)) + { + var hubConnection = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(ServerFixture.Url + "/default", options => + { + options.Headers.Add(HeaderUserIdProvider.HeaderName, "SuperAdmin"); + }) + .Build(); + try + { + await hubConnection.StartAsync().OrTimeout(); + + var userIdentifier = await hubConnection.InvokeAsync(nameof(TestHub.GetUserIdentifier)).OrTimeout(); + Assert.Equal("SuperAdmin", userIdentifier); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } + [Fact] public async Task NegotiationSkipsServerSentEventsWhenUsingBinaryProtocol() { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 03fc3ea45a..3e4fb69a3c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -37,6 +37,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).SendAsync("NoClientHandler"); } + public string GetUserIdentifier() + { + return Context.UserIdentifier; + } + public IEnumerable GetHeaderValues(string[] headerNames) { var context = Context.GetHttpContext(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs index 93865dec67..81a67d13a4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Startup.cs @@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests options.EnableDetailedErrors = true; }) .AddMessagePackProtocol(); + services.AddSingleton(); services.AddAuthorization(options => { options.AddPolicy(JwtBearerDefaults.AuthenticationScheme, policy => diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 514d069757..2ac5a70512 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 900df9778a..e3068ffd0e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -1706,6 +1706,57 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task ConnectionUserIdIsAssignedByUserIdProvider() + { + var firstRequest = true; + var userIdProvider = new TestUserIdProvider(c => + { + if (firstRequest) + { + firstRequest = false; + return "client1"; + } + else + { + return "client2"; + } + }); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSingleton(userIdProvider); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client1 = new TestClient()) + using (var client2 = new TestClient()) + { + var connectionHandlerTask1 = await client1.ConnectAsync(connectionHandler); + var connectionHandlerTask2 = await client2.ConnectAsync(connectionHandler); + + await client1.Connected.OrTimeout(); + await client2.Connected.OrTimeout(); + + await client2.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), new[] { "client1" }, "Hi!").OrTimeout(); + + var message = (InvocationMessage)await client1.ReadAsync().OrTimeout(); + + Assert.Equal("Send", message.Target); + Assert.Collection(message.Arguments, arg => Assert.Equal("Hi!", arg)); + + client1.Dispose(); + client2.Dispose(); + + await connectionHandlerTask1.OrTimeout(); + await connectionHandlerTask2.OrTimeout(); + + // Read the completion, then we should have nothing left in client2's queue + Assert.IsType(client2.TryRead()); + Assert.IsType(client2.TryRead()); + Assert.Null(client2.TryRead()); + } + } + private class CustomFormatter : IFormatterResolver { public IMessagePackFormatter GetFormatter() @@ -2141,5 +2192,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests { public HttpContext HttpContext { get; set; } } + + private class TestUserIdProvider : IUserIdProvider + { + private readonly Func _getUserId; + + public TestUserIdProvider(Func getUserId) + { + _getUserId = getUserId; + } + + public string GetUserId(HubConnectionContext connection) => _getUserId(connection); + } } }