Adds MinimumReadThreshold to StreamPipeReader. (#4372)

This commit is contained in:
Justin Kotalik 2018-12-12 13:09:15 -05:00 committed by GitHub
parent 12966c63a6
commit 226f2c0c2c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 181 additions and 55 deletions

View File

@ -20,6 +20,7 @@ namespace Microsoft.AspNetCore.Http
public class StreamPipeReader : PipeReader
{
private readonly int _minimumSegmentSize;
private readonly int _minimumReadThreshold;
private readonly Stream _readingStream;
private readonly MemoryPool<byte> _pool;
@ -35,11 +36,51 @@ namespace Microsoft.AspNetCore.Http
private bool _examinedEverything;
private object _lock = new object();
/// <summary>
/// Creates a new StreamPipeReader.
/// </summary>
/// <param name="readingStream">The stream to read from.</param>
public StreamPipeReader(Stream readingStream)
: this(readingStream, StreamPipeReaderOptions.DefaultOptions)
{
}
/// <summary>
/// Creates a new StreamPipeReader.
/// </summary>
/// <param name="readingStream">The stream to read from.</param>
/// <param name="options">The options to use.</param>
public StreamPipeReader(Stream readingStream, StreamPipeReaderOptions options)
{
_readingStream = readingStream ?? throw new ArgumentNullException(nameof(readingStream));
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}
if (options.MinimumReadThreshold <= 0)
{
throw new ArgumentOutOfRangeException(nameof(options.MinimumReadThreshold));
}
_minimumSegmentSize = options.MinimumSegmentSize;
_minimumReadThreshold = Math.Min(options.MinimumReadThreshold, options.MinimumSegmentSize);
_pool = options.MemoryPool;
}
/// <inheritdoc />
public override void AdvanceTo(SequencePosition consumed)
{
AdvanceTo(consumed, consumed);
}
private CancellationTokenSource InternalTokenSource
{
get
{
lock(_lock)
lock (_lock)
{
if (_internalTokenSource == null)
{
@ -52,34 +93,6 @@ namespace Microsoft.AspNetCore.Http
{
_internalTokenSource = value;
}
}
/// <summary>
/// Creates a new StreamPipeReader.
/// </summary>
/// <param name="readingStream">The stream to read from.</param>
public StreamPipeReader(Stream readingStream) : this(readingStream, minimumSegmentSize: 4096)
{
}
/// <summary>
/// Creates a new StreamPipeReader.
/// </summary>
/// <param name="readingStream">The stream to read from.</param>
/// <param name="minimumSegmentSize">The minimum segment size to return from ReadAsync.</param>
/// <param name="pool"></param>
public StreamPipeReader(Stream readingStream, int minimumSegmentSize, MemoryPool<byte> pool = null)
{
_minimumSegmentSize = minimumSegmentSize;
_readingStream = readingStream;
_pool = pool ?? MemoryPool<byte>.Shared;
}
/// <inheritdoc />
public override void AdvanceTo(SequencePosition consumed)
{
AdvanceTo(consumed, consumed);
}
/// <inheritdoc />
@ -309,7 +322,7 @@ namespace Microsoft.AspNetCore.Http
_readHead.SetMemory(_pool.Rent(GetSegmentSize()));
_readTail = _readHead;
}
else if (_readTail.WritableBytes == 0)
else if (_readTail.WritableBytes < _minimumReadThreshold)
{
CreateNewTailSegment();
}

View File

@ -0,0 +1,31 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.Http
{
public class StreamPipeReaderOptions
{
public static StreamPipeReaderOptions DefaultOptions = new StreamPipeReaderOptions();
public const int DefaultMinimumSegmentSize = 4096;
public const int DefaultMinimumReadThreshold = 256;
public StreamPipeReaderOptions()
{
}
public StreamPipeReaderOptions(int minimumSegmentSize, int minimumReadThreshold, MemoryPool<byte> memoryPool)
{
MinimumSegmentSize = minimumSegmentSize;
MinimumReadThreshold = minimumReadThreshold;
MemoryPool = memoryPool;
}
public int MinimumSegmentSize { get; set; } = DefaultMinimumSegmentSize;
public int MinimumReadThreshold { get; set; } = DefaultMinimumReadThreshold;
public MemoryPool<byte> MemoryPool { get; set; } = MemoryPool<byte>.Shared;
}
}

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;

View File

@ -23,7 +23,7 @@ namespace Microsoft.AspNetCore.Http.Tests
{
MemoryStream = new MemoryStream();
Writer = new StreamPipeWriter(MemoryStream, MinimumSegmentSize, new TestMemoryPool());
Reader = new StreamPipeReader(MemoryStream, MinimumSegmentSize, new TestMemoryPool());
Reader = new StreamPipeReader(MemoryStream, new StreamPipeReaderOptions(MinimumSegmentSize, minimumReadThreshold: 256, new TestMemoryPool()));
}
public void Dispose()
@ -49,6 +49,13 @@ namespace Microsoft.AspNetCore.Http.Tests
MemoryStream.Write(data, 0, data.Length);
}
public void Append(byte[] data)
{
var originalPosition = MemoryStream.Position;
MemoryStream.Write(data, 0, data.Length);
MemoryStream.Position = originalPosition;
}
public byte[] ReadWithoutFlush()
{
MemoryStream.Position = 0;

View File

@ -58,7 +58,7 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ReadWithAdvance()
{
Write(new byte[10000]);
WriteByteArray(9000);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.End);
@ -71,8 +71,9 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ReadWithAdvanceDifferentSegmentSize()
{
Reader = new StreamPipeReader(MemoryStream, 4095, new TestMemoryPool());
Write(new byte[10000]);
CreateReader(minimumSegmentSize: 4095);
WriteByteArray(9000);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.End);
@ -85,8 +86,9 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ReadWithAdvanceSmallSegments()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
Write(new byte[128]);
CreateReader();
WriteByteArray(128);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.End);
@ -251,8 +253,9 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task AdvanceMultipleSegments()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
Write(new byte[128]);
CreateReader();
WriteByteArray(128);
var result = await Reader.ReadAsync();
Assert.Equal(16, result.Buffer.Length);
@ -269,8 +272,9 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task AdvanceMultipleSegmentsEdgeCase()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
Write(new byte[128]);
CreateReader();
WriteByteArray(128);
var result = await Reader.ReadAsync();
Reader.AdvanceTo(result.Buffer.Start, result.Buffer.End);
@ -288,7 +292,7 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task CompleteReaderWithoutAdvanceDoesNotThrow()
{
Write(new byte[100]);
WriteByteArray(100);
await Reader.ReadAsync();
Reader.Complete();
}
@ -296,7 +300,7 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task AdvanceAfterCompleteThrows()
{
Write(new byte[100]);
WriteByteArray(100);
var buffer = (await Reader.ReadAsync()).Buffer;
Reader.Complete();
@ -309,7 +313,7 @@ namespace Microsoft.AspNetCore.Http.Tests
public async Task ReadBetweenBlocks()
{
var blockSize = 16;
Reader = new StreamPipeReader(MemoryStream, blockSize, new TestMemoryPool());
CreateReader();
WriteWithoutPosition(Enumerable.Repeat((byte)'a', blockSize - 5).ToArray());
Write(Encoding.ASCII.GetBytes("Hello World"));
@ -364,7 +368,7 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public void ReadAsyncWithDataReadyReturnsTaskWithValue()
{
Write(new byte[20]);
WriteByteArray(20);
var task = Reader.ReadAsync();
Assert.True(IsTaskWithResult(task));
}
@ -380,8 +384,9 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task AdvancePastMinReadSizeReadAsyncReturnsMoreData()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
Write(new byte[32]);
CreateReader();
WriteByteArray(32);
var result = await Reader.ReadAsync();
Assert.Equal(16, result.Buffer.Length);
@ -393,7 +398,7 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ExamineEverythingResetsAfterSuccessfulRead()
{
Write(Encoding.ASCII.GetBytes(new string('a', 10000)));
WriteByteArray(10000);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
@ -408,10 +413,10 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ReadMultipleTimesAdvanceFreesAppropriately()
{
var blockSize = 16;
var pool = new TestMemoryPool();
Reader = new StreamPipeReader(MemoryStream, blockSize, pool);
Write(Encoding.ASCII.GetBytes(new string('a', 10000)));
CreateReader(memoryPool: pool);
WriteByteArray(2000);
for (var i = 0; i < 99; i++)
{
@ -428,8 +433,9 @@ namespace Microsoft.AspNetCore.Http.Tests
public async Task AsyncReadWorks()
{
MemoryStream = new AsyncStream();
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
Write(Encoding.ASCII.GetBytes(new string('a', 10000)));
CreateReader();
WriteByteArray(2000);
for (var i = 0; i < 99; i++)
{
@ -445,7 +451,8 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ConsumePartialBufferWorks()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
CreateReader();
Write(Encoding.ASCII.GetBytes(new string('a', 8)));
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.GetPosition(4), readResult.Buffer.End);
@ -460,7 +467,8 @@ namespace Microsoft.AspNetCore.Http.Tests
[Fact]
public async Task ConsumePartialBufferBetweenMultipleSegmentsWorks()
{
Reader = new StreamPipeReader(MemoryStream, 16, new TestMemoryPool());
CreateReader();
Write(Encoding.ASCII.GetBytes(new string('a', 8)));
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.GetPosition(4), readResult.Buffer.End);
@ -477,11 +485,78 @@ namespace Microsoft.AspNetCore.Http.Tests
Reader.AdvanceTo(readResult.Buffer.End);
}
[Fact]
public async Task SetMinimumReadThresholdSegmentAdvancesCorrectly()
{
CreateReader(minimumReadThreshold: 8);
WriteByteArray(9);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
AppendByteArray(9);
readResult = await Reader.ReadAsync();
foreach (var segment in readResult.Buffer)
{
Assert.Equal(9, segment.Length);
}
Assert.False(readResult.Buffer.IsSingleSegment);
}
[Fact]
public async Task SetMinimumReadThresholdToMiminumSegmentSizeOnlyGetNewBlockWhenDataIsWritten()
{
CreateReader(minimumReadThreshold: 16);
WriteByteArray(0);
var readResult = await Reader.ReadAsync();
Reader.AdvanceTo(readResult.Buffer.Start, readResult.Buffer.End);
WriteByteArray(16);
readResult = await Reader.ReadAsync();
Assert.Equal(16, readResult.Buffer.Length);
Assert.True(readResult.Buffer.IsSingleSegment);
}
[Fact]
public void SetMinimumReadThresholdOfZeroThrows()
{
Assert.Throws<ArgumentOutOfRangeException>(() => new StreamPipeReader(MemoryStream,
new StreamPipeReaderOptions(minimumSegmentSize: 4096, minimumReadThreshold: 0, new TestMemoryPool())));
}
[Fact]
public void SetOptionsToNullThrows()
{
Assert.Throws<ArgumentNullException>(() => new StreamPipeReader(MemoryStream, null));
}
private void CreateReader(int minimumSegmentSize = 16, int minimumReadThreshold = 4, MemoryPool<byte> memoryPool = null)
{
Reader = new StreamPipeReader(MemoryStream,
new StreamPipeReaderOptions(
minimumSegmentSize,
minimumReadThreshold,
memoryPool ?? new TestMemoryPool()));
}
private bool IsTaskWithResult<T>(ValueTask<T> task)
{
return task == new ValueTask<T>(task.Result);
}
private void WriteByteArray(int size)
{
Write(new byte[size]);
}
private void AppendByteArray(int size)
{
Append(new byte[size]);
}
private class AsyncStream : MemoryStream
{
private static byte[] bytes = Encoding.ASCII.GetBytes("Hello World");