diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs index 85a4da5fbe..374abf7adf 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs @@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http MemoryPool memoryPool) { // Allow appending more data to the PipeWriter when a flush is pending. - _pipeWriter = new ConcurrentPipeWriter(pipeWriter, memoryPool); + _pipeWriter = new ConcurrentPipeWriter(pipeWriter, memoryPool, _contextLock); _connectionId = connectionId; _connectionContext = connectionContext; _log = log; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs index 1f2b057392..294c524272 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2FrameWriter.cs @@ -56,7 +56,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 IKestrelTrace log) { // Allow appending more data to the PipeWriter when a flush is pending. - _outputWriter = new ConcurrentPipeWriter(outputPipeWriter, memoryPool); + _outputWriter = new ConcurrentPipeWriter(outputPipeWriter, memoryPool, _writeLock); _connectionContext = connectionContext; _http2Connection = http2Connection; _connectionOutputFlowControl = connectionOutputFlowControl; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs index 46b1143929..6726f1fed6 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 var pipe = CreateDataPipe(pool); - _pipeWriter = new ConcurrentPipeWriter(pipe.Writer, pool); + _pipeWriter = new ConcurrentPipeWriter(pipe.Writer, pool, _dataWriterLock); _pipeReader = pipe.Reader; // No need to pass in timeoutControl here, since no minDataRates are passed to the TimingPipeFlusher. diff --git a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/PipeWriterHelpers/ConcurrentPipeWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/PipeWriterHelpers/ConcurrentPipeWriter.cs index 5a35bafb85..c7e894253b 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Infrastructure/PipeWriterHelpers/ConcurrentPipeWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Infrastructure/PipeWriterHelpers/ConcurrentPipeWriter.cs @@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.PipeW private static readonly Exception _successfullyCompletedSentinel = new Exception(); - private readonly object _sync = new object(); + private readonly object _sync; private readonly PipeWriter _innerPipeWriter; private readonly MemoryPool _pool; private readonly BufferSegmentStack _bufferSegmentPool = new BufferSegmentStack(InitialSegmentPoolSize); @@ -51,97 +51,86 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.PipeW private bool _aborted; private Exception _completeException; - public ConcurrentPipeWriter(PipeWriter innerPipeWriter, MemoryPool pool) + public ConcurrentPipeWriter(PipeWriter innerPipeWriter, MemoryPool pool, object sync) { _innerPipeWriter = innerPipeWriter; _pool = pool; + _sync = sync; } public override Memory GetMemory(int sizeHint = 0) { - lock (_sync) + if (_currentFlushTcs == null && _head == null) { - if (_currentFlushTcs == null && _head == null) - { - return _innerPipeWriter.GetMemory(sizeHint); - } - - AllocateMemoryUnsynchronized(sizeHint); - return _tailMemory; + return _innerPipeWriter.GetMemory(sizeHint); } + + AllocateMemoryUnsynchronized(sizeHint); + return _tailMemory; } public override Span GetSpan(int sizeHint = 0) { - lock (_sync) + if (_currentFlushTcs == null && _head == null) { - if (_currentFlushTcs == null && _head == null) - { - return _innerPipeWriter.GetSpan(sizeHint); - } - - AllocateMemoryUnsynchronized(sizeHint); - return _tailMemory.Span; + return _innerPipeWriter.GetSpan(sizeHint); } + + AllocateMemoryUnsynchronized(sizeHint); + return _tailMemory.Span; } public override void Advance(int bytes) { - lock (_sync) + if (_currentFlushTcs == null && _head == null) { - if (_currentFlushTcs == null && _head == null) - { - _innerPipeWriter.Advance(bytes); - return; - } - - if ((uint)bytes > (uint)_tailMemory.Length) - { - ThrowArgumentOutOfRangeException(nameof(bytes)); - } - - _tailBytesBuffered += bytes; - _bytesBuffered += bytes; - _tailMemory = _tailMemory.Slice(bytes); - _bufferedWritePending = false; + _innerPipeWriter.Advance(bytes); + return; } + + if ((uint)bytes > (uint)_tailMemory.Length) + { + ThrowArgumentOutOfRangeException(nameof(bytes)); + } + + _tailBytesBuffered += bytes; + _bytesBuffered += bytes; + _tailMemory = _tailMemory.Slice(bytes); + _bufferedWritePending = false; } public override ValueTask FlushAsync(CancellationToken cancellationToken = default) { - lock (_sync) + if (_currentFlushTcs != null) + { + return new ValueTask(_currentFlushTcs.Task); + } + + if (_bytesBuffered > 0) + { + CopyAndReturnSegmentsUnsynchronized(); + } + + var flushTask = _innerPipeWriter.FlushAsync(cancellationToken); + + if (flushTask.IsCompletedSuccessfully) { if (_currentFlushTcs != null) { - return new ValueTask(_currentFlushTcs.Task); + CompleteFlushUnsynchronized(flushTask.GetAwaiter().GetResult(), null); } - if (_bytesBuffered > 0) - { - CopyAndReturnSegmentsUnsynchronized(); - } - - var flushTask = _innerPipeWriter.FlushAsync(cancellationToken); - - if (flushTask.IsCompletedSuccessfully) - { - if (_currentFlushTcs != null) - { - CompleteFlushUnsynchronized(flushTask.GetAwaiter().GetResult(), null); - } - - return flushTask; - } - - // Use a TCS instead of something resettable so it can be awaited by multiple awaiters. - _currentFlushTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); - var result = new ValueTask(_currentFlushTcs.Task); - - // FlushAsyncAwaited clears the TCS prior to completing. Make sure to construct the ValueTask - // from the TCS before calling FlushAsyncAwaited in case FlushAsyncAwaited completes inline. - _ = FlushAsyncAwaited(flushTask, cancellationToken); - return result; + return flushTask; } + + // Use a TCS instead of something resettable so it can be awaited by multiple awaiters. + _currentFlushTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var result = new ValueTask(_currentFlushTcs.Task); + + // FlushAsyncAwaited clears the TCS prior to completing. Make sure to construct the ValueTask + // from the TCS before calling FlushAsyncAwaited in case FlushAsyncAwaited completes inline. + _ = FlushAsyncAwaited(flushTask, cancellationToken); + return result; } private async Task FlushAsyncAwaited(ValueTask flushTask, CancellationToken cancellationToken) @@ -199,40 +188,34 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.PipeW public override void Complete(Exception exception = null) { - lock (_sync) + // We store the complete exception or s sentinel exception instance in a field if a flush was ongoing. + // We call the inner Complete() method after the flush loop ended. + + // To simply ensure everything gets returned after the PipeWriter is left in some unknown state (say GetMemory() was + // called but not Advance(), or there's a flush pending), but you don't want to complete the inner pipe, just call Abort(). + _completeException = exception ?? _successfullyCompletedSentinel; + + if (_currentFlushTcs == null) { - // We store the complete exception or s sentinel exception instance in a field if a flush was ongoing. - // We call the inner Complete() method after the flush loop ended. - - // To simply ensure everything gets returned after the PipeWriter is left in some unknown state (say GetMemory() was - // called but not Advance(), or there's a flush pending), but you don't want to complete the inner pipe, just call Abort(). - _completeException = exception ?? _successfullyCompletedSentinel; - - if (_currentFlushTcs == null) + if (_bytesBuffered > 0) { - if (_bytesBuffered > 0) - { - CopyAndReturnSegmentsUnsynchronized(); - } - - CleanupSegmentsUnsynchronized(); - - _innerPipeWriter.Complete(exception); + CopyAndReturnSegmentsUnsynchronized(); } + + CleanupSegmentsUnsynchronized(); + + _innerPipeWriter.Complete(exception); } } public void Abort() { - lock (_sync) - { - _aborted = true; + _aborted = true; - // If we're flushing, the cleanup will happen after the flush. - if (_currentFlushTcs == null) - { - CleanupSegmentsUnsynchronized(); - } + // If we're flushing, the cleanup will happen after the flush. + if (_currentFlushTcs == null) + { + CleanupSegmentsUnsynchronized(); } } diff --git a/src/Servers/Kestrel/Core/test/ConcurrentPipeWriterTests.cs b/src/Servers/Kestrel/Core/test/ConcurrentPipeWriterTests.cs index eb3ab8eead..bb3402bd73 100644 --- a/src/Servers/Kestrel/Core/test/ConcurrentPipeWriterTests.cs +++ b/src/Servers/Kestrel/Core/test/ConcurrentPipeWriterTests.cs @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray); - var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool); + var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, new object()); var memory = concurrentPipeWriter.GetMemory(); Assert.Equal(1, mockPipeWriter.GetMemoryCallCount); @@ -72,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray); - var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool); + var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, new object()); var memory = concurrentPipeWriter.GetMemory(); Assert.Equal(1, mockPipeWriter.GetMemoryCallCount); @@ -152,7 +152,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray); - var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool); + var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, new object()); var memory = concurrentPipeWriter.GetMemory(); Assert.Equal(1, mockPipeWriter.GetMemoryCallCount); @@ -218,7 +218,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray); - var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool); + var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, new object()); var memory = concurrentPipeWriter.GetMemory(); Assert.Equal(1, mockPipeWriter.GetMemoryCallCount); @@ -273,7 +273,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests }; var mockPipeWriter = new MockPipeWriter(pipeWriterFlushTcsArray); - var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool); + var concurrentPipeWriter = new ConcurrentPipeWriter(mockPipeWriter, diagnosticPool, new object()); var memory = concurrentPipeWriter.GetMemory(); Assert.Equal(1, mockPipeWriter.GetMemoryCallCount);