diff --git a/src/Http/WebUtilities/ref/Microsoft.AspNetCore.WebUtilities.netcoreapp3.0.cs b/src/Http/WebUtilities/ref/Microsoft.AspNetCore.WebUtilities.netcoreapp3.0.cs index 45d806e297..afdad71011 100644 --- a/src/Http/WebUtilities/ref/Microsoft.AspNetCore.WebUtilities.netcoreapp3.0.cs +++ b/src/Http/WebUtilities/ref/Microsoft.AspNetCore.WebUtilities.netcoreapp3.0.cs @@ -41,6 +41,7 @@ namespace Microsoft.AspNetCore.WebUtilities } public partial class FileBufferingReadStream : System.IO.Stream { + public FileBufferingReadStream(System.IO.Stream inner, int memoryThreshold) { } public FileBufferingReadStream(System.IO.Stream inner, int memoryThreshold, long? bufferLimit, System.Func tempFileDirectoryAccessor) { } public FileBufferingReadStream(System.IO.Stream inner, int memoryThreshold, long? bufferLimit, System.Func tempFileDirectoryAccessor, System.Buffers.ArrayPool bytePool) { } public FileBufferingReadStream(System.IO.Stream inner, int memoryThreshold, long? bufferLimit, string tempFileDirectory) { } diff --git a/src/Http/WebUtilities/src/FileBufferingReadStream.cs b/src/Http/WebUtilities/src/FileBufferingReadStream.cs index 9dd1fbf13f..3aa1fef6b9 100644 --- a/src/Http/WebUtilities/src/FileBufferingReadStream.cs +++ b/src/Http/WebUtilities/src/FileBufferingReadStream.cs @@ -7,6 +7,7 @@ using System.Diagnostics; using System.IO; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Internal; namespace Microsoft.AspNetCore.WebUtilities { @@ -33,6 +34,16 @@ namespace Microsoft.AspNetCore.WebUtilities private bool _disposed; + /// + /// Initializes a new instance of . + /// + /// The wrapping . + /// The maximum size to buffer in memory. + public FileBufferingReadStream(Stream inner, int memoryThreshold) + : this(inner, memoryThreshold, bufferLimit: null, tempFileDirectoryAccessor: AspNetCoreTempDirectory.TempDirectoryFactory) + { + } + public FileBufferingReadStream( Stream inner, int memoryThreshold, @@ -223,13 +234,19 @@ namespace Microsoft.AspNetCore.WebUtilities { oldBuffer.Position = 0; var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize)); - var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); - while (copyRead > 0) + try { - _buffer.Write(rentedBuffer, 0, copyRead); - copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + while (copyRead > 0) + { + _buffer.Write(rentedBuffer, 0, copyRead); + copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + } + } + finally + { + _bytePool.Return(rentedBuffer); } - _bytePool.Return(rentedBuffer); } else { @@ -277,14 +294,20 @@ namespace Microsoft.AspNetCore.WebUtilities { oldBuffer.Position = 0; var rentedBuffer = _bytePool.Rent(Math.Min((int)oldBuffer.Length, _maxRentedBufferSize)); - // oldBuffer is a MemoryStream, no need to do async reads. - var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); - while (copyRead > 0) + try { - await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken); - copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + // oldBuffer is a MemoryStream, no need to do async reads. + var copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + while (copyRead > 0) + { + await _buffer.WriteAsync(rentedBuffer, 0, copyRead, cancellationToken); + copyRead = oldBuffer.Read(rentedBuffer, 0, rentedBuffer.Length); + } + } + finally + { + _bytePool.Return(rentedBuffer); } - _bytePool.Return(rentedBuffer); } else { @@ -351,4 +374,4 @@ namespace Microsoft.AspNetCore.WebUtilities } } } -} \ No newline at end of file +} diff --git a/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs b/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs index a83f1574eb..a220635632 100644 --- a/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs +++ b/src/Http/WebUtilities/test/FileBufferingReadStreamTests.cs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.IO; using System.Text; using System.Threading.Tasks; +using Moq; using Xunit; namespace Microsoft.AspNetCore.WebUtilities @@ -291,9 +293,67 @@ namespace Microsoft.AspNetCore.WebUtilities Assert.False(File.Exists(tempFileName)); } + [Fact] + public void FileBufferingReadStream_UsingMemoryStream_RentsAndReturnsRentedBuffer_WhenCopyingFromMemoryStreamDuringRead() + { + var inner = MakeStream(1024 * 1024 + 25); + string tempFileName; + var arrayPool = new Mock>(); + arrayPool.Setup(p => p.Rent(It.IsAny())) + .Returns((int m) => ArrayPool.Shared.Rent(m)); + arrayPool.Setup(p => p.Return(It.IsAny(), It.IsAny())) + .Callback((byte[] bytes, bool clear) => ArrayPool.Shared.Return(bytes, clear)); + + using (var stream = new FileBufferingReadStream(inner, 1024 * 1024 + 1, 2 * 1024 * 1024, GetCurrentDirectory(), arrayPool.Object)) + { + arrayPool.Verify(v => v.Rent(It.IsAny()), Times.Never()); + + stream.Read(new byte[1024 * 1024]); + Assert.False(File.Exists(stream.TempFileName), "tempFile should not be created as yet"); + + stream.Read(new byte[4]); + Assert.True(File.Exists(stream.TempFileName), "tempFile should be created"); + tempFileName = stream.TempFileName; + + arrayPool.Verify(v => v.Rent(It.IsAny()), Times.Once()); + arrayPool.Verify(v => v.Return(It.IsAny(), It.IsAny()), Times.Once()); + } + + Assert.False(File.Exists(tempFileName)); + } + + [Fact] + public async Task FileBufferingReadStream_UsingMemoryStream_RentsAndReturnsRentedBuffer_WhenCopyingFromMemoryStreamDuringReadAsync() + { + var inner = MakeStream(1024 * 1024 + 25); + string tempFileName; + var arrayPool = new Mock>(); + arrayPool.Setup(p => p.Rent(It.IsAny())) + .Returns((int m) => ArrayPool.Shared.Rent(m)); + arrayPool.Setup(p => p.Return(It.IsAny(), It.IsAny())) + .Callback((byte[] bytes, bool clear) => ArrayPool.Shared.Return(bytes, clear)); + + using (var stream = new FileBufferingReadStream(inner, 1024 * 1024 + 1, 2 * 1024 * 1024, GetCurrentDirectory(), arrayPool.Object)) + { + arrayPool.Verify(v => v.Rent(It.IsAny()), Times.Never()); + + await stream.ReadAsync(new byte[1024 * 1024]); + Assert.False(File.Exists(stream.TempFileName), "tempFile should not be created as yet"); + + await stream.ReadAsync(new byte[4]); + Assert.True(File.Exists(stream.TempFileName), "tempFile should be created"); + tempFileName = stream.TempFileName; + + arrayPool.Verify(v => v.Rent(It.IsAny()), Times.Once()); + arrayPool.Verify(v => v.Return(It.IsAny(), It.IsAny()), Times.Once()); + } + + Assert.False(File.Exists(tempFileName)); + } + private static string GetCurrentDirectory() { return AppContext.BaseDirectory; } } -} \ No newline at end of file +} diff --git a/src/Mvc/Mvc.Core/src/MvcOptions.cs b/src/Mvc/Mvc.Core/src/MvcOptions.cs index 3fa005f96c..f9e13f6317 100644 --- a/src/Mvc/Mvc.Core/src/MvcOptions.cs +++ b/src/Mvc/Mvc.Core/src/MvcOptions.cs @@ -103,7 +103,8 @@ namespace Microsoft.AspNetCore.Mvc public FormatterCollection InputFormatters { get; } /// - /// Gets or sets the flag to buffer the request body in input formatters. Default is false. + /// Gets or sets a value that determines if buffering is disabled for input formatters that + /// synchronously read from the HTTP request body. /// public bool SuppressInputFormatterBuffering { get; set; } diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerInputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerInputFormatter.cs index a849abd151..69f73e688a 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerInputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerInputFormatter.cs @@ -4,14 +4,12 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; using System.IO; using System.Runtime.Serialization; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Xml; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.Formatters.Xml; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.WebUtilities; @@ -24,6 +22,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters /// public class XmlDataContractSerializerInputFormatter : TextInputFormatter, IInputFormatterExceptionPolicy { + private const int DefaultMemoryThreshold = 1024 * 30; private readonly ConcurrentDictionary _serializerCache = new ConcurrentDictionary(); private readonly XmlDictionaryReaderQuotas _readerQuotas = FormattingUtilities.GetDefaultXmlReaderQuotas(); private readonly MvcOptions _options; @@ -118,43 +117,55 @@ namespace Microsoft.AspNetCore.Mvc.Formatters } var request = context.HttpContext.Request; + Stream readStream = new NonDisposableStream(request.Body); if (!request.Body.CanSeek && !_options.SuppressInputFormatterBuffering) { // XmlDataContractSerializer does synchronous reads. In order to avoid blocking on the stream, we asynchronously // read everything into a buffer, and then seek back to the beginning. - request.EnableBuffering(); - Debug.Assert(request.Body.CanSeek); + var memoryThreshold = DefaultMemoryThreshold; + if (request.ContentLength.HasValue && request.ContentLength.Value > 0 && request.ContentLength.Value < memoryThreshold) + { + // If the Content-Length is known and is smaller than the default buffer size, use it. + memoryThreshold = (int)request.ContentLength.Value; + } - await request.Body.DrainAsync(CancellationToken.None); - request.Body.Seek(0L, SeekOrigin.Begin); + readStream = new FileBufferingReadStream(request.Body, memoryThreshold); + + await readStream.DrainAsync(CancellationToken.None); + readStream.Seek(0L, SeekOrigin.Begin); } try { - using (var xmlReader = CreateXmlReader(new NonDisposableStream(request.Body), encoding)) + using var xmlReader = CreateXmlReader(readStream, encoding); + var type = GetSerializableType(context.ModelType); + var serializer = GetCachedSerializer(type); + + var deserializedObject = serializer.ReadObject(xmlReader); + + // Unwrap only if the original type was wrapped. + if (type != context.ModelType) { - var type = GetSerializableType(context.ModelType); - var serializer = GetCachedSerializer(type); - - var deserializedObject = serializer.ReadObject(xmlReader); - - // Unwrap only if the original type was wrapped. - if (type != context.ModelType) + if (deserializedObject is IUnwrappable unwrappable) { - if (deserializedObject is IUnwrappable unwrappable) - { - deserializedObject = unwrappable.Unwrap(declaredType: context.ModelType); - } + deserializedObject = unwrappable.Unwrap(declaredType: context.ModelType); } - - return InputFormatterResult.Success(deserializedObject); } + + return InputFormatterResult.Success(deserializedObject); } catch (SerializationException exception) { throw new InputFormatterException(Resources.ErrorDeserializingInputData, exception); } + finally + { + if (readStream is FileBufferingReadStream fileBufferingReadStream) + { + fileBufferingReadStream.Dispose(); + } + } } /// diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerInputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerInputFormatter.cs index 3dad1a7305..6bc52920b1 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerInputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerInputFormatter.cs @@ -4,14 +4,12 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; using System.IO; using System.Text; using System.Threading; using System.Threading.Tasks; using System.Xml; using System.Xml.Serialization; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.Formatters.Xml; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.AspNetCore.WebUtilities; @@ -24,6 +22,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters /// public class XmlSerializerInputFormatter : TextInputFormatter, IInputFormatterExceptionPolicy { + private const int DefaultMemoryThreshold = 1024 * 30; private readonly ConcurrentDictionary _serializerCache = new ConcurrentDictionary(); private readonly XmlDictionaryReaderQuotas _readerQuotas = FormattingUtilities.GetDefaultXmlReaderQuotas(); private readonly MvcOptions _options; @@ -99,39 +98,43 @@ namespace Microsoft.AspNetCore.Mvc.Formatters } var request = context.HttpContext.Request; - + Stream readStream = new NonDisposableStream(request.Body); if (!request.Body.CanSeek && !_options.SuppressInputFormatterBuffering) { // XmlSerializer does synchronous reads. In order to avoid blocking on the stream, we asynchronously // read everything into a buffer, and then seek back to the beginning. - request.EnableBuffering(); - Debug.Assert(request.Body.CanSeek); + var memoryThreshold = DefaultMemoryThreshold; + if (request.ContentLength.HasValue && request.ContentLength.Value > 0 && request.ContentLength.Value < memoryThreshold) + { + // If the Content-Length is known and is smaller than the default buffer size, use it. + memoryThreshold = (int)request.ContentLength.Value; + } - await request.Body.DrainAsync(CancellationToken.None); - request.Body.Seek(0L, SeekOrigin.Begin); + readStream = new FileBufferingReadStream(request.Body, memoryThreshold); + + await readStream.DrainAsync(CancellationToken.None); + readStream.Seek(0L, SeekOrigin.Begin); } try { - using (var xmlReader = CreateXmlReader(new NonDisposableStream(request.Body), encoding)) + using var xmlReader = CreateXmlReader(readStream, encoding); + var type = GetSerializableType(context.ModelType); + + var serializer = GetCachedSerializer(type); + + var deserializedObject = serializer.Deserialize(xmlReader); + + // Unwrap only if the original type was wrapped. + if (type != context.ModelType) { - var type = GetSerializableType(context.ModelType); - - var serializer = GetCachedSerializer(type); - - var deserializedObject = serializer.Deserialize(xmlReader); - - // Unwrap only if the original type was wrapped. - if (type != context.ModelType) + if (deserializedObject is IUnwrappable unwrappable) { - if (deserializedObject is IUnwrappable unwrappable) - { - deserializedObject = unwrappable.Unwrap(declaredType: context.ModelType); - } + deserializedObject = unwrappable.Unwrap(declaredType: context.ModelType); } - - return InputFormatterResult.Success(deserializedObject); } + + return InputFormatterResult.Success(deserializedObject); } // XmlSerializer wraps actual exceptions (like FormatException or XmlException) into an InvalidOperationException // https://github.com/dotnet/corefx/blob/master/src/System.Private.Xml/src/System/Xml/Serialization/XmlSerializer.cs#L652 @@ -149,6 +152,13 @@ namespace Microsoft.AspNetCore.Mvc.Formatters { throw new InputFormatterException(Resources.ErrorDeserializingInputData, exception.InnerException); } + finally + { + if (readStream is FileBufferingReadStream fileBufferingReadStream) + { + fileBufferingReadStream.Dispose(); + } + } } /// diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerInputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerInputFormatterTest.cs index b8359803fc..50ddb255a4 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerInputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerInputFormatterTest.cs @@ -6,6 +6,7 @@ using System.IO; using System.Linq; using System.Runtime.Serialization; using System.Text; +using System.Threading; using System.Threading.Tasks; using System.Xml; using Microsoft.AspNetCore.Http; @@ -149,7 +150,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml var contentBytes = Encoding.UTF8.GetBytes(input); var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: true); httpContext.Request.ContentType = "application/json"; var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne)); @@ -163,19 +164,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal(expectedInt, model.SampleInt); Assert.Equal(expectedString, model.sampleString); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(context); - - // Assert - Assert.NotNull(result); - Assert.False(result.HasError); - model = Assert.IsType(result.Model); - - Assert.Equal(expectedInt, model.SampleInt); - Assert.Equal(expectedString, model.sampleString); } [Fact] @@ -227,8 +215,9 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml var formatter = new XmlDataContractSerializerInputFormatter(new MvcOptions()); var contentBytes = Encoding.UTF8.GetBytes(input); var httpContext = new DefaultHttpContext(); + httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: false); httpContext.Request.ContentType = "application/json"; var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne)); @@ -242,19 +231,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal(expectedInt, model.SampleInt); Assert.Equal(expectedString, model.sampleString); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(context); - - // Assert - Assert.NotNull(result); - Assert.False(result.HasError); - model = Assert.IsType(result.Model); - - Assert.Equal(expectedInt, model.SampleInt); - Assert.Equal(expectedString, model.sampleString); } [Fact] @@ -287,9 +263,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal(expectedInt, model.SampleInt); Assert.Equal(expectedString, model.sampleString); - - // Reading again should fail as buffering request body is disabled - await Assert.ThrowsAsync(() => formatter.ReadAsync(context)); } [Fact] diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerInputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerInputFormatterTest.cs index 172b49ce67..e37ee32ca3 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerInputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerInputFormatterTest.cs @@ -57,7 +57,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml var contentBytes = Encoding.UTF8.GetBytes(input); var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: true); httpContext.Request.ContentType = "application/json"; var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne)); @@ -69,22 +69,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.False(result.HasError); var model = Assert.IsType(result.Model); - Assert.Equal(expectedInt, model.SampleInt); - Assert.Equal(expectedString, model.sampleString); - Assert.Equal( - XmlConvert.ToDateTime(expectedDateTime, XmlDateTimeSerializationMode.Utc), - model.SampleDate); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(context); - - // Assert - Assert.NotNull(result); - Assert.False(result.HasError); - model = Assert.IsType(result.Model); - Assert.Equal(expectedInt, model.SampleInt); Assert.Equal(expectedString, model.sampleString); Assert.Equal( @@ -127,9 +111,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal( XmlConvert.ToDateTime(expectedDateTime, XmlDateTimeSerializationMode.Utc), model.SampleDate); - - // Reading again should fail as buffering request body is disabled - await Assert.ThrowsAsync(() => formatter.ReadAsync(context)); } [Fact] @@ -149,7 +130,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml var contentBytes = Encoding.UTF8.GetBytes(input); var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: false); httpContext.Request.ContentType = "application/json"; var context = GetInputFormatterContext(httpContext, typeof(TestLevelOne)); @@ -161,22 +142,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.False(result.HasError); var model = Assert.IsType(result.Model); - Assert.Equal(expectedInt, model.SampleInt); - Assert.Equal(expectedString, model.sampleString); - Assert.Equal( - XmlConvert.ToDateTime(expectedDateTime, XmlDateTimeSerializationMode.Utc), - model.SampleDate); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(context); - - // Assert - Assert.NotNull(result); - Assert.False(result.HasError); - model = Assert.IsType(result.Model); - Assert.Equal(expectedInt, model.SampleInt); Assert.Equal(expectedString, model.sampleString); Assert.Equal( diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonInputFormatter.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonInputFormatter.cs index 254ebcb7b3..8275dbc542 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonInputFormatter.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonInputFormatter.cs @@ -3,13 +3,11 @@ using System; using System.Buffers; -using System.Diagnostics; using System.IO; using System.Runtime.ExceptionServices; using System.Text; using System.Threading; using System.Threading.Tasks; -using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.AspNetCore.Mvc.NewtonsoftJson; using Microsoft.AspNetCore.WebUtilities; @@ -24,6 +22,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters /// public class NewtonsoftJsonInputFormatter : TextInputFormatter, IInputFormatterExceptionPolicy { + private const int DefaultMemoryThreshold = 1024 * 30; private readonly IArrayPool _charPool; private readonly ILogger _logger; private readonly ObjectPoolProvider _objectPoolProvider; @@ -129,117 +128,128 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var suppressInputFormatterBuffering = _options.SuppressInputFormatterBuffering; + var readStream = request.Body; if (!request.Body.CanSeek && !suppressInputFormatterBuffering) { // JSON.Net does synchronous reads. In order to avoid blocking on the stream, we asynchronously // read everything into a buffer, and then seek back to the beginning. - request.EnableBuffering(); - Debug.Assert(request.Body.CanSeek); + var memoryThreshold = DefaultMemoryThreshold; + if (request.ContentLength.HasValue && request.ContentLength.Value > 0 && request.ContentLength.Value < memoryThreshold) + { + // If the Content-Length is known and is smaller than the default buffer size, use it. + memoryThreshold = (int)request.ContentLength.Value; + } - await request.Body.DrainAsync(CancellationToken.None); - request.Body.Seek(0L, SeekOrigin.Begin); + readStream = new FileBufferingReadStream(request.Body, memoryThreshold); + + await readStream.DrainAsync(CancellationToken.None); + readStream.Seek(0L, SeekOrigin.Begin); } - using (var streamReader = context.ReaderFactory(request.Body, encoding)) + var successful = true; + Exception exception = null; + object model; + + using (var streamReader = context.ReaderFactory(readStream, encoding)) { - using (var jsonReader = new JsonTextReader(streamReader)) + using var jsonReader = new JsonTextReader(streamReader); + jsonReader.ArrayPool = _charPool; + jsonReader.CloseInput = false; + + var type = context.ModelType; + var jsonSerializer = CreateJsonSerializer(context); + jsonSerializer.Error += ErrorHandler; + try { - jsonReader.ArrayPool = _charPool; - jsonReader.CloseInput = false; + model = jsonSerializer.Deserialize(jsonReader, type); + } + finally + { + // Clean up the error handler since CreateJsonSerializer() pools instances. + jsonSerializer.Error -= ErrorHandler; + ReleaseJsonSerializer(jsonSerializer); - var successful = true; - Exception exception = null; - void ErrorHandler(object sender, Newtonsoft.Json.Serialization.ErrorEventArgs eventArgs) + if (readStream is FileBufferingReadStream fileBufferingReadStream) { - successful = false; - - // When ErrorContext.Path does not include ErrorContext.Member, add Member to form full path. - var path = eventArgs.ErrorContext.Path; - var member = eventArgs.ErrorContext.Member?.ToString(); - var addMember = !string.IsNullOrEmpty(member); - if (addMember) - { - // Path.Member case (path.Length < member.Length) needs no further checks. - if (path.Length == member.Length) - { - // Add Member in Path.Memb case but not for Path.Path. - addMember = !string.Equals(path, member, StringComparison.Ordinal); - } - else if (path.Length > member.Length) - { - // Finally, check whether Path already ends with Member. - if (member[0] == '[') - { - addMember = !path.EndsWith(member, StringComparison.Ordinal); - } - else - { - addMember = !path.EndsWith("." + member, StringComparison.Ordinal); - } - } - } - - if (addMember) - { - path = ModelNames.CreatePropertyModelName(path, member); - } - - // Handle path combinations such as ""+"Property", "Parent"+"Property", or "Parent"+"[12]". - var key = ModelNames.CreatePropertyModelName(context.ModelName, path); - - exception = eventArgs.ErrorContext.Error; - - var metadata = GetPathMetadata(context.Metadata, path); - var modelStateException = WrapExceptionForModelState(exception); - context.ModelState.TryAddModelError(key, modelStateException, metadata); - - _logger.JsonInputException(exception); - - // Error must always be marked as handled - // Failure to do so can cause the exception to be rethrown at every recursive level and - // overflow the stack for x64 CLR processes - eventArgs.ErrorContext.Handled = true; + fileBufferingReadStream.Dispose(); } + } + } - var type = context.ModelType; - var jsonSerializer = CreateJsonSerializer(context); - jsonSerializer.Error += ErrorHandler; - object model; - try - { - model = jsonSerializer.Deserialize(jsonReader, type); - } - finally - { - // Clean up the error handler since CreateJsonSerializer() pools instances. - jsonSerializer.Error -= ErrorHandler; - ReleaseJsonSerializer(jsonSerializer); - } + if (successful) + { + if (model == null && !context.TreatEmptyInputAsDefaultValue) + { + // Some nonempty inputs might deserialize as null, for example whitespace, + // or the JSON-encoded value "null". The upstream BodyModelBinder needs to + // be notified that we don't regard this as a real input so it can register + // a model binding error. + return InputFormatterResult.NoValue(); + } + else + { + return InputFormatterResult.Success(model); + } + } - if (successful) + if (!(exception is JsonException || exception is OverflowException)) + { + var exceptionDispatchInfo = ExceptionDispatchInfo.Capture(exception); + exceptionDispatchInfo.Throw(); + } + + return InputFormatterResult.Failure(); + + void ErrorHandler(object sender, Newtonsoft.Json.Serialization.ErrorEventArgs eventArgs) + { + successful = false; + + // When ErrorContext.Path does not include ErrorContext.Member, add Member to form full path. + var path = eventArgs.ErrorContext.Path; + var member = eventArgs.ErrorContext.Member?.ToString(); + var addMember = !string.IsNullOrEmpty(member); + if (addMember) + { + // Path.Member case (path.Length < member.Length) needs no further checks. + if (path.Length == member.Length) { - if (model == null && !context.TreatEmptyInputAsDefaultValue) + // Add Member in Path.Memb case but not for Path.Path. + addMember = !string.Equals(path, member, StringComparison.Ordinal); + } + else if (path.Length > member.Length) + { + // Finally, check whether Path already ends with Member. + if (member[0] == '[') { - // Some nonempty inputs might deserialize as null, for example whitespace, - // or the JSON-encoded value "null". The upstream BodyModelBinder needs to - // be notified that we don't regard this as a real input so it can register - // a model binding error. - return InputFormatterResult.NoValue(); + addMember = !path.EndsWith(member, StringComparison.Ordinal); } else { - return InputFormatterResult.Success(model); + addMember = !path.EndsWith("." + member, StringComparison.Ordinal); } } - - if (!(exception is JsonException || exception is OverflowException)) - { - var exceptionDispatchInfo = ExceptionDispatchInfo.Capture(exception); - exceptionDispatchInfo.Throw(); - } - - return InputFormatterResult.Failure(); } + + if (addMember) + { + path = ModelNames.CreatePropertyModelName(path, member); + } + + // Handle path combinations such as ""+"Property", "Parent"+"Property", or "Parent"+"[12]". + var key = ModelNames.CreatePropertyModelName(context.ModelName, path); + + exception = eventArgs.ErrorContext.Error; + + var metadata = GetPathMetadata(context.Metadata, path); + var modelStateException = WrapExceptionForModelState(exception); + context.ModelState.TryAddModelError(key, modelStateException, metadata); + + _logger.JsonInputException(exception); + + // Error must always be marked as handled + // Failure to do so can cause the exception to be rethrown at every recursive level and + // overflow the stack for x64 CLR processes + eventArgs.ErrorContext.Handled = true; } } diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonInputFormatterTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonInputFormatterTest.cs index fa1760a65f..f9d79a6f6d 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonInputFormatterTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonInputFormatterTest.cs @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var contentBytes = Encoding.UTF8.GetBytes(content); var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: false); httpContext.Request.ContentType = "application/json"; var formatterContext = CreateInputFormatterContext(typeof(User), httpContext); @@ -54,18 +54,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var userModel = Assert.IsType(result.Model); Assert.Equal("Person Name", userModel.Name); Assert.Equal(30, userModel.Age); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(formatterContext); - - // Assert - Assert.False(result.HasError); - - userModel = Assert.IsType(result.Model); - Assert.Equal("Person Name", userModel.Name); - Assert.Equal(30, userModel.Age); } [Fact] @@ -102,13 +90,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var userModel = Assert.IsType(result.Model); Assert.Equal("Person Name", userModel.Name); Assert.Equal(30, userModel.Age); - - Assert.False(httpContext.Request.Body.CanSeek); - result = await formatter.ReadAsync(formatterContext); - - // Assert - Assert.False(result.HasError); - Assert.Null(result.Model); } [Fact] diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonPatchInputFormatterTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonPatchInputFormatterTest.cs index c451729acc..f7d8ee32b7 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonPatchInputFormatterTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonPatchInputFormatterTest.cs @@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var httpContext = new DefaultHttpContext(); httpContext.Features.Set(new TestResponseFeature()); - httpContext.Request.Body = new NonSeekableReadStream(contentBytes); + httpContext.Request.Body = new NonSeekableReadStream(contentBytes, allowSyncReads: false); httpContext.Request.ContentType = "application/json"; var formatterContext = CreateInputFormatterContext(typeof(JsonPatchDocument), httpContext); @@ -55,18 +55,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters Assert.Equal("add", patchDocument.Operations[0].op); Assert.Equal("Customer/Name", patchDocument.Operations[0].path); Assert.Equal("John", patchDocument.Operations[0].value); - - Assert.True(httpContext.Request.Body.CanSeek); - httpContext.Request.Body.Seek(0L, SeekOrigin.Begin); - - result = await formatter.ReadAsync(formatterContext); - - // Assert - Assert.False(result.HasError); - patchDocument = Assert.IsType>(result.Model); - Assert.Equal("add", patchDocument.Operations[0].op); - Assert.Equal("Customer/Name", patchDocument.Operations[0].path); - Assert.Equal("John", patchDocument.Operations[0].value); } [Fact] diff --git a/src/Mvc/shared/Mvc.Core.TestCommon/NonSeekableReadableStream.cs b/src/Mvc/shared/Mvc.Core.TestCommon/NonSeekableReadableStream.cs index b5ccd405ac..62be4fc02b 100644 --- a/src/Mvc/shared/Mvc.Core.TestCommon/NonSeekableReadableStream.cs +++ b/src/Mvc/shared/Mvc.Core.TestCommon/NonSeekableReadableStream.cs @@ -11,15 +11,17 @@ namespace Microsoft.AspNetCore.Mvc public class NonSeekableReadStream : Stream { private Stream _inner; + private readonly bool _allowSyncReads; - public NonSeekableReadStream(byte[] data) - : this(new MemoryStream(data)) + public NonSeekableReadStream(byte[] data, bool allowSyncReads = true) + : this(new MemoryStream(data), allowSyncReads) { } - public NonSeekableReadStream(Stream inner) + public NonSeekableReadStream(Stream inner, bool allowSyncReads) { _inner = inner; + _allowSyncReads = allowSyncReads; } public override bool CanRead => _inner.CanRead; @@ -61,6 +63,11 @@ namespace Microsoft.AspNetCore.Mvc public override int Read(byte[] buffer, int offset, int count) { + if (!_allowSyncReads) + { + throw new InvalidOperationException("Cannot perform synchronous reads"); + } + count = Math.Max(count, 1); return _inner.Read(buffer, offset, count); }