diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs index 8b6109e6c2..12aff921a2 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionContext.cs @@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http; +using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Http.Connections { @@ -28,6 +29,7 @@ namespace Microsoft.AspNetCore.Http.Connections { private readonly object _heartbeatLock = new object(); private List<(Action handler, object state)> _heartbeatHandlers; + private readonly ILogger _logger; // This tcs exists so that multiple calls to DisposeAsync all wait asynchronously // on the same task @@ -38,7 +40,8 @@ namespace Microsoft.AspNetCore.Http.Connections /// The caller is expected to set the and pipes manually. /// /// - public HttpConnectionContext(string id) + /// + public HttpConnectionContext(string id, ILogger logger) { ConnectionId = id; LastSeenUtc = DateTime.UtcNow; @@ -47,6 +50,8 @@ namespace Microsoft.AspNetCore.Http.Connections SupportedFormats = TransferFormat.Binary | TransferFormat.Text; ActiveFormat = TransferFormat.Text; + _logger = logger; + // PERF: This type could just implement IFeatureCollection Features = new FeatureCollection(); Features.Set(this); @@ -59,8 +64,8 @@ namespace Microsoft.AspNetCore.Http.Connections Features.Set(this); } - public HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application) - : this(id) + public HttpConnectionContext(string id, IDuplexPipe transport, IDuplexPipe application, ILogger logger = null) + : this(id, logger) { Transport = transport; Application = application; @@ -68,6 +73,8 @@ namespace Microsoft.AspNetCore.Http.Connections public CancellationTokenSource Cancellation { get; set; } + public HttpTransportType TransportType { get; set; } + public SemaphoreSlim Lock { get; } = new SemaphoreSlim(1, 1); public Task TransportTask { get; set; } @@ -124,7 +131,7 @@ namespace Microsoft.AspNetCore.Http.Connections } } - public async Task DisposeAsync() + public async Task DisposeAsync(bool closeGracefully = false) { var disposeTask = Task.CompletedTask; @@ -140,30 +147,12 @@ namespace Microsoft.AspNetCore.Http.Connections { Status = ConnectionStatus.Disposed; - // If the application task is faulted, propagate the error to the transport - if (ApplicationTask?.IsFaulted == true) - { - Transport?.Output.Complete(ApplicationTask.Exception.InnerException); - } - else - { - Transport?.Output.Complete(); - } - - // If the transport task is faulted, propagate the error to the application - if (TransportTask?.IsFaulted == true) - { - Application?.Output.Complete(TransportTask.Exception.InnerException); - } - else - { - Application?.Output.Complete(); - } + Log.DisposingConnection(_logger, ConnectionId); var applicationTask = ApplicationTask ?? Task.CompletedTask; var transportTask = TransportTask ?? Task.CompletedTask; - disposeTask = WaitOnTasks(applicationTask, transportTask); + disposeTask = WaitOnTasks(applicationTask, transportTask, closeGracefully); } } finally @@ -171,25 +160,88 @@ namespace Microsoft.AspNetCore.Http.Connections Lock.Release(); } - try - { - await disposeTask; - } - finally - { - // REVIEW: Should we move this to the read loops? - - // Complete the reading side of the pipes - Application?.Input.Complete(); - Transport?.Input.Complete(); - } + await disposeTask; } - private async Task WaitOnTasks(Task applicationTask, Task transportTask) + private async Task WaitOnTasks(Task applicationTask, Task transportTask, bool closeGracefully) { try { - await Task.WhenAll(applicationTask, transportTask); + // Closing gracefully means we're only going to close the finished sides of the pipe + // If the application finishes, that means it's done with the transport pipe + // If the transport finishes, that means it's done with the application pipe + if (closeGracefully) + { + // Wait for either to finish + var result = await Task.WhenAny(applicationTask, transportTask); + + // If the application is complete, complete the transport pipe (it's the pipe to the transport) + if (result == applicationTask) + { + Transport?.Output.Complete(applicationTask.Exception?.InnerException); + Transport?.Input.Complete(); + + try + { + Log.WaitingForTransport(_logger, TransportType); + + // Transports are written by us and are well behaved, wait for them to drain + await transportTask; + } + finally + { + Log.TransportComplete(_logger, TransportType); + + // Now complete the application + Application?.Output.Complete(); + Application?.Input.Complete(); + } + } + else + { + // If the transport is complete, complete the application pipes + Application?.Output.Complete(transportTask.Exception?.InnerException); + Application?.Input.Complete(); + + try + { + // A poorly written application *could* in theory hang forever and it'll show up as a memory leak + Log.WaitingForApplication(_logger); + + await applicationTask; + } + finally + { + Log.ApplicationComplete(_logger); + + Transport?.Output.Complete(); + Transport?.Input.Complete(); + } + } + } + else + { + Log.ShuttingDownTransportAndApplication(_logger, TransportType); + + // Shutdown both sides and wait for nothing + Transport?.Output.Complete(applicationTask.Exception?.InnerException); + Application?.Output.Complete(transportTask.Exception?.InnerException); + + try + { + Log.WaitingForTransportAndApplication(_logger, TransportType); + // A poorly written application *could* in theory hang forever and it'll show up as a memory leak + await Task.WhenAll(applicationTask, transportTask); + } + finally + { + Log.TransportAndApplicationComplete(_logger, TransportType); + + // Close the reading side after both sides run + Application?.Input.Complete(); + Transport?.Input.Complete(); + } + } // Notify all waiters that we're done disposing _disposeTcs.TrySetResult(null); @@ -214,5 +266,111 @@ namespace Microsoft.AspNetCore.Http.Connections Active, Disposed } + + private static class Log + { + private static readonly Action _disposingConnection = + LoggerMessage.Define(LogLevel.Trace, new EventId(1, "DisposingConnection"), "Disposing connection {TransportConnectionId}."); + + private static readonly Action _waitingForApplication = + LoggerMessage.Define(LogLevel.Trace, new EventId(2, "WaitingForApplication"), "Waiting for application to complete."); + + private static readonly Action _applicationComplete = + LoggerMessage.Define(LogLevel.Trace, new EventId(3, "ApplicationComplete"), "Application complete."); + + private static readonly Action _waitingForTransport = + LoggerMessage.Define(LogLevel.Trace, new EventId(4, "WaitingForTransport"), "Waiting for {TransportType} transport to complete."); + + private static readonly Action _transportComplete = + LoggerMessage.Define(LogLevel.Trace, new EventId(5, "TransportComplete"), "{TransportType} transport complete."); + + private static readonly Action _shuttingDownTransportAndApplication = + LoggerMessage.Define(LogLevel.Trace, new EventId(6, "ShuttingDownTransportAndApplication"), "Shutting down both the application and the {TransportType} transport."); + + private static readonly Action _waitingForTransportAndApplication = + LoggerMessage.Define(LogLevel.Trace, new EventId(7, "WaitingForTransportAndApplication"), "Waiting for both the application and {TransportType} transport to complete."); + + private static readonly Action _transportAndApplicationComplete = + LoggerMessage.Define(LogLevel.Trace, new EventId(8, "TransportAndApplicationComplete"), "The application and {TransportType} transport are both complete."); + + public static void DisposingConnection(ILogger logger, string connectionId) + { + if (logger == null) + { + return; + } + + _disposingConnection(logger, connectionId, null); + } + + public static void WaitingForApplication(ILogger logger) + { + if (logger == null) + { + return; + } + + _waitingForApplication(logger, null); + } + + public static void ApplicationComplete(ILogger logger) + { + if (logger == null) + { + return; + } + + _applicationComplete(logger, null); + } + + public static void WaitingForTransport(ILogger logger, HttpTransportType transportType) + { + if (logger == null) + { + return; + } + + _waitingForTransport(logger, transportType, null); + } + + public static void TransportComplete(ILogger logger, HttpTransportType transportType) + { + if (logger == null) + { + return; + } + + _transportComplete(logger, transportType, null); + } + public static void ShuttingDownTransportAndApplication(ILogger logger, HttpTransportType transportType) + { + if (logger == null) + { + return; + } + + _shuttingDownTransportAndApplication(logger, transportType, null); + } + + public static void WaitingForTransportAndApplication(ILogger logger, HttpTransportType transportType) + { + if (logger == null) + { + return; + } + + _waitingForTransportAndApplication(logger, transportType, null); + } + + public static void TransportAndApplicationComplete(ILogger logger, HttpTransportType transportType) + { + if (logger == null) + { + return; + } + + _transportAndApplicationComplete(logger, transportType, null); + } + } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs index 239b793db2..fc8072be4c 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs @@ -221,8 +221,6 @@ namespace Microsoft.AspNetCore.Http.Connections { Log.EstablishedConnection(_logger); - connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; - connection.ApplicationTask = ExecuteApplication(connectionDelegate, connection); } else @@ -270,7 +268,8 @@ namespace Microsoft.AspNetCore.Http.Connections if (context.Response.StatusCode == StatusCodes.Status204NoContent) { // We should be able to safely dispose because there's no more data being written - await _manager.DisposeAndRemoveAsync(connection); + // We don't need to wait for close here since we've already waited for both sides + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: false); // Don't poll again if we've removed the connection completely pollAgain = false; @@ -355,15 +354,15 @@ namespace Microsoft.AspNetCore.Http.Connections // Wait for any of them to end await Task.WhenAny(connection.ApplicationTask, connection.TransportTask); - await _manager.DisposeAndRemoveAsync(connection); + await _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); } - private async Task ExecuteApplication(ConnectionDelegate connectionDelegate, ConnectionContext connection) + private async Task ExecuteApplication(ConnectionDelegate connectionDelegate, HttpConnectionContext connection) { // Verify some initialization invariants // We want to be positive that the IConnectionInherentKeepAliveFeature is initialized before invoking the application, if the long polling transport is in use. - Debug.Assert(connection.Items[ConnectionMetadataNames.Transport] != null, "Transport has not been initialized yet"); - Debug.Assert((HttpTransportType?)connection.Items[ConnectionMetadataNames.Transport] != HttpTransportType.LongPolling || + Debug.Assert(connection.TransportType != HttpTransportType.None, "Transport has not been initialized yet"); + Debug.Assert(connection.TransportType != HttpTransportType.LongPolling || connection.Features.Get() != null, "Long-polling transport is in use but IConnectionInherentKeepAliveFeature as not configured"); // Jump onto the thread pool thread so blocking user code doesn't block the setup of the @@ -440,8 +439,7 @@ namespace Microsoft.AspNetCore.Http.Connections context.Response.ContentType = "text/plain"; - var transport = (HttpTransportType?)connection.Items[ConnectionMetadataNames.Transport]; - if (transport == HttpTransportType.WebSockets) + if (connection.TransportType == HttpTransportType.WebSockets) { Log.PostNotAllowedForWebSockets(_logger); context.Response.StatusCode = StatusCodes.Status405MethodNotAllowed; @@ -457,6 +455,16 @@ namespace Microsoft.AspNetCore.Http.Connections try { + if (connection.Status == HttpConnectionContext.ConnectionStatus.Disposed) + { + Log.ConnectionDisposed(_logger, connection.ConnectionId); + + // The connection was disposed + context.Response.StatusCode = StatusCodes.Status404NotFound; + context.Response.ContentType = "text/plain"; + return; + } + await context.Request.Body.CopyToAsync(pipeWriterStream); } finally @@ -481,17 +489,16 @@ namespace Microsoft.AspNetCore.Http.Connections // Set the IHttpConnectionFeature now that we can access it. connection.Features.Set(context.Features.Get()); - var transport = (HttpTransportType?)connection.Items[ConnectionMetadataNames.Transport]; - - if (transport == null) + if (connection.TransportType == HttpTransportType.None) { + connection.TransportType = transportType; connection.Items[ConnectionMetadataNames.Transport] = transportType; } - else if (transport != transportType) + else if (connection.TransportType != transportType) { context.Response.ContentType = "text/plain"; context.Response.StatusCode = StatusCodes.Status400BadRequest; - Log.CannotChangeTransport(_logger, transport.Value, transportType); + Log.CannotChangeTransport(_logger, connection.TransportType, transportType); await context.Response.WriteAsync("Cannot change transports mid-connection"); return false; } diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs index bba5e582c8..2b42146ac4 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionManager.cs @@ -29,12 +29,14 @@ namespace Microsoft.AspNetCore.Http.Connections private readonly ConcurrentDictionary _connections = new ConcurrentDictionary(); private Timer _timer; private readonly ILogger _logger; + private readonly ILogger _connectionLogger; private object _executionLock = new object(); private bool _disposed; - public HttpConnectionManager(ILogger logger, IApplicationLifetime appLifetime) + public HttpConnectionManager(ILoggerFactory loggerFactory, IApplicationLifetime appLifetime) { - _logger = logger; + _logger = loggerFactory.CreateLogger(); + _connectionLogger = loggerFactory.CreateLogger(); appLifetime.ApplicationStarted.Register(() => Start()); appLifetime.ApplicationStopping.Register(() => CloseConnections()); } @@ -82,7 +84,7 @@ namespace Microsoft.AspNetCore.Http.Connections Log.CreatedNewConnection(_logger, id); var connectionTimer = HttpConnectionsEventSource.Log.ConnectionStart(id); - var connection = new HttpConnectionContext(id); + var connection = new HttpConnectionContext(id, _connectionLogger); var pair = DuplexPipe.CreateConnectionPair(transportPipeOptions, appPipeOptions); connection.Transport = pair.Application; connection.Application = pair.Transport; @@ -135,7 +137,7 @@ namespace Microsoft.AspNetCore.Http.Connections } // Pause the timer while we're running - _timer.Change(Timeout.Infinite, Timeout.Infinite); + _timer?.Change(Timeout.Infinite, Timeout.Infinite); // Time the scan so we know if it gets slower than 1sec var timer = ValueStopwatch.StartNew(); @@ -169,7 +171,11 @@ namespace Microsoft.AspNetCore.Http.Connections { Log.ConnectionTimedOut(_logger, connection.ConnectionId); HttpConnectionsEventSource.Log.ConnectionTimedOut(connection.ConnectionId); - var ignore = DisposeAndRemoveAsync(connection); + + // This is most likely a long polling connection. The transport here ends because + // a poll completed and has been inactive for > 5 seconds so we wait for the + // application to finish gracefully + _ = DisposeAndRemoveAsync(connection, closeGracefully: true); } else { @@ -184,7 +190,7 @@ namespace Microsoft.AspNetCore.Http.Connections Log.ScannedConnections(_logger, elapsed); // Resume once we finished processing all connections - _timer.Change(_heartbeatTickRate, _heartbeatTickRate); + _timer?.Change(_heartbeatTickRate, _heartbeatTickRate); } finally { @@ -209,20 +215,23 @@ namespace Microsoft.AspNetCore.Http.Connections var tasks = new List(); + // REVIEW: In the future we can consider a hybrid where we first try to wait for shutdown + // for a certain time frame then after some grace period we shutdown more aggressively foreach (var c in _connections) { - tasks.Add(DisposeAndRemoveAsync(c.Value.Connection)); + // We're shutting down so don't wait for closing the application + tasks.Add(DisposeAndRemoveAsync(c.Value.Connection, closeGracefully: false)); } Task.WaitAll(tasks.ToArray(), TimeSpan.FromSeconds(5)); } } - public async Task DisposeAndRemoveAsync(HttpConnectionContext connection) + public async Task DisposeAndRemoveAsync(HttpConnectionContext connection, bool closeGracefully) { try { - await connection.DisposeAsync(); + await connection.DisposeAsync(closeGracefully); } catch (IOException ex) { diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index f7ab74e33d..4dce9362a8 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -106,6 +106,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var pipeOptions = new PipeOptions(pauseWriterThreshold: 8, resumeWriterThreshold: 4); var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = transportType; connection.Items[ConnectionMetadataNames.Transport] = transportType; using (var requestBody = new MemoryStream()) @@ -263,6 +264,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var manager = CreateConnectionManager(loggerFactory); var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.WebSockets; connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.WebSockets; using (var strm = new MemoryStream()) @@ -292,6 +294,169 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Fact] + public async Task PostReturns404IfConnectionDisposed() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; + await connection.DisposeAsync(closeGracefully: false); + + using (var strm = new MemoryStream()) + { + var context = new DefaultHttpContext(); + context.Response.Body = strm; + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + + Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode); + } + } + } + + [Theory] + [InlineData(HttpTransportType.ServerSentEvents)] + [InlineData(HttpTransportType.WebSockets)] + public async Task TransportEndingGracefullyWaitsOnApplication(HttpTransportType transportType) + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = transportType; + connection.Items[ConnectionMetadataNames.Transport] = transportType; + + using (var strm = new MemoryStream()) + { + var context = new DefaultHttpContext(); + SetTransport(context, transportType); + var cts = new CancellationTokenSource(); + context.Response.Body = strm; + context.RequestAborted = cts.Token; + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "GET"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.Use(next => + { + return async connectionContext => + { + // Ensure both sides of the pipe are ok + var result = await connectionContext.Transport.Input.ReadAsync(); + Assert.True(result.IsCompleted); + await connectionContext.Transport.Output.WriteAsync(result.Buffer.First); + }; + }); + + var app = builder.Build(); + var task = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + + // Pretend the transport closed because the client disconnected + if (context.WebSockets.IsWebSocketRequest) + { + var ws = (TestWebSocketConnectionFeature)context.Features.Get(); + await ws.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", default); + } + else + { + cts.Cancel(); + } + + await task.OrTimeout(); + + await connection.ApplicationTask.OrTimeout(); + } + } + } + + [Fact] + public async Task TransportEndingGracefullyWaitsOnApplicationLongPolling() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; + + using (var strm = new MemoryStream()) + { + var context = new DefaultHttpContext(); + SetTransport(context, HttpTransportType.LongPolling); + var cts = new CancellationTokenSource(); + context.Response.Body = strm; + context.RequestAborted = cts.Token; + + var services = new ServiceCollection(); + services.AddSingleton(); + services.AddOptions(); + context.Request.Path = "/foo"; + context.Request.Method = "GET"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.Use(next => + { + return async connectionContext => + { + // Ensure both sides of the pipe are ok + var result = await connectionContext.Transport.Input.ReadAsync(); + Assert.True(result.IsCompleted); + await connectionContext.Transport.Output.WriteAsync(result.Buffer.First); + }; + }); + + var app = builder.Build(); + var task = dispatcher.ExecuteAsync(context, new HttpConnectionOptions(), app); + + // Pretend the transport closed because the client disconnected + cts.Cancel(); + + await task.OrTimeout(); + + // We've been gone longer than the expiration time + connection.LastSeenUtc = DateTime.UtcNow.Subtract(TimeSpan.FromSeconds(10)); + + // The application is still running here because the poll is only killed + // by the heartbeat so we pretend to do a scan and this should force the application task to complete + manager.Scan(); + + // The application task should complete gracefully + await connection.ApplicationTask.OrTimeout(); + } + } + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] @@ -302,6 +467,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var manager = CreateConnectionManager(loggerFactory); var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; connection.Items[ConnectionMetadataNames.Transport] = transportType; using (var requestBody = new MemoryStream()) @@ -348,6 +514,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var manager = CreateConnectionManager(loggerFactory); var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; connection.Items[ConnectionMetadataNames.Transport] = transportType; // Allow a maximum of one caller to use code at one time @@ -433,6 +600,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var manager = CreateConnectionManager(loggerFactory); var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; using (var requestBody = new MemoryStream()) using (var responseBody = new MemoryStream()) @@ -631,6 +800,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -658,6 +829,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = MakeRequest("/foo", connection); @@ -684,6 +857,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -710,6 +885,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -735,6 +912,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.WebSockets; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.WebSockets; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -764,6 +943,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; + connection.Items[ConnectionMetadataNames.Transport] = transportType; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -806,6 +987,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -843,6 +1026,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; + connection.Items[ConnectionMetadataNames.Transport] = transportType; connection.Status = HttpConnectionContext.ConnectionStatus.Disposed; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -870,6 +1055,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -904,6 +1091,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.ServerSentEvents; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.ServerSentEvents; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -938,6 +1127,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -971,6 +1162,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -1012,6 +1205,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; + connection.Items[ConnectionMetadataNames.Transport] = transportType; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -1043,6 +1238,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -1088,6 +1285,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -1135,6 +1334,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); @@ -1191,6 +1392,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); @@ -1272,6 +1475,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); context.Features.Set(new ResponseFeature()); @@ -1329,6 +1534,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); var context = new DefaultHttpContext(); var services = new ServiceCollection(); @@ -1382,6 +1589,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + connection.Items[ConnectionMetadataNames.Transport] = HttpTransportType.LongPolling; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); @@ -1477,6 +1686,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { var manager = CreateConnectionManager(loggerFactory); var connection = manager.CreateConnection(); + connection.TransportType = transportType; + connection.Items[ConnectionMetadataNames.Transport] = transportType; var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); using (var strm = new MemoryStream()) { @@ -1547,7 +1758,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests private static HttpConnectionManager CreateConnectionManager(ILoggerFactory loggerFactory) { - return new HttpConnectionManager(new Logger(loggerFactory ?? new LoggerFactory()), new EmptyApplicationLifetime()); + return new HttpConnectionManager(loggerFactory ?? new LoggerFactory(), new EmptyApplicationLifetime()); } private string GetContentAsString(Stream body) diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs index 4ecfec5e02..809be8f1c1 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionManagerTests.cs @@ -28,6 +28,83 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.NotNull(connection.Application); } + [Theory] + [InlineData(ConnectionStates.ClosedUngracefully | ConnectionStates.ApplicationNotFaulted | ConnectionStates.TransportNotFaulted)] + [InlineData(ConnectionStates.ClosedUngracefully | ConnectionStates.ApplicationNotFaulted | ConnectionStates.TransportFaulted)] + [InlineData(ConnectionStates.ClosedUngracefully | ConnectionStates.ApplicationFaulted | ConnectionStates.TransportFaulted)] + [InlineData(ConnectionStates.ClosedUngracefully | ConnectionStates.ApplicationFaulted | ConnectionStates.TransportNotFaulted)] + + [InlineData(ConnectionStates.CloseGracefully | ConnectionStates.ApplicationNotFaulted | ConnectionStates.TransportNotFaulted)] + [InlineData(ConnectionStates.CloseGracefully | ConnectionStates.ApplicationNotFaulted | ConnectionStates.TransportFaulted)] + [InlineData(ConnectionStates.CloseGracefully | ConnectionStates.ApplicationFaulted | ConnectionStates.TransportFaulted)] + [InlineData(ConnectionStates.CloseGracefully | ConnectionStates.ApplicationFaulted | ConnectionStates.TransportNotFaulted)] + public async Task DisposingConnectionsClosesBothSidesOfThePipe(ConnectionStates states) + { + var closeGracefully = (states & ConnectionStates.CloseGracefully) != 0; + var applicationFaulted = (states & ConnectionStates.ApplicationFaulted) != 0; + var transportFaulted = (states & ConnectionStates.TransportFaulted) != 0; + + var connectionManager = CreateConnectionManager(); + var connection = connectionManager.CreateConnection(); + + if (applicationFaulted) + { + // If the application is faulted then we want to make sure the transport task only completes after + // the application completes + connection.ApplicationTask = Task.FromException(new Exception("Application failed")); + connection.TransportTask = Task.Run(async () => + { + // Wait for the application to end + var result = await connection.Application.Input.ReadAsync(); + connection.Application.Input.AdvanceTo(result.Buffer.End); + + if (transportFaulted) + { + throw new Exception("Transport failed"); + } + }); + + } + else if (transportFaulted) + { + // If the transport is faulted then we want to make sure the transport task only completes after + // the application completes + connection.TransportTask = Task.FromException(new Exception("Application failed")); + connection.ApplicationTask = Task.Run(async () => + { + // Wait for the application to end + var result = await connection.Transport.Input.ReadAsync(); + connection.Transport.Input.AdvanceTo(result.Buffer.End); + }); + } + else + { + connection.ApplicationTask = Task.CompletedTask; + connection.TransportTask = Task.CompletedTask; + } + + var applicationInputTcs = new TaskCompletionSource(); + var applicationOutputTcs = new TaskCompletionSource(); + var transportInputTcs = new TaskCompletionSource(); + var transportOutputTcs = new TaskCompletionSource(); + + connection.Transport.Input.OnWriterCompleted((_, __) => transportInputTcs.TrySetResult(null), null); + connection.Transport.Output.OnReaderCompleted((_, __) => transportOutputTcs.TrySetResult(null), null); + connection.Application.Input.OnWriterCompleted((_, __) => applicationInputTcs.TrySetResult(null), null); + connection.Application.Output.OnReaderCompleted((_, __) => applicationOutputTcs.TrySetResult(null), null); + + try + { + await connection.DisposeAsync(closeGracefully); + } + catch + { + // Ignore the exception that bubbles out of the failing task + } + + await Task.WhenAll(applicationInputTcs.Task, applicationOutputTcs.Task, transportInputTcs.Task, transportOutputTcs.Task).OrTimeout(); + } + [Fact] public void NewConnectionsCanBeRetrieved() { @@ -242,7 +319,18 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests private static HttpConnectionManager CreateConnectionManager(IApplicationLifetime lifetime = null) { lifetime = lifetime ?? new EmptyApplicationLifetime(); - return new HttpConnectionManager(new Logger(new LoggerFactory()), lifetime); + return new HttpConnectionManager(new LoggerFactory(), lifetime); + } + + [Flags] + public enum ConnectionStates + { + ClosedUngracefully = 1, + ApplicationNotFaulted = 2, + TransportNotFaulted = 4, + ApplicationFaulted = 8, + TransportFaulted = 16, + CloseGracefully = 32 } } }