diff --git a/src/Http/WebUtilities/src/FileBufferingReadStream.cs b/src/Http/WebUtilities/src/FileBufferingReadStream.cs index f03fd29edd..4ad9564157 100644 --- a/src/Http/WebUtilities/src/FileBufferingReadStream.cs +++ b/src/Http/WebUtilities/src/FileBufferingReadStream.cs @@ -363,6 +363,11 @@ namespace Microsoft.AspNetCore.WebUtilities public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { + // Set a minimum buffer size of 4K since the base Stream implementation has weird behavior when the stream is + // seekable *and* the length is 0 (it passes in a buffer size of 1). + // See https://github.com/dotnet/runtime/blob/222415c56c9ea73530444768c0e68413eb374f5d/src/libraries/System.Private.CoreLib/src/System/IO/Stream.cs#L164-L184 + bufferSize = Math.Max(4096, bufferSize); + // If we're completed buffered then copy from the underlying source if (_completelyBuffered) { @@ -372,7 +377,7 @@ namespace Microsoft.AspNetCore.WebUtilities async Task CopyToAsyncImpl() { // At least a 4K buffer - byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096)); + byte[] buffer = _bytePool.Rent(bufferSize); try { while (true) diff --git a/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs b/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs index e00292acc1..ce52c78a28 100644 --- a/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs +++ b/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs @@ -5,7 +5,6 @@ using System; using System.Buffers; using System.IO; using System.Linq; -using System.Text; using System.Threading.Tasks; using Moq; using Xunit; @@ -353,39 +352,49 @@ namespace Microsoft.AspNetCore.WebUtilities [Fact] public async Task CopyToAsyncWorks() { - var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray(); + // 4K is the lower bound on buffer sizes + var bufferSize = 4096; + var mostExpectedWrites = 8; + var data = Enumerable.Range(0, bufferSize * mostExpectedWrites).Select(b => (byte)b).ToArray(); var inner = new MemoryStream(data); using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory()); - var withoutBufferMs = new MemoryStream(); + var withoutBufferMs = new NumberOfWritesMemoryStream(); await stream.CopyToAsync(withoutBufferMs); - var withBufferMs = new MemoryStream(); + var withBufferMs = new NumberOfWritesMemoryStream(); stream.Position = 0; await stream.CopyToAsync(withBufferMs); Assert.Equal(data, withoutBufferMs.ToArray()); + Assert.Equal(mostExpectedWrites, withoutBufferMs.NumberOfWrites); Assert.Equal(data, withBufferMs.ToArray()); + Assert.InRange(withBufferMs.NumberOfWrites, 1, mostExpectedWrites); } [Fact] public async Task CopyToAsyncWorksWithFileThreshold() { - var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray(); + // 4K is the lower bound on buffer sizes + var bufferSize = 4096; + var mostExpectedWrites = 8; + var data = Enumerable.Range(0, bufferSize * mostExpectedWrites).Select(b => (byte)b).Reverse().ToArray(); var inner = new MemoryStream(data); using var stream = new FileBufferingReadStream(inner, 100, bufferLimit: null, GetCurrentDirectory()); - var withoutBufferMs = new MemoryStream(); + var withoutBufferMs = new NumberOfWritesMemoryStream(); await stream.CopyToAsync(withoutBufferMs); - var withBufferMs = new MemoryStream(); + var withBufferMs = new NumberOfWritesMemoryStream(); stream.Position = 0; await stream.CopyToAsync(withBufferMs); Assert.Equal(data, withoutBufferMs.ToArray()); + Assert.Equal(mostExpectedWrites, withoutBufferMs.NumberOfWrites); Assert.Equal(data, withBufferMs.ToArray()); + Assert.InRange(withBufferMs.NumberOfWrites, 1, mostExpectedWrites); } [Fact] @@ -486,5 +495,22 @@ namespace Microsoft.AspNetCore.WebUtilities { return AppContext.BaseDirectory; } + + private class NumberOfWritesMemoryStream : MemoryStream + { + public int NumberOfWrites { get; set; } + + public override void Write(byte[] buffer, int offset, int count) + { + NumberOfWrites++; + base.Write(buffer, offset, count); + } + + public override void Write(ReadOnlySpan source) + { + NumberOfWrites++; + base.Write(source); + } + } } }