// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using Moq; using System; using System.Buffers; using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; using Xunit; namespace Microsoft.AspNetCore.WebUtilities { public class HttpRequestStreamReaderTest { private static readonly char[] CharData = new char[] { char.MinValue, char.MaxValue, '\t', ' ', '$', '@', '#', '\0', '\v', '\'', '\u3190', '\uC3A0', 'A', '5', '\r', '\uFE70', '-', ';', '\r', '\n', 'T', '3', '\n', 'K', '\u00E6', }; [Fact] public static async Task ReadToEndAsync() { // Arrange var reader = new HttpRequestStreamReader(GetLargeStream(), Encoding.UTF8); var result = await reader.ReadToEndAsync(); Assert.Equal(5000, result.Length); } [Fact] public static async Task ReadToEndAsync_Reads_Asynchronously() { // Arrange var stream = new AsyncOnlyStreamWrapper(GetLargeStream()); var reader = new HttpRequestStreamReader(stream, Encoding.UTF8); var streamReader = new StreamReader(GetLargeStream()); string expected = await streamReader.ReadToEndAsync(); // Act var actual = await reader.ReadToEndAsync(); // Assert Assert.Equal(expected, actual); } [Fact] public static void TestRead() { // Arrange var reader = CreateReader(); // Act & Assert for (var i = 0; i < CharData.Length; i++) { var tmp = reader.Read(); Assert.Equal((int)CharData[i], tmp); } } [Fact] public static void TestPeek() { // Arrange var reader = CreateReader(); // Act & Assert for (var i = 0; i < CharData.Length; i++) { var peek = reader.Peek(); Assert.Equal((int)CharData[i], peek); reader.Read(); } } [Fact] public static void EmptyStream() { // Arrange var reader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8); var buffer = new char[10]; // Act var read = reader.Read(buffer, 0, 1); // Assert Assert.Equal(0, read); } [Fact] public static void Read_ReadAllCharactersAtOnce() { // Arrange var reader = CreateReader(); var chars = new char[CharData.Length]; // Act var read = reader.Read(chars, 0, chars.Length); // Assert Assert.Equal(chars.Length, read); for (var i = 0; i < CharData.Length; i++) { Assert.Equal(CharData[i], chars[i]); } } [Fact] public static async Task ReadAsync_ReadInTwoChunks() { // Arrange var reader = CreateReader(); var chars = new char[CharData.Length]; // Act var read = await reader.ReadAsync(chars, 4, 3); // Assert Assert.Equal(3, read); for (var i = 0; i < 3; i++) { Assert.Equal(CharData[i], chars[i + 4]); } } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_ReadMultipleLines(Func> action) { // Arrange var reader = CreateReader(); var valueString = new string(CharData); // Act & Assert var data = await action(reader); Assert.Equal(valueString.Substring(0, valueString.IndexOf('\r')), data); data = await action(reader); Assert.Equal(valueString.Substring(valueString.IndexOf('\r') + 1, 3), data); data = await action(reader); Assert.Equal(valueString.Substring(valueString.IndexOf('\n') + 1, 2), data); data = await action(reader); Assert.Equal((valueString.Substring(valueString.LastIndexOf('\n') + 1)), data); } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_ReadWithNoNewlines(Func> action) { // Arrange var reader = CreateReader(); var valueString = new string(CharData); var temp = new char[10]; // Act reader.Read(temp, 0, 1); var data = await action(reader); // Assert Assert.Equal(valueString.Substring(1, valueString.IndexOf('\r') - 1), data); } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_MultipleContinuousLines(Func> action) { // Arrange var stream = new MemoryStream(); var writer = new StreamWriter(stream); writer.Write("\n\n\r\r\n\r"); writer.Flush(); stream.Position = 0; var reader = new HttpRequestStreamReader(stream, Encoding.UTF8); // Act & Assert for (var i = 0; i < 5; i++) { var data = await action(reader); Assert.Equal(string.Empty, data); } var eof = await action(reader); Assert.Null(eof); } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_CarriageReturnAndLineFeedAcrossBufferBundaries(Func> action) { // Arrange var stream = new MemoryStream(); var writer = new StreamWriter(stream); writer.Write("123456789\r\nfoo"); writer.Flush(); stream.Position = 0; var reader = new HttpRequestStreamReader(stream, Encoding.UTF8, 10); // Act & Assert var data = await action(reader); Assert.Equal("123456789", data); data = await action(reader); Assert.Equal("foo", data); var eof = await action(reader); Assert.Null(eof); } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_EOF(Func> action) { // Arrange var stream = new MemoryStream(); var reader = new HttpRequestStreamReader(stream, Encoding.UTF8); // Act & Assert var eof = await action(reader); Assert.Null(eof); } [Theory] [MemberData(nameof(ReadLineData))] public static async Task ReadLine_NewLineOnly(Func> action) { // Arrange var stream = new MemoryStream(); var writer = new StreamWriter(stream); writer.Write("\r\n"); writer.Flush(); stream.Position = 0; var reader = new HttpRequestStreamReader(stream, Encoding.UTF8); // Act & Assert var empty = await action(reader); Assert.Equal(string.Empty, empty); } [Fact] public static void Read_Span_ReadAllCharactersAtOnce() { // Arrange var reader = CreateReader(); var chars = new char[CharData.Length]; var span = new Span(chars); // Act var read = reader.Read(span); // Assert Assert.Equal(chars.Length, read); for (var i = 0; i < CharData.Length; i++) { Assert.Equal(CharData[i], chars[i]); } } [Fact] public static void Read_Span_WithMoreDataThanInternalBufferSize() { // Arrange var reader = CreateReader(10); var chars = new char[CharData.Length]; var span = new Span(chars); // Act var read = reader.Read(span); // Assert Assert.Equal(chars.Length, read); for (var i = 0; i < CharData.Length; i++) { Assert.Equal(CharData[i], chars[i]); } } [Fact] public async static Task ReadAsync_Memory_ReadAllCharactersAtOnce() { // Arrange var reader = CreateReader(); var chars = new char[CharData.Length]; var memory = new Memory(chars); // Act var read = await reader.ReadAsync(memory); // Assert Assert.Equal(chars.Length, read); for (var i = 0; i < CharData.Length; i++) { Assert.Equal(CharData[i], chars[i]); } } [Fact] public async static Task ReadAsync_Memory_WithMoreDataThanInternalBufferSize() { // Arrange var reader = CreateReader(10); var chars = new char[CharData.Length]; var memory = new Memory(chars); // Act var read = await reader.ReadAsync(memory); // Assert Assert.Equal(chars.Length, read); for (var i = 0; i < CharData.Length; i++) { Assert.Equal(CharData[i], chars[i]); } } [Theory] [MemberData(nameof(HttpRequestNullData))] public static void NullInputsInConstructor_ExpectArgumentNullException(Stream stream, Encoding encoding, ArrayPool bytePool, ArrayPool charPool) { Assert.Throws(() => { var httpRequestStreamReader = new HttpRequestStreamReader(stream, encoding, 1, bytePool, charPool); }); } [Theory] [InlineData(0)] [InlineData(-1)] public static void NegativeOrZeroBufferSize_ExpectArgumentOutOfRangeException(int size) { Assert.Throws(() => { var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, size, ArrayPool.Shared, ArrayPool.Shared); }); } [Fact] public static void StreamCannotRead_ExpectArgumentException() { var mockStream = new Mock(); mockStream.Setup(m => m.CanRead).Returns(false); Assert.Throws(() => { var httpRequestStreamReader = new HttpRequestStreamReader(mockStream.Object, Encoding.UTF8, 1, ArrayPool.Shared, ArrayPool.Shared); }); } [Theory] [MemberData(nameof(HttpRequestDisposeData))] public static void StreamDisposed_ExpectedObjectDisposedException(Action action) { var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); httpRequestStreamReader.Dispose(); Assert.Throws(() => { action(httpRequestStreamReader); }); } [Theory] [MemberData(nameof(HttpRequestDisposeDataAsync))] public static async Task StreamDisposed_ExpectObjectDisposedExceptionAsync(Func action) { var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); httpRequestStreamReader.Dispose(); await Assert.ThrowsAsync(() => action(httpRequestStreamReader)); } private static HttpRequestStreamReader CreateReader() { MemoryStream stream = CreateStream(); return new HttpRequestStreamReader(stream, Encoding.UTF8); } private static HttpRequestStreamReader CreateReader(int bufferSize) { MemoryStream stream = CreateStream(); return new HttpRequestStreamReader(stream, Encoding.UTF8, bufferSize); } private static MemoryStream CreateStream() { var stream = new MemoryStream(); var writer = new StreamWriter(stream); writer.Write(CharData); writer.Flush(); stream.Position = 0; return stream; } private static MemoryStream GetSmallStream() { var testData = new byte[] { 72, 69, 76, 76, 79 }; return new MemoryStream(testData); } private static MemoryStream GetLargeStream() { var testData = new byte[] { 72, 69, 76, 76, 79 }; // System.Collections.Generic. var data = new List(); for (var i = 0; i < 1000; i++) { data.AddRange(testData); } return new MemoryStream(data.ToArray()); } public static IEnumerable HttpRequestNullData() { yield return new object?[] { null, Encoding.UTF8, ArrayPool.Shared, ArrayPool.Shared }; yield return new object?[] { new MemoryStream(), null, ArrayPool.Shared, ArrayPool.Shared }; yield return new object?[] { new MemoryStream(), Encoding.UTF8, null, ArrayPool.Shared }; yield return new object?[] { new MemoryStream(), Encoding.UTF8, ArrayPool.Shared, null }; } public static IEnumerable HttpRequestDisposeData() { yield return new object[] { new Action((httpRequestStreamReader) => { var res = httpRequestStreamReader.Read(); })}; yield return new object[] { new Action((httpRequestStreamReader) => { var res = httpRequestStreamReader.Read(new char[10], 0, 1); })}; yield return new object[] { new Action((httpRequestStreamReader) => { var res = httpRequestStreamReader.Read(new Span(new char[10], 0, 1)); })}; yield return new object[] { new Action((httpRequestStreamReader) => { var res = httpRequestStreamReader.Peek(); })}; } public static IEnumerable HttpRequestDisposeDataAsync() { yield return new object[] { new Func(async (httpRequestStreamReader) => { await httpRequestStreamReader.ReadAsync(new char[10], 0, 1); })}; yield return new object[] { new Func(async (httpRequestStreamReader) => { await httpRequestStreamReader.ReadAsync(new Memory(new char[10], 0, 1)); })}; } public static IEnumerable ReadLineData() { yield return new object[] { new Func>((httpRequestStreamReader) => Task.FromResult(httpRequestStreamReader.ReadLine()) )}; yield return new object[] { new Func>((httpRequestStreamReader) => httpRequestStreamReader.ReadLineAsync() )}; } private class AsyncOnlyStreamWrapper : Stream { private readonly Stream _inner; public AsyncOnlyStreamWrapper(Stream inner) { _inner = inner; } public override bool CanRead => _inner.CanRead; public override bool CanSeek => _inner.CanSeek; public override bool CanWrite => _inner.CanWrite; public override long Length => _inner.Length; public override long Position { get => _inner.Position; set => _inner.Position = value; } public override void Flush() { throw SyncOperationForbiddenException(); } public override Task FlushAsync(CancellationToken cancellationToken) { return _inner.FlushAsync(cancellationToken); } public override int Read(byte[] buffer, int offset, int count) { throw SyncOperationForbiddenException(); } public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { return _inner.ReadAsync(buffer, offset, count, cancellationToken); } public override long Seek(long offset, SeekOrigin origin) { return _inner.Seek(offset, origin); } public override void SetLength(long value) { _inner.SetLength(value); } public override void Write(byte[] buffer, int offset, int count) { throw SyncOperationForbiddenException(); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { return _inner.WriteAsync(buffer, offset, count, cancellationToken); } protected override void Dispose(bool disposing) { _inner.Dispose(); } public override ValueTask DisposeAsync() { return _inner.DisposeAsync(); } private Exception SyncOperationForbiddenException() { return new InvalidOperationException("The stream cannot be accessed synchronously"); } } } }