[SignalR] Implement IConnectionLifetimeFeature (#20604)

This commit is contained in:
Brennan 2020-06-23 22:14:12 -07:00 committed by GitHub
parent 0d8d4e709c
commit cc15b1bb43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 239 additions and 10 deletions

View File

@ -26,7 +26,6 @@ namespace Microsoft.AspNetCore.Connections
public DefaultConnectionContext() :
this(Guid.NewGuid().ToString())
{
ConnectionClosed = _connectionClosedTokenSource.Token;
}
/// <summary>
@ -45,6 +44,8 @@ namespace Microsoft.AspNetCore.Connections
Features.Set<IConnectionTransportFeature>(this);
Features.Set<IConnectionLifetimeFeature>(this);
Features.Set<IConnectionEndPointFeature>(this);
ConnectionClosed = _connectionClosedTokenSource.Token;
}
public DefaultConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application)

View File

@ -29,7 +29,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
ITransferFormatFeature,
IHttpContextFeature,
IHttpTransportFeature,
IConnectionInherentKeepAliveFeature
IConnectionInherentKeepAliveFeature,
IConnectionLifetimeFeature
{
private static long _tenSeconds = TimeSpan.FromSeconds(10).Ticks;
@ -41,6 +42,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
private PipeWriterStream _applicationStream;
private IDuplexPipe _application;
private IDictionary<object, object> _items;
private CancellationTokenSource _connectionClosedTokenSource;
private CancellationTokenSource _sendCts;
private bool _activeSend;
@ -82,6 +84,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
Features.Set<IHttpContextFeature>(this);
Features.Set<IHttpTransportFeature>(this);
Features.Set<IConnectionInherentKeepAliveFeature>(this);
Features.Set<IConnectionLifetimeFeature>(this);
_connectionClosedTokenSource = new CancellationTokenSource();
ConnectionClosed = _connectionClosedTokenSource.Token;
}
public CancellationTokenSource Cancellation { get; set; }
@ -170,6 +176,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
public HttpContext HttpContext { get; set; }
public override CancellationToken ConnectionClosed { get; set; }
public override void Abort()
{
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
HttpContext?.Abort();
}
public void OnHeartbeat(Action<object> action, object state)
{
lock (_heartbeatLock)
@ -305,6 +320,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
// Now complete the application
Application?.Output.Complete();
Application?.Input.Complete();
// Trigger ConnectionClosed
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
}
}
else
@ -313,6 +331,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
Application?.Output.Complete(transportTask.Exception?.InnerException);
Application?.Input.Complete();
// Trigger ConnectionClosed
ThreadPool.UnsafeQueueUserWorkItem(cts => ((CancellationTokenSource)cts).Cancel(), _connectionClosedTokenSource);
try
{
// A poorly written application *could* in theory get stuck forever and it'll show up as a memory leak

View File

@ -961,7 +961,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
[Fact]
public async Task SynchronusExceptionEndsConnection()
public async Task SynchronousExceptionEndsConnection()
{
bool ExpectedErrors(WriteContext writeContext)
{
@ -2269,6 +2269,173 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
}
[Fact]
public async Task LongPollingConnectionClosingTriggersConnectionClosedToken()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var pipeOptions = new PipeOptions(pauseWriterThreshold: 2, resumeWriterThreshold: 1);
var connection = manager.CreateConnection(pipeOptions, pipeOptions);
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var pollTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(pollTask.IsCompleted);
// Now send the second poll
pollTask = dispatcher.ExecuteAsync(context, options, app);
// Issue the delete request and make sure the poll completes
var deleteContext = new DefaultHttpContext();
deleteContext.Request.Path = "/foo";
deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}");
deleteContext.Request.Method = "DELETE";
Assert.False(pollTask.IsCompleted);
await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout();
await pollTask.OrTimeout();
// Verify that transport shuts down
await connection.TransportTask.OrTimeout();
// Verify the response from the DELETE request
Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode);
Assert.Equal("text/plain", deleteContext.Response.ContentType);
Assert.Equal(HttpConnectionStatus.Disposed, connection.Status);
// Verify the connection not removed because application is hanging
Assert.True(manager.TryGetConnection(connection.ConnectionId, out _));
Assert.True(connection.ConnectionClosed.IsCancellationRequested);
}
}
[Fact]
public async Task SSEConnectionClosingTriggersConnectionClosedToken()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, connection.TransportType);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
_ = dispatcher.ExecuteAsync(context, options, app);
// Close the SSE connection
connection.Transport.Output.Complete();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
await tcs.Task.OrTimeout();
}
}
[Fact]
public async Task WebSocketConnectionClosingTriggersConnectionClosedToken()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.WebSockets;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.WebSockets);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
_ = dispatcher.ExecuteAsync(context, options, app);
var websocket = (TestWebSocketConnectionFeature)context.Features.Get<IHttpWebSocketFeature>();
await websocket.Accepted.OrTimeout();
await websocket.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", cancellationToken: default).OrTimeout();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
await tcs.Task.OrTimeout();
}
}
public class CustomHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature
{
public CancellationToken RequestAborted { get; set; }
private CancellationTokenSource _cts;
public CustomHttpRequestLifetimeFeature()
{
_cts = new CancellationTokenSource();
RequestAborted = _cts.Token;
}
public void Abort()
{
_cts.Cancel();
}
}
[Fact]
public async Task AbortingConnectionAbortsHttpContextAndTriggersConnectionClosedToken()
{
using (StartVerifiableLog())
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var lifetimeFeature = new CustomHttpRequestLifetimeFeature();
context.Features.Set<IHttpRequestLifetimeFeature>(lifetimeFeature);
SetTransport(context, connection.TransportType);
var services = new ServiceCollection();
services.AddSingleton<NeverEndingConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<NeverEndingConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
_ = dispatcher.ExecuteAsync(context, options, app);
connection.Abort();
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
connection.ConnectionClosed.Register(() => tcs.SetResult(null));
await tcs.Task.OrTimeout();
tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
lifetimeFeature.RequestAborted.Register(() => tcs.SetResult(null));
await tcs.Task.OrTimeout();
}
}
private static async Task CheckTransportSupported(HttpTransportType supportedTransports, HttpTransportType transportType, int status, ILoggerFactory loggerFactory)
{
var manager = CreateConnectionManager(loggerFactory);

View File

@ -37,6 +37,7 @@ namespace Microsoft.AspNetCore.SignalR
private readonly SemaphoreSlim _writeLock = new SemaphoreSlim(1);
private readonly object _receiveMessageTimeoutLock = new object();
private readonly ISystemClock _systemClock;
private readonly CancellationTokenRegistration _closedRegistration;
private StreamTracker _streamTracker;
private long _lastSendTimeStamp;
@ -66,6 +67,7 @@ namespace Microsoft.AspNetCore.SignalR
_connectionContext = connectionContext;
_logger = loggerFactory.CreateLogger<HubConnectionContext>();
ConnectionAborted = _connectionAbortedTokenSource.Token;
_closedRegistration = connectionContext.ConnectionClosed.Register((state) => ((HubConnectionContext)state).Abort(), this);
HubCallerContext = new DefaultHubCallerContext(this);
@ -624,12 +626,6 @@ namespace Microsoft.AspNetCore.SignalR
finally
{
_ = InnerAbortConnection(connection);
// Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist
if (connection._streamTracker != null)
{
connection._streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed."));
}
}
static async Task InnerAbortConnection(HubConnectionContext connection)
@ -670,6 +666,17 @@ namespace Microsoft.AspNetCore.SignalR
}
}
internal void Cleanup()
{
_closedRegistration.Dispose();
// Use _streamTracker to avoid lazy init from StreamTracker getter if it doesn't exist
if (_streamTracker != null)
{
_streamTracker.CompleteAll(new OperationCanceledException("The underlying connection was closed."));
}
}
private static class Log
{
// Category: HubConnectionContext

View File

@ -139,6 +139,8 @@ namespace Microsoft.AspNetCore.SignalR
}
finally
{
connectionContext.Cleanup();
Log.ConnectedEnding(_logger);
await _lifetimeManager.OnDisconnectedAsync(connectionContext);
}

View File

@ -221,7 +221,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
public async Task<int> StreamingSum(ChannelReader<int> source)
{
var total = 0;
@ -322,6 +321,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
tcs.TrySetResult(42);
}
}
public async Task BlockingMethod()
{
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
Context.ConnectionAborted.Register(state => ((TaskCompletionSource<object>)state).SetResult(null), tcs);
await tcs.Task;
}
}
public abstract class TestHub : Hub

View File

@ -948,6 +948,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.True(hasErrorLog);
}
[Fact]
public async Task HubMethodListeningToConnectionAbortedClosesOnConnectionContextAbort()
{
using (StartVerifiableLog())
{
var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(MethodHub), loggerFactory: LoggerFactory);
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
var invokeTask = client.InvokeAsync(nameof(MethodHub.BlockingMethod));
client.Connection.Abort();
// If this completes then the server has completed the connection
await connectionHandlerTask.OrTimeout();
// Nothing written to connection because it was closed
Assert.False(invokeTask.IsCompleted);
}
}
}
[Fact]
public async Task DetailedExceptionEvenWhenNotExplicitlySet()
{