diff --git a/src/Microsoft.AspNetCore.WebUtilities/FormReader.cs b/src/Microsoft.AspNetCore.WebUtilities/FormReader.cs index db2866b7fd..9080ca1daf 100644 --- a/src/Microsoft.AspNetCore.WebUtilities/FormReader.cs +++ b/src/Microsoft.AspNetCore.WebUtilities/FormReader.cs @@ -28,6 +28,9 @@ namespace Microsoft.AspNetCore.WebUtilities private readonly StringBuilder _builder = new StringBuilder(); private int _bufferOffset; private int _bufferCount; + private string _currentKey; + private string _currentValue; + private bool _endOfStream; private bool _disposed; public FormReader(string data) @@ -97,13 +100,29 @@ namespace Microsoft.AspNetCore.WebUtilities /// The next key value pair, or null when the end of the form is reached. public KeyValuePair? ReadNextPair() { - var key = ReadWord('=', KeyLengthLimit); - if (string.IsNullOrEmpty(key) && _bufferCount == 0) + ReadNextPairImpl(); + if (ReadSucceded()) { - return null; + return new KeyValuePair(_currentKey, _currentValue); + } + return null; + } + + private void ReadNextPairImpl() + { + StartReadNextPair(); + while (!_endOfStream) + { + // Empty + if (_bufferCount == 0) + { + Buffer(); + } + if (TryReadNextPair()) + { + break; + } } - var value = ReadWord('&', ValueLengthLimit); - return new KeyValuePair(key, value); } // Format: key1=value1&key2=value2 @@ -114,51 +133,74 @@ namespace Microsoft.AspNetCore.WebUtilities /// The next key value pair, or null when the end of the form is reached. public async Task?> ReadNextPairAsync(CancellationToken cancellationToken = new CancellationToken()) { - var key = await ReadWordAsync('=', KeyLengthLimit, cancellationToken); - if (string.IsNullOrEmpty(key) && _bufferCount == 0) + await ReadNextPairAsyncImpl(cancellationToken); + if (ReadSucceded()) { - return null; + return new KeyValuePair(_currentKey, _currentValue); } - var value = await ReadWordAsync('&', ValueLengthLimit, cancellationToken); - return new KeyValuePair(key, value); + return null; } - private string ReadWord(char seperator, int limit) + private async Task ReadNextPairAsyncImpl(CancellationToken cancellationToken = new CancellationToken()) { - while (true) - { - // Empty - if (_bufferCount == 0) - { - Buffer(); - } - - string word; - if (ReadChar(seperator, limit, out word)) - { - return word; - } - } - } - - private async Task ReadWordAsync(char seperator, int limit, CancellationToken cancellationToken) - { - while (true) + StartReadNextPair(); + while (!_endOfStream) { // Empty if (_bufferCount == 0) { await BufferAsync(cancellationToken); } - - string word; - if (ReadChar(seperator, limit, out word)) + if (TryReadNextPair()) { - return word; + break; } } } + private void StartReadNextPair() + { + _currentKey = null; + _currentValue = null; + } + + private bool TryReadNextPair() + { + if (_currentKey == null) + { + if (!TryReadWord('=', KeyLengthLimit, out _currentKey)) + { + return false; + } + + if (_bufferCount == 0) + { + return false; + } + } + + if (_currentValue == null) + { + if (!TryReadWord('&', ValueLengthLimit, out _currentValue)) + { + return false; + } + } + return true; + } + + private bool TryReadWord(char seperator, int limit, out string value) + { + do + { + if (ReadChar(seperator, limit, out value)) + { + return true; + } + } while (_bufferCount > 0); + return false; + } + private bool ReadChar(char seperator, int limit, out string word) { // End @@ -198,6 +240,7 @@ namespace Microsoft.AspNetCore.WebUtilities { _bufferOffset = 0; _bufferCount = _reader.Read(_buffer, 0, _buffer.Length); + _endOfStream = _bufferCount == 0; } private async Task BufferAsync(CancellationToken cancellationToken) @@ -206,6 +249,7 @@ namespace Microsoft.AspNetCore.WebUtilities cancellationToken.ThrowIfCancellationRequested(); _bufferOffset = 0; _bufferCount = await _reader.ReadAsync(_buffer, 0, _buffer.Length); + _endOfStream = _bufferCount == 0; } /// @@ -215,17 +259,11 @@ namespace Microsoft.AspNetCore.WebUtilities public Dictionary ReadForm() { var accumulator = new KeyValueAccumulator(); - var pair = ReadNextPair(); - while (pair.HasValue) + while (!_endOfStream) { - accumulator.Append(pair.Value.Key, pair.Value.Value); - if (accumulator.Count > KeyCountLimit) - { - throw new InvalidDataException($"Form key count limit {KeyCountLimit} exceeded."); - } - pair = ReadNextPair(); + ReadNextPairImpl(); + Append(ref accumulator); } - return accumulator.GetResults(); } @@ -237,18 +275,29 @@ namespace Microsoft.AspNetCore.WebUtilities public async Task> ReadFormAsync(CancellationToken cancellationToken = new CancellationToken()) { var accumulator = new KeyValueAccumulator(); - var pair = await ReadNextPairAsync(cancellationToken); - while (pair.HasValue) + while (!_endOfStream) { - accumulator.Append(pair.Value.Key, pair.Value.Value); + await ReadNextPairAsyncImpl(cancellationToken); + Append(ref accumulator); + } + return accumulator.GetResults(); + } + + private bool ReadSucceded() + { + return _currentKey != null && _currentValue != null; + } + + private void Append(ref KeyValueAccumulator accumulator) + { + if (ReadSucceded()) + { + accumulator.Append(_currentKey, _currentValue); if (accumulator.Count > KeyCountLimit) { throw new InvalidDataException($"Form key count limit {KeyCountLimit} exceeded."); } - pair = await ReadNextPairAsync(cancellationToken); } - - return accumulator.GetResults(); } public void Dispose() diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderAsyncTest.cs b/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderAsyncTest.cs new file mode 100644 index 0000000000..0a7b5e20a9 --- /dev/null +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderAsyncTest.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.WebUtilities +{ + public class FormReaderAsyncTest : FormReaderTests + { + protected override async Task> ReadFormAsync(FormReader reader) + { + return await reader.ReadFormAsync(); + } + + protected override async Task?> ReadPair(FormReader reader) + { + return await reader.ReadNextPairAsync(); + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderTests.cs b/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderTests.cs index fb0557d275..f28f307f05 100644 --- a/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderTests.cs +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/FormReaderTests.cs @@ -1,10 +1,11 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Collections.Generic; using System.IO; -using System.Linq; using System.Text; using System.Threading.Tasks; +using Microsoft.Extensions.Primitives; using Xunit; namespace Microsoft.AspNetCore.WebUtilities @@ -18,7 +19,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "=bar"); - var formCollection = await new FormReader(body).ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body)); Assert.Equal("bar", formCollection[""].ToString()); } @@ -30,7 +31,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "=bar&baz=2"); - var formCollection = await new FormReader(body).ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body)); Assert.Equal("bar", formCollection[""].ToString()); Assert.Equal("2", formCollection["baz"].ToString()); @@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "foo="); - var formCollection = await new FormReader(body).ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body)); Assert.Equal("", formCollection["foo"].ToString()); } @@ -55,7 +56,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "foo=&baz=2"); - var formCollection = await new FormReader(body).ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body)); Assert.Equal("", formCollection["foo"].ToString()); Assert.Equal("2", formCollection["baz"].ToString()); @@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "foo=1&bar=2&baz=3&baz=4"); - var formCollection = await new FormReader(body) { KeyCountLimit = 3 }.ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body) { KeyCountLimit = 3 }); Assert.Equal("1", formCollection["foo"].ToString()); Assert.Equal("2", formCollection["bar"].ToString()); @@ -84,7 +85,7 @@ namespace Microsoft.AspNetCore.WebUtilities var body = MakeStream(bufferRequest, "foo=1&baz=2&bar=3&baz=4&baf=5"); var exception = await Assert.ThrowsAsync( - () => new FormReader(body) { KeyCountLimit = 3 }.ReadFormAsync()); + () => ReadFormAsync(new FormReader(body) { KeyCountLimit = 3 })); Assert.Equal("Form key count limit 3 exceeded.", exception.Message); } @@ -95,7 +96,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "foo=1&bar=2&baz=3&baz=4"); - var formCollection = await new FormReader(body) { KeyLengthLimit = 10 }.ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body) { KeyLengthLimit = 10 }); Assert.Equal("1", formCollection["foo"].ToString()); Assert.Equal("2", formCollection["bar"].ToString()); @@ -111,7 +112,7 @@ namespace Microsoft.AspNetCore.WebUtilities var body = MakeStream(bufferRequest, "foo=1&baz1234567890=2"); var exception = await Assert.ThrowsAsync( - () => new FormReader(body) { KeyLengthLimit = 10 }.ReadFormAsync()); + () => ReadFormAsync(new FormReader(body) { KeyLengthLimit = 10 })); Assert.Equal("Form key or value length limit 10 exceeded.", exception.Message); } @@ -122,7 +123,7 @@ namespace Microsoft.AspNetCore.WebUtilities { var body = MakeStream(bufferRequest, "foo=1&bar=1234567890&baz=3&baz=4"); - var formCollection = await new FormReader(body) { ValueLengthLimit = 10 }.ReadFormAsync(); + var formCollection = await ReadFormAsync(new FormReader(body) { ValueLengthLimit = 10 }); Assert.Equal("1", formCollection["foo"].ToString()); Assert.Equal("1234567890", formCollection["bar"].ToString()); @@ -138,10 +139,54 @@ namespace Microsoft.AspNetCore.WebUtilities var body = MakeStream(bufferRequest, "foo=1&baz=1234567890123"); var exception = await Assert.ThrowsAsync( - () => new FormReader(body) { ValueLengthLimit = 10 }.ReadFormAsync()); + () => ReadFormAsync(new FormReader(body) { ValueLengthLimit = 10 })); Assert.Equal("Form key or value length limit 10 exceeded.", exception.Message); } + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadNextPair_ReadsAllPairs(bool bufferRequest) + { + var body = MakeStream(bufferRequest, "foo=&baz=2"); + + var reader = new FormReader(body); + + var pair = (KeyValuePair)await ReadPair(reader); + + Assert.Equal("foo", pair.Key); + Assert.Equal("", pair.Value); + + pair = (KeyValuePair)await ReadPair(reader); + + Assert.Equal("baz", pair.Key); + Assert.Equal("2", pair.Value); + + Assert.Null(await ReadPair(reader)); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task ReadNextPair_ReturnsNullOnEmptyStream(bool bufferRequest) + { + var body = MakeStream(bufferRequest, ""); + + var reader = new FormReader(body); + + Assert.Null(await ReadPair(reader)); + } + + protected virtual Task> ReadFormAsync(FormReader reader) + { + return Task.FromResult(reader.ReadForm()); + } + + protected virtual Task?> ReadPair(FormReader reader) + { + return Task.FromResult(reader.ReadNextPair()); + } + private static Stream MakeStream(bool bufferRequest, string text) { var formContent = Encoding.UTF8.GetBytes(text); diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/NonSeekableReadStream.cs b/test/Microsoft.AspNetCore.WebUtilities.Tests/NonSeekableReadStream.cs index 11c5d0ccc9..f3c77abb38 100644 --- a/test/Microsoft.AspNetCore.WebUtilities.Tests/NonSeekableReadStream.cs +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/NonSeekableReadStream.cs @@ -61,11 +61,13 @@ namespace Microsoft.AspNetCore.WebUtilities public override int Read(byte[] buffer, int offset, int count) { + count = Math.Max(count, 1); return _inner.Read(buffer, offset, count); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { + count = Math.Max(count, 1); return _inner.ReadAsync(buffer, offset, count, cancellationToken); } }