diff --git a/.gitignore b/.gitignore index d5717b3f3f..3d7e16e84a 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ project.lock.json /.vs/ .vscode/ global.json +BenchmarkDotNet.Artifacts/ diff --git a/benchmarks/Microsoft.AspNetCore.Http.Performance/StreamPipeWriterBenchmark.cs b/benchmarks/Microsoft.AspNetCore.Http.Performance/StreamPipeWriterBenchmark.cs new file mode 100644 index 0000000000..705cb0d8af --- /dev/null +++ b/benchmarks/Microsoft.AspNetCore.Http.Performance/StreamPipeWriterBenchmark.cs @@ -0,0 +1,89 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using BenchmarkDotNet.Attributes; + +namespace Microsoft.AspNetCore.Http +{ + public class StreamPipeWriterBenchmark + { + private Stream _memoryStream; + private StreamPipeWriter _pipeWriter; + private static byte[] _helloWorldBytes = Encoding.ASCII.GetBytes("Hello World"); + private static byte[] _largeWrite = Encoding.ASCII.GetBytes(new string('a', 50000)); + + [IterationSetup] + public void Setup() + { + _memoryStream = new NoopStream(); + _pipeWriter = new StreamPipeWriter(_memoryStream); + } + + [Benchmark] + public async Task WriteHelloWorld() + { + await _pipeWriter.WriteAsync(_helloWorldBytes); + } + + [Benchmark] + public async Task WriteHelloWorldLargeWrite() + { + await _pipeWriter.WriteAsync(_largeWrite); + } + + public class NoopStream : Stream + { + public override bool CanRead => false; + + public override bool CanSeek => throw new System.NotImplementedException(); + + public override bool CanWrite => true; + + public override long Length => throw new System.NotImplementedException(); + + public override long Position { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new System.NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new System.NotImplementedException(); + } + + public override void SetLength(long value) + { + } + + public override void Write(byte[] buffer, int offset, int count) + { + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + + public override ValueTask WriteAsync(ReadOnlyMemory buffer, CancellationToken cancellationToken = default(CancellationToken)) + { + return default(ValueTask); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return Task.CompletedTask; + } + } + } +} diff --git a/build/dependencies.props b/build/dependencies.props index 2aa50e3e13..c3991cb407 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -24,6 +24,7 @@ 4.9.0 2.0.3 4.6.0-preview1-26907-04 + 4.6.0-preview1-26907-04 4.6.0-preview1-26907-04 0.10.0 2.3.1 diff --git a/src/Microsoft.AspNetCore.Http/Microsoft.AspNetCore.Http.csproj b/src/Microsoft.AspNetCore.Http/Microsoft.AspNetCore.Http.csproj index 162315a7a6..94080281b3 100644 --- a/src/Microsoft.AspNetCore.Http/Microsoft.AspNetCore.Http.csproj +++ b/src/Microsoft.AspNetCore.Http/Microsoft.AspNetCore.Http.csproj @@ -2,7 +2,7 @@ ASP.NET Core default HTTP feature implementations. - netstandard2.0 + netstandard2.0;netcoreapp2.2 $(NoWarn);CS1591 true true @@ -19,6 +19,7 @@ + diff --git a/src/Microsoft.AspNetCore.Http/StreamPipeWriter.cs b/src/Microsoft.AspNetCore.Http/StreamPipeWriter.cs new file mode 100644 index 0000000000..f232aa97cf --- /dev/null +++ b/src/Microsoft.AspNetCore.Http/StreamPipeWriter.cs @@ -0,0 +1,320 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Runtime.ExceptionServices; +using System.Runtime.InteropServices; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Http +{ + /// + /// Implements PipeWriter using a underlying stream. + /// + public class StreamPipeWriter : PipeWriter, IDisposable + { + private readonly int _minimumSegmentSize; + private readonly Stream _writingStream; + private int _bytesWritten; + + private List _completedSegments; + private Memory _currentSegment; + private IMemoryOwner _currentSegmentOwner; + private MemoryPool _pool; + private int _position; + + private CancellationTokenSource _internalTokenSource; + private bool _isCompleted; + private ExceptionDispatchInfo _exceptionInfo; + private object _lockObject = new object(); + + private CancellationTokenSource InternalTokenSource + { + get + { + lock (_lockObject) + { + if (_internalTokenSource == null) + { + _internalTokenSource = new CancellationTokenSource(); + } + return _internalTokenSource; + } + } + } + + /// + /// Creates a new StreamPipeWrapper + /// + /// The stream to write to + public StreamPipeWriter(Stream writingStream) : this(writingStream, 4096) + { + } + + public StreamPipeWriter(Stream writingStream, int minimumSegmentSize, MemoryPool pool = null) + { + _minimumSegmentSize = minimumSegmentSize; + _writingStream = writingStream; + _pool = pool ?? MemoryPool.Shared; + } + + /// + public override void Advance(int count) + { + if (_currentSegment.IsEmpty) // TODO confirm this + { + throw new InvalidOperationException("No writing operation. Make sure GetMemory() was called."); + } + + if (count >= 0) + { + if (_currentSegment.Length < _position + count) + { + throw new InvalidOperationException("Can't advance past buffer size."); + } + _bytesWritten += count; + _position += count; + } + } + + /// + public override Memory GetMemory(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment; + } + + /// + public override Span GetSpan(int sizeHint = 0) + { + EnsureCapacity(sizeHint); + + return _currentSegment.Span.Slice(_position); + } + + /// + public override void CancelPendingFlush() + { + Cancel(); + } + + /// + public override void Complete(Exception exception = null) + { + if (_isCompleted) + { + return; + } + + _isCompleted = true; + if (exception != null) + { + _exceptionInfo = ExceptionDispatchInfo.Capture(exception); + } + + _internalTokenSource?.Dispose(); + + if (_completedSegments != null) + { + foreach (var segment in _completedSegments) + { + segment.Return(); + } + } + + _currentSegmentOwner?.Dispose(); + } + + /// + public override void OnReaderCompleted(Action callback, object state) + { + throw new NotSupportedException("OnReaderCompleted isn't supported in StreamPipeWrapper."); + } + + /// + public override ValueTask FlushAsync(CancellationToken cancellationToken = default) + { + if (_bytesWritten == 0) + { + return new ValueTask(new FlushResult(isCanceled: false, IsCompletedOrThrow())); + } + + return FlushAsyncInternal(cancellationToken); + } + + private void Cancel() + { + InternalTokenSource.Cancel(); + } + + private async ValueTask FlushAsyncInternal(CancellationToken cancellationToken = default) + { + // Write all completed segments and whatever remains in the current segment + // and flush the result. + CancellationTokenRegistration reg = new CancellationTokenRegistration(); + if (cancellationToken.CanBeCanceled) + { + reg = cancellationToken.Register(state => ((StreamPipeWriter)state).Cancel(), this); + } + using (reg) + { + var localToken = InternalTokenSource.Token; + try + { + if (_completedSegments != null && _completedSegments.Count > 0) + { + var count = _completedSegments.Count; + for (var i = 0; i < count; i++) + { + var segment = _completedSegments[0]; +#if NETCOREAPP2_2 + await _writingStream.WriteAsync(segment.Buffer.Slice(0, segment.Length), localToken); +#elif NETSTANDARD2_0 + MemoryMarshal.TryGetArray(segment.Buffer, out var arraySegment); + await _writingStream.WriteAsync(arraySegment.Array, 0, segment.Length, localToken); +#else +#error Target frameworks need to be updated. +#endif + _bytesWritten -= segment.Length; + segment.Return(); + _completedSegments.RemoveAt(0); + } + } + + if (!_currentSegment.IsEmpty) + { +#if NETCOREAPP2_2 + await _writingStream.WriteAsync(_currentSegment.Slice(0, _position), localToken); +#elif NETSTANDARD2_0 + MemoryMarshal.TryGetArray(_currentSegment, out var arraySegment); + await _writingStream.WriteAsync(arraySegment.Array, 0, _position, localToken); +#else +#error Target frameworks need to be updated. +#endif + _bytesWritten -= _position; + _position = 0; + } + + await _writingStream.FlushAsync(localToken); + + return new FlushResult(isCanceled: false, IsCompletedOrThrow()); + } + catch (OperationCanceledException) + { + // Remove the cancellation token such that the next time Flush is called + // A new CTS is created. + lock (_lockObject) + { + _internalTokenSource = null; + } + + if (cancellationToken.IsCancellationRequested) + { + throw; + } + + // Catch any cancellation and translate it into setting isCanceled = true + return new FlushResult(isCanceled: true, IsCompletedOrThrow()); + } + } + } + + private void EnsureCapacity(int sizeHint) + { + // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null. + // If _currentSegment is null, it returns 0. + var remainingSize = _currentSegment.Length - _position; + + // If the sizeHint is 0, any capacity will do + // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment. + if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint)) + { + // We have capacity in the current segment + return; + } + + AddSegment(sizeHint); + } + + private void AddSegment(int sizeHint = 0) + { + if (_currentSegment.Length != 0) + { + // We're adding a segment to the list + if (_completedSegments == null) + { + _completedSegments = new List(); + } + + // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when + // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to + // ignore any empty space in it. + _completedSegments.Add(new CompletedBuffer(_currentSegmentOwner, _position)); + } + + // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment. + _currentSegmentOwner = _pool.Rent(Math.Max(_minimumSegmentSize, sizeHint)); + _currentSegment = _currentSegmentOwner.Memory; + _position = 0; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private bool IsCompletedOrThrow() + { + if (!_isCompleted) + { + return false; + } + + if (_exceptionInfo != null) + { + ThrowLatchedException(); + } + + return true; + } + + [MethodImpl(MethodImplOptions.NoInlining)] + private void ThrowLatchedException() + { + _exceptionInfo.Throw(); + } + + public void Dispose() + { + Complete(); + } + + /// + /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. + /// + private readonly struct CompletedBuffer + { + public Memory Buffer { get; } + public int Length { get; } + + public ReadOnlySpan Span => Buffer.Span; + + private readonly IMemoryOwner _memoryOwner; + + public CompletedBuffer(IMemoryOwner buffer, int length) + { + Buffer = buffer.Memory; + Length = length; + _memoryOwner = buffer; + } + + public void Return() + { + _memoryOwner.Dispose(); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/FlushResultCancellationTests.cs b/test/Microsoft.AspNetCore.Http.Tests/FlushResultCancellationTests.cs new file mode 100644 index 0000000000..f4ab7cb96f --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/FlushResultCancellationTests.cs @@ -0,0 +1,68 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Runtime.CompilerServices; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Http.Tests +{ + public class FlushResultCancellationTests : PipeTest + { + [Fact] + public void FlushAsyncCancellationDeadlock() + { + var cts = new CancellationTokenSource(); + var cts2 = new CancellationTokenSource(); + + PipeWriter buffer = Writer.WriteEmpty(MaximumSizeHigh); + + var e = new ManualResetEventSlim(); + + ValueTaskAwaiter awaiter = buffer.FlushAsync(cts.Token).GetAwaiter(); + awaiter.OnCompleted( + () => { + // We are on cancellation thread and need to wait until another FlushAsync call + // takes pipe state lock + e.Wait(); + + // Make sure we had enough time to reach _cancellationTokenRegistration.Dispose + Thread.Sleep(100); + + // Try to take pipe state lock + buffer.FlushAsync(); + }); + + // Start a thread that would run cancellation callbacks + Task cancellationTask = Task.Run(() => cts.Cancel()); + // Start a thread that would call FlushAsync with different token + // and block on _cancellationTokenRegistration.Dispose + Task blockingTask = Task.Run( + () => { + e.Set(); + buffer.FlushAsync(cts2.Token); + }); + + bool completed = Task.WhenAll(cancellationTask, blockingTask).Wait(TimeSpan.FromSeconds(10)); + Assert.True(completed); + } + + [Fact] + public async Task FlushAsyncWithNewCancellationTokenNotAffectedByPrevious() + { + var cancellationTokenSource1 = new CancellationTokenSource(); + PipeWriter buffer = Writer.WriteEmpty(10); + await buffer.FlushAsync(cancellationTokenSource1.Token); + + cancellationTokenSource1.Cancel(); + + var cancellationTokenSource2 = new CancellationTokenSource(); + buffer = Writer.WriteEmpty(10); + + await buffer.FlushAsync(cancellationTokenSource2.Token); + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/Microsoft.AspNetCore.Http.Tests.csproj b/test/Microsoft.AspNetCore.Http.Tests/Microsoft.AspNetCore.Http.Tests.csproj index aa428320cd..a8ee8f19fc 100644 --- a/test/Microsoft.AspNetCore.Http.Tests/Microsoft.AspNetCore.Http.Tests.csproj +++ b/test/Microsoft.AspNetCore.Http.Tests/Microsoft.AspNetCore.Http.Tests.csproj @@ -2,8 +2,9 @@ $(StandardTestTfms) + true - + diff --git a/test/Microsoft.AspNetCore.Http.Tests/PipeTest.cs b/test/Microsoft.AspNetCore.Http.Tests/PipeTest.cs new file mode 100644 index 0000000000..2e94e3a267 --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/PipeTest.cs @@ -0,0 +1,43 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.IO.Pipelines; + +namespace Microsoft.AspNetCore.Http.Tests +{ + public abstract class PipeTest : IDisposable + { + protected const int MaximumSizeHigh = 65; + + public MemoryStream MemoryStream { get; set; } + + public PipeWriter Writer { get; set; } + + protected PipeTest() + { + MemoryStream = new MemoryStream(); + Writer = new StreamPipeWriter(MemoryStream, 4096, new TestMemoryPool()); + } + + public void Dispose() + { + Writer.Complete(); + } + + public byte[] Read() + { + Writer.FlushAsync().GetAwaiter().GetResult(); + return ReadWithoutFlush(); + } + + public byte[] ReadWithoutFlush() + { + MemoryStream.Position = 0; + var buffer = new byte[MemoryStream.Length]; + var result = MemoryStream.Read(buffer, 0, (int)MemoryStream.Length); + return buffer; + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/PipeWriterTests.cs b/test/Microsoft.AspNetCore.Http.Tests/PipeWriterTests.cs new file mode 100644 index 0000000000..0cc6dc012f --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/PipeWriterTests.cs @@ -0,0 +1,221 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.Collections.Generic; +using System.IO; +using System.IO.Pipelines; +using System.Linq; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Http.Tests +{ + public class PipeWriterTests : PipeTest + { + + [Theory] + [InlineData(3, -1, 0)] + [InlineData(3, 0, -1)] + [InlineData(3, 0, 4)] + [InlineData(3, 4, 0)] + [InlineData(3, -1, -1)] + [InlineData(3, 4, 4)] + public void ThrowsForInvalidParameters(int arrayLength, int offset, int length) + { + var array = new byte[arrayLength]; + for (var i = 0; i < array.Length; i++) + { + array[i] = (byte)(i + 1); + } + + Writer.Write(new Span(array, 0, 0)); + Writer.Write(new Span(array, array.Length, 0)); + + try + { + Writer.Write(new Span(array, offset, length)); + Assert.True(false); + } + catch (Exception ex) + { + Assert.True(ex is ArgumentOutOfRangeException); + } + + Writer.Write(new Span(array, 0, array.Length)); + Assert.Equal(array, Read()); + } + + [Theory] + [InlineData(0, 3)] + [InlineData(1, 2)] + [InlineData(2, 1)] + [InlineData(1, 1)] + public void CanWriteWithOffsetAndLength(int offset, int length) + { + var array = new byte[] { 1, 2, 3 }; + + Writer.Write(new Span(array, offset, length)); + + Assert.Equal(array.Skip(offset).Take(length).ToArray(), Read()); + } + + [Fact] + public void CanWriteIntoHeadlessBuffer() + { + + Writer.Write(new byte[] { 1, 2, 3 }); + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Fact] + public void CanGetNewMemoryWhenSizeTooLarge() + { + var memory = Writer.GetMemory(0); + + var memoryLarge = Writer.GetMemory(10000); + + Assert.NotEqual(memory, memoryLarge); + } + + [Fact] + public void CanGetSameMemoryWhenNoAdvance() + { + var memory = Writer.GetMemory(0); + + var secondMemory = Writer.GetMemory(0); + + Assert.Equal(memory, secondMemory); + } + + [Fact] + public void CanGetNewSpanWhenNoAdvanceWhenSizeTooLarge() + { + var span = Writer.GetSpan(0); + + var secondSpan = Writer.GetSpan(8000); + + Assert.False(span.SequenceEqual(secondSpan)); + } + + [Fact] + public void CanGetSameSpanWhenNoAdvance() + { + var span = Writer.GetSpan(0); + + var secondSpan = Writer.GetSpan(0); + + Assert.True(span.SequenceEqual(secondSpan)); + } + + [Theory] + [InlineData(16, 32, 32)] + [InlineData(16, 16, 16)] + [InlineData(64, 32, 64)] + [InlineData(40, 32, 64)] // memory sizes are powers of 2. + public void CheckMinimumSegmentSizeWithGetMemory(int minimumSegmentSize, int getMemorySize, int expectedSize) + { + var writer = new StreamPipeWriter(new MemoryStream(), minimumSegmentSize); + var memory = writer.GetMemory(getMemorySize); + + Assert.Equal(expectedSize, memory.Length); + } + + [Fact] + public void CanWriteMultipleTimes() + { + + Writer.Write(new byte[] { 1 }); + Writer.Write(new byte[] { 2 }); + Writer.Write(new byte[] { 3 }); + + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Fact] + public void CanWriteOverTheBlockLength() + { + Memory memory = Writer.GetMemory(); + + IEnumerable source = Enumerable.Range(0, memory.Length).Select(i => (byte)i); + byte[] expectedBytes = source.Concat(source).Concat(source).ToArray(); + + Writer.Write(expectedBytes); + + Assert.Equal(expectedBytes, Read()); + } + + [Fact] + public void EnsureAllocatesSpan() + { + var span = Writer.GetSpan(10); + + Assert.True(span.Length >= 10); + // 0 byte Flush would not complete the reader so we complete. + Writer.Complete(); + Assert.Equal(new byte[] { }, Read()); + } + + [Fact] + public void SlicesSpanAndAdvancesAfterWrite() + { + int initialLength = Writer.GetSpan(3).Length; + + + Writer.Write(new byte[] { 1, 2, 3 }); + Span span = Writer.GetSpan(); + + Assert.Equal(initialLength - 3, span.Length); + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Theory] + [InlineData(5)] + [InlineData(50)] + [InlineData(500)] + [InlineData(5000)] + [InlineData(50000)] + public async Task WriteLargeDataBinary(int length) + { + var data = new byte[length]; + new Random(length).NextBytes(data); + PipeWriter output = Writer; + output.Write(data); + await output.FlushAsync(); + + var input = Read(); + Assert.Equal(data, input.ToArray()); + } + + [Fact] + public async Task CanWriteNothingToBuffer() + { + Writer.GetMemory(0); + Writer.Advance(0); // doing nothing, the hard way + await Writer.FlushAsync(); + } + + [Fact] + public void EmptyWriteDoesNotThrow() + { + Writer.Write(new byte[0]); + } + + [Fact] + public void ThrowsOnAdvanceOverMemorySize() + { + Memory buffer = Writer.GetMemory(1); + var exception = Assert.Throws(() => Writer.Advance(buffer.Length + 1)); + Assert.Equal("Can't advance past buffer size.", exception.Message); + } + + [Fact] + public void ThrowsOnAdvanceWithNoMemory() + { + PipeWriter buffer = Writer; + var exception = Assert.Throws(() => buffer.Advance(1)); + Assert.Equal("No writing operation. Make sure GetMemory() was called.", exception.Message); + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/StreamPipeWriterTests.cs b/test/Microsoft.AspNetCore.Http.Tests/StreamPipeWriterTests.cs new file mode 100644 index 0000000000..76d3b34fae --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/StreamPipeWriterTests.cs @@ -0,0 +1,380 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Buffers; +using System.IO; +using System.IO.Pipelines; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Http.Tests +{ + public class StreamPipeWriterTests : PipeTest + { + [Fact] + public async Task CanWriteAsyncMultipleTimesIntoSameBlock() + { + + await Writer.WriteAsync(new byte[] { 1 }); + await Writer.WriteAsync(new byte[] { 2 }); + await Writer.WriteAsync(new byte[] { 3 }); + + Assert.Equal(new byte[] { 1, 2, 3 }, Read()); + } + + [Theory] + [InlineData(100, 1000)] + [InlineData(100, 8000)] + [InlineData(100, 10000)] + [InlineData(8000, 100)] + [InlineData(8000, 8000)] + public async Task CanAdvanceWithPartialConsumptionOfFirstSegment(int firstWriteLength, int secondWriteLength) + { + await Writer.WriteAsync(Encoding.ASCII.GetBytes("a")); + + var expectedLength = firstWriteLength + secondWriteLength + 1; + + var memory = Writer.GetMemory(firstWriteLength); + Writer.Advance(firstWriteLength); + + memory = Writer.GetMemory(secondWriteLength); + Writer.Advance(secondWriteLength); + + await Writer.FlushAsync(); + + Assert.Equal(expectedLength, Read().Length); + } + + [Fact] + public async Task ThrowsOnCompleteAndWrite() + { + Writer.Complete(new InvalidOperationException("Whoops")); + var exception = await Assert.ThrowsAsync(async () => await Writer.FlushAsync()); + + Assert.Equal("Whoops", exception.Message); + } + + [Fact] + public async Task WriteCanBeCancelledViaProvidedCancellationToken() + { + var pipeWriter = new StreamPipeWriter(new HangingStream()); + var cts = new CancellationTokenSource(1); + await Assert.ThrowsAsync(async () => await pipeWriter.WriteAsync(Encoding.ASCII.GetBytes("data"), cts.Token)); + } + + [Fact] + public async Task WriteCanBeCanceledViaCancelPendingFlushWhenFlushIsAsync() + { + var pipeWriter = new StreamPipeWriter(new HangingStream()); + FlushResult flushResult = new FlushResult(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var task = Task.Run(async () => + { + try + { + var writingTask = pipeWriter.WriteAsync(Encoding.ASCII.GetBytes("data")); + tcs.SetResult(0); + flushResult = await writingTask; + } + catch (Exception ex) + { + Console.WriteLine(ex.Message); + throw ex; + } + }); + + await tcs.Task; + + pipeWriter.CancelPendingFlush(); + + await task; + + Assert.True(flushResult.IsCanceled); + } + + [Fact] + public void FlushAsyncCompletedAfterPreCancellation() + { + PipeWriter writableBuffer = Writer.WriteEmpty(1); + + Writer.CancelPendingFlush(); + + ValueTask flushAsync = writableBuffer.FlushAsync(); + + Assert.True(flushAsync.IsCompleted); + + FlushResult flushResult = flushAsync.GetAwaiter().GetResult(); + + Assert.True(flushResult.IsCanceled); + + flushAsync = writableBuffer.FlushAsync(); + + Assert.True(flushAsync.IsCompleted); + } + + [Fact] + public void FlushAsyncReturnsCanceledIfCanceledBeforeFlush() + { + CheckCanceledFlush(); + } + + [Fact] + public void FlushAsyncReturnsCanceledIfCanceledBeforeFlushMultipleTimes() + { + for (var i = 0; i < 10; i++) + { + CheckCanceledFlush(); + } + } + + [Fact] + public async Task FlushAsyncReturnsCanceledInterleaved() + { + for (var i = 0; i < 5; i++) + { + CheckCanceledFlush(); + await CheckWriteIsNotCanceled(); + } + } + + [Fact] + public async Task CancelPendingFlushBetweenWritesAllDataIsPreserved() + { + MemoryStream = new SingleWriteStream(); + Writer = new StreamPipeWriter(MemoryStream); + FlushResult flushResult = new FlushResult(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var task = Task.Run(async () => + { + try + { + await Writer.WriteAsync(Encoding.ASCII.GetBytes("data")); + + var writingTask = Writer.WriteAsync(Encoding.ASCII.GetBytes(" data")); + tcs.SetResult(0); + flushResult = await writingTask; + } + catch (Exception ex) + { + Console.WriteLine(ex.Message); + throw ex; + } + }); + + await tcs.Task; + + Writer.CancelPendingFlush(); + + await task; + + Assert.True(flushResult.IsCanceled); + + await Writer.WriteAsync(Encoding.ASCII.GetBytes(" more data")); + Assert.Equal(Encoding.ASCII.GetBytes("data data more data"), Read()); + } + + [Fact] + public async Task CancelPendingFlushAfterAllWritesAllDataIsPreserved() + { + MemoryStream = new CannotFlushStream(); + Writer = new StreamPipeWriter(MemoryStream); + FlushResult flushResult = new FlushResult(); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + var task = Task.Run(async () => + { + try + { + // Create two Segments + // First one will succeed to write, other one will hang. + var writingTask = Writer.WriteAsync(Encoding.ASCII.GetBytes("data")); + tcs.SetResult(0); + flushResult = await writingTask; + } + catch (Exception ex) + { + Console.WriteLine(ex.Message); + throw ex; + } + }); + + await tcs.Task; + + Writer.CancelPendingFlush(); + + await task; + + Assert.True(flushResult.IsCanceled); + } + + [Fact] + public async Task CancelPendingFlushLostOfCancellationsNoDataLost() + { + var writeSize = 16; + var singleWriteStream = new SingleWriteStream(); + MemoryStream = singleWriteStream; + Writer = new StreamPipeWriter(MemoryStream, minimumSegmentSize: writeSize); + + for (var i = 0; i < 10; i++) + { + FlushResult flushResult = new FlushResult(); + var expectedData = Encoding.ASCII.GetBytes(new string('a', writeSize)); + + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + // TaskCreationOptions.RunAsync + + var task = Task.Run(async () => + { + try + { + // Create two Segments + // First one will succeed to write, other one will hang. + for (var j = 0; j < 2; j++) + { + Writer.Write(expectedData); + } + + var flushTask = Writer.FlushAsync(); + tcs.SetResult(0); + flushResult = await flushTask; + } + catch (Exception ex) + { + Console.WriteLine(ex.Message); + throw ex; + } + }); + + await tcs.Task; + + Writer.CancelPendingFlush(); + + await task; + + Assert.True(flushResult.IsCanceled); + } + + // Only half of the data was written because every other flush failed. + Assert.Equal(16 * 10, ReadWithoutFlush().Length); + + // Start allowing all writes to make read succeed. + singleWriteStream.AllowAllWrites = true; + + Assert.Equal(16 * 10 * 2, Read().Length); + } + + private async Task CheckWriteIsNotCanceled() + { + var flushResult = await Writer.WriteAsync(Encoding.ASCII.GetBytes("data")); + Assert.False(flushResult.IsCanceled); + } + + private void CheckCanceledFlush() + { + PipeWriter writableBuffer = Writer.WriteEmpty(MaximumSizeHigh); + + Writer.CancelPendingFlush(); + + ValueTask flushAsync = writableBuffer.FlushAsync(); + + Assert.True(flushAsync.IsCompleted); + FlushResult flushResult = flushAsync.GetAwaiter().GetResult(); + Assert.True(flushResult.IsCanceled); + } + } + + internal class HangingStream : MemoryStream + { + + public HangingStream() + { + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await Task.Delay(30000, cancellationToken); + } + + public override async Task FlushAsync(CancellationToken cancellationToken) + { + await Task.Delay(30000, cancellationToken); + } + + public override async Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await Task.Delay(30000, cancellationToken); + return 0; + } + } + + internal class SingleWriteStream : MemoryStream + { + private bool _shouldNextWriteFail; + + public bool AllowAllWrites { get; set; } + + +#if NETCOREAPP2_2 + public override async ValueTask WriteAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) + { + try + { + if (_shouldNextWriteFail && !AllowAllWrites) + { + await Task.Delay(30000, cancellationToken); + } + else + { + await base.WriteAsync(source, cancellationToken); + } + } + finally + { + _shouldNextWriteFail = !_shouldNextWriteFail; + } + } +#endif + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + try + { + if (_shouldNextWriteFail && !AllowAllWrites) + { + await Task.Delay(30000, cancellationToken); + } + await base.WriteAsync(buffer, offset, count, cancellationToken); + } + finally + { + _shouldNextWriteFail = !_shouldNextWriteFail; + } + } + } + + internal class CannotFlushStream : MemoryStream + { + public override async Task FlushAsync(CancellationToken cancellationToken) + { + await Task.Delay(30000, cancellationToken); + } + } + + internal static class TestWriterExtensions + { + public static PipeWriter WriteEmpty(this PipeWriter Writer, int count) + { + Writer.GetSpan(count).Slice(0, count).Fill(0); + Writer.Advance(count); + return Writer; + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/TestMemoryPool.cs b/test/Microsoft.AspNetCore.Http.Tests/TestMemoryPool.cs new file mode 100644 index 0000000000..c5dd647dd1 --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/TestMemoryPool.cs @@ -0,0 +1,139 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System; +using System.Buffers; +using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Microsoft.AspNetCore.Http.Tests +{ + public class TestMemoryPool : MemoryPool + { + private MemoryPool _pool = Shared; + + private bool _disposed; + + public override IMemoryOwner Rent(int minBufferSize = -1) + { + CheckDisposed(); + return new PooledMemory(_pool.Rent(minBufferSize), this); + } + + protected override void Dispose(bool disposing) + { + _disposed = true; + } + + public override int MaxBufferSize => 4096; + + internal void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(nameof(TestMemoryPool)); + } + } + + private class PooledMemory : MemoryManager + { + private IMemoryOwner _owner; + + private readonly TestMemoryPool _pool; + + private int _referenceCount; + + private bool _returned; + + private string _leaser; + + public PooledMemory(IMemoryOwner owner, TestMemoryPool pool) + { + _owner = owner; + _pool = pool; + _leaser = Environment.StackTrace; + _referenceCount = 1; + } + + ~PooledMemory() + { + Debug.Assert(_returned, "Block being garbage collected instead of returned to pool" + Environment.NewLine + _leaser); + } + + protected override void Dispose(bool disposing) + { + _pool.CheckDisposed(); + } + + public override MemoryHandle Pin(int elementIndex = 0) + { + _pool.CheckDisposed(); + Interlocked.Increment(ref _referenceCount); + + if (!MemoryMarshal.TryGetArray(_owner.Memory, out ArraySegment segment)) + { + throw new InvalidOperationException(); + } + + unsafe + { + try + { + if ((uint)elementIndex > (uint)segment.Count) + { + throw new ArgumentOutOfRangeException(nameof(elementIndex)); + } + + GCHandle handle = GCHandle.Alloc(segment.Array, GCHandleType.Pinned); + + return new MemoryHandle(Unsafe.Add(((void*)handle.AddrOfPinnedObject()), elementIndex + segment.Offset), handle, this); + } + catch + { + Unpin(); + throw; + } + } + } + + public override void Unpin() + { + _pool.CheckDisposed(); + + int newRefCount = Interlocked.Decrement(ref _referenceCount); + + if (newRefCount < 0) + throw new InvalidOperationException(); + + if (newRefCount == 0) + { + _returned = true; + } + } + + protected override bool TryGetArray(out ArraySegment segment) + { + _pool.CheckDisposed(); + return MemoryMarshal.TryGetArray(_owner.Memory, out segment); + } + + public override Memory Memory + { + get + { + _pool.CheckDisposed(); + return _owner.Memory; + } + } + + public override Span GetSpan() + { + _pool.CheckDisposed(); + return _owner.Memory.Span; + } + } + } +} \ No newline at end of file