Make RequestServices work for cloned longpolling HttpContext (#23014)

This commit is contained in:
Brennan 2020-06-24 12:29:11 -07:00 committed by GitHub
parent 5155e11120
commit dad1ca68d7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 265 additions and 79 deletions

View File

@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Http.Connections.Features;
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Http.Connections.Internal
@ -99,6 +100,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
// Used for testing only
internal Task DisposeAndRemoveTask { get; set; }
// Used for LongPolling because we need to create a scope that spans the lifetime of multiple requests on the cloned HttpContext
internal IServiceScope ServiceScope { get; set; }
public Task TransportTask { get; set; }
public Task PreviousPollTask { get; set; } = Task.CompletedTask;
@ -251,6 +255,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
(identity as IDisposable)?.Dispose();
}
}
ServiceScope?.Dispose();
}
await disposeTask;

View File

@ -14,6 +14,7 @@ using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Internal;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Primitives;
@ -537,8 +538,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
var existing = connection.HttpContext;
if (existing == null)
{
var httpContext = CloneHttpContext(context);
connection.HttpContext = httpContext;
CloneHttpContext(context, connection);
}
else
{
@ -606,7 +606,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
}
}
private static HttpContext CloneHttpContext(HttpContext context)
private static void CloneHttpContext(HttpContext context, HttpConnectionContext connection)
{
// The reason we're copying the base features instead of the HttpContext properties is
// so that we can get all of the logic built into DefaultHttpContext to extract higher level
@ -660,14 +660,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
CloneUser(newHttpContext, context);
// Making request services function property could be tricky and expensive as it would require
// DI scope per connection. It would also mean that services resolved in middleware leading up to here
// wouldn't be the same instance (but maybe that's fine). For now, we just return an empty service provider
newHttpContext.RequestServices = EmptyServiceProvider.Instance;
connection.ServiceScope = context.RequestServices.CreateScope();
newHttpContext.RequestServices = connection.ServiceScope.ServiceProvider;
// REVIEW: This extends the lifetime of anything that got put into HttpContext.Items
newHttpContext.Items = new Dictionary<object, object>(context.Items);
return newHttpContext;
connection.HttpContext = newHttpContext;
}
private async Task<HttpConnectionContext> GetConnectionAsync(HttpContext context)

View File

@ -510,6 +510,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
values["negotiateVersion"] = "1";
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.RequestServices = services.BuildServiceProvider();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.Use(next =>
@ -723,6 +724,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
context.Connection.RemotePort = 43456;
context.SetEndpoint(new Endpoint(null, null, "TestName"));
context.RequestServices = services.BuildServiceProvider();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<HttpContextConnectionHandler>();
@ -942,12 +944,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, HttpTransportType.ServerSentEvents);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var app = builder.Build();
@ -976,11 +978,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.ServerSentEvents);
var services = new ServiceCollection();
services.AddSingleton<SynchronusExceptionConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, HttpTransportType.ServerSentEvents);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<SynchronusExceptionConnectionHandler>();
var app = builder.Build();
@ -1004,10 +1006,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var app = builder.Build();
@ -1035,10 +1037,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1128,9 +1130,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1161,10 +1163,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, connection.TransportType);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, connection.TransportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1199,10 +1201,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
connection.TransportType = HttpTransportType.WebSockets;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var sync = new SyncPoint();
var context = MakeRequest("/foo", connection);
SetTransport(context, connection.TransportType, sync);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, connection.TransportType, sync);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1232,11 +1234,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.WebSockets);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, HttpTransportType.WebSockets);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var app = builder.Build();
@ -1262,14 +1264,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context1 = MakeRequest("/foo", connection, services);
var context2 = MakeRequest("/foo", connection, services);
SetTransport(context1, transportType);
SetTransport(context2, transportType);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1305,11 +1307,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context1 = MakeRequest("/foo", connection, services);
var context2 = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1369,11 +1371,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context1 = MakeRequest("/foo", connection, services);
var context2 = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1432,11 +1434,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, transportType);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, transportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1459,10 +1461,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1494,11 +1496,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.ServerSentEvents);
var services = new ServiceCollection();
services.AddSingleton<BlockingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, HttpTransportType.ServerSentEvents);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<BlockingConnectionHandler>();
var app = builder.Build();
@ -1529,10 +1531,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<BlockingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<BlockingConnectionHandler>();
var app = builder.Build();
@ -1577,12 +1579,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var context1 = MakeRequest("/foo", connection);
var context1 = MakeRequest("/foo", connection, services);
// This is the initial poll to make sure things are setup
var task1 = dispatcher.ExecuteAsync(context1, options, app);
Assert.True(task1.IsCompleted);
task1 = dispatcher.ExecuteAsync(context1, options, app);
var context2 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection, services);
var task2 = dispatcher.ExecuteAsync(context2, options, app);
// Task 1 should finish when request 2 arrives
@ -1615,11 +1617,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, transportType);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, transportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var app = builder.Build();
@ -1751,10 +1753,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1783,11 +1785,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, transportType);
var serviceCollection = new ServiceCollection();
serviceCollection.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, serviceCollection);
SetTransport(context, transportType);
var services = serviceCollection.BuildServiceProvider();
var builder = new ConnectionBuilder(services);
builder.UseConnectionHandler<TestConnectionHandler>();
@ -1824,10 +1826,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1877,10 +1879,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
@ -1926,10 +1928,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
@ -1992,7 +1994,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
};
{
var options = new HttpConnectionDispatcherOptions();
var context = MakeRequest("/foo", connection);
var context = MakeRequest("/foo", connection, new ServiceCollection());
await dispatcher.ExecuteAsync(context, options, connectionDelegate).OrTimeout();
// second poll should have data
@ -2008,7 +2010,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
waitForMessageTcs2.SetResult();
await messageTcs2.Task.OrTimeout();
context = MakeRequest("/foo", connection);
context = MakeRequest("/foo", connection, new ServiceCollection());
ms.Seek(0, SeekOrigin.Begin);
context.Response.Body = ms;
// This is the third poll which gets the final message after the app is complete
@ -2217,7 +2219,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var sendTask = dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.False(sendTask.IsCompleted);
var pollContext = MakeRequest("/foo", connection);
var pollContext = MakeRequest("/foo", connection, services);
// This should unblock the send that is waiting because of backpressure
// Testing deadlock regression where pipe backpressure would hold the same lock that poll would use
await dispatcher.ExecuteAsync(pollContext, options, app).OrTimeout();
@ -2253,12 +2255,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var context = MakeRequest("/foo", connection);
var context = MakeRequest("/foo", connection, services);
// Initial poll will complete immediately
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
var pollContext = MakeRequest("/foo", connection);
var pollContext = MakeRequest("/foo", connection, services);
var pollTask = dispatcher.ExecuteAsync(pollContext, options, app);
// fail LongPollingTransport ReadAsync
connection.Transport.Output.Complete(new InvalidOperationException());
@ -2281,10 +2283,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
@ -2332,10 +2334,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, connection.TransportType);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, connection.TransportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
@ -2361,12 +2364,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
connection.TransportType = HttpTransportType.WebSockets;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.WebSockets);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
SetTransport(context, HttpTransportType.WebSockets);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
@ -2411,13 +2413,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var context = MakeRequest("/foo", connection, services);
var lifetimeFeature = new CustomHttpRequestLifetimeFeature();
context.Features.Set<IHttpRequestLifetimeFeature>(lifetimeFeature);
SetTransport(context, connection.TransportType);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
@ -2436,6 +2438,149 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
[Fact]
public async Task ServicesAvailableWithLongPolling()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var services = new ServiceCollection();
services.AddSingleton<ServiceProviderConnectionHandler>();
services.AddSingleton(new MessageWrapper() { Buffer = new ReadOnlySequence<byte>(new byte[] { 1, 2, 3 }) });
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ServiceProviderConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var context = MakeRequest("/foo", connection, services);
// Initial poll will complete immediately
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
var pollContext = MakeRequest("/foo", connection, services);
var pollTask = dispatcher.ExecuteAsync(pollContext, options, app);
await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout();
await pollTask.OrTimeout();
var memory = new Memory<byte>(new byte[10]);
pollContext.Response.Body.Position = 0;
Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout());
Assert.Equal(new byte[] { 1, 2, 3 }, memory.Slice(0, 3).ToArray());
// Connection will use the original service provider so this will have no effect
services.AddSingleton(new MessageWrapper() { Buffer = new ReadOnlySequence<byte>(new byte[] { 4, 5, 6 }) });
pollContext = MakeRequest("/foo", connection, services);
pollTask = dispatcher.ExecuteAsync(pollContext, options, app);
await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout();
await pollTask.OrTimeout();
pollContext.Response.Body.Position = 0;
Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout());
Assert.Equal(new byte[] { 1, 2, 3 }, memory.Slice(0, 3).ToArray());
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task ServicesPreserveScopeWithLongPolling()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var services = new ServiceCollection();
services.AddSingleton<ServiceProviderConnectionHandler>();
var iteration = 0;
services.AddScoped(typeof(MessageWrapper), _ =>
{
iteration++;
return new MessageWrapper() { Buffer = new ReadOnlySequence<byte>(new byte[] { (byte)(iteration + 1), (byte)(iteration + 2), (byte)(iteration + 3) }) };
});
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ServiceProviderConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var context = MakeRequest("/foo", connection, services);
// Initial poll will complete immediately
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
var pollContext = MakeRequest("/foo", connection, services);
var pollTask = dispatcher.ExecuteAsync(pollContext, options, app);
await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout();
await pollTask.OrTimeout();
var memory = new Memory<byte>(new byte[10]);
pollContext.Response.Body.Position = 0;
Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout());
Assert.Equal(new byte[] { 2, 3, 4 }, memory.Slice(0, 3).ToArray());
pollContext = MakeRequest("/foo", connection, services);
pollTask = dispatcher.ExecuteAsync(pollContext, options, app);
await connection.Application.Output.WriteAsync(new byte[] { 1 }).OrTimeout();
await pollTask.OrTimeout();
pollContext.Response.Body.Position = 0;
Assert.Equal(3, await pollContext.Response.Body.ReadAsync(memory).OrTimeout());
Assert.Equal(new byte[] { 2, 3, 4 }, memory.Slice(0, 3).ToArray());
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task DisposeLongPollingConnectionDisposesServiceScope()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var services = new ServiceCollection();
services.AddSingleton<ServiceProviderConnectionHandler>();
var iteration = 0;
services.AddScoped(typeof(MessageWrapper), _ =>
{
iteration++;
return new MessageWrapper() { Buffer = new ReadOnlySequence<byte>(new byte[] { (byte)(iteration + 1), (byte)(iteration + 2), (byte)(iteration + 3) }) };
});
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ServiceProviderConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var context = MakeRequest("/foo", connection, services);
// Initial poll will complete immediately
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
// ServiceScope will be disposed here
await connection.DisposeAsync().OrTimeout();
Assert.Throws<ObjectDisposedException>(() => connection.ServiceScope.ServiceProvider.GetService<MessageWrapper>());
}
}
private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory)
{
var manager = CreateConnectionManager(loggerFactory);
@ -2459,6 +2604,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
values["negotiateVersion"] = "1";
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.RequestServices = services.BuildServiceProvider();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
@ -2478,7 +2624,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, string format = null)
private static DefaultHttpContext MakeRequest(string path, HttpConnectionContext connection, IServiceCollection serviceCollection, string format = null)
{
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
@ -2494,6 +2640,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Response.Body = new MemoryStream();
context.RequestServices = serviceCollection.BuildServiceProvider();
return context;
}
@ -2632,6 +2779,35 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
public class ServiceProviderConnectionHandler : ConnectionHandler
{
public override async Task OnConnectedAsync(ConnectionContext connection)
{
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
try
{
if (result.IsCompleted)
{
break;
}
var context = connection.GetHttpContext();
var message = context.RequestServices.GetService<MessageWrapper>();
// Echo the results
await connection.Transport.Output.WriteAsync(message.Buffer.ToArray());
}
finally
{
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
}
public class ResponseFeature : HttpResponseFeature
{
public override void OnCompleted(Func<object, Task> callback, object state)
@ -2642,4 +2818,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
{
}
}
public class MessageWrapper
{
public ReadOnlySequence<byte> Buffer { get; set; }
}
}