Adds PipeWriterAdapter (#1065)

This commit is contained in:
Justin Kotalik 2018-11-16 19:18:47 -08:00 committed by GitHub
parent 49d785c934
commit 962ec07bdb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 1266 additions and 2 deletions

1
.gitignore vendored
View File

@ -30,3 +30,4 @@ project.lock.json
/.vs/
.vscode/
global.json
BenchmarkDotNet.Artifacts/

View File

@ -0,0 +1,89 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
namespace Microsoft.AspNetCore.Http
{
public class StreamPipeWriterBenchmark
{
private Stream _memoryStream;
private StreamPipeWriter _pipeWriter;
private static byte[] _helloWorldBytes = Encoding.ASCII.GetBytes("Hello World");
private static byte[] _largeWrite = Encoding.ASCII.GetBytes(new string('a', 50000));
[IterationSetup]
public void Setup()
{
_memoryStream = new NoopStream();
_pipeWriter = new StreamPipeWriter(_memoryStream);
}
[Benchmark]
public async Task WriteHelloWorld()
{
await _pipeWriter.WriteAsync(_helloWorldBytes);
}
[Benchmark]
public async Task WriteHelloWorldLargeWrite()
{
await _pipeWriter.WriteAsync(_largeWrite);
}
public class NoopStream : Stream
{
public override bool CanRead => false;
public override bool CanSeek => throw new System.NotImplementedException();
public override bool CanWrite => true;
public override long Length => throw new System.NotImplementedException();
public override long Position { get => throw new System.NotImplementedException(); set => throw new System.NotImplementedException(); }
public override void Flush()
{
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new System.NotImplementedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new System.NotImplementedException();
}
public override void SetLength(long value)
{
}
public override void Write(byte[] buffer, int offset, int count)
{
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
{
return default(ValueTask);
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return Task.CompletedTask;
}
}
}
}

View File

@ -24,6 +24,7 @@
<MoqPackageVersion>4.9.0</MoqPackageVersion>
<NETStandardLibrary20PackageVersion>2.0.3</NETStandardLibrary20PackageVersion>
<SystemBuffersPackageVersion>4.6.0-preview1-26907-04</SystemBuffersPackageVersion>
<SystemIOPipelinesPackageVersion>4.6.0-preview1-26907-04</SystemIOPipelinesPackageVersion>
<SystemTextEncodingsWebPackageVersion>4.6.0-preview1-26907-04</SystemTextEncodingsWebPackageVersion>
<XunitAnalyzersPackageVersion>0.10.0</XunitAnalyzersPackageVersion>
<XunitPackageVersion>2.3.1</XunitPackageVersion>

View File

@ -2,7 +2,7 @@
<PropertyGroup>
<Description>ASP.NET Core default HTTP feature implementations.</Description>
<TargetFramework>netstandard2.0</TargetFramework>
<TargetFrameworks>netstandard2.0;netcoreapp2.2</TargetFrameworks>
<NoWarn>$(NoWarn);CS1591</NoWarn>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
@ -19,6 +19,7 @@
<PackageReference Include="Microsoft.Extensions.CopyOnWriteDictionary.Sources" PrivateAssets="All" Version="$(MicrosoftExtensionsCopyOnWriteDictionarySourcesPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.ObjectPool" Version="$(MicrosoftExtensionsObjectPoolPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Options" Version="$(MicrosoftExtensionsOptionsPackageVersion)" />
<PackageReference Include="System.IO.Pipelines" Version="$(SystemIOPipelinesPackageVersion)" />
</ItemGroup>
</Project>

View File

@ -0,0 +1,320 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Http
{
/// <summary>
/// Implements PipeWriter using a underlying stream.
/// </summary>
public class StreamPipeWriter : PipeWriter, IDisposable
{
private readonly int _minimumSegmentSize;
private readonly Stream _writingStream;
private int _bytesWritten;
private List<CompletedBuffer> _completedSegments;
private Memory<byte> _currentSegment;
private IMemoryOwner<byte> _currentSegmentOwner;
private MemoryPool<byte> _pool;
private int _position;
private CancellationTokenSource _internalTokenSource;
private bool _isCompleted;
private ExceptionDispatchInfo _exceptionInfo;
private object _lockObject = new object();
private CancellationTokenSource InternalTokenSource
{
get
{
lock (_lockObject)
{
if (_internalTokenSource == null)
{
_internalTokenSource = new CancellationTokenSource();
}
return _internalTokenSource;
}
}
}
/// <summary>
/// Creates a new StreamPipeWrapper
/// </summary>
/// <param name="writingStream">The stream to write to</param>
public StreamPipeWriter(Stream writingStream) : this(writingStream, 4096)
{
}
public StreamPipeWriter(Stream writingStream, int minimumSegmentSize, MemoryPool<byte> pool = null)
{
_minimumSegmentSize = minimumSegmentSize;
_writingStream = writingStream;
_pool = pool ?? MemoryPool<byte>.Shared;
}
/// <inheritdoc />
public override void Advance(int count)
{
if (_currentSegment.IsEmpty) // TODO confirm this
{
throw new InvalidOperationException("No writing operation. Make sure GetMemory() was called.");
}
if (count >= 0)
{
if (_currentSegment.Length < _position + count)
{
throw new InvalidOperationException("Can't advance past buffer size.");
}
_bytesWritten += count;
_position += count;
}
}
/// <inheritdoc />
public override Memory<byte> GetMemory(int sizeHint = 0)
{
EnsureCapacity(sizeHint);
return _currentSegment;
}
/// <inheritdoc />
public override Span<byte> GetSpan(int sizeHint = 0)
{
EnsureCapacity(sizeHint);
return _currentSegment.Span.Slice(_position);
}
/// <inheritdoc />
public override void CancelPendingFlush()
{
Cancel();
}
/// <inheritdoc />
public override void Complete(Exception exception = null)
{
if (_isCompleted)
{
return;
}
_isCompleted = true;
if (exception != null)
{
_exceptionInfo = ExceptionDispatchInfo.Capture(exception);
}
_internalTokenSource?.Dispose();
if (_completedSegments != null)
{
foreach (var segment in _completedSegments)
{
segment.Return();
}
}
_currentSegmentOwner?.Dispose();
}
/// <inheritdoc />
public override void OnReaderCompleted(Action<Exception, object> callback, object state)
{
throw new NotSupportedException("OnReaderCompleted isn't supported in StreamPipeWrapper.");
}
/// <inheritdoc />
public override ValueTask<FlushResult> FlushAsync(CancellationToken cancellationToken = default)
{
if (_bytesWritten == 0)
{
return new ValueTask<FlushResult>(new FlushResult(isCanceled: false, IsCompletedOrThrow()));
}
return FlushAsyncInternal(cancellationToken);
}
private void Cancel()
{
InternalTokenSource.Cancel();
}
private async ValueTask<FlushResult> FlushAsyncInternal(CancellationToken cancellationToken = default)
{
// Write all completed segments and whatever remains in the current segment
// and flush the result.
CancellationTokenRegistration reg = new CancellationTokenRegistration();
if (cancellationToken.CanBeCanceled)
{
reg = cancellationToken.Register(state => ((StreamPipeWriter)state).Cancel(), this);
}
using (reg)
{
var localToken = InternalTokenSource.Token;
try
{
if (_completedSegments != null && _completedSegments.Count > 0)
{
var count = _completedSegments.Count;
for (var i = 0; i < count; i++)
{
var segment = _completedSegments[0];
#if NETCOREAPP2_2
await _writingStream.WriteAsync(segment.Buffer.Slice(0, segment.Length), localToken);
#elif NETSTANDARD2_0
MemoryMarshal.TryGetArray<byte>(segment.Buffer, out var arraySegment);
await _writingStream.WriteAsync(arraySegment.Array, 0, segment.Length, localToken);
#else
#error Target frameworks need to be updated.
#endif
_bytesWritten -= segment.Length;
segment.Return();
_completedSegments.RemoveAt(0);
}
}
if (!_currentSegment.IsEmpty)
{
#if NETCOREAPP2_2
await _writingStream.WriteAsync(_currentSegment.Slice(0, _position), localToken);
#elif NETSTANDARD2_0
MemoryMarshal.TryGetArray<byte>(_currentSegment, out var arraySegment);
await _writingStream.WriteAsync(arraySegment.Array, 0, _position, localToken);
#else
#error Target frameworks need to be updated.
#endif
_bytesWritten -= _position;
_position = 0;
}
await _writingStream.FlushAsync(localToken);
return new FlushResult(isCanceled: false, IsCompletedOrThrow());
}
catch (OperationCanceledException)
{
// Remove the cancellation token such that the next time Flush is called
// A new CTS is created.
lock (_lockObject)
{
_internalTokenSource = null;
}
if (cancellationToken.IsCancellationRequested)
{
throw;
}
// Catch any cancellation and translate it into setting isCanceled = true
return new FlushResult(isCanceled: true, IsCompletedOrThrow());
}
}
}
private void EnsureCapacity(int sizeHint)
{
// This does the Right Thing. It only subtracts _position from the current segment length if it's non-null.
// If _currentSegment is null, it returns 0.
var remainingSize = _currentSegment.Length - _position;
// If the sizeHint is 0, any capacity will do
// Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment.
if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint))
{
// We have capacity in the current segment
return;
}
AddSegment(sizeHint);
}
private void AddSegment(int sizeHint = 0)
{
if (_currentSegment.Length != 0)
{
// We're adding a segment to the list
if (_completedSegments == null)
{
_completedSegments = new List<CompletedBuffer>();
}
// Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when
// GetMemory was called. In that case we'll take the current segment and call it "completed", but need to
// ignore any empty space in it.
_completedSegments.Add(new CompletedBuffer(_currentSegmentOwner, _position));
}
// Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment.
_currentSegmentOwner = _pool.Rent(Math.Max(_minimumSegmentSize, sizeHint));
_currentSegment = _currentSegmentOwner.Memory;
_position = 0;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private bool IsCompletedOrThrow()
{
if (!_isCompleted)
{
return false;
}
if (_exceptionInfo != null)
{
ThrowLatchedException();
}
return true;
}
[MethodImpl(MethodImplOptions.NoInlining)]
private void ThrowLatchedException()
{
_exceptionInfo.Throw();
}
public void Dispose()
{
Complete();
}
/// <summary>
/// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it.
/// </summary>
private readonly struct CompletedBuffer
{
public Memory<byte> Buffer { get; }
public int Length { get; }
public ReadOnlySpan<byte> Span => Buffer.Span;
private readonly IMemoryOwner<byte> _memoryOwner;
public CompletedBuffer(IMemoryOwner<byte> buffer, int length)
{
Buffer = buffer.Memory;
Length = length;
_memoryOwner = buffer;
}
public void Return()
{
_memoryOwner.Dispose();
}
}
}
}

View File

@ -0,0 +1,68 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.Http.Tests
{
public class FlushResultCancellationTests : PipeTest
{
[Fact]
public void FlushAsyncCancellationDeadlock()
{
var cts = new CancellationTokenSource();
var cts2 = new CancellationTokenSource();
PipeWriter buffer = Writer.WriteEmpty(MaximumSizeHigh);
var e = new ManualResetEventSlim();
ValueTaskAwaiter<FlushResult> awaiter = buffer.FlushAsync(cts.Token).GetAwaiter();
awaiter.OnCompleted(
() => {
// We are on cancellation thread and need to wait until another FlushAsync call
// takes pipe state lock
e.Wait();
// Make sure we had enough time to reach _cancellationTokenRegistration.Dispose
Thread.Sleep(100);
// Try to take pipe state lock
buffer.FlushAsync();
});
// Start a thread that would run cancellation callbacks
Task cancellationTask = Task.Run(() => cts.Cancel());
// Start a thread that would call FlushAsync with different token
// and block on _cancellationTokenRegistration.Dispose
Task blockingTask = Task.Run(
() => {
e.Set();
buffer.FlushAsync(cts2.Token);
});
bool completed = Task.WhenAll(cancellationTask, blockingTask).Wait(TimeSpan.FromSeconds(10));
Assert.True(completed);
}
[Fact]
public async Task FlushAsyncWithNewCancellationTokenNotAffectedByPrevious()
{
var cancellationTokenSource1 = new CancellationTokenSource();
PipeWriter buffer = Writer.WriteEmpty(10);
await buffer.FlushAsync(cancellationTokenSource1.Token);
cancellationTokenSource1.Cancel();
var cancellationTokenSource2 = new CancellationTokenSource();
buffer = Writer.WriteEmpty(10);
await buffer.FlushAsync(cancellationTokenSource2.Token);
}
}
}

View File

@ -2,8 +2,9 @@
<PropertyGroup>
<TargetFrameworks>$(StandardTestTfms)</TargetFrameworks>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\..\src\Microsoft.AspNetCore.Http\Microsoft.AspNetCore.Http.csproj" />
</ItemGroup>

View File

@ -0,0 +1,43 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.IO.Pipelines;
namespace Microsoft.AspNetCore.Http.Tests
{
public abstract class PipeTest : IDisposable
{
protected const int MaximumSizeHigh = 65;
public MemoryStream MemoryStream { get; set; }
public PipeWriter Writer { get; set; }
protected PipeTest()
{
MemoryStream = new MemoryStream();
Writer = new StreamPipeWriter(MemoryStream, 4096, new TestMemoryPool());
}
public void Dispose()
{
Writer.Complete();
}
public byte[] Read()
{
Writer.FlushAsync().GetAwaiter().GetResult();
return ReadWithoutFlush();
}
public byte[] ReadWithoutFlush()
{
MemoryStream.Position = 0;
var buffer = new byte[MemoryStream.Length];
var result = MemoryStream.Read(buffer, 0, (int)MemoryStream.Length);
return buffer;
}
}
}

View File

@ -0,0 +1,221 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.Http.Tests
{
public class PipeWriterTests : PipeTest
{
[Theory]
[InlineData(3, -1, 0)]
[InlineData(3, 0, -1)]
[InlineData(3, 0, 4)]
[InlineData(3, 4, 0)]
[InlineData(3, -1, -1)]
[InlineData(3, 4, 4)]
public void ThrowsForInvalidParameters(int arrayLength, int offset, int length)
{
var array = new byte[arrayLength];
for (var i = 0; i < array.Length; i++)
{
array[i] = (byte)(i + 1);
}
Writer.Write(new Span<byte>(array, 0, 0));
Writer.Write(new Span<byte>(array, array.Length, 0));
try
{
Writer.Write(new Span<byte>(array, offset, length));
Assert.True(false);
}
catch (Exception ex)
{
Assert.True(ex is ArgumentOutOfRangeException);
}
Writer.Write(new Span<byte>(array, 0, array.Length));
Assert.Equal(array, Read());
}
[Theory]
[InlineData(0, 3)]
[InlineData(1, 2)]
[InlineData(2, 1)]
[InlineData(1, 1)]
public void CanWriteWithOffsetAndLength(int offset, int length)
{
var array = new byte[] { 1, 2, 3 };
Writer.Write(new Span<byte>(array, offset, length));
Assert.Equal(array.Skip(offset).Take(length).ToArray(), Read());
}
[Fact]
public void CanWriteIntoHeadlessBuffer()
{
Writer.Write(new byte[] { 1, 2, 3 });
Assert.Equal(new byte[] { 1, 2, 3 }, Read());
}
[Fact]
public void CanGetNewMemoryWhenSizeTooLarge()
{
var memory = Writer.GetMemory(0);
var memoryLarge = Writer.GetMemory(10000);
Assert.NotEqual(memory, memoryLarge);
}
[Fact]
public void CanGetSameMemoryWhenNoAdvance()
{
var memory = Writer.GetMemory(0);
var secondMemory = Writer.GetMemory(0);
Assert.Equal(memory, secondMemory);
}
[Fact]
public void CanGetNewSpanWhenNoAdvanceWhenSizeTooLarge()
{
var span = Writer.GetSpan(0);
var secondSpan = Writer.GetSpan(8000);
Assert.False(span.SequenceEqual(secondSpan));
}
[Fact]
public void CanGetSameSpanWhenNoAdvance()
{
var span = Writer.GetSpan(0);
var secondSpan = Writer.GetSpan(0);
Assert.True(span.SequenceEqual(secondSpan));
}
[Theory]
[InlineData(16, 32, 32)]
[InlineData(16, 16, 16)]
[InlineData(64, 32, 64)]
[InlineData(40, 32, 64)] // memory sizes are powers of 2.
public void CheckMinimumSegmentSizeWithGetMemory(int minimumSegmentSize, int getMemorySize, int expectedSize)
{
var writer = new StreamPipeWriter(new MemoryStream(), minimumSegmentSize);
var memory = writer.GetMemory(getMemorySize);
Assert.Equal(expectedSize, memory.Length);
}
[Fact]
public void CanWriteMultipleTimes()
{
Writer.Write(new byte[] { 1 });
Writer.Write(new byte[] { 2 });
Writer.Write(new byte[] { 3 });
Assert.Equal(new byte[] { 1, 2, 3 }, Read());
}
[Fact]
public void CanWriteOverTheBlockLength()
{
Memory<byte> memory = Writer.GetMemory();
IEnumerable<byte> source = Enumerable.Range(0, memory.Length).Select(i => (byte)i);
byte[] expectedBytes = source.Concat(source).Concat(source).ToArray();
Writer.Write(expectedBytes);
Assert.Equal(expectedBytes, Read());
}
[Fact]
public void EnsureAllocatesSpan()
{
var span = Writer.GetSpan(10);
Assert.True(span.Length >= 10);
// 0 byte Flush would not complete the reader so we complete.
Writer.Complete();
Assert.Equal(new byte[] { }, Read());
}
[Fact]
public void SlicesSpanAndAdvancesAfterWrite()
{
int initialLength = Writer.GetSpan(3).Length;
Writer.Write(new byte[] { 1, 2, 3 });
Span<byte> span = Writer.GetSpan();
Assert.Equal(initialLength - 3, span.Length);
Assert.Equal(new byte[] { 1, 2, 3 }, Read());
}
[Theory]
[InlineData(5)]
[InlineData(50)]
[InlineData(500)]
[InlineData(5000)]
[InlineData(50000)]
public async Task WriteLargeDataBinary(int length)
{
var data = new byte[length];
new Random(length).NextBytes(data);
PipeWriter output = Writer;
output.Write(data);
await output.FlushAsync();
var input = Read();
Assert.Equal(data, input.ToArray());
}
[Fact]
public async Task CanWriteNothingToBuffer()
{
Writer.GetMemory(0);
Writer.Advance(0); // doing nothing, the hard way
await Writer.FlushAsync();
}
[Fact]
public void EmptyWriteDoesNotThrow()
{
Writer.Write(new byte[0]);
}
[Fact]
public void ThrowsOnAdvanceOverMemorySize()
{
Memory<byte> buffer = Writer.GetMemory(1);
var exception = Assert.Throws<InvalidOperationException>(() => Writer.Advance(buffer.Length + 1));
Assert.Equal("Can't advance past buffer size.", exception.Message);
}
[Fact]
public void ThrowsOnAdvanceWithNoMemory()
{
PipeWriter buffer = Writer;
var exception = Assert.Throws<InvalidOperationException>(() => buffer.Advance(1));
Assert.Equal("No writing operation. Make sure GetMemory() was called.", exception.Message);
}
}
}

View File

@ -0,0 +1,380 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO;
using System.IO.Pipelines;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.Http.Tests
{
public class StreamPipeWriterTests : PipeTest
{
[Fact]
public async Task CanWriteAsyncMultipleTimesIntoSameBlock()
{
await Writer.WriteAsync(new byte[] { 1 });
await Writer.WriteAsync(new byte[] { 2 });
await Writer.WriteAsync(new byte[] { 3 });
Assert.Equal(new byte[] { 1, 2, 3 }, Read());
}
[Theory]
[InlineData(100, 1000)]
[InlineData(100, 8000)]
[InlineData(100, 10000)]
[InlineData(8000, 100)]
[InlineData(8000, 8000)]
public async Task CanAdvanceWithPartialConsumptionOfFirstSegment(int firstWriteLength, int secondWriteLength)
{
await Writer.WriteAsync(Encoding.ASCII.GetBytes("a"));
var expectedLength = firstWriteLength + secondWriteLength + 1;
var memory = Writer.GetMemory(firstWriteLength);
Writer.Advance(firstWriteLength);
memory = Writer.GetMemory(secondWriteLength);
Writer.Advance(secondWriteLength);
await Writer.FlushAsync();
Assert.Equal(expectedLength, Read().Length);
}
[Fact]
public async Task ThrowsOnCompleteAndWrite()
{
Writer.Complete(new InvalidOperationException("Whoops"));
var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await Writer.FlushAsync());
Assert.Equal("Whoops", exception.Message);
}
[Fact]
public async Task WriteCanBeCancelledViaProvidedCancellationToken()
{
var pipeWriter = new StreamPipeWriter(new HangingStream());
var cts = new CancellationTokenSource(1);
await Assert.ThrowsAsync<TaskCanceledException>(async () => await pipeWriter.WriteAsync(Encoding.ASCII.GetBytes("data"), cts.Token));
}
[Fact]
public async Task WriteCanBeCanceledViaCancelPendingFlushWhenFlushIsAsync()
{
var pipeWriter = new StreamPipeWriter(new HangingStream());
FlushResult flushResult = new FlushResult();
var tcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var task = Task.Run(async () =>
{
try
{
var writingTask = pipeWriter.WriteAsync(Encoding.ASCII.GetBytes("data"));
tcs.SetResult(0);
flushResult = await writingTask;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
throw ex;
}
});
await tcs.Task;
pipeWriter.CancelPendingFlush();
await task;
Assert.True(flushResult.IsCanceled);
}
[Fact]
public void FlushAsyncCompletedAfterPreCancellation()
{
PipeWriter writableBuffer = Writer.WriteEmpty(1);
Writer.CancelPendingFlush();
ValueTask<FlushResult> flushAsync = writableBuffer.FlushAsync();
Assert.True(flushAsync.IsCompleted);
FlushResult flushResult = flushAsync.GetAwaiter().GetResult();
Assert.True(flushResult.IsCanceled);
flushAsync = writableBuffer.FlushAsync();
Assert.True(flushAsync.IsCompleted);
}
[Fact]
public void FlushAsyncReturnsCanceledIfCanceledBeforeFlush()
{
CheckCanceledFlush();
}
[Fact]
public void FlushAsyncReturnsCanceledIfCanceledBeforeFlushMultipleTimes()
{
for (var i = 0; i < 10; i++)
{
CheckCanceledFlush();
}
}
[Fact]
public async Task FlushAsyncReturnsCanceledInterleaved()
{
for (var i = 0; i < 5; i++)
{
CheckCanceledFlush();
await CheckWriteIsNotCanceled();
}
}
[Fact]
public async Task CancelPendingFlushBetweenWritesAllDataIsPreserved()
{
MemoryStream = new SingleWriteStream();
Writer = new StreamPipeWriter(MemoryStream);
FlushResult flushResult = new FlushResult();
var tcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var task = Task.Run(async () =>
{
try
{
await Writer.WriteAsync(Encoding.ASCII.GetBytes("data"));
var writingTask = Writer.WriteAsync(Encoding.ASCII.GetBytes(" data"));
tcs.SetResult(0);
flushResult = await writingTask;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
throw ex;
}
});
await tcs.Task;
Writer.CancelPendingFlush();
await task;
Assert.True(flushResult.IsCanceled);
await Writer.WriteAsync(Encoding.ASCII.GetBytes(" more data"));
Assert.Equal(Encoding.ASCII.GetBytes("data data more data"), Read());
}
[Fact]
public async Task CancelPendingFlushAfterAllWritesAllDataIsPreserved()
{
MemoryStream = new CannotFlushStream();
Writer = new StreamPipeWriter(MemoryStream);
FlushResult flushResult = new FlushResult();
var tcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
var task = Task.Run(async () =>
{
try
{
// Create two Segments
// First one will succeed to write, other one will hang.
var writingTask = Writer.WriteAsync(Encoding.ASCII.GetBytes("data"));
tcs.SetResult(0);
flushResult = await writingTask;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
throw ex;
}
});
await tcs.Task;
Writer.CancelPendingFlush();
await task;
Assert.True(flushResult.IsCanceled);
}
[Fact]
public async Task CancelPendingFlushLostOfCancellationsNoDataLost()
{
var writeSize = 16;
var singleWriteStream = new SingleWriteStream();
MemoryStream = singleWriteStream;
Writer = new StreamPipeWriter(MemoryStream, minimumSegmentSize: writeSize);
for (var i = 0; i < 10; i++)
{
FlushResult flushResult = new FlushResult();
var expectedData = Encoding.ASCII.GetBytes(new string('a', writeSize));
var tcs = new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously);
// TaskCreationOptions.RunAsync
var task = Task.Run(async () =>
{
try
{
// Create two Segments
// First one will succeed to write, other one will hang.
for (var j = 0; j < 2; j++)
{
Writer.Write(expectedData);
}
var flushTask = Writer.FlushAsync();
tcs.SetResult(0);
flushResult = await flushTask;
}
catch (Exception ex)
{
Console.WriteLine(ex.Message);
throw ex;
}
});
await tcs.Task;
Writer.CancelPendingFlush();
await task;
Assert.True(flushResult.IsCanceled);
}
// Only half of the data was written because every other flush failed.
Assert.Equal(16 * 10, ReadWithoutFlush().Length);
// Start allowing all writes to make read succeed.
singleWriteStream.AllowAllWrites = true;
Assert.Equal(16 * 10 * 2, Read().Length);
}
private async Task CheckWriteIsNotCanceled()
{
var flushResult = await Writer.WriteAsync(Encoding.ASCII.GetBytes("data"));
Assert.False(flushResult.IsCanceled);
}
private void CheckCanceledFlush()
{
PipeWriter writableBuffer = Writer.WriteEmpty(MaximumSizeHigh);
Writer.CancelPendingFlush();
ValueTask<FlushResult> flushAsync = writableBuffer.FlushAsync();
Assert.True(flushAsync.IsCompleted);
FlushResult flushResult = flushAsync.GetAwaiter().GetResult();
Assert.True(flushResult.IsCanceled);
}
}
internal class HangingStream : MemoryStream
{
public HangingStream()
{
}
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await Task.Delay(30000, cancellationToken);
}
public override async Task FlushAsync(CancellationToken cancellationToken)
{
await Task.Delay(30000, cancellationToken);
}
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await Task.Delay(30000, cancellationToken);
return 0;
}
}
internal class SingleWriteStream : MemoryStream
{
private bool _shouldNextWriteFail;
public bool AllowAllWrites { get; set; }
#if NETCOREAPP2_2
public override async ValueTask WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
{
try
{
if (_shouldNextWriteFail && !AllowAllWrites)
{
await Task.Delay(30000, cancellationToken);
}
else
{
await base.WriteAsync(source, cancellationToken);
}
}
finally
{
_shouldNextWriteFail = !_shouldNextWriteFail;
}
}
#endif
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
try
{
if (_shouldNextWriteFail && !AllowAllWrites)
{
await Task.Delay(30000, cancellationToken);
}
await base.WriteAsync(buffer, offset, count, cancellationToken);
}
finally
{
_shouldNextWriteFail = !_shouldNextWriteFail;
}
}
}
internal class CannotFlushStream : MemoryStream
{
public override async Task FlushAsync(CancellationToken cancellationToken)
{
await Task.Delay(30000, cancellationToken);
}
}
internal static class TestWriterExtensions
{
public static PipeWriter WriteEmpty(this PipeWriter Writer, int count)
{
Writer.GetSpan(count).Slice(0, count).Fill(0);
Writer.Advance(count);
return Writer;
}
}
}

View File

@ -0,0 +1,139 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Threading;
namespace Microsoft.AspNetCore.Http.Tests
{
public class TestMemoryPool : MemoryPool<byte>
{
private MemoryPool<byte> _pool = Shared;
private bool _disposed;
public override IMemoryOwner<byte> Rent(int minBufferSize = -1)
{
CheckDisposed();
return new PooledMemory(_pool.Rent(minBufferSize), this);
}
protected override void Dispose(bool disposing)
{
_disposed = true;
}
public override int MaxBufferSize => 4096;
internal void CheckDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(nameof(TestMemoryPool));
}
}
private class PooledMemory : MemoryManager<byte>
{
private IMemoryOwner<byte> _owner;
private readonly TestMemoryPool _pool;
private int _referenceCount;
private bool _returned;
private string _leaser;
public PooledMemory(IMemoryOwner<byte> owner, TestMemoryPool pool)
{
_owner = owner;
_pool = pool;
_leaser = Environment.StackTrace;
_referenceCount = 1;
}
~PooledMemory()
{
Debug.Assert(_returned, "Block being garbage collected instead of returned to pool" + Environment.NewLine + _leaser);
}
protected override void Dispose(bool disposing)
{
_pool.CheckDisposed();
}
public override MemoryHandle Pin(int elementIndex = 0)
{
_pool.CheckDisposed();
Interlocked.Increment(ref _referenceCount);
if (!MemoryMarshal.TryGetArray(_owner.Memory, out ArraySegment<byte> segment))
{
throw new InvalidOperationException();
}
unsafe
{
try
{
if ((uint)elementIndex > (uint)segment.Count)
{
throw new ArgumentOutOfRangeException(nameof(elementIndex));
}
GCHandle handle = GCHandle.Alloc(segment.Array, GCHandleType.Pinned);
return new MemoryHandle(Unsafe.Add<byte>(((void*)handle.AddrOfPinnedObject()), elementIndex + segment.Offset), handle, this);
}
catch
{
Unpin();
throw;
}
}
}
public override void Unpin()
{
_pool.CheckDisposed();
int newRefCount = Interlocked.Decrement(ref _referenceCount);
if (newRefCount < 0)
throw new InvalidOperationException();
if (newRefCount == 0)
{
_returned = true;
}
}
protected override bool TryGetArray(out ArraySegment<byte> segment)
{
_pool.CheckDisposed();
return MemoryMarshal.TryGetArray(_owner.Memory, out segment);
}
public override Memory<byte> Memory
{
get
{
_pool.CheckDisposed();
return _owner.Memory;
}
}
public override Span<byte> GetSpan()
{
_pool.CheckDisposed();
return _owner.Memory.Span;
}
}
}
}