From f1a377524710ad2912f30e6a5649af12e2e8f1b7 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 22 Mar 2018 15:24:35 -0700 Subject: [PATCH] Copy HttpContext properties for long polling transport (#1684) - The long polling transport simulates a persistent connection over multiple http requests. In order to expose common http request properties, we need to copy them to a fake http context on the first poll and set that as the HttpContext exposed via the IHttpContextFeature. --- .../HttpConnectionContextExtensions.cs | 2 +- .../HttpConnectionDispatcher.cs | 97 ++++++++++++- .../Microsoft.AspNetCore.Sockets.Http.csproj | 1 + .../HubConnectionTests.cs | 2 +- .../HubEndpointTests.cs | 10 +- .../HttpConnectionDispatcherTests.cs | 127 +++++++++++++++++- 6 files changed, 229 insertions(+), 10 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs index 0118376e34..67a4d968d6 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionContextExtensions.cs @@ -11,7 +11,7 @@ namespace Microsoft.AspNetCore.Sockets { public static HttpContext GetHttpContext(this ConnectionContext connection) { - return connection.Features.Get().HttpContext; + return connection.Features.Get()?.HttpContext; } public static void SetHttpContext(this ConnectionContext connection, HttpContext httpContext) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index f09cad93a2..5b08db7a9e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.IO.Pipelines; @@ -10,6 +11,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Http.Internal; using Microsoft.AspNetCore.Protocols; using Microsoft.AspNetCore.Protocols.Features; using Microsoft.AspNetCore.Sockets.Internal; @@ -276,8 +278,6 @@ namespace Microsoft.AspNetCore.Sockets connection.Status = DefaultConnectionContext.ConnectionStatus.Inactive; - connection.SetHttpContext(null); - // Dispose the cancellation token connection.Cancellation.Dispose(); @@ -500,15 +500,35 @@ namespace Microsoft.AspNetCore.Sockets return false; } + // Setup the connection state from the http context + connection.User = context.User; + // Configure transport-specific features. if (transportType == TransportType.LongPolling) { connection.Features.Set(new ConnectionInherentKeepAliveFeature(options.LongPolling.PollTimeout)); - } - // Setup the connection state from the http context - connection.User = context.User; - connection.SetHttpContext(context); + // For long polling, the requests come and go but the connection is still alive. + // To make the IHttpContextFeature work well, we make a copy of the relevant properties + // to a new HttpContext. This means that it's impossible to affect the context + // with subsequent requests. + var existing = connection.GetHttpContext(); + if (existing == null) + { + var httpContext = CloneHttpContext(context); + connection.SetHttpContext(httpContext); + } + else + { + // Set the request trace identifier to the current http request handling the poll + existing.TraceIdentifier = context.TraceIdentifier; + existing.User = context.User; + } + } + else + { + connection.SetHttpContext(context); + } // Set the Connection ID on the logging scope so that logs from now on will have the // Connection ID metadata set. @@ -517,6 +537,65 @@ namespace Microsoft.AspNetCore.Sockets return true; } + private static HttpContext CloneHttpContext(HttpContext context) + { + // The reason we're copying the base features instead of the HttpContext properties is + // so that we can get all of the logic built into DefaultHttpContext to extract higher level + // structure from the low level properties + var existingRequestFeature = context.Features.Get(); + + var requestFeature = new HttpRequestFeature(); + requestFeature.Protocol = existingRequestFeature.Protocol; + requestFeature.Method = existingRequestFeature.Method; + requestFeature.Scheme = existingRequestFeature.Scheme; + requestFeature.Path = existingRequestFeature.Path; + requestFeature.PathBase = existingRequestFeature.PathBase; + requestFeature.QueryString = existingRequestFeature.QueryString; + requestFeature.RawTarget = existingRequestFeature.RawTarget; + var requestHeaders = new Dictionary(existingRequestFeature.Headers.Count); + foreach (var header in existingRequestFeature.Headers) + { + requestHeaders[header.Key] = header.Value; + } + requestFeature.Headers = new HeaderDictionary(requestHeaders); + + var existingConnectionFeature = context.Features.Get(); + var connectionFeature = new HttpConnectionFeature(); + + if (existingConnectionFeature != null) + { + connectionFeature.ConnectionId = existingConnectionFeature.ConnectionId; + connectionFeature.LocalIpAddress = existingConnectionFeature.LocalIpAddress; + connectionFeature.LocalPort = existingConnectionFeature.LocalPort; + connectionFeature.RemoteIpAddress = existingConnectionFeature.RemoteIpAddress; + connectionFeature.RemotePort = existingConnectionFeature.RemotePort; + } + + // The response is a dud, you can't do anything with it anyways + var responseFeature = new HttpResponseFeature(); + + var features = new FeatureCollection(); + features.Set(requestFeature); + features.Set(responseFeature); + features.Set(connectionFeature); + + // REVIEW: We could strategically look at adding other features but it might be better + // if we expose a callback that would allow the user to preserve HttpContext properties. + + var newHttpContext = new DefaultHttpContext(features); + newHttpContext.TraceIdentifier = context.TraceIdentifier; + newHttpContext.User = context.User; + + // Making request services function property could be tricky and expensive as it would require + // DI scope per connection. It would also mean that services resolved in middleware leading up to here + // wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider + newHttpContext.RequestServices = EmptyServiceProvider.Instance; + + // REVIEW: This extends the lifetime of anything that got put into HttpContext.Items + newHttpContext.Items = new Dictionary(context.Items); + return newHttpContext; + } + private async Task GetConnectionAsync(HttpContext context, HttpSocketOptions options) { var connectionId = GetConnectionId(context); @@ -580,5 +659,11 @@ namespace Microsoft.AspNetCore.Sockets return connection; } + + private class EmptyServiceProvider : IServiceProvider + { + public static EmptyServiceProvider Instance { get; } = new EmptyServiceProvider(); + public object GetService(Type serviceType) => null; + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index 6a97d4da6e..6c12890f74 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -19,6 +19,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 5e03d870ce..c9ef03fa5c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -692,7 +692,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } - [Theory(Skip = "HttpContext + Long Polling fails. Issue logged - https://github.com/aspnet/SignalR/issues/1644")] + [Theory] [MemberData(nameof(TransportTypes))] public async Task ClientCanSendHeaders(TransportType transportType) { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 588432ea5e..698c87215c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -540,7 +540,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); // Ensure the client channel is empty - Assert.Null(client.TryRead()); + var message = client.TryRead(); + switch (message) + { + case CloseMessage close: + break; + default: + Assert.Null(message); + break; + } await endPointTask.OrTimeout(); } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index c6cb016cb8..9492d14947 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -6,6 +6,8 @@ using System.Buffers; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; +using System.Linq; +using System.Net; using System.Net.WebSockets; using System.Security.Claims; using System.Text; @@ -337,6 +339,96 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + [Fact] + public async Task HttpContextFeatureForLongpollingWorksBetweenPolls() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + + using (var requestBody = new MemoryStream()) + using (var responseBody = new MemoryStream()) + { + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + context.Response.Body = responseBody; + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + + // Setup state on the HttpContext + context.Request.Path = "/foo"; + context.Request.Method = "GET"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + values["another"] = "value"; + var qs = new QueryCollection(values); + context.Request.Query = qs; + context.Request.Headers["header1"] = "h1"; + context.Request.Headers["header2"] = "h2"; + context.Request.Headers["header3"] = "h3"; + context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") })); + context.TraceIdentifier = "requestid"; + context.Connection.Id = "connectionid"; + context.Connection.LocalIpAddress = IPAddress.Loopback; + context.Connection.LocalPort = 4563; + context.Connection.RemoteIpAddress = IPAddress.IPv6Any; + context.Connection.RemotePort = 43456; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + + // Start a poll + var task = dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); + + // Send to the application + var buffer = Encoding.UTF8.GetBytes("Hello World"); + await connection.Application.Output.WriteAsync(buffer); + + // The poll request should end + await task; + + // Make sure the actual response isn't affected + Assert.Equal("application/octet-stream", context.Response.ContentType); + + // Now do a new send again without the poll (that request should have ended) + await connection.Application.Output.WriteAsync(buffer); + + connection.Application.Output.Complete(); + + // Wait for the endpoint to end + await connection.ApplicationTask; + + var connectionHttpContext = connection.GetHttpContext(); + Assert.NotNull(connectionHttpContext); + + Assert.Equal(2, connectionHttpContext.Request.Query.Count); + Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]); + Assert.Equal("value", connectionHttpContext.Request.Query["another"]); + + Assert.Equal(3, connectionHttpContext.Request.Headers.Count); + Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]); + Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]); + Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]); + Assert.Equal("requestid", connectionHttpContext.TraceIdentifier); + Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value); + Assert.Equal("connectionid", connectionHttpContext.Connection.Id); + Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress); + Assert.Equal(4563, connectionHttpContext.Connection.LocalPort); + Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress); + Assert.Equal(43456, connectionHttpContext.Connection.RemotePort); + Assert.NotNull(connectionHttpContext.RequestServices); + Assert.Equal(Stream.Null, connectionHttpContext.Response.Body); + Assert.NotNull(connectionHttpContext.Response.Headers); + Assert.Equal("application/xml", connectionHttpContext.Response.ContentType); + } + } + } + [Theory] [InlineData(TransportType.ServerSentEvents)] [InlineData(TransportType.LongPolling)] @@ -713,7 +805,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests await task; Assert.Equal(DefaultConnectionContext.ConnectionStatus.Inactive, connection.Status); - Assert.Null(connection.GetHttpContext()); + Assert.NotNull(connection.GetHttpContext()); Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); } @@ -1418,6 +1510,39 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } + public class HttpContextEndPoint : EndPoint + { + public override async Task OnConnectedAsync(ConnectionContext connection) + { + while (true) + { + var result = await connection.Transport.Input.ReadAsync(); + + try + { + if (result.IsCompleted) + { + break; + } + + // Make sure we have an http context + var context = connection.GetHttpContext(); + Assert.NotNull(context); + + // Setting the response headers should have no effect + context.Response.ContentType = "application/xml"; + + // Echo the results + await connection.Transport.Output.WriteAsync(result.Buffer.ToArray()); + } + finally + { + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } + } + } + } + public class TestEndPoint : EndPoint { public override async Task OnConnectedAsync(ConnectionContext connection)