fix #2097 by adding GetHttpContext to HubConnectionContext (#2099)

This commit is contained in:
Andrew Stanton-Nurse 2018-04-19 19:14:39 -07:00 committed by GitHub
parent 1957655653
commit ace9a0d414
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 151 additions and 22 deletions

View File

@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections.Features;
namespace Microsoft.AspNetCore.SignalR
{
public static class GetHttpContextExtensions
{
public static HttpContext GetHttpContext(this HubCallerContext connection)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}
public static HttpContext GetHttpContext(this HubConnectionContext connection)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}
}
}

View File

@ -1,16 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections.Features;
namespace Microsoft.AspNetCore.SignalR
{
public static class HubCallerContextExtensions
{
public static HttpContext GetHttpContext(this HubCallerContext connection)
{
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}
}
}

View File

@ -0,0 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
internal class HeaderUserIdProvider : IUserIdProvider
{
public static readonly string HeaderName = "Super-Insecure-UserName";
public string GetUserId(HubConnectionContext connection)
{
// Super-insecure user id provider :). Don't use this for anything real!
return connection.GetHttpContext()?.Request?.Headers?[HeaderName];
}
}
}

View File

@ -864,15 +864,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task CheckHttpConnectionFeatures(HttpTransportType transportType)
[Fact]
public async Task CheckHttpConnectionFeatures()
{
using (StartVerifableLog(out var loggerFactory, $"{nameof(CheckHttpConnectionFeatures)}_{transportType}"))
using (StartVerifableLog(out var loggerFactory))
{
var hubConnection = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
.WithUrl(ServerFixture.Url + "/default", transportType)
.WithUrl(ServerFixture.Url + "/default")
.Build();
try
{
@ -901,6 +900,37 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Fact]
public async Task UserIdProviderCanAccessHttpContext()
{
using (StartVerifableLog(out var loggerFactory))
{
var hubConnection = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
.WithUrl(ServerFixture.Url + "/default", options =>
{
options.Headers.Add(HeaderUserIdProvider.HeaderName, "SuperAdmin");
})
.Build();
try
{
await hubConnection.StartAsync().OrTimeout();
var userIdentifier = await hubConnection.InvokeAsync<string>(nameof(TestHub.GetUserIdentifier)).OrTimeout();
Assert.Equal("SuperAdmin", userIdentifier);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
}
}
}
[Fact]
public async Task NegotiationSkipsServerSentEventsWhenUsingBinaryProtocol()
{

View File

@ -37,6 +37,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await Clients.Client(Context.ConnectionId).SendAsync("NoClientHandler");
}
public string GetUserIdentifier()
{
return Context.UserIdentifier;
}
public IEnumerable<string> GetHeaderValues(string[] headerNames)
{
var context = Context.GetHttpContext();

View File

@ -25,6 +25,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
options.EnableDetailedErrors = true;
})
.AddMessagePackProtocol();
services.AddSingleton<IUserIdProvider, HeaderUserIdProvider>();
services.AddAuthorization(options =>
{
options.AddPolicy(JwtBearerDefaults.AuthenticationScheme, policy =>

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;

View File

@ -1706,6 +1706,57 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task ConnectionUserIdIsAssignedByUserIdProvider()
{
var firstRequest = true;
var userIdProvider = new TestUserIdProvider(c =>
{
if (firstRequest)
{
firstRequest = false;
return "client1";
}
else
{
return "client2";
}
});
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSingleton<IUserIdProvider>(userIdProvider);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client1 = new TestClient())
using (var client2 = new TestClient())
{
var connectionHandlerTask1 = await client1.ConnectAsync(connectionHandler);
var connectionHandlerTask2 = await client2.ConnectAsync(connectionHandler);
await client1.Connected.OrTimeout();
await client2.Connected.OrTimeout();
await client2.SendInvocationAsync(nameof(MethodHub.SendToMultipleUsers), new[] { "client1" }, "Hi!").OrTimeout();
var message = (InvocationMessage)await client1.ReadAsync().OrTimeout();
Assert.Equal("Send", message.Target);
Assert.Collection(message.Arguments, arg => Assert.Equal("Hi!", arg));
client1.Dispose();
client2.Dispose();
await connectionHandlerTask1.OrTimeout();
await connectionHandlerTask2.OrTimeout();
// Read the completion, then we should have nothing left in client2's queue
Assert.IsType<CompletionMessage>(client2.TryRead());
Assert.IsType<CloseMessage>(client2.TryRead());
Assert.Null(client2.TryRead());
}
}
private class CustomFormatter : IFormatterResolver
{
public IMessagePackFormatter<T> GetFormatter<T>()
@ -2141,5 +2192,17 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
public HttpContext HttpContext { get; set; }
}
private class TestUserIdProvider : IUserIdProvider
{
private readonly Func<HubConnectionContext, string> _getUserId;
public TestUserIdProvider(Func<HubConnectionContext, string> getUserId)
{
_getUserId = getUserId;
}
public string GetUserId(HubConnectionContext connection) => _getUserId(connection);
}
}
}