Implement CopyToAsync in the FileBufferingReadStream (#24499)

* Implement CopyToAsync in the FileBufferingReadStream
- overrride Span and Memory overloads and implement array overloads in terms of those overloads.
- Implemented CopyToAsync (but not CopyTo)
- Added tests

Fixes #24032
This commit is contained in:
David Fowler 2020-08-03 14:14:42 -07:00 committed by GitHub
parent e31998c94b
commit 1f5149a663
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 268 additions and 34 deletions

View File

@ -208,39 +208,41 @@ namespace Microsoft.AspNetCore.WebUtilities
FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan);
}
public override int Read(byte[] buffer, int offset, int count)
public override int Read(Span<byte> buffer)
{
ThrowIfDisposed();
if (_buffer.Position < _buffer.Length || _completelyBuffered)
{
// Just read from the buffer
return _buffer.Read(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position));
return _buffer.Read(buffer);
}
int read = _inner.Read(buffer, offset, count);
var read = _inner.Read(buffer);
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}
if (_inMemory && _buffer.Length + read > _memoryThreshold)
// We're about to go over the threshold, switch to a file
if (_inMemory && _memoryThreshold - read < _buffer.Length)
{
_inMemory = false;
var oldBuffer = _buffer;
_buffer = CreateTempFile();
if (_rentedBuffer == null)
{
// Copy data from the in memory buffer to the file stream using a pooled buffer
oldBuffer.Position = 0;
var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize));
try
{
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
var copyRead = oldBuffer.Read(rentedBuffer);
while (copyRead > 0)
{
_buffer.Write(rentedBuffer, 0, copyRead);
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
_buffer.Write(rentedBuffer.AsSpan(0, copyRead));
copyRead = oldBuffer.Read(rentedBuffer);
}
}
finally
@ -250,7 +252,7 @@ namespace Microsoft.AspNetCore.WebUtilities
}
else
{
_buffer.Write(_rentedBuffer, 0, (int)oldBuffer.Length);
_buffer.Write(_rentedBuffer.AsSpan(0, (int)oldBuffer.Length));
_bytePool.Return(_rentedBuffer);
_rentedBuffer = null;
}
@ -258,7 +260,7 @@ namespace Microsoft.AspNetCore.WebUtilities
if (read > 0)
{
_buffer.Write(buffer, offset, read);
_buffer.Write(buffer.Slice(0, read));
}
else
{
@ -268,24 +270,34 @@ namespace Microsoft.AspNetCore.WebUtilities
return read;
}
public override async Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
public override int Read(byte[] buffer, int offset, int count)
{
return Read(buffer.AsSpan(offset, count));
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return ReadAsync(buffer.AsMemory(offset, count), cancellationToken).AsTask();
}
public override async ValueTask<int> ReadAsync(Memory<byte> buffer, CancellationToken cancellationToken = default)
{
ThrowIfDisposed();
if (_buffer.Position < _buffer.Length || _completelyBuffered)
{
// Just read from the buffer
return await _buffer.ReadAsync(buffer, offset, (int)Math.Min(count, _buffer.Length - _buffer.Position), cancellationToken);
return await _buffer.ReadAsync(buffer, cancellationToken);
}
int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
var read = await _inner.ReadAsync(buffer, cancellationToken);
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}
if (_inMemory && _buffer.Length + read > _memoryThreshold)
if (_inMemory && _memoryThreshold - read < _buffer.Length)
{
_inMemory = false;
var oldBuffer = _buffer;
@ -297,11 +309,11 @@ namespace Microsoft.AspNetCore.WebUtilities
try
{
// oldBuffer is a MemoryStream, no need to do async reads.
var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
var copyRead = oldBuffer.Read(rentedBuffer);
while (copyRead > 0)
{
await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken);
copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length);
await _buffer.WriteAsync(rentedBuffer.AsMemory(0, copyRead), cancellationToken);
copyRead = oldBuffer.Read(rentedBuffer);
}
}
finally
@ -311,7 +323,7 @@ namespace Microsoft.AspNetCore.WebUtilities
}
else
{
await _buffer.WriteAsync(_rentedBuffer, 0, (int)oldBuffer.Length, cancellationToken);
await _buffer.WriteAsync(_rentedBuffer.AsMemory(0, (int)oldBuffer.Length), cancellationToken);
_bytePool.Return(_rentedBuffer);
_rentedBuffer = null;
}
@ -319,7 +331,7 @@ namespace Microsoft.AspNetCore.WebUtilities
if (read > 0)
{
await _buffer.WriteAsync(buffer, offset, read, cancellationToken);
await _buffer.WriteAsync(buffer.Slice(0, read), cancellationToken);
}
else
{
@ -349,6 +361,39 @@ namespace Microsoft.AspNetCore.WebUtilities
throw new NotSupportedException();
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
// If we're completed buffered then copy from the underlying source
if (_completelyBuffered)
{
return _buffer.CopyToAsync(destination, bufferSize, cancellationToken);
}
async Task CopyToAsyncImpl()
{
// At least a 4K buffer
byte[] buffer = _bytePool.Rent(Math.Min(bufferSize, 4096));
try
{
while (true)
{
int bytesRead = await ReadAsync(buffer, cancellationToken);
if (bytesRead == 0)
{
break;
}
await destination.WriteAsync(buffer.AsMemory(0, bytesRead), cancellationToken);
}
}
finally
{
_bytePool.Return(buffer);
}
}
return CopyToAsyncImpl();
}
protected override void Dispose(bool disposing)
{
if (!_disposed)

View File

@ -4,6 +4,7 @@
using System;
using System.Buffers;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Moq;
@ -157,7 +158,6 @@ namespace Microsoft.AspNetCore.WebUtilities
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}
Assert.False(File.Exists(tempFileName));
@ -287,7 +287,6 @@ namespace Microsoft.AspNetCore.WebUtilities
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}
Assert.False(File.Exists(tempFileName));
@ -351,6 +350,138 @@ namespace Microsoft.AspNetCore.WebUtilities
Assert.False(File.Exists(tempFileName));
}
[Fact]
public async Task CopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
await stream.CopyToAsync(withoutBufferMs);
var withBufferMs = new MemoryStream();
stream.Position = 0;
await stream.CopyToAsync(withBufferMs);
Assert.Equal(data, withoutBufferMs.ToArray());
Assert.Equal(data, withBufferMs.ToArray());
}
[Fact]
public async Task CopyToAsyncWorksWithFileThreshold()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).Reverse().ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 100, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
await stream.CopyToAsync(withoutBufferMs);
var withBufferMs = new MemoryStream();
stream.Position = 0;
await stream.CopyToAsync(withBufferMs);
Assert.Equal(data, withoutBufferMs.ToArray());
Assert.Equal(data, withBufferMs.ToArray());
}
[Fact]
public async Task ReadAsyncThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
await stream.ReadAsync(buffer);
await stream.CopyToAsync(withoutBufferMs);
Assert.Equal(data.AsMemory(0, 100).ToArray(), buffer);
Assert.Equal(data.AsMemory(100).ToArray(), withoutBufferMs.ToArray());
}
[Fact]
public async Task ReadThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read = stream.Read(buffer);
await stream.CopyToAsync(withoutBufferMs);
Assert.Equal(100, read);
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
Assert.Equal(data.AsMemory(read).ToArray(), withoutBufferMs.ToArray());
}
[Fact]
public async Task ReadThenSeekThenCopyToAsyncWorks()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read = stream.Read(buffer);
stream.Position = 0;
await stream.CopyToAsync(withoutBufferMs);
Assert.Equal(100, read);
Assert.Equal(data.AsMemory(0, read).ToArray(), buffer);
Assert.Equal(data.ToArray(), withoutBufferMs.ToArray());
}
[Fact]
public void PartialReadThenSeekReplaysBuffer()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read1 = stream.Read(buffer);
stream.Position = 0;
var buffer2 = new byte[200];
var read2 = stream.Read(buffer2);
Assert.Equal(100, read1);
Assert.Equal(100, read2);
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
}
[Fact]
public async Task PartialReadAsyncThenSeekReplaysBuffer()
{
var data = Enumerable.Range(0, 1024).Select(b => (byte)b).ToArray();
var inner = new MemoryStream(data);
using var stream = new FileBufferingReadStream(inner, 1024 * 1024, bufferLimit: null, GetCurrentDirectory());
var withoutBufferMs = new MemoryStream();
var buffer = new byte[100];
var read1 = await stream.ReadAsync(buffer);
stream.Position = 0;
var buffer2 = new byte[200];
var read2 = await stream.ReadAsync(buffer2);
Assert.Equal(100, read1);
Assert.Equal(100, read2);
Assert.Equal(data.AsMemory(0, read1).ToArray(), buffer);
Assert.Equal(data.AsMemory(0, read2).ToArray(), buffer2.AsMemory(0, read2).ToArray());
}
private static string GetCurrentDirectory()
{
return AppContext.BaseDirectory;

View File

@ -497,8 +497,8 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
var content = "{\"name\": \"Test\"}";
var contentBytes = Encoding.UTF8.GetBytes(content);
var httpContext = GetHttpContext(contentBytes);
var testBufferedReadStream = new Mock<FileBufferingReadStream>(httpContext.Request.Body, 1024) { CallBase = true };
httpContext.Request.Body = testBufferedReadStream.Object;
var testBufferedReadStream = new VerifyDisposeFileBufferingReadStream(httpContext.Request.Body, 1024);
httpContext.Request.Body = testBufferedReadStream;
var formatterContext = CreateInputFormatterContext(typeof(ComplexModel), httpContext);
@ -508,8 +508,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
// Assert
var userModel = Assert.IsType<ComplexModel>(result.Model);
Assert.Equal("Test", userModel.Name);
testBufferedReadStream.Verify(v => v.DisposeAsync(), Times.Never());
Assert.False(testBufferedReadStream.Disposed);
}
[Fact]
@ -635,5 +634,25 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
public byte Small { get; set; }
}
private class VerifyDisposeFileBufferingReadStream : FileBufferingReadStream
{
public bool Disposed { get; private set; }
public VerifyDisposeFileBufferingReadStream(Stream inner, int memoryThreshold) : base(inner, memoryThreshold)
{
}
protected override void Dispose(bool disposing)
{
Disposed = true;
base.Dispose(disposing);
}
public override ValueTask DisposeAsync()
{
Disposed = true;
return base.DisposeAsync();
}
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;
@ -182,8 +183,8 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
var contentBytes = Encoding.UTF8.GetBytes(input);
var httpContext = new DefaultHttpContext();
var testBufferedReadStream = new Mock<FileBufferingReadStream>(new MemoryStream(contentBytes), 1024) { CallBase = true };
httpContext.Request.Body = testBufferedReadStream.Object;
var testBufferedReadStream = new VerifyDisposeFileBufferingReadStream(new MemoryStream(contentBytes), 1024);
httpContext.Request.Body = testBufferedReadStream;
var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne));
// Act
@ -196,8 +197,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
Assert.Equal(expectedInt, model.SampleInt);
Assert.Equal(expectedString, model.sampleString);
testBufferedReadStream.Verify(v => v.DisposeAsync(), Times.Never());
Assert.False(testBufferedReadStream.Disposed);
}
[Fact]
@ -773,5 +773,25 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
// do not do anything
}
}
private class VerifyDisposeFileBufferingReadStream : FileBufferingReadStream
{
public bool Disposed { get; private set; }
public VerifyDisposeFileBufferingReadStream(Stream inner, int memoryThreshold) : base(inner, memoryThreshold)
{
}
protected override void Dispose(bool disposing)
{
Disposed = true;
base.Dispose(disposing);
}
public override ValueTask DisposeAsync()
{
Disposed = true;
return base.DisposeAsync();
}
}
}
}

View File

@ -638,8 +638,8 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
var contentBytes = Encoding.UTF8.GetBytes(input);
var httpContext = new DefaultHttpContext();
var testBufferedReadStream = new Mock<FileBufferingReadStream>(new MemoryStream(contentBytes), 1024) { CallBase = true };
httpContext.Request.Body = testBufferedReadStream.Object;
var testBufferedReadStream = new VerifyDisposeFileBufferingReadStream(new MemoryStream(contentBytes), 1024);
httpContext.Request.Body = testBufferedReadStream;
var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne));
// Act
@ -652,8 +652,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
Assert.Equal(expectedInt, model.SampleInt);
Assert.Equal(expectedString, model.sampleString);
testBufferedReadStream.Verify(v => v.DisposeAsync(), Times.Never());
Assert.False(testBufferedReadStream.Disposed);
}
private InputFormatterContext GetInputFormatterContext(byte[] contentBytes, Type modelType)
@ -713,5 +712,25 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
// do not do anything
}
}
private class VerifyDisposeFileBufferingReadStream : FileBufferingReadStream
{
public bool Disposed { get; private set; }
public VerifyDisposeFileBufferingReadStream(Stream inner, int memoryThreshold) : base(inner, memoryThreshold)
{
}
protected override void Dispose(bool disposing)
{
Disposed = true;
base.Dispose(disposing);
}
public override ValueTask DisposeAsync()
{
Disposed = true;
return base.DisposeAsync();
}
}
}
}