diff --git a/src/Common/PipeWriterStream.cs b/src/Common/PipeWriterStream.cs index 1545acd9b1..8c39ca46fb 100644 --- a/src/Common/PipeWriterStream.cs +++ b/src/Common/PipeWriterStream.cs @@ -83,7 +83,16 @@ namespace System.IO.Pipelines return default; - async ValueTask WriteSlowAsync(ValueTask flushTask) => await flushTask; + async ValueTask WriteSlowAsync(ValueTask flushTask) + { + var flushResult = await flushTask; + + // Cancellation can be triggered by PipeWriter.CancelPendingFlush + if (flushResult.IsCanceled) + { + throw new OperationCanceledException(); + } + } } public void Reset() diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs index d428722223..0ea9f1c394 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionContext.cs @@ -178,7 +178,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal public async Task DisposeAsync(bool closeGracefully = false) { - var disposeTask = Task.CompletedTask; + Task disposeTask; await StateLock.WaitAsync(); try @@ -267,6 +267,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { Log.ShuttingDownTransportAndApplication(_logger, TransportType); + // Cancel any pending flushes from back pressure + Application?.Output.CancelPendingFlush(); + // Shutdown both sides and wait for nothing Transport?.Output.Complete(applicationTask.Exception?.InnerException); Application?.Output.Complete(transportTask.Exception?.InnerException); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs index 6b16d8533a..cf983c804b 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.Log.cs @@ -11,10 +11,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private static class Log { private static readonly Action _connectionDisposed = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "ConnectionDisposed"), "Connection Id {TransportConnectionId} was disposed."); + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "ConnectionDisposed"), "Connection {TransportConnectionId} was disposed."); private static readonly Action _connectionAlreadyActive = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "ConnectionAlreadyActive"), "Connection Id {TransportConnectionId} is already active via {RequestId}."); + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "ConnectionAlreadyActive"), "Connection {TransportConnectionId} is already active via {RequestId}."); private static readonly Action _pollCanceled = LoggerMessage.Define(LogLevel.Trace, new EventId(3, "PollCanceled"), "Previous poll canceled for {TransportConnectionId} on {RequestId}."); @@ -46,6 +46,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal private static readonly Action _terminatingConnection = LoggerMessage.Define(LogLevel.Trace, new EventId(12, "TerminatingConection"), "Terminating Long Polling connection due to a DELETE request."); + private static readonly Action _connectionDisposedWhileWriteInProgress = + LoggerMessage.Define(LogLevel.Debug, new EventId(13, "ConnectionDisposedWhileWriteInProgress"), "Connection {TransportConnectionId} was disposed while a write was in progress."); + public static void ConnectionDisposed(ILogger logger, string connectionId) { _connectionDisposed(logger, connectionId, null); @@ -105,6 +108,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { _terminatingConnection(logger, null); } + + public static void ConnectionDisposedWhileWriteInProgress(ILogger logger, string connectionId, Exception ex) + { + _connectionDisposedWhileWriteInProgress(logger, connectionId, ex); + } } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs index 55bae0f94a..4d4bd93c1d 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs @@ -479,12 +479,40 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal return; } - await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize); + try + { + try + { + await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize); + } + catch (InvalidOperationException ex) + { + // PipeWriter will throw an error if it is written to while dispose is in progress and the writer has been completed + // Dispose isn't taking WriteLock because it could be held because of backpressure, and calling CancelPendingFlush + // then taking the lock introduces a race condition that could lead to a deadlock + Log.ConnectionDisposedWhileWriteInProgress(_logger, connection.ConnectionId, ex); - Log.ReceivedBytes(_logger, connection.ApplicationStream.Length); + context.Response.StatusCode = StatusCodes.Status404NotFound; + context.Response.ContentType = "text/plain"; + return; + } + catch (OperationCanceledException) + { + // CancelPendingFlush has canceled pending writes caused by backpresure + Log.ConnectionDisposed(_logger, connection.ConnectionId); - // Clear the amount of read bytes so logging is accurate - connection.ApplicationStream.Reset(); + context.Response.StatusCode = StatusCodes.Status404NotFound; + context.Response.ContentType = "text/plain"; + return; + } + + Log.ReceivedBytes(_logger, connection.ApplicationStream.Length); + } + finally + { + // Clear the amount of read bytes so logging is accurate + connection.ApplicationStream.Reset(); + } } finally { diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index a1681f1e9c..39d19723bd 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -1797,6 +1797,134 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + private class ControllableMemoryStream : MemoryStream + { + private readonly SyncPoint _syncPoint; + + public ControllableMemoryStream(SyncPoint syncPoint) + { + _syncPoint = syncPoint; + } + + public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + await _syncPoint.WaitToContinue(); + + await base.CopyToAsync(destination, bufferSize, cancellationToken); + } + } + + [Fact] + public async Task WriteThatIsDisposedBeforeCompleteReturns404() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 13, resumeWriterThreshold: 10); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + SyncPoint streamCopySyncPoint = new SyncPoint(); + + using (var responseBody = new MemoryStream()) + using (var requestBody = new ControllableMemoryStream(streamCopySyncPoint)) + { + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + context.Response.Body = responseBody; + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + var buffer = Encoding.UTF8.GetBytes("Hello, world"); + requestBody.Write(buffer, 0, buffer.Length); + requestBody.Seek(0, SeekOrigin.Begin); + + // Write + var sendTask = dispatcher.ExecuteAsync(context, options, app); + + // Wait on the sync point inside ApplicationStream.CopyToAsync + await streamCopySyncPoint.WaitForSyncPoint(); + + // Start disposing. This will close the output and cause the write to error + var disposeTask = connection.DisposeAsync().OrTimeout(); + + // Continue writing on a completed writer + streamCopySyncPoint.Continue(); + + await sendTask.OrTimeout(); + await disposeTask.OrTimeout(); + + // Ensure response status is correctly set + Assert.Equal(404, context.Response.StatusCode); + } + } + } + + [Fact] + public async Task CanDisposeWhileWriteLockIsBlockedOnBackpressureAndResponseReturns404() + { + using (StartVerifiableLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var pipeOptions = new PipeOptions(pauseWriterThreshold: 13, resumeWriterThreshold: 10); + var connection = manager.CreateConnection(pipeOptions, pipeOptions); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionDispatcherOptions(); + + using (var responseBody = new MemoryStream()) + using (var requestBody = new MemoryStream()) + { + var context = new DefaultHttpContext(); + context.Request.Body = requestBody; + context.Response.Body = responseBody; + context.Request.Path = "/foo"; + context.Request.Method = "POST"; + var values = new Dictionary(); + values["id"] = connection.ConnectionId; + var qs = new QueryCollection(values); + context.Request.Query = qs; + var buffer = Encoding.UTF8.GetBytes("Hello, world"); + requestBody.Write(buffer, 0, buffer.Length); + requestBody.Seek(0, SeekOrigin.Begin); + + // Write some data to the pipe to fill it up and make the next write wait + await connection.ApplicationStream.WriteAsync(buffer, 0, buffer.Length).OrTimeout(); + + // Write. This will take the WriteLock and block because of back pressure + var sendTask = dispatcher.ExecuteAsync(context, options, app); + + // Start disposing. This will take the StateLock and attempt to take the WriteLock + // Dispose will cancel pending flush and should unblock WriteLock + await connection.DisposeAsync().OrTimeout(); + + // Sends were unblocked + await sendTask.OrTimeout(); + + Assert.Equal(404, context.Response.StatusCode); + } + } + } + [Fact] public async Task LongPollingCanPollIfWritePipeHasBackpressure() {