Copy HttpContext properties for long polling transport (#1684)
- The long polling transport simulates a persistent connection over multiple http requests. In order to expose common http request properties, we need to copy them to a fake http context on the first poll and set that as the HttpContext exposed via the IHttpContextFeature.
This commit is contained in:
parent
b5c46f35b3
commit
f1a3775247
|
|
@ -11,7 +11,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
public static HttpContext GetHttpContext(this ConnectionContext connection)
|
||||
{
|
||||
return connection.Features.Get<IHttpContextFeature>().HttpContext;
|
||||
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
|
||||
}
|
||||
|
||||
public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Diagnostics;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
|
|
@ -10,6 +11,7 @@ using System.Threading;
|
|||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Http.Features;
|
||||
using Microsoft.AspNetCore.Http.Internal;
|
||||
using Microsoft.AspNetCore.Protocols;
|
||||
using Microsoft.AspNetCore.Protocols.Features;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
|
|
@ -276,8 +278,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;
|
||||
|
||||
connection.SetHttpContext(null);
|
||||
|
||||
// Dispose the cancellation token
|
||||
connection.Cancellation.Dispose();
|
||||
|
||||
|
|
@ -500,15 +500,35 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return false;
|
||||
}
|
||||
|
||||
// Setup the connection state from the http context
|
||||
connection.User = context.User;
|
||||
|
||||
// Configure transport-specific features.
|
||||
if (transportType == TransportType.LongPolling)
|
||||
{
|
||||
connection.Features.Set<IConnectionInherentKeepAliveFeature>(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout));
|
||||
}
|
||||
|
||||
// Setup the connection state from the http context
|
||||
connection.User = context.User;
|
||||
connection.SetHttpContext(context);
|
||||
// For long polling, the requests come and go but the connection is still alive.
|
||||
// To make the IHttpContextFeature work well, we make a copy of the relevant properties
|
||||
// to a new HttpContext. This means that it's impossible to affect the context
|
||||
// with subsequent requests.
|
||||
var existing = connection.GetHttpContext();
|
||||
if (existing == null)
|
||||
{
|
||||
var httpContext = CloneHttpContext(context);
|
||||
connection.SetHttpContext(httpContext);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Set the request trace identifier to the current http request handling the poll
|
||||
existing.TraceIdentifier = context.TraceIdentifier;
|
||||
existing.User = context.User;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
connection.SetHttpContext(context);
|
||||
}
|
||||
|
||||
// Set the Connection ID on the logging scope so that logs from now on will have the
|
||||
// Connection ID metadata set.
|
||||
|
|
@ -517,6 +537,65 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
return true;
|
||||
}
|
||||
|
||||
private static HttpContext CloneHttpContext(HttpContext context)
|
||||
{
|
||||
// The reason we're copying the base features instead of the HttpContext properties is
|
||||
// so that we can get all of the logic built into DefaultHttpContext to extract higher level
|
||||
// structure from the low level properties
|
||||
var existingRequestFeature = context.Features.Get<IHttpRequestFeature>();
|
||||
|
||||
var requestFeature = new HttpRequestFeature();
|
||||
requestFeature.Protocol = existingRequestFeature.Protocol;
|
||||
requestFeature.Method = existingRequestFeature.Method;
|
||||
requestFeature.Scheme = existingRequestFeature.Scheme;
|
||||
requestFeature.Path = existingRequestFeature.Path;
|
||||
requestFeature.PathBase = existingRequestFeature.PathBase;
|
||||
requestFeature.QueryString = existingRequestFeature.QueryString;
|
||||
requestFeature.RawTarget = existingRequestFeature.RawTarget;
|
||||
var requestHeaders = new Dictionary<string, StringValues>(existingRequestFeature.Headers.Count);
|
||||
foreach (var header in existingRequestFeature.Headers)
|
||||
{
|
||||
requestHeaders[header.Key] = header.Value;
|
||||
}
|
||||
requestFeature.Headers = new HeaderDictionary(requestHeaders);
|
||||
|
||||
var existingConnectionFeature = context.Features.Get<IHttpConnectionFeature>();
|
||||
var connectionFeature = new HttpConnectionFeature();
|
||||
|
||||
if (existingConnectionFeature != null)
|
||||
{
|
||||
connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId;
|
||||
connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress;
|
||||
connectionFeature.LocalPort = existingConnectionFeature.LocalPort;
|
||||
connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress;
|
||||
connectionFeature.RemotePort = existingConnectionFeature.RemotePort;
|
||||
}
|
||||
|
||||
// The response is a dud, you can't do anything with it anyways
|
||||
var responseFeature = new HttpResponseFeature();
|
||||
|
||||
var features = new FeatureCollection();
|
||||
features.Set<IHttpRequestFeature>(requestFeature);
|
||||
features.Set<IHttpResponseFeature>(responseFeature);
|
||||
features.Set<IHttpConnectionFeature>(connectionFeature);
|
||||
|
||||
// REVIEW: We could strategically look at adding other features but it might be better
|
||||
// if we expose a callback that would allow the user to preserve HttpContext properties.
|
||||
|
||||
var newHttpContext = new DefaultHttpContext(features);
|
||||
newHttpContext.TraceIdentifier = context.TraceIdentifier;
|
||||
newHttpContext.User = context.User;
|
||||
|
||||
// Making request services function property could be tricky and expensive as it would require
|
||||
// DI scope per connection. It would also mean that services resolved in middleware leading up to here
|
||||
// wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider
|
||||
newHttpContext.RequestServices = EmptyServiceProvider.Instance;
|
||||
|
||||
// REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
|
||||
newHttpContext.Items = new Dictionary<object, object>(context.Items);
|
||||
return newHttpContext;
|
||||
}
|
||||
|
||||
private async Task<DefaultConnectionContext> GetConnectionAsync(HttpContext context, HttpSocketOptions options)
|
||||
{
|
||||
var connectionId = GetConnectionId(context);
|
||||
|
|
@ -580,5 +659,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
return connection;
|
||||
}
|
||||
|
||||
private class EmptyServiceProvider : IServiceProvider
|
||||
{
|
||||
public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider();
|
||||
public object GetService(Type serviceType) => null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@
|
|||
<ItemGroup>
|
||||
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(MicrosoftAspNetCoreAuthorizationPolicyPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(MicrosoftAspNetCoreHostingAbstractionsPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Http" Version="$(MicrosoftAspNetCoreHttpPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(MicrosoftAspNetCoreRoutingPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="$(MicrosoftAspNetCoreWebSocketsPackageVersion)" />
|
||||
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsSecurityHelperSourcesPackageVersion)" />
|
||||
|
|
|
|||
|
|
@ -692,7 +692,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
}
|
||||
}
|
||||
|
||||
[Theory(Skip = "HttpContext + Long Polling fails. Issue logged - https://github.com/aspnet/SignalR/issues/1644")]
|
||||
[Theory]
|
||||
[MemberData(nameof(TransportTypes))]
|
||||
public async Task ClientCanSendHeaders(TransportType transportType)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -540,7 +540,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
client.Dispose();
|
||||
|
||||
// Ensure the client channel is empty
|
||||
Assert.Null(client.TryRead());
|
||||
var message = client.TryRead();
|
||||
switch (message)
|
||||
{
|
||||
case CloseMessage close:
|
||||
break;
|
||||
default:
|
||||
Assert.Null(message);
|
||||
break;
|
||||
}
|
||||
|
||||
await endPointTask.OrTimeout();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,8 @@ using System.Buffers;
|
|||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.IO.Pipelines;
|
||||
using System.Linq;
|
||||
using System.Net;
|
||||
using System.Net.WebSockets;
|
||||
using System.Security.Claims;
|
||||
using System.Text;
|
||||
|
|
@ -337,6 +339,96 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
|
||||
{
|
||||
using (StartLog(out var loggerFactory, LogLevel.Debug))
|
||||
{
|
||||
var manager = CreateConnectionManager(loggerFactory);
|
||||
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
|
||||
var connection = manager.CreateConnection();
|
||||
|
||||
using (var requestBody = new MemoryStream())
|
||||
using (var responseBody = new MemoryStream())
|
||||
{
|
||||
var context = new DefaultHttpContext();
|
||||
context.Request.Body = requestBody;
|
||||
context.Response.Body = responseBody;
|
||||
|
||||
var services = new ServiceCollection();
|
||||
services.AddSingleton<HttpContextEndPoint>();
|
||||
services.AddOptions();
|
||||
|
||||
// Setup state on the HttpContext
|
||||
context.Request.Path = "/foo";
|
||||
context.Request.Method = "GET";
|
||||
var values = new Dictionary<string, StringValues>();
|
||||
values["id"] = connection.ConnectionId;
|
||||
values["another"] = "value";
|
||||
var qs = new QueryCollection(values);
|
||||
context.Request.Query = qs;
|
||||
context.Request.Headers["header1"] = "h1";
|
||||
context.Request.Headers["header2"] = "h2";
|
||||
context.Request.Headers["header3"] = "h3";
|
||||
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
|
||||
context.TraceIdentifier = "requestid";
|
||||
context.Connection.Id = "connectionid";
|
||||
context.Connection.LocalIpAddress = IPAddress.Loopback;
|
||||
context.Connection.LocalPort = 4563;
|
||||
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
|
||||
context.Connection.RemotePort = 43456;
|
||||
|
||||
var builder = new ConnectionBuilder(services.BuildServiceProvider());
|
||||
builder.UseEndPoint<HttpContextEndPoint>();
|
||||
var app = builder.Build();
|
||||
|
||||
// Start a poll
|
||||
var task = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
|
||||
|
||||
// Send to the application
|
||||
var buffer = Encoding.UTF8.GetBytes("Hello World");
|
||||
await connection.Application.Output.WriteAsync(buffer);
|
||||
|
||||
// The poll request should end
|
||||
await task;
|
||||
|
||||
// Make sure the actual response isn't affected
|
||||
Assert.Equal("application/octet-stream", context.Response.ContentType);
|
||||
|
||||
// Now do a new send again without the poll (that request should have ended)
|
||||
await connection.Application.Output.WriteAsync(buffer);
|
||||
|
||||
connection.Application.Output.Complete();
|
||||
|
||||
// Wait for the endpoint to end
|
||||
await connection.ApplicationTask;
|
||||
|
||||
var connectionHttpContext = connection.GetHttpContext();
|
||||
Assert.NotNull(connectionHttpContext);
|
||||
|
||||
Assert.Equal(2, connectionHttpContext.Request.Query.Count);
|
||||
Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
|
||||
Assert.Equal("value", connectionHttpContext.Request.Query["another"]);
|
||||
|
||||
Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
|
||||
Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
|
||||
Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
|
||||
Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
|
||||
Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
|
||||
Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
|
||||
Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
|
||||
Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
|
||||
Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
|
||||
Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
|
||||
Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
|
||||
Assert.NotNull(connectionHttpContext.RequestServices);
|
||||
Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
|
||||
Assert.NotNull(connectionHttpContext.Response.Headers);
|
||||
Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(TransportType.ServerSentEvents)]
|
||||
[InlineData(TransportType.LongPolling)]
|
||||
|
|
@ -713,7 +805,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
await task;
|
||||
|
||||
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
|
||||
Assert.Null(connection.GetHttpContext());
|
||||
Assert.NotNull(connection.GetHttpContext());
|
||||
|
||||
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
|
||||
}
|
||||
|
|
@ -1418,6 +1510,39 @@ namespace Microsoft.AspNetCore.Sockets.Tests
|
|||
}
|
||||
}
|
||||
|
||||
public class HttpContextEndPoint : EndPoint
|
||||
{
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var result = await connection.Transport.Input.ReadAsync();
|
||||
|
||||
try
|
||||
{
|
||||
if (result.IsCompleted)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Make sure we have an http context
|
||||
var context = connection.GetHttpContext();
|
||||
Assert.NotNull(context);
|
||||
|
||||
// Setting the response headers should have no effect
|
||||
context.Response.ContentType = "application/xml";
|
||||
|
||||
// Echo the results
|
||||
await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
connection.Transport.Input.AdvanceTo(result.Buffer.End);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public class TestEndPoint : EndPoint
|
||||
{
|
||||
public override async Task OnConnectedAsync(ConnectionContext connection)
|
||||
|
|
|
|||
Loading…
Reference in New Issue