Merge branch 'release' into dev

This commit is contained in:
Chris R 2016-04-25 12:00:37 -07:00
commit 3e69df87f8
19 changed files with 894 additions and 235 deletions

View File

@ -14,7 +14,10 @@ namespace Microsoft.AspNetCore.Http.Features
{
public class FormFeature : IFormFeature
{
private static readonly FormOptions DefaultFormOptions = new FormOptions();
private readonly HttpRequest _request;
private readonly FormOptions _options;
private Task<IFormCollection> _parsedFormTask;
private IFormCollection _form;
@ -27,15 +30,24 @@ namespace Microsoft.AspNetCore.Http.Features
Form = form;
}
public FormFeature(HttpRequest request)
: this(request, DefaultFormOptions)
{
}
public FormFeature(HttpRequest request, FormOptions options)
{
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}
_request = request;
_options = options;
}
private MediaTypeHeaderValue ContentType
@ -118,6 +130,11 @@ namespace Microsoft.AspNetCore.Http.Features
cancellationToken.ThrowIfCancellationRequested();
if (_options.BufferBody)
{
_request.EnableRewind(_options.MemoryBufferThreshold, _options.BufferBodyLengthLimit);
}
FormCollection formFields = null;
FormFileCollection files = null;
@ -129,14 +146,27 @@ namespace Microsoft.AspNetCore.Http.Features
if (HasApplicationFormContentType(contentType))
{
var encoding = FilterEncoding(contentType.Encoding);
formFields = new FormCollection(await FormReader.ReadFormAsync(_request.Body, encoding, cancellationToken));
using (var formReader = new FormReader(_request.Body, encoding)
{
KeyCountLimit = _options.KeyCountLimit,
KeyLengthLimit = _options.KeyLengthLimit,
ValueLengthLimit = _options.ValueLengthLimit,
})
{
formFields = new FormCollection(await formReader.ReadFormAsync(cancellationToken));
}
}
else if (HasMultipartFormContentType(contentType))
{
var formAccumulator = new KeyValueAccumulator();
var boundary = GetBoundary(contentType);
var multipartReader = new MultipartReader(boundary, _request.Body);
var boundary = GetBoundary(contentType, _options.MultipartBoundaryLengthLimit);
var multipartReader = new MultipartReader(boundary, _request.Body)
{
HeadersCountLimit = _options.MultipartHeadersCountLimit,
HeadersLengthLimit = _options.MultipartHeadersLengthLimit,
BodyLengthLimit = _options.MultipartBodyLengthLimit,
};
var section = await multipartReader.ReadNextSectionAsync(cancellationToken);
while (section != null)
{
@ -145,7 +175,8 @@ namespace Microsoft.AspNetCore.Http.Features
if (HasFileContentDisposition(contentDisposition))
{
// Enable buffering for the file if not already done for the full body
section.EnableRewind(_request.HttpContext.Response.RegisterForDispose);
section.EnableRewind(_request.HttpContext.Response.RegisterForDispose,
_options.MemoryBufferThreshold, _options.MultipartBodyLengthLimit);
// Find the end
await section.Body.DrainAsync(cancellationToken);
@ -169,6 +200,10 @@ namespace Microsoft.AspNetCore.Http.Features
{
files = new FormFileCollection();
}
if (files.Count >= _options.KeyCountLimit)
{
throw new InvalidDataException($"Form key count limit {_options.KeyCountLimit} exceeded.");
}
files.Add(file);
}
else if (HasFormDataContentDisposition(contentDisposition))
@ -177,14 +212,20 @@ namespace Microsoft.AspNetCore.Http.Features
//
// value
// Do not limit the key name length here because the mulipart headers length limit is already in effect.
var key = HeaderUtilities.RemoveQuotes(contentDisposition.Name);
MediaTypeHeaderValue mediaType;
MediaTypeHeaderValue.TryParse(section.ContentType, out mediaType);
var encoding = FilterEncoding(mediaType?.Encoding);
using (var reader = new StreamReader(section.Body, encoding, detectEncodingFromByteOrderMarks: true, bufferSize: 1024, leaveOpen: true))
{
// The value length limit is enforced by MultipartBodyLengthLimit
var value = await reader.ReadToEndAsync();
formAccumulator.Append(key, value);
if (formAccumulator.Count > _options.KeyCountLimit)
{
throw new InvalidDataException($"Form key count limit {_options.KeyCountLimit} exceeded.");
}
}
}
else
@ -261,13 +302,17 @@ namespace Microsoft.AspNetCore.Http.Features
}
// Content-Type: multipart/form-data; boundary="----WebKitFormBoundarymx2fSWqWSd0OxQqq"
// TODO: Limit the length of boundary we accept. The spec says ~70 chars.
private static string GetBoundary(MediaTypeHeaderValue contentType)
// The spec says 70 characters is a reasonable limit.
private static string GetBoundary(MediaTypeHeaderValue contentType, int lengthLimit)
{
var boundary = HeaderUtilities.RemoveQuotes(contentType.Boundary);
if (string.IsNullOrWhiteSpace(boundary))
{
throw new InvalidOperationException("Missing content-type boundary.");
throw new InvalidDataException("Missing content-type boundary.");
}
if (boundary.Length > lengthLimit)
{
throw new InvalidDataException($"Multipart boundary length limit {lengthLimit} exceeded.");
}
return boundary;
}

View File

@ -0,0 +1,26 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.WebUtilities;
namespace Microsoft.AspNetCore.Http.Features
{
public class FormOptions
{
public const int DefaultMemoryBufferThreshold = 1024 * 64;
public const int DefaultBufferBodyLengthLimit = 1024 * 1024 * 128;
public const int DefaultMultipartBoundaryLengthLimit = 128;
public const long DefaultMultipartBodyLengthLimit = 1024 * 1024 * 128;
public bool BufferBody { get; set; } = false;
public int MemoryBufferThreshold { get; set; } = DefaultMemoryBufferThreshold;
public long BufferBodyLengthLimit { get; set; } = DefaultBufferBodyLengthLimit;
public int KeyCountLimit { get; set; } = FormReader.DefaultKeyCountLimit;
public int KeyLengthLimit { get; set; } = FormReader.DefaultKeyLengthLimit;
public int ValueLengthLimit { get; set; } = FormReader.DefaultValueLengthLimit;
public int MultipartBoundaryLengthLimit { get; set; } = DefaultMultipartBoundaryLengthLimit;
public int MultipartHeadersCountLimit { get; set; } = MultipartReader.DefaultHeadersCountLimit;
public int MultipartHeadersLengthLimit { get; set; } = MultipartReader.DefaultHeadersLengthLimit;
public long MultipartBodyLengthLimit { get; set; } = DefaultMultipartBodyLengthLimit;
}
}

View File

@ -5,6 +5,7 @@ using System;
using System.Text;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Http
{
@ -12,20 +13,26 @@ namespace Microsoft.AspNetCore.Http
{
private readonly ObjectPool<StringBuilder> _builderPool;
private readonly IHttpContextAccessor _httpContextAccessor;
private readonly FormOptions _formOptions;
public HttpContextFactory(ObjectPoolProvider poolProvider)
: this(poolProvider, httpContextAccessor: null)
public HttpContextFactory(ObjectPoolProvider poolProvider, IOptions<FormOptions> formOptions)
: this(poolProvider, formOptions, httpContextAccessor: null)
{
}
public HttpContextFactory(ObjectPoolProvider poolProvider, IHttpContextAccessor httpContextAccessor)
public HttpContextFactory(ObjectPoolProvider poolProvider, IOptions<FormOptions> formOptions, IHttpContextAccessor httpContextAccessor)
{
if (poolProvider == null)
{
throw new ArgumentNullException(nameof(poolProvider));
}
if (formOptions == null)
{
throw new ArgumentNullException(nameof(formOptions));
}
_builderPool = poolProvider.CreateStringBuilderPool();
_formOptions = formOptions.Value;
_httpContextAccessor = httpContextAccessor;
}
@ -45,6 +52,9 @@ namespace Microsoft.AspNetCore.Http
_httpContextAccessor.HttpContext = httpContext;
}
var formFeature = new FormFeature(httpContext.Request, _formOptions);
featureCollection.Set<IFormFeature>(formFeature);
return httpContext;
}

View File

@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Http.Internal
}
}
public static HttpRequest EnableRewind(this HttpRequest request, int bufferThreshold = DefaultBufferThreshold)
public static HttpRequest EnableRewind(this HttpRequest request, int bufferThreshold = DefaultBufferThreshold, long? bufferLimit = null)
{
if (request == null)
{
@ -48,14 +48,15 @@ namespace Microsoft.AspNetCore.Http.Internal
var body = request.Body;
if (!body.CanSeek)
{
var fileStream = new FileBufferingReadStream(body, bufferThreshold, _getTempDirectory);
var fileStream = new FileBufferingReadStream(body, bufferThreshold, bufferLimit, _getTempDirectory);
request.Body = fileStream;
request.HttpContext.Response.RegisterForDispose(fileStream);
}
return request;
}
public static MultipartSection EnableRewind(this MultipartSection section, Action<IDisposable> registerForDispose, int bufferThreshold = DefaultBufferThreshold)
public static MultipartSection EnableRewind(this MultipartSection section, Action<IDisposable> registerForDispose,
int bufferThreshold = DefaultBufferThreshold, long? bufferLimit = null)
{
if (section == null)
{
@ -69,7 +70,7 @@ namespace Microsoft.AspNetCore.Http.Internal
var body = section.Body;
if (!body.CanSeek)
{
var fileStream = new FileBufferingReadStream(body, bufferThreshold, _getTempDirectory);
var fileStream = new FileBufferingReadStream(body, bufferThreshold, bufferLimit, _getTempDirectory);
section.Body = fileStream;
registerForDispose(fileStream);
}

View File

@ -0,0 +1,48 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Features;
namespace Microsoft.AspNetCore.Http
{
public static class RequestFormReaderExtensions
{
/// <summary>
/// Read the request body as a form with the given options. These options will only be used
/// if the form has not already been read.
/// </summary>
/// <param name="request">The request.</param>
/// <param name="options">Options for reading the form.</param>
/// <param name="cancellationToken"></param>
/// <returns>The parsed form.</returns>
public static Task<IFormCollection> ReadFormAsync(this HttpRequest request, FormOptions options,
CancellationToken cancellationToken = new CancellationToken())
{
if (request == null)
{
throw new ArgumentNullException(nameof(request));
}
if (options == null)
{
throw new ArgumentNullException(nameof(options));
}
if (!request.HasFormContentType)
{
throw new InvalidOperationException("Incorrect Content-Type: " + request.ContentType);
}
var features = request.HttpContext.Features;
var formFeature = features.Get<IFormFeature>();
if (formFeature == null || formFeature.Form == null)
{
// We haven't read the form yet, replace the reader with one using our own options.
features.Set<IFormFeature>(new FormFeature(request, options));
}
return request.ReadFormAsync(cancellationToken);
}
}
}

View File

@ -21,6 +21,7 @@
"Microsoft.AspNetCore.Http.Abstractions": "1.0.0-*",
"Microsoft.AspNetCore.WebUtilities": "1.0.0-*",
"Microsoft.Extensions.ObjectPool": "1.0.0-*",
"Microsoft.Extensions.Options": "1.0.0-*",
"Microsoft.Net.Http.Headers": "1.0.0-*",
"System.Buffers": "4.0.0-*"
},

View File

@ -352,7 +352,7 @@ namespace Microsoft.AspNetCore.WebUtilities
{
if (builder.Length > lengthLimit)
{
throw new InvalidOperationException("Line length limit exceeded: " + lengthLimit.ToString());
throw new InvalidDataException($"Line length limit {lengthLimit} exceeded.");
}
ProcessLineChar(builder, ref foundCR, ref foundCRLF);
}
@ -372,7 +372,7 @@ namespace Microsoft.AspNetCore.WebUtilities
{
if (builder.Length > lengthLimit)
{
throw new InvalidOperationException("Line length limit exceeded: " + lengthLimit.ToString());
throw new InvalidDataException($"Line length limit {lengthLimit} exceeded.");
}
ProcessLineChar(builder, ref foundCR, ref foundCRLF);

View File

@ -21,8 +21,10 @@ namespace Microsoft.AspNetCore.WebUtilities
private readonly Stream _inner;
private readonly ArrayPool<byte> _bytePool;
private readonly int _memoryThreshold;
private readonly long? _bufferLimit;
private string _tempFileDirectory;
private readonly Func<string> _tempFileDirectoryAccessor;
private string _tempFileName;
private Stream _buffer;
private byte[] _rentedBuffer;
@ -31,18 +33,19 @@ namespace Microsoft.AspNetCore.WebUtilities
private bool _disposed;
// TODO: allow for an optional buffer size limit to prevent filling hard disks. 1gb?
public FileBufferingReadStream(
Stream inner,
int memoryThreshold,
long? bufferLimit,
Func<string> tempFileDirectoryAccessor)
: this(inner, memoryThreshold, tempFileDirectoryAccessor, ArrayPool<byte>.Shared)
: this(inner, memoryThreshold, bufferLimit, tempFileDirectoryAccessor, ArrayPool<byte>.Shared)
{
}
public FileBufferingReadStream(
Stream inner,
int memoryThreshold,
long? bufferLimit,
Func<string> tempFileDirectoryAccessor,
ArrayPool<byte> bytePool)
{
@ -70,18 +73,23 @@ namespace Microsoft.AspNetCore.WebUtilities
_inner = inner;
_memoryThreshold = memoryThreshold;
_bufferLimit = bufferLimit;
_tempFileDirectoryAccessor = tempFileDirectoryAccessor;
}
// TODO: allow for an optional buffer size limit to prevent filling hard disks. 1gb?
public FileBufferingReadStream(Stream inner, int memoryThreshold, string tempFileDirectory)
: this(inner, memoryThreshold, tempFileDirectory, ArrayPool<byte>.Shared)
public FileBufferingReadStream(
Stream inner,
int memoryThreshold,
long? bufferLimit,
string tempFileDirectory)
: this(inner, memoryThreshold, bufferLimit, tempFileDirectory, ArrayPool<byte>.Shared)
{
}
public FileBufferingReadStream(
Stream inner,
int memoryThreshold,
long? bufferLimit,
string tempFileDirectory,
ArrayPool<byte> bytePool)
{
@ -109,9 +117,20 @@ namespace Microsoft.AspNetCore.WebUtilities
_inner = inner;
_memoryThreshold = memoryThreshold;
_bufferLimit = bufferLimit;
_tempFileDirectory = tempFileDirectory;
}
public bool InMemory
{
get { return _inMemory; }
}
public string TempFileName
{
get { return _tempFileName; }
}
public override bool CanRead
{
get { return true; }
@ -173,8 +192,8 @@ namespace Microsoft.AspNetCore.WebUtilities
Debug.Assert(_tempFileDirectory != null);
}
var fileName = Path.Combine(_tempFileDirectory, "ASPNET_" + Guid.NewGuid().ToString() + ".tmp");
return new FileStream(fileName, FileMode.Create, FileAccess.ReadWrite, FileShare.Delete, 1024 * 16,
_tempFileName = Path.Combine(_tempFileDirectory, "ASPNETCORE_" + Guid.NewGuid().ToString() + ".tmp");
return new FileStream(_tempFileName, FileMode.Create, FileAccess.ReadWrite, FileShare.Delete, 1024 * 16,
FileOptions.Asynchronous | FileOptions.DeleteOnClose | FileOptions.SequentialScan);
}
@ -189,6 +208,12 @@ namespace Microsoft.AspNetCore.WebUtilities
int read = _inner.Read(buffer, offset, count);
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}
if (_inMemory && _buffer.Length + read > _memoryThreshold)
{
_inMemory = false;
@ -285,6 +310,12 @@ namespace Microsoft.AspNetCore.WebUtilities
int read = await _inner.ReadAsync(buffer, offset, count, cancellationToken);
if (_bufferLimit.HasValue && _bufferLimit - read < _buffer.Length)
{
Dispose();
throw new IOException("Buffer limit exceeded.");
}
if (_inMemory && _buffer.Length + read > _memoryThreshold)
{
_inMemory = false;

View File

@ -17,6 +17,10 @@ namespace Microsoft.AspNetCore.WebUtilities
/// </summary>
public class FormReader : IDisposable
{
public const int DefaultKeyCountLimit = 1024;
public const int DefaultKeyLengthLimit = 1024 * 2;
public const int DefaultValueLengthLimit = 1024 * 1024 * 4;
private const int _rentedCharPoolLength = 8192;
private readonly TextReader _reader;
private readonly char[] _buffer;
@ -43,6 +47,11 @@ namespace Microsoft.AspNetCore.WebUtilities
_reader = new StringReader(data);
}
public FormReader(Stream stream)
: this(stream, Encoding.UTF8, ArrayPool<char>.Shared)
{
}
public FormReader(Stream stream, Encoding encoding)
: this(stream, encoding, ArrayPool<char>.Shared)
{
@ -65,6 +74,21 @@ namespace Microsoft.AspNetCore.WebUtilities
_reader = new StreamReader(stream, encoding, detectEncodingFromByteOrderMarks: true, bufferSize: 1024 * 2, leaveOpen: true);
}
/// <summary>
/// The limit on the number of form keys to allow in ReadForm or ReadFormAsync.
/// </summary>
public int KeyCountLimit { get; set; } = DefaultKeyCountLimit;
/// <summary>
/// The limit on the length of form keys.
/// </summary>
public int KeyLengthLimit { get; set; } = DefaultKeyLengthLimit;
/// <summary>
/// The limit on the length of form values.
/// </summary>
public int ValueLengthLimit { get; set; } = DefaultValueLengthLimit;
// Format: key1=value1&key2=value2
/// <summary>
/// Reads the next key value pair from the form.
@ -73,12 +97,12 @@ namespace Microsoft.AspNetCore.WebUtilities
/// <returns>The next key value pair, or null when the end of the form is reached.</returns>
public KeyValuePair<string, string>? ReadNextPair()
{
var key = ReadWord('=');
var key = ReadWord('=', KeyLengthLimit);
if (string.IsNullOrEmpty(key) && _bufferCount == 0)
{
return null;
}
var value = ReadWord('&');
var value = ReadWord('&', ValueLengthLimit);
return new KeyValuePair<string, string>(key, value);
}
@ -88,20 +112,19 @@ namespace Microsoft.AspNetCore.WebUtilities
/// </summary>
/// <param name="cancellationToken"></param>
/// <returns>The next key value pair, or null when the end of the form is reached.</returns>
public async Task<KeyValuePair<string, string>?> ReadNextPairAsync(CancellationToken cancellationToken)
public async Task<KeyValuePair<string, string>?> ReadNextPairAsync(CancellationToken cancellationToken = new CancellationToken())
{
var key = await ReadWordAsync('=', cancellationToken);
var key = await ReadWordAsync('=', KeyLengthLimit, cancellationToken);
if (string.IsNullOrEmpty(key) && _bufferCount == 0)
{
return null;
}
var value = await ReadWordAsync('&', cancellationToken);
var value = await ReadWordAsync('&', ValueLengthLimit, cancellationToken);
return new KeyValuePair<string, string>(key, value);
}
private string ReadWord(char seperator)
private string ReadWord(char seperator, int limit)
{
// TODO: Configurable value size limit
while (true)
{
// Empty
@ -110,26 +133,16 @@ namespace Microsoft.AspNetCore.WebUtilities
Buffer();
}
// End
if (_bufferCount == 0)
string word;
if (ReadChar(seperator, limit, out word))
{
return BuildWord();
return word;
}
var c = _buffer[_bufferOffset++];
_bufferCount--;
if (c == seperator)
{
return BuildWord();
}
_builder.Append(c);
}
}
private async Task<string> ReadWordAsync(char seperator, CancellationToken cancellationToken)
private async Task<string> ReadWordAsync(char seperator, int limit, CancellationToken cancellationToken)
{
// TODO: Configurable value size limit
while (true)
{
// Empty
@ -138,23 +151,40 @@ namespace Microsoft.AspNetCore.WebUtilities
await BufferAsync(cancellationToken);
}
// End
if (_bufferCount == 0)
string word;
if (ReadChar(seperator, limit, out word))
{
return BuildWord();
return word;
}
var c = _buffer[_bufferOffset++];
_bufferCount--;
if (c == seperator)
{
return BuildWord();
}
_builder.Append(c);
}
}
private bool ReadChar(char seperator, int limit, out string word)
{
// End
if (_bufferCount == 0)
{
word = BuildWord();
return true;
}
var c = _buffer[_bufferOffset++];
_bufferCount--;
if (c == seperator)
{
word = BuildWord();
return true;
}
if (_builder.Length >= limit)
{
throw new InvalidDataException($"Form key or value length limit {limit} exceeded.");
}
_builder.Append(c);
word = null;
return false;
}
// '+' un-escapes to ' ', %HH un-escapes as ASCII (or utf-8?)
private string BuildWord()
{
@ -181,56 +211,44 @@ namespace Microsoft.AspNetCore.WebUtilities
/// <summary>
/// Parses text from an HTTP form body.
/// </summary>
/// <param name="text">The HTTP form body to parse.</param>
/// <returns>The collection containing the parsed HTTP form body.</returns>
public static Dictionary<string, StringValues> ReadForm(string text)
public Dictionary<string, StringValues> ReadForm()
{
using (var reader = new FormReader(text))
var accumulator = new KeyValueAccumulator();
var pair = ReadNextPair();
while (pair.HasValue)
{
var accumulator = new KeyValueAccumulator();
var pair = reader.ReadNextPair();
while (pair.HasValue)
accumulator.Append(pair.Value.Key, pair.Value.Value);
if (accumulator.Count > KeyCountLimit)
{
accumulator.Append(pair.Value.Key, pair.Value.Value);
pair = reader.ReadNextPair();
throw new InvalidDataException($"Form key count limit {KeyCountLimit} exceeded.");
}
return accumulator.GetResults();
pair = ReadNextPair();
}
return accumulator.GetResults();
}
/// <summary>
/// Parses an HTTP form body.
/// </summary>
/// <param name="stream">The HTTP form body to parse.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/>.</param>
/// <returns>The collection containing the parsed HTTP form body.</returns>
public static Task<Dictionary<string, StringValues>> ReadFormAsync(Stream stream, CancellationToken cancellationToken = new CancellationToken())
public async Task<Dictionary<string, StringValues>> ReadFormAsync(CancellationToken cancellationToken = new CancellationToken())
{
return ReadFormAsync(stream, Encoding.UTF8, cancellationToken);
}
/// <summary>
/// Parses an HTTP form body.
/// </summary>
/// <param name="stream">The HTTP form body to parse.</param>
/// <param name="encoding">The character encoding to use.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/>.</param>
/// <returns>The collection containing the parsed HTTP form body.</returns>
public static async Task<Dictionary<string, StringValues>> ReadFormAsync(Stream stream, Encoding encoding, CancellationToken cancellationToken = new CancellationToken())
{
using (var reader = new FormReader(stream, encoding))
var accumulator = new KeyValueAccumulator();
var pair = await ReadNextPairAsync(cancellationToken);
while (pair.HasValue)
{
var accumulator = new KeyValueAccumulator();
var pair = await reader.ReadNextPairAsync(cancellationToken);
while (pair.HasValue)
accumulator.Append(pair.Value.Key, pair.Value.Value);
if (accumulator.Count > KeyCountLimit)
{
accumulator.Append(pair.Value.Key, pair.Value.Value);
pair = await reader.ReadNextPairAsync(cancellationToken);
throw new InvalidDataException($"Form key count limit {KeyCountLimit} exceeded.");
}
return accumulator.GetResults();
pair = await ReadNextPairAsync(cancellationToken);
}
return accumulator.GetResults();
}
public void Dispose()

View File

@ -63,6 +63,8 @@ namespace Microsoft.AspNetCore.WebUtilities
public bool HasValues => _accumulator != null;
public int Count => _accumulator?.Count ?? 0;
public Dictionary<string, StringValues> GetResults()
{
if (_expandingAccumulator != null)

View File

@ -14,6 +14,8 @@ namespace Microsoft.AspNetCore.WebUtilities
// https://www.ietf.org/rfc/rfc2046.txt
public class MultipartReader
{
public const int DefaultHeadersCountLimit = 16;
public const int DefaultHeadersLengthLimit = 1024 * 16;
private const int DefaultBufferSize = 1024 * 4;
private readonly BufferedReadStream _stream;
@ -44,18 +46,24 @@ namespace Microsoft.AspNetCore.WebUtilities
_stream = new BufferedReadStream(stream, bufferSize);
_boundary = new MultipartBoundary(boundary, false);
// This stream will drain any preamble data and remove the first boundary marker.
_currentStream = new MultipartReaderStream(_stream, _boundary);
// TODO: HeadersLengthLimit can't be modified until after the constructor.
_currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = HeadersLengthLimit };
}
/// <summary>
/// The limit for individual header lines inside a multipart section.
/// The limit for the number of headers to read.
/// </summary>
public int HeaderLengthLimit { get; set; } = 1024 * 4;
public int HeadersCountLimit { get; set; } = DefaultHeadersCountLimit;
/// <summary>
/// The combined size limit for headers per multipart section.
/// </summary>
public int TotalHeaderSizeLimit { get; set; } = 1024 * 16;
public int HeadersLengthLimit { get; set; } = DefaultHeadersLengthLimit;
/// <summary>
/// The optional limit for the total response body length.
/// </summary>
public long? BodyLengthLimit { get; set; }
public async Task<MultipartSection> ReadNextSectionAsync(CancellationToken cancellationToken = new CancellationToken())
{
@ -65,12 +73,12 @@ namespace Microsoft.AspNetCore.WebUtilities
if (_currentStream.FinalBoundaryFound)
{
// There may be trailer data after the last boundary.
await _stream.DrainAsync(cancellationToken);
await _stream.DrainAsync(HeadersLengthLimit, cancellationToken);
return null;
}
var headers = await ReadHeadersAsync(cancellationToken);
_boundary.ExpectLeadingCrlf = true;
_currentStream = new MultipartReaderStream(_stream, _boundary);
_currentStream = new MultipartReaderStream(_stream, _boundary) { LengthLimit = BodyLengthLimit };
long? baseStreamOffset = _stream.CanSeek ? (long?)_stream.Position : null;
return new MultipartSection() { Headers = headers, Body = _currentStream, BaseStreamOffset = baseStreamOffset };
}
@ -79,23 +87,29 @@ namespace Microsoft.AspNetCore.WebUtilities
{
int totalSize = 0;
var accumulator = new KeyValueAccumulator();
var line = await _stream.ReadLineAsync(HeaderLengthLimit, cancellationToken);
var line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken);
while (!string.IsNullOrEmpty(line))
{
if (HeadersLengthLimit - totalSize < line.Length)
{
throw new InvalidDataException($"Multipart headers length limit {HeadersLengthLimit} exceeded.");
}
totalSize += line.Length;
if (totalSize > TotalHeaderSizeLimit)
{
throw new InvalidOperationException("Total header size limit exceeded: " + TotalHeaderSizeLimit.ToString());
}
int splitIndex = line.IndexOf(':');
Debug.Assert(splitIndex > 0, $"Invalid header line: {line}");
if (splitIndex >= 0)
if (splitIndex <= 0)
{
var name = line.Substring(0, splitIndex);
var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim();
accumulator.Append(name, value);
throw new InvalidDataException($"Invalid header line: {line}");
}
line = await _stream.ReadLineAsync(HeaderLengthLimit, cancellationToken);
var name = line.Substring(0, splitIndex);
var value = line.Substring(splitIndex + 1, line.Length - splitIndex - 1).Trim();
accumulator.Append(name, value);
if (accumulator.Count > HeadersCountLimit)
{
throw new InvalidDataException($"Multipart headers count limit {HeadersCountLimit} exceeded.");
}
line = await _stream.ReadLineAsync(HeadersLengthLimit - totalSize, cancellationToken);
}
return accumulator.GetResults();

View File

@ -57,6 +57,8 @@ namespace Microsoft.AspNetCore.WebUtilities
public bool FinalBoundaryFound { get; private set; }
public long? LengthLimit { get; set; }
public override bool CanRead
{
get { return true; }
@ -159,6 +161,10 @@ namespace Microsoft.AspNetCore.WebUtilities
if (_observedLength < _position)
{
_observedLength = _position;
if (LengthLimit.HasValue && _observedLength > LengthLimit.Value)
{
throw new InvalidDataException($"Multipart body length limit {LengthLimit.Value} exceeded.");
}
}
return read;
}

View File

@ -14,19 +14,32 @@ namespace Microsoft.AspNetCore.WebUtilities
public static Task DrainAsync(this Stream stream, CancellationToken cancellationToken)
{
return stream.DrainAsync(ArrayPool<byte>.Shared, cancellationToken);
return stream.DrainAsync(ArrayPool<byte>.Shared, null, cancellationToken);
}
public static async Task DrainAsync(this Stream stream, ArrayPool<byte> bytePool, CancellationToken cancellationToken)
public static Task DrainAsync(this Stream stream, long? limit, CancellationToken cancellationToken)
{
return stream.DrainAsync(ArrayPool<byte>.Shared, limit, cancellationToken);
}
public static async Task DrainAsync(this Stream stream, ArrayPool<byte> bytePool, long? limit, CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
var buffer = bytePool.Rent(_maxReadBufferSize);
long total = 0;
try
{
while (await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken) > 0)
var read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken);
while (read > 0)
{
// Not all streams support cancellation directly.
cancellationToken.ThrowIfCancellationRequested();
if (limit.HasValue && limit.Value - total < read)
{
throw new InvalidDataException($"The stream exceeded the data limit {limit.Value}.");
}
total += read;
read = await stream.ReadAsync(buffer, 0, buffer.Length, cancellationToken);
}
}
finally

View File

@ -3,11 +3,9 @@
using System;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.WebUtilities;
using Xunit;
namespace Microsoft.AspNetCore.Http.Features
@ -26,14 +24,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = "application/x-www-form-urlencoded; charset=utf-8";
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = await context.Request.ReadFormAsync();
@ -55,80 +47,6 @@ namespace Microsoft.AspNetCore.Http.Features
await responseFeature.CompleteAsync();
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyAtEndAllowed(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes("=bar");
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
Assert.Equal("bar", formCollection[""].FirstOrDefault());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed(bool bufferRequest)
{
var formContent = Encoding.UTF8.GetBytes("=bar&baz=2");
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
Assert.Equal("bar", formCollection[""].FirstOrDefault());
Assert.Equal("2", formCollection["baz"].FirstOrDefault());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedAtEndAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("foo=");
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
// Assert
Assert.Equal("", formCollection["foo"].FirstOrDefault());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed(bool bufferRequest)
{
// Arrange
var formContent = Encoding.UTF8.GetBytes("foo=&baz=2");
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
var formCollection = await FormReader.ReadFormAsync(body);
// Assert
Assert.Equal("", formCollection["foo"].FirstOrDefault());
Assert.Equal("2", formCollection["baz"].FirstOrDefault());
}
private const string MultipartContentType = "multipart/form-data; boundary=WebKitFormBoundary5pDRpGheQXaM8k3T";
private const string EmptyMultipartForm =
"--WebKitFormBoundary5pDRpGheQXaM8k3T--";
@ -170,14 +88,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = MultipartContentType;
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = context.Request.Form;
@ -211,14 +123,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = MultipartContentType;
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = context.Request.Form;
@ -254,14 +160,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = MultipartContentType;
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = await context.Request.ReadFormAsync();
@ -308,14 +208,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = MultipartContentType;
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = await context.Request.ReadFormAsync();
@ -367,14 +261,8 @@ namespace Microsoft.AspNetCore.Http.Features
context.Request.ContentType = MultipartContentType;
context.Request.Body = new NonSeekableReadStream(formContent);
if (bufferRequest)
{
context.Request.EnableRewind();
}
// Not cached yet
var formFeature = context.Features.Get<IFormFeature>();
Assert.Null(formFeature);
IFormFeature formFeature = new FormFeature(context.Request, new FormOptions() { BufferBody = bufferRequest });
context.Features.Set<IFormFeature>(formFeature);
var formCollection = await context.Request.ReadFormAsync();

View File

@ -3,6 +3,7 @@
using Microsoft.AspNetCore.Http.Features;
using Microsoft.Extensions.ObjectPool;
using Microsoft.Extensions.Options;
using Xunit;
namespace Microsoft.AspNetCore.Http
@ -14,7 +15,7 @@ namespace Microsoft.AspNetCore.Http
{
// Arrange
var accessor = new HttpContextAccessor();
var contextFactory = new HttpContextFactory(new DefaultObjectPoolProvider(), accessor);
var contextFactory = new HttpContextFactory(new DefaultObjectPoolProvider(), Options.Create(new FormOptions()), accessor);
// Act
var context = contextFactory.Create(new FeatureCollection());
@ -27,7 +28,7 @@ namespace Microsoft.AspNetCore.Http
public void AllowsCreatingContextWithoutSettingAccessor()
{
// Arrange
var contextFactory = new HttpContextFactory(new DefaultObjectPoolProvider());
var contextFactory = new HttpContextFactory(new DefaultObjectPoolProvider(), Options.Create(new FormOptions()));
// Act & Assert
var context = contextFactory.Create(new FeatureCollection());

View File

@ -0,0 +1,294 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.WebUtilities
{
public class FileBufferingReadStreamTests
{
private Stream MakeStream(int size)
{
// TODO: Fill with random data? Make readonly?
return new MemoryStream(new byte[size]);
}
[Fact]
public void FileBufferingReadStream_Properties_ExpectedValues()
{
var inner = MakeStream(1024 * 2);
using (var stream = new FileBufferingReadStream(inner, 1024, null, Directory.GetCurrentDirectory()))
{
Assert.True(stream.CanRead);
Assert.True(stream.CanSeek);
Assert.False(stream.CanWrite);
Assert.Equal(0, stream.Length); // Nothing buffered yet
Assert.Equal(0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
}
}
[Fact]
public void FileBufferingReadStream_SyncReadUnderThreshold_DoesntCreateFile()
{
var inner = MakeStream(1024 * 2);
using (var stream = new FileBufferingReadStream(inner, 1024 * 3, null, Directory.GetCurrentDirectory()))
{
var bytes = new byte[1000];
var read0 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read2 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(inner.Length - read0 - read1, read2);
Assert.Equal(read0 + read1 + read2, stream.Length);
Assert.Equal(read0 + read1 + read2, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read3 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(0, read3);
}
}
[Fact]
public void FileBufferingReadStream_SyncReadOverThreshold_CreatesFile()
{
var inner = MakeStream(1024 * 2);
string tempFileName;
using (var stream = new FileBufferingReadStream(inner, 1024, null, Directory.GetCurrentDirectory()))
{
var bytes = new byte[1000];
var read0 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
tempFileName = stream.TempFileName;
Assert.True(File.Exists(tempFileName));
var read2 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(inner.Length - read0 - read1, read2);
Assert.Equal(read0 + read1 + read2, stream.Length);
Assert.Equal(read0 + read1 + read2, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.True(File.Exists(tempFileName));
var read3 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(0, read3);
}
Assert.False(File.Exists(tempFileName));
}
[Fact]
public void FileBufferingReadStream_SyncReadWithInMemoryLimit_EnforcesLimit()
{
var inner = MakeStream(1024 * 2);
using (var stream = new FileBufferingReadStream(inner, 1024, 900, Directory.GetCurrentDirectory()))
{
var bytes = new byte[500];
var read0 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var exception = Assert.Throws<IOException>(() => stream.Read(bytes, 0, bytes.Length));
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
Assert.False(File.Exists(stream.TempFileName));
}
}
[Fact]
public void FileBufferingReadStream_SyncReadWithOnDiskLimit_EnforcesLimit()
{
var inner = MakeStream(1024 * 2);
string tempFileName;
using (var stream = new FileBufferingReadStream(inner, 512, 1024, Directory.GetCurrentDirectory()))
{
var bytes = new byte[500];
var read0 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = stream.Read(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
tempFileName = stream.TempFileName;
Assert.True(File.Exists(tempFileName));
var exception = Assert.Throws<IOException>(() => stream.Read(bytes, 0, bytes.Length));
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}
Assert.False(File.Exists(tempFileName));
}
///////////////////
[Fact]
public async Task FileBufferingReadStream_AsyncReadUnderThreshold_DoesntCreateFile()
{
var inner = MakeStream(1024 * 2);
using (var stream = new FileBufferingReadStream(inner, 1024 * 3, null, Directory.GetCurrentDirectory()))
{
var bytes = new byte[1000];
var read0 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read2 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(inner.Length - read0 - read1, read2);
Assert.Equal(read0 + read1 + read2, stream.Length);
Assert.Equal(read0 + read1 + read2, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read3 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(0, read3);
}
}
[Fact]
public async Task FileBufferingReadStream_AsyncReadOverThreshold_CreatesFile()
{
var inner = MakeStream(1024 * 2);
string tempFileName;
using (var stream = new FileBufferingReadStream(inner, 1024, null, Directory.GetCurrentDirectory()))
{
var bytes = new byte[1000];
var read0 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
tempFileName = stream.TempFileName;
Assert.True(File.Exists(tempFileName));
var read2 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(inner.Length - read0 - read1, read2);
Assert.Equal(read0 + read1 + read2, stream.Length);
Assert.Equal(read0 + read1 + read2, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.True(File.Exists(tempFileName));
var read3 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(0, read3);
}
Assert.False(File.Exists(tempFileName));
}
[Fact]
public async Task FileBufferingReadStream_AsyncReadWithInMemoryLimit_EnforcesLimit()
{
var inner = MakeStream(1024 * 2);
using (var stream = new FileBufferingReadStream(inner, 1024, 900, Directory.GetCurrentDirectory()))
{
var bytes = new byte[500];
var read0 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var exception = await Assert.ThrowsAsync<IOException>(() => stream.ReadAsync(bytes, 0, bytes.Length));
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
Assert.False(File.Exists(stream.TempFileName));
}
}
[Fact]
public async Task FileBufferingReadStream_AsyncReadWithOnDiskLimit_EnforcesLimit()
{
var inner = MakeStream(1024 * 2);
string tempFileName;
using (var stream = new FileBufferingReadStream(inner, 512, 1024, Directory.GetCurrentDirectory()))
{
var bytes = new byte[500];
var read0 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read0);
Assert.Equal(read0, stream.Length);
Assert.Equal(read0, stream.Position);
Assert.True(stream.InMemory);
Assert.Null(stream.TempFileName);
var read1 = await stream.ReadAsync(bytes, 0, bytes.Length);
Assert.Equal(bytes.Length, read1);
Assert.Equal(read0 + read1, stream.Length);
Assert.Equal(read0 + read1, stream.Position);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
tempFileName = stream.TempFileName;
Assert.True(File.Exists(tempFileName));
var exception = await Assert.ThrowsAsync<IOException>(() => stream.ReadAsync(bytes, 0, bytes.Length));
Assert.Equal("Buffer limit exceeded.", exception.Message);
Assert.False(stream.InMemory);
Assert.NotNull(stream.TempFileName);
Assert.False(File.Exists(tempFileName));
}
Assert.False(File.Exists(tempFileName));
}
}
}

View File

@ -0,0 +1,156 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.AspNetCore.WebUtilities
{
public class FormReaderTests
{
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyAtEndAllowed(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "=bar");
var formCollection = await new FormReader(body).ReadFormAsync();
Assert.Equal("bar", formCollection[""].ToString());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyKeyWithAdditionalEntryAllowed(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "=bar&baz=2");
var formCollection = await new FormReader(body).ReadFormAsync();
Assert.Equal("bar", formCollection[""].ToString());
Assert.Equal("2", formCollection["baz"].ToString());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedAtEndAllowed(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=");
var formCollection = await new FormReader(body).ReadFormAsync();
Assert.Equal("", formCollection["foo"].ToString());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_EmptyValuedWithAdditionalEntryAllowed(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=&baz=2");
var formCollection = await new FormReader(body).ReadFormAsync();
Assert.Equal("", formCollection["foo"].ToString());
Assert.Equal("2", formCollection["baz"].ToString());
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_KeyCountLimitMet_Success(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&bar=2&baz=3&baz=4");
var formCollection = await new FormReader(body) { KeyCountLimit = 3 }.ReadFormAsync();
Assert.Equal("1", formCollection["foo"].ToString());
Assert.Equal("2", formCollection["bar"].ToString());
Assert.Equal("3,4", formCollection["baz"].ToString());
Assert.Equal(3, formCollection.Count);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_KeyCountLimitExceeded_Throw(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&baz=2&bar=3&baz=4&baf=5");
var exception = await Assert.ThrowsAsync<InvalidDataException>(
() => new FormReader(body) { KeyCountLimit = 3 }.ReadFormAsync());
Assert.Equal("Form key count limit 3 exceeded.", exception.Message);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_KeyLengthLimitMet_Success(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&bar=2&baz=3&baz=4");
var formCollection = await new FormReader(body) { KeyLengthLimit = 10 }.ReadFormAsync();
Assert.Equal("1", formCollection["foo"].ToString());
Assert.Equal("2", formCollection["bar"].ToString());
Assert.Equal("3,4", formCollection["baz"].ToString());
Assert.Equal(3, formCollection.Count);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_KeyLengthLimitExceeded_Throw(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&baz1234567890=2");
var exception = await Assert.ThrowsAsync<InvalidDataException>(
() => new FormReader(body) { KeyLengthLimit = 10 }.ReadFormAsync());
Assert.Equal("Form key or value length limit 10 exceeded.", exception.Message);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_ValueLengthLimitMet_Success(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&bar=1234567890&baz=3&baz=4");
var formCollection = await new FormReader(body) { ValueLengthLimit = 10 }.ReadFormAsync();
Assert.Equal("1", formCollection["foo"].ToString());
Assert.Equal("1234567890", formCollection["bar"].ToString());
Assert.Equal("3,4", formCollection["baz"].ToString());
Assert.Equal(3, formCollection.Count);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task ReadFormAsync_ValueLengthLimitExceeded_Throw(bool bufferRequest)
{
var body = MakeStream(bufferRequest, "foo=1&baz=1234567890123");
var exception = await Assert.ThrowsAsync<InvalidDataException>(
() => new FormReader(body) { ValueLengthLimit = 10 }.ReadFormAsync());
Assert.Equal("Form key or value length limit 10 exceeded.", exception.Message);
}
private static Stream MakeStream(bool bufferRequest, string text)
{
var formContent = Encoding.UTF8.GetBytes(text);
Stream body = new MemoryStream(formContent);
if (!bufferRequest)
{
body = new NonSeekableReadStream(body);
}
return body;
}
}
}

View File

@ -18,6 +18,13 @@ namespace Microsoft.AspNetCore.WebUtilities
"Content-Disposition: form-data; name=\"text\"\r\n" +
"\r\n" +
"text default\r\n" +
"--9051914041544843365972754266--\r\n";
private const string OnePartBodyTwoHeaders =
"--9051914041544843365972754266\r\n" +
"Content-Disposition: form-data; name=\"text\"\r\n" +
"Custom-header: custom-value\r\n" +
"\r\n" +
"text default\r\n" +
"--9051914041544843365972754266--\r\n";
private const string OnePartBodyWithTrailingWhitespace =
"--9051914041544843365972754266 \r\n" +
@ -115,6 +122,32 @@ namespace Microsoft.AspNetCore.WebUtilities
Assert.Null(await reader.ReadNextSectionAsync());
}
[Fact]
public async Task MutipartReader_HeaderCountExceeded_Throws()
{
var stream = MakeStream(OnePartBodyTwoHeaders);
var reader = new MultipartReader(Boundary, stream)
{
HeadersCountLimit = 1,
};
var exception = await Assert.ThrowsAsync<InvalidDataException>(() => reader.ReadNextSectionAsync());
Assert.Equal("Multipart headers count limit 1 exceeded.", exception.Message);
}
[Fact]
public async Task MutipartReader_HeadersLengthExceeded_Throws()
{
var stream = MakeStream(OnePartBodyTwoHeaders);
var reader = new MultipartReader(Boundary, stream)
{
HeadersLengthLimit = 60,
};
var exception = await Assert.ThrowsAsync<InvalidDataException>(() => reader.ReadNextSectionAsync());
Assert.Equal("Line length limit 17 exceeded.", exception.Message);
}
[Fact]
public async Task MutipartReader_ReadSinglePartBodyWithTrailingWhitespace_Success()
{

View File

@ -0,0 +1,72 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.WebUtilities
{
public class NonSeekableReadStream : Stream
{
private Stream _inner;
public NonSeekableReadStream(byte[] data)
: this(new MemoryStream(data))
{
}
public NonSeekableReadStream(Stream inner)
{
_inner = inner;
}
public override bool CanRead => _inner.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => false;
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
public override void Flush()
{
throw new NotImplementedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
return _inner.Read(buffer, offset, count);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _inner.ReadAsync(buffer, offset, count, cancellationToken);
}
}
}