diff --git a/src/Microsoft.AspNetCore.WebUtilities/FileBufferingReadStream.cs b/src/Microsoft.AspNetCore.WebUtilities/FileBufferingReadStream.cs index ada0c67bf4..3d10321c7f 100644 --- a/src/Microsoft.AspNetCore.WebUtilities/FileBufferingReadStream.cs +++ b/src/Microsoft.AspNetCore.WebUtilities/FileBufferingReadStream.cs @@ -192,10 +192,26 @@ namespace Microsoft.AspNetCore.WebUtilities if (_inMemory && _buffer.Length + read > _memoryThreshold) { _inMemory = false; + var oldBuffer = _buffer; _buffer = CreateTempFile(); - _buffer.Write(_rentedBuffer, 0, (int)_buffer.Length); - _bytePool.Return(_rentedBuffer); - _rentedBuffer = null; + if (_rentedBuffer == null) + { + oldBuffer.Position = 0; + var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize)); + var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + while (copyRead > 0) + { + _buffer.Write(rentedBuffer, 0, copyRead); + copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + } + _bytePool.Return(rentedBuffer); + } + else + { + _buffer.Write(_rentedBuffer, 0, (int)oldBuffer.Length); + _bytePool.Return(_rentedBuffer); + _rentedBuffer = null; + } } if (read > 0) @@ -272,10 +288,27 @@ namespace Microsoft.AspNetCore.WebUtilities if (_inMemory && _buffer.Length + read > _memoryThreshold) { _inMemory = false; + var oldBuffer = _buffer; _buffer = CreateTempFile(); - await _buffer.WriteAsync(_rentedBuffer, 0, (int)_buffer.Length, cancellationToken); - _bytePool.Return(_rentedBuffer); - _rentedBuffer = null; + if (_rentedBuffer == null) + { + oldBuffer.Position = 0; + var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize)); + // oldBuffer is a MemoryStream, no need to do async reads. + var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + while (copyRead > 0) + { + await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken); + copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + } + _bytePool.Return(rentedBuffer); + } + else + { + await _buffer.WriteAsync(_rentedBuffer, 0, (int)oldBuffer.Length, cancellationToken); + _bytePool.Return(_rentedBuffer); + _rentedBuffer = null; + } } if (read > 0) diff --git a/test/Microsoft.AspNetCore.Http.Tests/Features/FormFeatureTests.cs b/test/Microsoft.AspNetCore.Http.Tests/Features/FormFeatureTests.cs index e4f01b86bb..a1472eeefa 100644 --- a/test/Microsoft.AspNetCore.Http.Tests/Features/FormFeatureTests.cs +++ b/test/Microsoft.AspNetCore.Http.Tests/Features/FormFeatureTests.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.IO; using System.Linq; using System.Text; @@ -347,5 +348,129 @@ namespace Microsoft.AspNetCore.Http.Features await responseFeature.CompleteAsync(); } + + [Theory] + // FileBufferingReadStream transitions to disk storage after 30kb, and stops pooling buffers at 1mb. + [InlineData(true, 1024)] + [InlineData(false, 1024)] + [InlineData(true, 40 * 1024)] + [InlineData(false, 40 * 1024)] + [InlineData(true, 4 * 1024 * 1024)] + [InlineData(false, 4 * 1024 * 1024)] + public async Task ReadFormAsync_MultipartWithFieldAndMediumFile_ReturnsParsedFormCollection(bool bufferRequest, int fileSize) + { + var fileContents = CreateFile(fileSize); + var formContent = CreateMultipartWithFormAndFile(fileContents); + var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); + context.Request.ContentType = MultipartContentType; + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } + + // Not cached yet + var formFeature = context.Features.Get(); + Assert.Null(formFeature); + + var formCollection = await context.Request.ReadFormAsync(); + + Assert.NotNull(formCollection); + + // Cached + formFeature = context.Features.Get(); + Assert.NotNull(formFeature); + Assert.NotNull(formFeature.Form); + Assert.Same(formFeature.Form, formCollection); + Assert.Same(formCollection, context.Request.Form); + + // Content + Assert.Equal(1, formCollection.Count); + Assert.Equal("Foo", formCollection["description"]); + + Assert.NotNull(formCollection.Files); + Assert.Equal(1, formCollection.Files.Count); + + var file = formCollection.Files["myfile1"]; + Assert.Equal("text/html", file.ContentType); + Assert.Equal(@"form-data; name=""myfile1""; filename=""temp.html""", file.ContentDisposition); + using (var body = file.OpenReadStream()) + { + Assert.True(body.CanSeek); + CompareStreams(fileContents, body); + } + + await responseFeature.CompleteAsync(); + } + + private Stream CreateFile(int size) + { + var stream = new MemoryStream(size); + var bytes = Encoding.ASCII.GetBytes("HelloWorld_ABCDEFGHIJKLMNOPQRSTUVWXYZ.abcdefghijklmnopqrstuvwxyz,0123456789;"); + int written = 0; + while (written < size) + { + var toWrite = Math.Min(size - written, bytes.Length); + stream.Write(bytes, 0, toWrite); + written += toWrite; + } + stream.Position = 0; + return stream; + } + + private Stream CreateMultipartWithFormAndFile(Stream fileContents) + { + var stream = new MemoryStream(); + var header = +"--WebKitFormBoundary5pDRpGheQXaM8k3T\r\n" + +"Content-Disposition: form-data; name=\"description\"\r\n" + +"\r\n" + +"Foo\r\n" + +"--WebKitFormBoundary5pDRpGheQXaM8k3T\r\n" + +"Content-Disposition: form-data; name=\"myfile1\"; filename=\"temp.html\"\r\n" + +"Content-Type: text/html\r\n" + +"\r\n"; + var footer = +"\r\n--WebKitFormBoundary5pDRpGheQXaM8k3T--"; + + var bytes = Encoding.ASCII.GetBytes(header); + stream.Write(bytes, 0, bytes.Length); + + fileContents.CopyTo(stream); + fileContents.Position = 0; + + bytes = Encoding.ASCII.GetBytes(footer); + stream.Write(bytes, 0, bytes.Length); + stream.Position = 0; + return stream; + } + + private void CompareStreams(Stream streamA, Stream streamB) + { + Assert.Equal(streamA.Length, streamB.Length); + byte[] bytesA = new byte[1024], bytesB = new byte[1024]; + var readA = streamA.Read(bytesA, 0, bytesA.Length); + var readB = streamB.Read(bytesB, 0, bytesB.Length); + Assert.Equal(readA, readB); + var loops = 0; + while (readA > 0) + { + for (int i = 0; i < readA; i++) + { + if (bytesA[i] != bytesB[i]) + { + throw new Exception($"Value mismatch at loop {loops}, index {i}; A:{bytesA[i]}, B:{bytesB[i]}"); + } + } + + readA = streamA.Read(bytesA, 0, bytesA.Length); + readB = streamB.Read(bytesB, 0, bytesB.Length); + Assert.Equal(readA, readB); + loops++; + } + } } }