diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs index ec7b124f8f..5958c7ef92 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionContext.cs @@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Http.Connections.Internal @@ -99,6 +100,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal // Used for testing only internal Task DisposeAndRemoveTask { get; set; } + // Used for LongPolling because we need to create a scope that spans the lifetime of multiple requests on the cloned HttpContext + internal IServiceScope ServiceScope { get; set; } + public Task TransportTask { get; set; } public Task PreviousPollTask { get; set; } = Task.CompletedTask; @@ -251,6 +255,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal (identity as IDisposable)?.Dispose(); } } + + ServiceScope?.Dispose(); } await disposeTask; diff --git a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs index 7f6477cab0..fab15305ce 100644 --- a/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs +++ b/src/SignalR/common/Http.Connections/src/Internal/HttpConnectionDispatcher.cs @@ -14,6 +14,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections.Internal.Transports; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; +using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; @@ -537,8 +538,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal var existing = connection.HttpContext; if (existing == null) { - var httpContext = CloneHttpContext(context); - connection.HttpContext = httpContext; + CloneHttpContext(context, connection); } else { @@ -606,7 +606,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal } } - private static HttpContext CloneHttpContext(HttpContext context) + private static void CloneHttpContext(HttpContext context, HttpConnectionContext connection) { // The reason we're copying the base features instead of the HttpContext properties is // so that we can get all of the logic built into DefaultHttpContext to extract higher level @@ -660,14 +660,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal CloneUser(newHttpContext, context); - // Making request services function property could be tricky and expensive as it would require - // DI scope per connection. It would also mean that services resolved in middleware leading up to here - // wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider - newHttpContext.RequestServices = EmptyServiceProvider.Instance; + connection.ServiceScope = context.RequestServices.CreateScope(); + newHttpContext.RequestServices = connection.ServiceScope.ServiceProvider; // REVIEW: This extends the lifetime of anything that got put into HttpContext.Items newHttpContext.Items = new Dictionary(context.Items); - return newHttpContext; + + connection.HttpContext = newHttpContext; } private async Task GetConnectionAsync(HttpContext context) diff --git a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs index bed8bba4db..85418c85c3 100644 --- a/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs +++ b/src/SignalR/common/Http.Connections/test/HttpConnectionDispatcherTests.cs @@ -510,6 +510,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; + context.RequestServices = services.BuildServiceProvider(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.Use(next => @@ -723,6 +724,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.Connection.RemoteIpAddress = IPAddress.IPv6Any; context.Connection.RemotePort = 43456; context.SetEndpoint(new Endpoint(null, null, "TestName")); + context.RequestServices = services.BuildServiceProvider(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); @@ -942,12 +944,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); + var services = new ServiceCollection(); + services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); SetTransport(context, HttpTransportType.ServerSentEvents); - var services = new ServiceCollection(); - services.AddSingleton(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -976,11 +978,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests connection.TransportType = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, HttpTransportType.ServerSentEvents); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.ServerSentEvents); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1004,10 +1006,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1035,10 +1037,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1128,9 +1130,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = manager.CreateConnection(); connection.TransportType = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1161,10 +1163,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = manager.CreateConnection(); connection.TransportType = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, connection.TransportType); var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, connection.TransportType); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1199,10 +1201,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests connection.TransportType = HttpTransportType.WebSockets; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); var sync = new SyncPoint(); - var context = MakeRequest("/foo", connection); - SetTransport(context, connection.TransportType, sync); var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, connection.TransportType, sync); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1232,11 +1234,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, HttpTransportType.WebSockets); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1262,14 +1264,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context1 = MakeRequest("/foo", connection); - var context2 = MakeRequest("/foo", connection); + var services = new ServiceCollection(); + services.AddSingleton(); + var context1 = MakeRequest("/foo", connection, services); + var context2 = MakeRequest("/foo", connection, services); SetTransport(context1, transportType); SetTransport(context2, transportType); - var services = new ServiceCollection(); - services.AddSingleton(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1305,11 +1307,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context1 = MakeRequest("/foo", connection); - var context2 = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context1 = MakeRequest("/foo", connection, services); + var context2 = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1369,11 +1371,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context1 = MakeRequest("/foo", connection); - var context2 = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context1 = MakeRequest("/foo", connection, services); + var context2 = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1432,11 +1434,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, transportType); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, transportType); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1459,10 +1461,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1494,11 +1496,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, HttpTransportType.ServerSentEvents); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.ServerSentEvents); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1529,10 +1531,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1577,12 +1579,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); - var context1 = MakeRequest("/foo", connection); + var context1 = MakeRequest("/foo", connection, services); // This is the initial poll to make sure things are setup var task1 = dispatcher.ExecuteAsync(context1, options, app); Assert.True(task1.IsCompleted); task1 = dispatcher.ExecuteAsync(context1, options, app); - var context2 = MakeRequest("/foo", connection); + var context2 = MakeRequest("/foo", connection, services); var task2 = dispatcher.ExecuteAsync(context2, options, app); // Task 1 should finish when request 2 arrives @@ -1615,11 +1617,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, transportType); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, transportType); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1751,10 +1753,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1783,11 +1785,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, transportType); - var serviceCollection = new ServiceCollection(); serviceCollection.AddSingleton(); + var context = MakeRequest("/foo", connection, serviceCollection); + SetTransport(context, transportType); + var services = serviceCollection.BuildServiceProvider(); var builder = new ConnectionBuilder(services); builder.UseConnectionHandler(); @@ -1824,10 +1826,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1877,10 +1879,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1926,10 +1928,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -1992,7 +1994,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests }; { var options = new HttpConnectionDispatcherOptions(); - var context = MakeRequest("/foo", connection); + var context = MakeRequest("/foo", connection, new ServiceCollection()); await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout(); // second poll should have data @@ -2008,7 +2010,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests waitForMessageTcs2.SetResult(); await messageTcs2.Task.OrTimeout(); - context = MakeRequest("/foo", connection); + context = MakeRequest("/foo", connection, new ServiceCollection()); ms.Seek(0, SeekOrigin.Begin); context.Response.Body = ms; // This is the third poll which gets the final message after the app is complete @@ -2217,7 +2219,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var sendTask = dispatcher.ExecuteAsync(context, options, app).OrTimeout(); Assert.False(sendTask.IsCompleted); - var pollContext = MakeRequest("/foo", connection); + var pollContext = MakeRequest("/foo", connection, services); // This should unblock the send that is waiting because of backpressure // Testing deadlock regression where pipe backpressure would hold the same lock that poll would use await dispatcher.ExecuteAsync(pollContext, options, app).OrTimeout(); @@ -2253,12 +2255,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); - var context = MakeRequest("/foo", connection); + var context = MakeRequest("/foo", connection, services); // Initial poll will complete immediately await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); - var pollContext = MakeRequest("/foo", connection); + var pollContext = MakeRequest("/foo", connection, services); var pollTask = dispatcher.ExecuteAsync(pollContext, options, app); // fail LongPollingTransport ReadAsync connection.Transport.Output.Complete(new InvalidOperationException()); @@ -2281,10 +2283,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -2332,10 +2334,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = manager.CreateConnection(); connection.TransportType = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); - SetTransport(context, connection.TransportType); var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, connection.TransportType); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -2361,12 +2364,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests connection.TransportType = HttpTransportType.WebSockets; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - - var context = MakeRequest("/foo", connection); - SetTransport(context, HttpTransportType.WebSockets); - var services = new ServiceCollection(); services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); + SetTransport(context, HttpTransportType.WebSockets); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -2411,13 +2413,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = manager.CreateConnection(); connection.TransportType = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); - var context = MakeRequest("/foo", connection); + var services = new ServiceCollection(); + services.AddSingleton(); + var context = MakeRequest("/foo", connection, services); var lifetimeFeature = new CustomHttpRequestLifetimeFeature(); context.Features.Set(lifetimeFeature); SetTransport(context, connection.TransportType); - var services = new ServiceCollection(); - services.AddSingleton(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); @@ -2436,6 +2438,149 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task ServicesAvailableWithLongPolling() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddSingleton(new MessageWrapper() { Buffer = new ReadOnlySequence(new byte[] { 1, 2, 3 }) }); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var context = MakeRequest("/foo", connection, services); + + // Initial poll will complete immediately + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + var pollContext = MakeRequest("/foo", connection, services); + var pollTask = dispatcher.ExecuteAsync(pollContext, options, app); + + await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await pollTask.OrTimeout(); + + var memory = new Memory(new byte[10]); + pollContext.Response.Body.Position = 0; + Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout()); + Assert.Equal(new byte[] { 1, 2, 3 }, memory.Slice(0, 3).ToArray()); + + // Connection will use the original service provider so this will have no effect + services.AddSingleton(new MessageWrapper() { Buffer = new ReadOnlySequence(new byte[] { 4, 5, 6 }) }); + pollContext = MakeRequest("/foo", connection, services); + pollTask = dispatcher.ExecuteAsync(pollContext, options, app); + + await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await pollTask.OrTimeout(); + + pollContext.Response.Body.Position = 0; + Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout()); + Assert.Equal(new byte[] { 1, 2, 3 }, memory.Slice(0, 3).ToArray()); + + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task ServicesPreserveScopeWithLongPolling() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var services = new ServiceCollection(); + services.AddSingleton(); + var iteration = 0; + services.AddScoped(typeof(MessageWrapper), _ => + { + iteration++; + return new MessageWrapper() { Buffer = new ReadOnlySequence(new byte[] { (byte)(iteration + 1), (byte)(iteration + 2), (byte)(iteration + 3) }) }; + }); + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var context = MakeRequest("/foo", connection, services); + + // Initial poll will complete immediately + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + var pollContext = MakeRequest("/foo", connection, services); + var pollTask = dispatcher.ExecuteAsync(pollContext, options, app); + + await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await pollTask.OrTimeout(); + + var memory = new Memory(new byte[10]); + pollContext.Response.Body.Position = 0; + Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout()); + Assert.Equal(new byte[] { 2, 3, 4 }, memory.Slice(0, 3).ToArray()); + + pollContext = MakeRequest("/foo", connection, services); + pollTask = dispatcher.ExecuteAsync(pollContext, options, app); + + await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout(); + await pollTask.OrTimeout(); + + pollContext.Response.Body.Position = 0; + Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout()); + Assert.Equal(new byte[] { 2, 3, 4 }, memory.Slice(0, 3).ToArray()); + + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task DisposeLongPollingConnectionDisposesServiceScope() + { + using (StartVerifiableLog()) + { + var manager = CreateConnectionManager(LoggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory); + + var services = new ServiceCollection(); + services.AddSingleton(); + var iteration = 0; + services.AddScoped(typeof(MessageWrapper), _ => + { + iteration++; + return new MessageWrapper() { Buffer = new ReadOnlySequence(new byte[] { (byte)(iteration + 1), (byte)(iteration + 2), (byte)(iteration + 3) }) }; + }); + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + var context = MakeRequest("/foo", connection, services); + + // Initial poll will complete immediately + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + // ServiceScope will be disposed here + await connection.DisposeAsync().OrTimeout(); + + Assert.Throws(() => connection.ServiceScope.ServiceProvider.GetService()); + } + } + private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory) { var manager = CreateConnectionManager(loggerFactory); @@ -2459,6 +2604,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests values["negotiateVersion"] = "1"; var qs = new QueryCollection(values); context.Request.Query = qs; + context.RequestServices = services.BuildServiceProvider(); var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); @@ -2478,7 +2624,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } - private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, string format = null) + private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, IServiceCollection serviceCollection, string format = null) { var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); @@ -2494,6 +2640,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var qs = new QueryCollection(values); context.Request.Query = qs; context.Response.Body = new MemoryStream(); + context.RequestServices = serviceCollection.BuildServiceProvider(); return context; } @@ -2632,6 +2779,35 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + public class ServiceProviderConnectionHandler : ConnectionHandler + { + public override async Task OnConnectedAsync(ConnectionContext connection) + { + while (true) + { + var result = await connection.Transport.Input.ReadAsync(); + + try + { + if (result.IsCompleted) + { + break; + } + + var context = connection.GetHttpContext(); + var message = context.RequestServices.GetService(); + + // Echo the results + await connection.Transport.Output.WriteAsync(message.Buffer.ToArray()); + } + finally + { + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } + } + } + } + public class ResponseFeature : HttpResponseFeature { public override void OnCompleted(Func callback, object state) @@ -2642,4 +2818,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { } } + + public class MessageWrapper + { + public ReadOnlySequence Buffer { get; set; } + } }