Copy HttpContext properties for long polling transport (#1684)

- The long polling transport simulates a persistent connection
over multiple http requests. In order to expose common http request
properties, we need to copy them to a fake http context on the first poll
and set that as the HttpContext exposed via the IHttpContextFeature.
This commit is contained in:
David Fowler 2018-03-22 15:24:35 -07:00 committed by GitHub
parent b5c46f35b3
commit f1a3775247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 229 additions and 10 deletions

View File

@ -11,7 +11,7 @@ namespace Microsoft.AspNetCore.Sockets
{
public static HttpContext GetHttpContext(this ConnectionContext connection)
{
return connection.Features.Get<IHttpContextFeature>().HttpContext;
return connection.Features.Get<IHttpContextFeature>()?.HttpContext;
}
public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext)

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.IO.Pipelines;
@ -10,6 +11,7 @@ using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.Protocols;
using Microsoft.AspNetCore.Protocols.Features;
using Microsoft.AspNetCore.Sockets.Internal;
@ -276,8 +278,6 @@ namespace Microsoft.AspNetCore.Sockets
connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive;
connection.SetHttpContext(null);
// Dispose the cancellation token
connection.Cancellation.Dispose();
@ -500,15 +500,35 @@ namespace Microsoft.AspNetCore.Sockets
return false;
}
// Setup the connection state from the http context
connection.User = context.User;
// Configure transport-specific features.
if (transportType == TransportType.LongPolling)
{
connection.Features.Set<IConnectionInherentKeepAliveFeature>(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout));
}
// Setup the connection state from the http context
connection.User = context.User;
connection.SetHttpContext(context);
// For long polling, the requests come and go but the connection is still alive.
// To make the IHttpContextFeature work well, we make a copy of the relevant properties
// to a new HttpContext. This means that it's impossible to affect the context
// with subsequent requests.
var existing = connection.GetHttpContext();
if (existing == null)
{
var httpContext = CloneHttpContext(context);
connection.SetHttpContext(httpContext);
}
else
{
// Set the request trace identifier to the current http request handling the poll
existing.TraceIdentifier = context.TraceIdentifier;
existing.User = context.User;
}
}
else
{
connection.SetHttpContext(context);
}
// Set the Connection ID on the logging scope so that logs from now on will have the
// Connection ID metadata set.
@ -517,6 +537,65 @@ namespace Microsoft.AspNetCore.Sockets
return true;
}
private static HttpContext CloneHttpContext(HttpContext context)
{
// The reason we're copying the base features instead of the HttpContext properties is
// so that we can get all of the logic built into DefaultHttpContext to extract higher level
// structure from the low level properties
var existingRequestFeature = context.Features.Get<IHttpRequestFeature>();
var requestFeature = new HttpRequestFeature();
requestFeature.Protocol = existingRequestFeature.Protocol;
requestFeature.Method = existingRequestFeature.Method;
requestFeature.Scheme = existingRequestFeature.Scheme;
requestFeature.Path = existingRequestFeature.Path;
requestFeature.PathBase = existingRequestFeature.PathBase;
requestFeature.QueryString = existingRequestFeature.QueryString;
requestFeature.RawTarget = existingRequestFeature.RawTarget;
var requestHeaders = new Dictionary<string, StringValues>(existingRequestFeature.Headers.Count);
foreach (var header in existingRequestFeature.Headers)
{
requestHeaders[header.Key] = header.Value;
}
requestFeature.Headers = new HeaderDictionary(requestHeaders);
var existingConnectionFeature = context.Features.Get<IHttpConnectionFeature>();
var connectionFeature = new HttpConnectionFeature();
if (existingConnectionFeature != null)
{
connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId;
connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress;
connectionFeature.LocalPort = existingConnectionFeature.LocalPort;
connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress;
connectionFeature.RemotePort = existingConnectionFeature.RemotePort;
}
// The response is a dud, you can't do anything with it anyways
var responseFeature = new HttpResponseFeature();
var features = new FeatureCollection();
features.Set<IHttpRequestFeature>(requestFeature);
features.Set<IHttpResponseFeature>(responseFeature);
features.Set<IHttpConnectionFeature>(connectionFeature);
// REVIEW: We could strategically look at adding other features but it might be better
// if we expose a callback that would allow the user to preserve HttpContext properties.
var newHttpContext = new DefaultHttpContext(features);
newHttpContext.TraceIdentifier = context.TraceIdentifier;
newHttpContext.User = context.User;
// Making request services function property could be tricky and expensive as it would require
// DI scope per connection. It would also mean that services resolved in middleware leading up to here
// wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider
newHttpContext.RequestServices = EmptyServiceProvider.Instance;
// REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
newHttpContext.Items = new Dictionary<object, object>(context.Items);
return newHttpContext;
}
private async Task<DefaultConnectionContext> GetConnectionAsync(HttpContext context, HttpSocketOptions options)
{
var connectionId = GetConnectionId(context);
@ -580,5 +659,11 @@ namespace Microsoft.AspNetCore.Sockets
return connection;
}
private class EmptyServiceProvider : IServiceProvider
{
public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider();
public object GetService(Type serviceType) => null;
}
}
}

View File

@ -19,6 +19,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Authorization.Policy" Version="$(MicrosoftAspNetCoreAuthorizationPolicyPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Hosting.Abstractions" Version="$(MicrosoftAspNetCoreHostingAbstractionsPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Http" Version="$(MicrosoftAspNetCoreHttpPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Routing" Version="$(MicrosoftAspNetCoreRoutingPackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.WebSockets" Version="$(MicrosoftAspNetCoreWebSocketsPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.SecurityHelper.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsSecurityHelperSourcesPackageVersion)" />

View File

@ -692,7 +692,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory(Skip = "HttpContext + Long Polling fails. Issue logged - https://github.com/aspnet/SignalR/issues/1644")]
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task ClientCanSendHeaders(TransportType transportType)
{

View File

@ -540,7 +540,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests
client.Dispose();
// Ensure the client channel is empty
Assert.Null(client.TryRead());
var message = client.TryRead();
switch (message)
{
case CloseMessage close:
break;
default:
Assert.Null(message);
break;
}
await endPointTask.OrTimeout();
}

View File

@ -6,6 +6,8 @@ using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Net;
using System.Net.WebSockets;
using System.Security.Claims;
using System.Text;
@ -337,6 +339,96 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
[Fact]
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
{
using (StartLog(out var loggerFactory, LogLevel.Debug))
{
var manager = CreateConnectionManager(loggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory);
var connection = manager.CreateConnection();
using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddSingleton<HttpContextEndPoint>();
services.AddOptions();
// Setup state on the HttpContext
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
values["another"] = "value";
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Request.Headers["header1"] = "h1";
context.Request.Headers["header2"] = "h2";
context.Request.Headers["header3"] = "h3";
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
context.TraceIdentifier = "requestid";
context.Connection.Id = "connectionid";
context.Connection.LocalIpAddress = IPAddress.Loopback;
context.Connection.LocalPort = 4563;
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
context.Connection.RemotePort = 43456;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseEndPoint<HttpContextEndPoint>();
var app = builder.Build();
// Start a poll
var task = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app);
// Send to the application
var buffer = Encoding.UTF8.GetBytes("Hello World");
await connection.Application.Output.WriteAsync(buffer);
// The poll request should end
await task;
// Make sure the actual response isn't affected
Assert.Equal("application/octet-stream", context.Response.ContentType);
// Now do a new send again without the poll (that request should have ended)
await connection.Application.Output.WriteAsync(buffer);
connection.Application.Output.Complete();
// Wait for the endpoint to end
await connection.ApplicationTask;
var connectionHttpContext = connection.GetHttpContext();
Assert.NotNull(connectionHttpContext);
Assert.Equal(2, connectionHttpContext.Request.Query.Count);
Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
Assert.Equal("value", connectionHttpContext.Request.Query["another"]);
Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
Assert.NotNull(connectionHttpContext.RequestServices);
Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
Assert.NotNull(connectionHttpContext.Response.Headers);
Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
}
}
}
[Theory]
[InlineData(TransportType.ServerSentEvents)]
[InlineData(TransportType.LongPolling)]
@ -713,7 +805,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
await task;
Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status);
Assert.Null(connection.GetHttpContext());
Assert.NotNull(connection.GetHttpContext());
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
@ -1418,6 +1510,39 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
public class HttpContextEndPoint : EndPoint
{
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
try
{
if (result.IsCompleted)
{
break;
}
// Make sure we have an http context
var context = connection.GetHttpContext();
Assert.NotNull(context);
// Setting the response headers should have no effect
context.Response.ContentType = "application/xml";
// Echo the results
await connection.Transport.Output.WriteAsync(result.Buffer.ToArray());
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
}
public class TestEndPoint : EndPoint
{
public override async Task OnConnectedAsync(ConnectionContext connection)