From 057fc816fab1062700959cdaad2a32e9d85efca6 Mon Sep 17 00:00:00 2001 From: Justin Kotalik Date: Wed, 12 Jul 2017 17:59:10 -0700 Subject: [PATCH] Refactor to HttpRequest/Response Streams --- build/dependencies.props | 1 + .../HttpRequestStreamReader.cs | 75 ++++-------- .../HttpResponseStreamWriter.cs | 71 ++++++----- .../HttpRequestStreamReaderTest.cs | 87 ++++++++++++++ .../HttpResponseStreamWriterTest.cs | 110 ++++++++++++++++++ ...osoft.AspNetCore.WebUtilities.Tests.csproj | 1 + 6 files changed, 255 insertions(+), 90 deletions(-) diff --git a/build/dependencies.props b/build/dependencies.props index 3c71bd378d..4e451e9dc7 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -3,6 +3,7 @@ 2.1.0-* 4.4.0-* 2.1.1-* + 4.7.49 2.0.0-* 2.0.0-* 2.0.0-* diff --git a/src/Microsoft.AspNetCore.WebUtilities/HttpRequestStreamReader.cs b/src/Microsoft.AspNetCore.WebUtilities/HttpRequestStreamReader.cs index 24ee35936c..3f9478c5de 100644 --- a/src/Microsoft.AspNetCore.WebUtilities/HttpRequestStreamReader.cs +++ b/src/Microsoft.AspNetCore.WebUtilities/HttpRequestStreamReader.cs @@ -32,6 +32,7 @@ namespace Microsoft.AspNetCore.WebUtilities private int _bytesRead; private bool _isBlocked; + private bool _disposed; public HttpRequestStreamReader(Stream stream, Encoding encoding) : this(stream, encoding, DefaultBufferSize, ArrayPool.Shared, ArrayPool.Shared) @@ -50,44 +51,23 @@ namespace Microsoft.AspNetCore.WebUtilities ArrayPool bytePool, ArrayPool charPool) { - if (stream == null) - { - throw new ArgumentNullException(nameof(stream)); - } - - if (!stream.CanRead) - { - throw new ArgumentException(Resources.HttpRequestStreamReader_StreamNotReadable, nameof(stream)); - } - - if (encoding == null) - { - throw new ArgumentNullException(nameof(encoding)); - } - - if (bytePool == null) - { - throw new ArgumentNullException(nameof(bytePool)); - } - - if (charPool == null) - { - throw new ArgumentNullException(nameof(charPool)); - } + _stream = stream ?? throw new ArgumentNullException(nameof(stream)); + _encoding = encoding ?? throw new ArgumentNullException(nameof(encoding)); + _bytePool = bytePool ?? throw new ArgumentNullException(nameof(bytePool)); + _charPool = charPool ?? throw new ArgumentNullException(nameof(charPool)); if (bufferSize <= 0) { throw new ArgumentOutOfRangeException(nameof(bufferSize)); } + if (!stream.CanRead) + { + throw new ArgumentException(Resources.HttpRequestStreamReader_StreamNotReadable, nameof(stream)); + } - _stream = stream; - _encoding = encoding; _byteBufferSize = bufferSize; - _bytePool = bytePool; - _charPool = charPool; _decoder = encoding.GetDecoder(); - _byteBuffer = _bytePool.Rent(bufferSize); try @@ -98,33 +78,24 @@ namespace Microsoft.AspNetCore.WebUtilities catch { _bytePool.Return(_byteBuffer); - _byteBuffer = null; if (_charBuffer != null) { _charPool.Return(_charBuffer); - _charBuffer = null; } + + throw; } } protected override void Dispose(bool disposing) { - if (disposing && _stream != null) + if (disposing && !_disposed) { - _stream = null; + _disposed = true; - if (_bytePool != null) - { - _bytePool.Return(_byteBuffer); - _byteBuffer = null; - } - - if (_charPool != null) - { - _charPool.Return(_charBuffer); - _charBuffer = null; - } + _bytePool.Return(_byteBuffer); + _charPool.Return(_charBuffer); } base.Dispose(disposing); @@ -132,9 +103,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override int Peek() { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpRequestStreamReader)); } if (_charBufferIndex == _charsRead) @@ -150,9 +121,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override int Read() { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpRequestStreamReader)); } if (_charBufferIndex == _charsRead) @@ -183,9 +154,9 @@ namespace Microsoft.AspNetCore.WebUtilities throw new ArgumentOutOfRangeException(nameof(count)); } - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpRequestStreamReader)); } var charsRead = 0; @@ -246,9 +217,9 @@ namespace Microsoft.AspNetCore.WebUtilities throw new ArgumentOutOfRangeException(nameof(count)); } - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpRequestStreamReader)); } if (_charBufferIndex == _charsRead && await ReadIntoBufferAsync() == 0) diff --git a/src/Microsoft.AspNetCore.WebUtilities/HttpResponseStreamWriter.cs b/src/Microsoft.AspNetCore.WebUtilities/HttpResponseStreamWriter.cs index 4356693d58..050088ccb7 100644 --- a/src/Microsoft.AspNetCore.WebUtilities/HttpResponseStreamWriter.cs +++ b/src/Microsoft.AspNetCore.WebUtilities/HttpResponseStreamWriter.cs @@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.WebUtilities private char[] _charBuffer; private int _charBufferCount; + private bool _disposed; public HttpResponseStreamWriter(Stream stream, Encoding encoding) : this(stream, encoding, DefaultBufferSize, ArrayPool.Shared, ArrayPool.Shared) @@ -46,15 +47,19 @@ namespace Microsoft.AspNetCore.WebUtilities ArrayPool bytePool, ArrayPool charPool) { - if (!stream.CanWrite) - { - throw new ArgumentException(Resources.HttpResponseStreamWriter_StreamNotWritable, nameof(stream)); - } - + _stream = stream ?? throw new ArgumentNullException(nameof(stream)); Encoding = encoding ?? throw new ArgumentNullException(nameof(encoding)); _bytePool = bytePool ?? throw new ArgumentNullException(nameof(bytePool)); _charPool = charPool ?? throw new ArgumentNullException(nameof(charPool)); - _stream = stream ?? throw new ArgumentNullException(nameof(stream)); + + if (bufferSize <= 0) + { + throw new ArgumentOutOfRangeException(nameof(bufferSize)); + } + if (!_stream.CanWrite) + { + throw new ArgumentException(Resources.HttpResponseStreamWriter_StreamNotWritable, nameof(stream)); + } _charBufferSize = bufferSize; @@ -69,12 +74,10 @@ namespace Microsoft.AspNetCore.WebUtilities catch { charPool.Return(_charBuffer); - _charBuffer = null; if (_byteBuffer != null) { bytePool.Return(_byteBuffer); - _byteBuffer = null; } throw; @@ -85,9 +88,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override void Write(char value) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (_charBufferCount == _charBufferSize) @@ -101,9 +104,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override void Write(char[] values, int index, int count) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (values == null) @@ -124,9 +127,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override void Write(string value) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (value == null) @@ -149,9 +152,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override async Task WriteAsync(char value) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (_charBufferCount == _charBufferSize) @@ -165,9 +168,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override async Task WriteAsync(char[] values, int index, int count) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (values == null) @@ -188,9 +191,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override async Task WriteAsync(string value) { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } if (value == null) @@ -216,9 +219,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override void Flush() { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } FlushInternal(flushEncoder: true); @@ -226,9 +229,9 @@ namespace Microsoft.AspNetCore.WebUtilities public override Task FlushAsync() { - if (_stream == null) + if (_disposed) { - throw new ObjectDisposedException("stream"); + throw new ObjectDisposedException(nameof(HttpResponseStreamWriter)); } return FlushInternalAsync(flushEncoder: true); @@ -236,29 +239,21 @@ namespace Microsoft.AspNetCore.WebUtilities protected override void Dispose(bool disposing) { - if (disposing && _stream != null) + if (disposing && !_disposed) { + _disposed = true; try { FlushInternal(flushEncoder: true); } finally { - _stream = null; - - if (_bytePool != null) - { - _bytePool.Return(_byteBuffer); - _byteBuffer = null; - } - - if (_charPool != null) - { - _charPool.Return(_charBuffer); - _charBuffer = null; - } + _bytePool.Return(_byteBuffer); + _charPool.Return(_charBuffer); } } + + base.Dispose(disposing); } // Note: our FlushInternal method does NOT flush the underlying stream. This would result in diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpRequestStreamReaderTest.cs b/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpRequestStreamReaderTest.cs index 82163c2b54..ee4d2f2bdc 100644 --- a/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpRequestStreamReaderTest.cs +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpRequestStreamReaderTest.cs @@ -1,12 +1,16 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Moq; +using System; +using System.Buffers; using System.Collections.Generic; using System.IO; using System.Text; using System.Threading.Tasks; using Xunit; + namespace Microsoft.AspNetCore.WebUtilities.Test { public class HttpResponseStreamReaderTest @@ -191,6 +195,64 @@ namespace Microsoft.AspNetCore.WebUtilities.Test Assert.Null(eol); } + [Theory] + [MemberData(nameof(HttpRequestNullData))] + public static void NullInputsInConstructor_ExpectArgumentNullException(Stream stream, Encoding encoding, ArrayPool bytePool, ArrayPool charPool) + { + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpRequestStreamReader(stream, encoding, 1, bytePool, charPool); + }); + } + + + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public static void NegativeOrZeroBufferSize_ExpectArgumentOutOfRangeException(int size) + { + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, size, ArrayPool.Shared, ArrayPool.Shared); + }); + } + + [Fact] + public static void StreamCannotRead_ExpectArgumentException() + { + var mockStream = new Mock(); + mockStream.Setup(m => m.CanRead).Returns(false); + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpRequestStreamReader(mockStream.Object, Encoding.UTF8, 1, ArrayPool.Shared, ArrayPool.Shared); + }); + } + + [Theory] + [MemberData(nameof(HttpRequestDisposeData))] + public static void StreamDisposed_ExpectedObjectDisposedException(Action action) + { + var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); + httpRequestStreamReader.Dispose(); + + Assert.Throws(() => + { + action(httpRequestStreamReader); + }); + } + + [Fact] + public static async Task StreamDisposed_ExpectObjectDisposedExceptionAsync() + { + var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); + httpRequestStreamReader.Dispose(); + + await Assert.ThrowsAsync(() => + { + return httpRequestStreamReader.ReadAsync(new char[10], 0, 1); + }); + } private static HttpRequestStreamReader CreateReader() { var stream = new MemoryStream(); @@ -221,5 +283,30 @@ namespace Microsoft.AspNetCore.WebUtilities.Test return new MemoryStream(data.ToArray()); } + private static IEnumerable HttpRequestNullData() + { + yield return new object[] { null, Encoding.UTF8, ArrayPool.Shared, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), null, ArrayPool.Shared, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), Encoding.UTF8, null, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), Encoding.UTF8, ArrayPool.Shared, null }; + } + + private static IEnumerable HttpRequestDisposeData() + { + yield return new object[] { new Action((httpRequestStreamReader) => + { + var res = httpRequestStreamReader.Read(); + })}; + yield return new object[] { new Action((httpRequestStreamReader) => + { + var res = httpRequestStreamReader.Read(new char[10], 0, 1); + })}; + + yield return new object[] { new Action((httpRequestStreamReader) => + { + var res = httpRequestStreamReader.Peek(); + })}; + + } } } diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpResponseStreamWriterTest.cs b/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpResponseStreamWriterTest.cs index 10bd4bda3f..d0c94f9739 100644 --- a/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpResponseStreamWriterTest.cs +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/HttpResponseStreamWriterTest.cs @@ -1,8 +1,10 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using Moq; using System; using System.Buffers; +using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; @@ -405,6 +407,65 @@ namespace Microsoft.AspNetCore.WebUtilities.Test Assert.Equal(content, actualContent); } + [Theory] + [MemberData(nameof(HttpResponseStreamWriterData))] + public static void NullInputsInConstructor_ExpectArgumentNullException(Stream stream, Encoding encoding, ArrayPool bytePool, ArrayPool charPool) + { + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpResponseStreamWriter(stream, encoding, 1, bytePool, charPool); + }); + } + + [Theory] + [InlineData(0)] + [InlineData(-1)] + public static void NegativeOrZeroBufferSize_ExpectArgumentOutOfRangeException(int size) + { + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpRequestStreamReader(new MemoryStream(), Encoding.UTF8, size, ArrayPool.Shared, ArrayPool.Shared); + }); + } + + [Fact] + public static void StreamCannotRead_ExpectArgumentException() + { + var mockStream = new Mock(); + mockStream.Setup(m => m.CanWrite).Returns(false); + Assert.Throws(() => + { + var httpRequestStreamReader = new HttpRequestStreamReader(mockStream.Object, Encoding.UTF8, 1, ArrayPool.Shared, ArrayPool.Shared); + }); + } + + [Theory] + [MemberData(nameof(HttpResponseDisposeData))] + public static void StreamDisposed_ExpectedObjectDisposedException(Action action) + { + var httpResponseStreamWriter = new HttpResponseStreamWriter(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); + httpResponseStreamWriter.Dispose(); + + Assert.Throws(() => + { + action(httpResponseStreamWriter); + }); + } + + [Theory] + [MemberData(nameof(HttpResponseDisposeDataAsync))] + public static async Task StreamDisposed_ExpectedObjectDisposedExceptionAsync(Func function) + { + var httpResponseStreamWriter = new HttpResponseStreamWriter(new MemoryStream(), Encoding.UTF8, 10, ArrayPool.Shared, ArrayPool.Shared); + httpResponseStreamWriter.Dispose(); + + await Assert.ThrowsAsync(() => + { + return function(httpResponseStreamWriter); + }); + } + + private class TestMemoryStream : MemoryStream { public int FlushCallCount { get; private set; } @@ -459,5 +520,54 @@ namespace Microsoft.AspNetCore.WebUtilities.Test base.Dispose(disposing); } } + + private static IEnumerable HttpResponseStreamWriterData() + { + yield return new object[] { null, Encoding.UTF8, ArrayPool.Shared, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), null, ArrayPool.Shared, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), Encoding.UTF8, null, ArrayPool.Shared }; + yield return new object[] { new MemoryStream(), Encoding.UTF8, ArrayPool.Shared, null }; + } + + private static IEnumerable HttpResponseDisposeData() + { + yield return new object[] { new Action((httpResponseStreamWriter) => + { + httpResponseStreamWriter.Write('a'); + })}; + yield return new object[] { new Action((httpResponseStreamWriter) => + { + httpResponseStreamWriter.Write(new char[] { 'a', 'b' }, 0, 1); + })}; + + yield return new object[] { new Action((httpResponseStreamWriter) => + { + httpResponseStreamWriter.Write("hello"); + })}; + yield return new object[] { new Action((httpResponseStreamWriter) => + { + httpResponseStreamWriter.Flush(); + })}; + } + private static IEnumerable HttpResponseDisposeDataAsync() + { + yield return new object[] { new Func(async (httpResponseStreamWriter) => + { + await httpResponseStreamWriter.WriteAsync('a'); + })}; + yield return new object[] { new Func(async (httpResponseStreamWriter) => + { + await httpResponseStreamWriter.WriteAsync(new char[] { 'a', 'b' }, 0, 1); + })}; + + yield return new object[] { new Func(async (httpResponseStreamWriter) => + { + await httpResponseStreamWriter.WriteAsync("hello"); + })}; + yield return new object[] { new Func(async (httpResponseStreamWriter) => + { + await httpResponseStreamWriter.FlushAsync(); + })}; + } } } diff --git a/test/Microsoft.AspNetCore.WebUtilities.Tests/Microsoft.AspNetCore.WebUtilities.Tests.csproj b/test/Microsoft.AspNetCore.WebUtilities.Tests/Microsoft.AspNetCore.WebUtilities.Tests.csproj index c37aac336e..66a940aed2 100644 --- a/test/Microsoft.AspNetCore.WebUtilities.Tests/Microsoft.AspNetCore.WebUtilities.Tests.csproj +++ b/test/Microsoft.AspNetCore.WebUtilities.Tests/Microsoft.AspNetCore.WebUtilities.Tests.csproj @@ -13,6 +13,7 @@ +