Add extension method to get HttpContext on HubConnectionContext (#849)

This commit is contained in:
BrennanConroy 2017-09-11 16:55:32 -07:00 committed by GitHub
parent 62bbe943e8
commit 393ab6a4f0
2 changed files with 69 additions and 0 deletions

View File

@ -0,0 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Sockets.Http.Features;
namespace Microsoft.AspNetCore.SignalR
{
public static class HttpConnectionContextExtensions
{
public static HttpContext GetHttpContext(this HubConnectionContext connection)
{
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}
}
}

View File

@ -8,8 +8,10 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Moq;
using Newtonsoft.Json;
@ -852,6 +854,52 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task CanGetHttpContextFromHubConnectionContext()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var httpContext = new DefaultHttpContext();
client.Connection.SetHttpContext(httpContext);
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
await client.Connected.OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.HasHttpContext)).OrTimeout()).Result;
Assert.True((bool)result);
client.Dispose();
await endPointLifetime.OrTimeout();
}
}
[Fact]
public async Task GetHttpContextFromHubConnectionContextHandlesNull()
{
var serviceProvider = CreateServiceProvider();
var endPoint = serviceProvider.GetService<HubEndPoint<MethodHub>>();
using (var client = new TestClient())
{
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
await client.Connected.OrTimeout();
var result = (await client.InvokeAsync(nameof(MethodHub.HasHttpContext)).OrTimeout()).Result;
Assert.False((bool)result);
client.Dispose();
await endPointLifetime.OrTimeout();
}
}
private static void AssertHubMessage(HubMessage expected, HubMessage actual)
{
// We aren't testing InvocationIds here
@ -1166,6 +1214,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
return Clients.AllExcept(excludedIds).InvokeAsync("Send", message);
}
public bool HasHttpContext()
{
return Context.Connection.GetHttpContext() != null;
}
}
private class InheritedHub : BaseHub