#578 Do not buffer the request body by default when reading forms.

This commit is contained in:
Chris R 2016-03-17 15:17:02 -07:00
parent 5da3673777
commit ce408a999e
6 changed files with 261 additions and 46 deletions

View File

@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.Http.Internal
if (_tempDirectory == null)
{
// Look for folders in the following order.
var temp = Environment.GetEnvironmentVariable("ASPNET_TEMP") ?? // ASPNET_TEMP - User set temporary location.
var temp = Environment.GetEnvironmentVariable("ASPNETCORE_TEMP") ?? // ASPNETCORE_TEMP - User set temporary location.
Path.GetTempPath(); // Fall back.
if (!Directory.Exists(temp))
@ -54,5 +54,26 @@ namespace Microsoft.AspNetCore.Http.Internal
}
return request;
}
public static MultipartSection EnableRewind(this MultipartSection section, Action<IDisposable> registerForDispose, int bufferThreshold = DefaultBufferThreshold)
{
if (section == null)
{
throw new ArgumentNullException(nameof(section));
}
if (registerForDispose == null)
{
throw new ArgumentNullException(nameof(registerForDispose));
}
var body = section.Body;
if (!body.CanSeek)
{
var fileStream = new FileBufferingReadStream(body, bufferThreshold, _getTempDirectory);
section.Body = fileStream;
registerForDispose(fileStream);
}
return section;
}
}
}

View File

@ -118,8 +118,6 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
cancellationToken.ThrowIfCancellationRequested();
_request.EnableRewind();
FormCollection formFields = null;
FormFileCollection files = null;
@ -146,16 +144,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
ContentDispositionHeaderValue.TryParse(section.ContentDisposition, out contentDisposition);
if (HasFileContentDisposition(contentDisposition))
{
// Enable buffering for the file if not already done for the full body
section.EnableRewind(_request.HttpContext.Response.RegisterForDispose);
// Find the end
await section.Body.DrainAsync(cancellationToken);
var name = HeaderUtilities.RemoveQuotes(contentDisposition.Name) ?? string.Empty;
var fileName = HeaderUtilities.RemoveQuotes(contentDisposition.FileName) ?? string.Empty;
var file = new FormFile(_request.Body, section.BaseStreamOffset.Value, section.Body.Length, name, fileName)
FormFile file;
if (section.BaseStreamOffset.HasValue)
{
Headers = new HeaderDictionary(section.Headers),
};
// Relative reference to buffered request body
file = new FormFile(_request.Body, section.BaseStreamOffset.Value, section.Body.Length, name, fileName);
}
else
{
// Individually buffered file body
file = new FormFile(section.Body, 0, section.Body.Length, name, fileName);
}
file.Headers = new HeaderDictionary(section.Headers);
if (files == null)
{
files = new FormFileCollection();
@ -194,7 +203,10 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
}
// Rewind so later readers don't have to.
_request.Body.Seek(0, SeekOrigin.Begin);
if (_request.Body.CanSeek)
{
_request.Body.Seek(0, SeekOrigin.Begin);
}
if (formFields != null)
{

View File

@ -25,17 +25,17 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
public Stream Body { get; set; }
public bool HasStarted
public virtual bool HasStarted
{
get { return false; }
}
public void OnStarting(Func<object, Task> callback, object state)
public virtual void OnStarting(Func<object, Task> callback, object state)
{
throw new NotImplementedException();
}
public void OnCompleted(Func<object, Task> callback, object state)
public virtual void OnCompleted(Func<object, Task> callback, object state)
{
throw new NotImplementedException();
}

View File

@ -0,0 +1,30 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features.Internal;
namespace Microsoft.AspNetCore.Http.Features.Internal
{
public class FakeResponseFeature : HttpResponseFeature
{
List<Tuple<Func<object, Task>, object>> _onCompletedCallbacks = new List<Tuple<Func<object, Task>, object>>();
public override void OnCompleted(Func<object, Task> callback, object state)
{
_onCompletedCallbacks.Add(new Tuple<Func<object, Task>, object>(callback, state));
}
public async Task CompleteAsync()
{
var callbacks = _onCompletedCallbacks;
_onCompletedCallbacks = null;
foreach (var callback in callbacks)
{
await callback.Item1(callback.Item2);
}
}
}
}

View File

@ -13,68 +13,94 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
{
public class FormFeatureTests
{
[Fact]
public async Task ReadFormAsync_SimpleData_ReturnsParsedFormCollection()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_SimpleData_ReturnsParsedFormCollection(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("foo=bar&baz=2");
var context = new DefaultHttpContext();
var responseFeature = new FakeResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
context.Request.ContentType = "application/x-www-form-urlencoded; charset=utf-8";
context.Request.Body = new MemoryStream(formContent);
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
// Act
var formCollection = await context.Request.ReadFormAsync();
// Assert
Assert.Equal("bar", formCollection["foo"]);
Assert.Equal("2", formCollection["baz"]);
Assert.Equal(0, context.Request.Body.Position);
Assert.True(context.Request.Body.CanSeek);
Assert.Equal(bufferRequest, context.Request.Body.CanSeek);
if (bufferRequest)
{
Assert.Equal(0, context.Request.Body.Position);
}
// Cached
formFeature = context.Features.Get<IFormFeature>();
Assert.NotNull(formFeature);
Assert.NotNull(formFeature.Form);
Assert.Same(formFeature.Form, formCollection);
// Cleanup
await responseFeature.CompleteAsync();
}
[Fact]
public async Task ReadFormAsync_EmptyKeyAtEndAllowed()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyAtEndAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("=bar");
var body = new MemoryStream(formContent);
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
// Assert
Assert.Equal("bar", formCollection[""].FirstOrDefault());
}
[Fact]
public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("=bar&baz=2");
var body = new MemoryStream(formContent);
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
// Assert
Assert.Equal("bar", formCollection[""].FirstOrDefault());
Assert.Equal("2", formCollection["baz"].FirstOrDefault());
}
[Fact]
public async Task ReadFormAsync_EmptyValuedAtEndAllowed()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedAtEndAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("foo=");
var body = new MemoryStream(formContent);
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
@ -82,12 +108,18 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
Assert.Equal("", formCollection["foo"].FirstOrDefault());
}
[Fact]
public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("foo=&baz=2");
var body = new MemoryStream(formContent);
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
@ -125,13 +157,22 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
"<html><body>Hello World</body></html>\r\n" +
"--WebKitFormBoundary5pDRpGheQXaM8k3T--";
[Fact]
public async Task ReadForm_EmptyMultipart_ReturnsParsedFormCollection()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadForm_EmptyMultipart_ReturnsParsedFormCollection(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes(EmptyMultipartForm);
var context = new DefaultHttpContext();
var responseFeature = new FakeResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
context.Request.ContentType = MultipartContentType;
context.Request.Body = new MemoryStream(formContent);
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
@ -152,15 +193,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
Assert.Equal(0, formCollection.Count);
Assert.NotNull(formCollection.Files);
Assert.Equal(0, formCollection.Files.Count);
// Cleanup
await responseFeature.CompleteAsync();
}
[Fact]
public async Task ReadForm_MultipartWithField_ReturnsParsedFormCollection()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadForm_MultipartWithField_ReturnsParsedFormCollection(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes(MultipartFormWithField);
var context = new DefaultHttpContext();
var responseFeature = new FakeResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
context.Request.ContentType = MultipartContentType;
context.Request.Body = new MemoryStream(formContent);
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
@ -183,15 +236,27 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
Assert.NotNull(formCollection.Files);
Assert.Equal(0, formCollection.Files.Count);
// Cleanup
await responseFeature.CompleteAsync();
}
[Fact]
public async Task ReadFormAsync_MultipartWithFile_ReturnsParsedFormCollection()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_MultipartWithFile_ReturnsParsedFormCollection(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFile);
var context = new DefaultHttpContext();
var responseFeature = new FakeResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
context.Request.ContentType = MultipartContentType;
context.Request.Body = new MemoryStream(formContent);
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
@ -222,18 +287,30 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
var body = file.OpenReadStream();
using (var reader = new StreamReader(body))
{
Assert.True(body.CanSeek);
var content = reader.ReadToEnd();
Assert.Equal(content, "<html><body>Hello World</body></html>");
}
await responseFeature.CompleteAsync();
}
[Fact]
public async Task ReadFormAsync_MultipartWithFieldAndFile_ReturnsParsedFormCollection()
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_MultipartWithFieldAndFile_ReturnsParsedFormCollection(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes(MultipartFormWithFieldAndFile);
var context = new DefaultHttpContext();
var responseFeature = new FakeResponseFeature();
context.Features.Set<IHttpResponseFeature>(responseFeature);
context.Request.ContentType = MultipartContentType;
context.Request.Body = new MemoryStream(formContent);
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
@ -263,9 +340,12 @@ namespace Microsoft.AspNetCore.Http.Features.Internal
var body = file.OpenReadStream();
using (var reader = new StreamReader(body))
{
Assert.True(body.CanSeek);
var content = reader.ReadToEnd();
Assert.Equal(content, "<html><body>Hello World</body></html>");
}
await responseFeature.CompleteAsync();
}
}
}

View File

@ -0,0 +1,72 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Http.Features.Internal
{
public class NonSeekableReadStream : Stream
{
private Stream _inner;
public NonSeekableReadStream(byte[] data)
: this(new MemoryStream(data))
{
}
public NonSeekableReadStream(Stream inner)
{
_inner = inner;
}
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
public override void Flush()
{
throw new NotImplementedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
return _inner.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _inner.ReadAsync(buffer, offset, count, cancellationToken);
}
}
}