From ce408a999ea271f8b2083cf5019226a99f00fd0c Mon Sep 17 00:00:00 2001 From: Chris R Date: Thu, 17 Mar 2016 15:17:02 -0700 Subject: [PATCH] #578 Do not buffer the request body by default when reading forms. --- .../BufferingHelper.cs | 23 ++- .../Features/FormFeature.cs | 24 ++- .../Features/HttpResponseFeature.cs | 6 +- .../FakeResponseFeature.cs | 30 ++++ .../FormFeatureTests.cs | 152 +++++++++++++----- .../NonSeekableReadStream.cs | 72 +++++++++ 6 files changed, 261 insertions(+), 46 deletions(-) create mode 100644 test/Microsoft.AspNetCore.Http.Tests/FakeResponseFeature.cs create mode 100644 test/Microsoft.AspNetCore.Http.Tests/NonSeekableReadStream.cs diff --git a/src/Microsoft.AspNetCore.Http/BufferingHelper.cs b/src/Microsoft.AspNetCore.Http/BufferingHelper.cs index 80f95ecbf1..f7929f1532 100644 --- a/src/Microsoft.AspNetCore.Http/BufferingHelper.cs +++ b/src/Microsoft.AspNetCore.Http/BufferingHelper.cs @@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Http.Internal if (_tempDirectory == null) { // Look for folders in the following order. - var temp = Environment.GetEnvironmentVariable("ASPNET_TEMP") ?? // ASPNET_TEMP - User set temporary location. + var temp = Environment.GetEnvironmentVariable("ASPNETCORE_TEMP") ?? // ASPNETCORE_TEMP - User set temporary location. Path.GetTempPath(); // Fall back. if (!Directory.Exists(temp)) @@ -54,5 +54,26 @@ namespace Microsoft.AspNetCore.Http.Internal } return request; } + + public static MultipartSection EnableRewind(this MultipartSection section, Action registerForDispose, int bufferThreshold = DefaultBufferThreshold) + { + if (section == null) + { + throw new ArgumentNullException(nameof(section)); + } + if (registerForDispose == null) + { + throw new ArgumentNullException(nameof(registerForDispose)); + } + + var body = section.Body; + if (!body.CanSeek) + { + var fileStream = new FileBufferingReadStream(body, bufferThreshold, _getTempDirectory); + section.Body = fileStream; + registerForDispose(fileStream); + } + return section; + } } } \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Http/Features/FormFeature.cs b/src/Microsoft.AspNetCore.Http/Features/FormFeature.cs index 004e2892ce..1a9f00fa2b 100644 --- a/src/Microsoft.AspNetCore.Http/Features/FormFeature.cs +++ b/src/Microsoft.AspNetCore.Http/Features/FormFeature.cs @@ -118,8 +118,6 @@ namespace Microsoft.AspNetCore.Http.Features.Internal cancellationToken.ThrowIfCancellationRequested(); - _request.EnableRewind(); - FormCollection formFields = null; FormFileCollection files = null; @@ -146,16 +144,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal ContentDispositionHeaderValue.TryParse(section.ContentDisposition, out contentDisposition); if (HasFileContentDisposition(contentDisposition)) { + // Enable buffering for the file if not already done for the full body + section.EnableRewind(_request.HttpContext.Response.RegisterForDispose); // Find the end await section.Body.DrainAsync(cancellationToken); var name = HeaderUtilities.RemoveQuotes(contentDisposition.Name) ?? string.Empty; var fileName = HeaderUtilities.RemoveQuotes(contentDisposition.FileName) ?? string.Empty; - var file = new FormFile(_request.Body, section.BaseStreamOffset.Value, section.Body.Length, name, fileName) + FormFile file; + if (section.BaseStreamOffset.HasValue) { - Headers = new HeaderDictionary(section.Headers), - }; + // Relative reference to buffered request body + file = new FormFile(_request.Body, section.BaseStreamOffset.Value, section.Body.Length, name, fileName); + } + else + { + // Individually buffered file body + file = new FormFile(section.Body, 0, section.Body.Length, name, fileName); + } + file.Headers = new HeaderDictionary(section.Headers); + if (files == null) { files = new FormFileCollection(); @@ -194,7 +203,10 @@ namespace Microsoft.AspNetCore.Http.Features.Internal } // Rewind so later readers don't have to. - _request.Body.Seek(0, SeekOrigin.Begin); + if (_request.Body.CanSeek) + { + _request.Body.Seek(0, SeekOrigin.Begin); + } if (formFields != null) { diff --git a/src/Microsoft.AspNetCore.Http/Features/HttpResponseFeature.cs b/src/Microsoft.AspNetCore.Http/Features/HttpResponseFeature.cs index c40074b92b..ef8f845ca1 100644 --- a/src/Microsoft.AspNetCore.Http/Features/HttpResponseFeature.cs +++ b/src/Microsoft.AspNetCore.Http/Features/HttpResponseFeature.cs @@ -25,17 +25,17 @@ namespace Microsoft.AspNetCore.Http.Features.Internal public Stream Body { get; set; } - public bool HasStarted + public virtual bool HasStarted { get { return false; } } - public void OnStarting(Func callback, object state) + public virtual void OnStarting(Func callback, object state) { throw new NotImplementedException(); } - public void OnCompleted(Func callback, object state) + public virtual void OnCompleted(Func callback, object state) { throw new NotImplementedException(); } diff --git a/test/Microsoft.AspNetCore.Http.Tests/FakeResponseFeature.cs b/test/Microsoft.AspNetCore.Http.Tests/FakeResponseFeature.cs new file mode 100644 index 0000000000..10cd5cc6d5 --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/FakeResponseFeature.cs @@ -0,0 +1,30 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features.Internal; + +namespace Microsoft.AspNetCore.Http.Features.Internal +{ + public class FakeResponseFeature : HttpResponseFeature + { + List, object>> _onCompletedCallbacks = new List, object>>(); + + public override void OnCompleted(Func callback, object state) + { + _onCompletedCallbacks.Add(new Tuple, object>(callback, state)); + } + + public async Task CompleteAsync() + { + var callbacks = _onCompletedCallbacks; + _onCompletedCallbacks = null; + foreach (var callback in callbacks) + { + await callback.Item1(callback.Item2); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Http.Tests/FormFeatureTests.cs b/test/Microsoft.AspNetCore.Http.Tests/FormFeatureTests.cs index 0ea7fe9681..280e77774f 100644 --- a/test/Microsoft.AspNetCore.Http.Tests/FormFeatureTests.cs +++ b/test/Microsoft.AspNetCore.Http.Tests/FormFeatureTests.cs @@ -13,68 +13,94 @@ namespace Microsoft.AspNetCore.Http.Features.Internal { public class FormFeatureTests { - [Fact] - public async Task ReadFormAsync_SimpleData_ReturnsParsedFormCollection() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_SimpleData_ReturnsParsedFormCollection(bool bufferRequest) { - // Arrange var formContent = Encoding.UTF8.GetBytes("foo=bar&baz=2"); var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); context.Request.ContentType = "application/x-www-form-urlencoded; charset=utf-8"; - context.Request.Body = new MemoryStream(formContent); + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } // Not cached yet var formFeature = context.Features.Get(); Assert.Null(formFeature); - // Act var formCollection = await context.Request.ReadFormAsync(); - // Assert Assert.Equal("bar", formCollection["foo"]); Assert.Equal("2", formCollection["baz"]); - Assert.Equal(0, context.Request.Body.Position); - Assert.True(context.Request.Body.CanSeek); + Assert.Equal(bufferRequest, context.Request.Body.CanSeek); + if (bufferRequest) + { + Assert.Equal(0, context.Request.Body.Position); + } // Cached formFeature = context.Features.Get(); Assert.NotNull(formFeature); Assert.NotNull(formFeature.Form); Assert.Same(formFeature.Form, formCollection); + + // Cleanup + await responseFeature.CompleteAsync(); } - [Fact] - public async Task ReadFormAsync_EmptyKeyAtEndAllowed() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_EmptyKeyAtEndAllowed(bool bufferRequest) { - // Arrange var formContent = Encoding.UTF8.GetBytes("=bar"); - var body = new MemoryStream(formContent); + Stream body = new MemoryStream(formContent); + if (!bufferRequest) + { + body = new NonSeekableReadStream(body); + } var formCollection = await FormReader.ReadFormAsync(body); - // Assert Assert.Equal("bar", formCollection[""].FirstOrDefault()); } - [Fact] - public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed(bool bufferRequest) { - // Arrange var formContent = Encoding.UTF8.GetBytes("=bar&baz=2"); - var body = new MemoryStream(formContent); + Stream body = new MemoryStream(formContent); + if (!bufferRequest) + { + body = new NonSeekableReadStream(body); + } var formCollection = await FormReader.ReadFormAsync(body); - // Assert Assert.Equal("bar", formCollection[""].FirstOrDefault()); Assert.Equal("2", formCollection["baz"].FirstOrDefault()); } - [Fact] - public async Task ReadFormAsync_EmptyValuedAtEndAllowed() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_EmptyValuedAtEndAllowed(bool bufferRequest) { // Arrange var formContent = Encoding.UTF8.GetBytes("foo="); - var body = new MemoryStream(formContent); + Stream body = new MemoryStream(formContent); + if (!bufferRequest) + { + body = new NonSeekableReadStream(body); + } var formCollection = await FormReader.ReadFormAsync(body); @@ -82,12 +108,18 @@ namespace Microsoft.AspNetCore.Http.Features.Internal Assert.Equal("", formCollection["foo"].FirstOrDefault()); } - [Fact] - public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed(bool bufferRequest) { // Arrange var formContent = Encoding.UTF8.GetBytes("foo=&baz=2"); - var body = new MemoryStream(formContent); + Stream body = new MemoryStream(formContent); + if (!bufferRequest) + { + body = new NonSeekableReadStream(body); + } var formCollection = await FormReader.ReadFormAsync(body); @@ -125,13 +157,22 @@ namespace Microsoft.AspNetCore.Http.Features.Internal "Hello World\r\n" + "--WebKitFormBoundary5pDRpGheQXaM8k3T--"; - [Fact] - public async Task ReadForm_EmptyMultipart_ReturnsParsedFormCollection() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadForm_EmptyMultipart_ReturnsParsedFormCollection(bool bufferRequest) { var formContent = Encoding.UTF8.GetBytes(EmptyMultipartForm); var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); context.Request.ContentType = MultipartContentType; - context.Request.Body = new MemoryStream(formContent); + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } // Not cached yet var formFeature = context.Features.Get(); @@ -152,15 +193,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal Assert.Equal(0, formCollection.Count); Assert.NotNull(formCollection.Files); Assert.Equal(0, formCollection.Files.Count); + + // Cleanup + await responseFeature.CompleteAsync(); } - [Fact] - public async Task ReadForm_MultipartWithField_ReturnsParsedFormCollection() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadForm_MultipartWithField_ReturnsParsedFormCollection(bool bufferRequest) { var formContent = Encoding.UTF8.GetBytes(MultipartFormWithField); var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); context.Request.ContentType = MultipartContentType; - context.Request.Body = new MemoryStream(formContent); + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } // Not cached yet var formFeature = context.Features.Get(); @@ -183,15 +236,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal Assert.NotNull(formCollection.Files); Assert.Equal(0, formCollection.Files.Count); + + // Cleanup + await responseFeature.CompleteAsync(); } - [Fact] - public async Task ReadFormAsync_MultipartWithFile_ReturnsParsedFormCollection() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_MultipartWithFile_ReturnsParsedFormCollection(bool bufferRequest) { var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFile); var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); context.Request.ContentType = MultipartContentType; - context.Request.Body = new MemoryStream(formContent); + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } // Not cached yet var formFeature = context.Features.Get(); @@ -222,18 +287,30 @@ namespace Microsoft.AspNetCore.Http.Features.Internal var body = file.OpenReadStream(); using (var reader = new StreamReader(body)) { + Assert.True(body.CanSeek); var content = reader.ReadToEnd(); Assert.Equal(content, "Hello World"); } + + await responseFeature.CompleteAsync(); } - [Fact] - public async Task ReadFormAsync_MultipartWithFieldAndFile_ReturnsParsedFormCollection() + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadFormAsync_MultipartWithFieldAndFile_ReturnsParsedFormCollection(bool bufferRequest) { var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFieldAndFile); var context = new DefaultHttpContext(); + var responseFeature = new FakeResponseFeature(); + context.Features.Set(responseFeature); context.Request.ContentType = MultipartContentType; - context.Request.Body = new MemoryStream(formContent); + context.Request.Body = new NonSeekableReadStream(formContent); + + if (bufferRequest) + { + context.Request.EnableRewind(); + } // Not cached yet var formFeature = context.Features.Get(); @@ -263,9 +340,12 @@ namespace Microsoft.AspNetCore.Http.Features.Internal var body = file.OpenReadStream(); using (var reader = new StreamReader(body)) { + Assert.True(body.CanSeek); var content = reader.ReadToEnd(); Assert.Equal(content, "Hello World"); } + + await responseFeature.CompleteAsync(); } } } diff --git a/test/Microsoft.AspNetCore.Http.Tests/NonSeekableReadStream.cs b/test/Microsoft.AspNetCore.Http.Tests/NonSeekableReadStream.cs new file mode 100644 index 0000000000..2e6af5ab4d --- /dev/null +++ b/test/Microsoft.AspNetCore.Http.Tests/NonSeekableReadStream.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Http.Features.Internal +{ + public class NonSeekableReadStream : Stream + { + private Stream _inner; + + public NonSeekableReadStream(byte[] data) + : this(new MemoryStream(data)) + { + } + + public NonSeekableReadStream(Stream inner) + { + _inner = inner; + } + + public override bool CanRead => _inner.CanRead; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override void Flush() + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + return _inner.Read(buffer, offset, count); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _inner.ReadAsync(buffer, offset, count, cancellationToken); + } + } +}