diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs index 846e8f19d5..3d4679d89b 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs @@ -5,7 +5,6 @@ using System; using System.Collections; using System.Collections.Concurrent; using System.Collections.Generic; -using System.Diagnostics; using System.Reflection; using System.Threading.Tasks; using Microsoft.AspNetCore.Mvc.Core; @@ -17,8 +16,6 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson namespace Microsoft.AspNetCore.Mvc.Infrastructure #endif { - using ReaderFunc = Func, Task>; - /// /// Type that reads an instance into a /// generic collection instance. @@ -34,8 +31,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure nameof(ReadInternal), BindingFlags.NonPublic | BindingFlags.Instance); - private readonly ConcurrentDictionary _asyncEnumerableConverters = - new ConcurrentDictionary(); + private readonly ConcurrentDictionary>> _asyncEnumerableConverters = + new ConcurrentDictionary>>(); private readonly MvcOptions _mvcOptions; /// @@ -48,37 +45,39 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure } /// - /// Reads a into an . + /// Attempts to produces a delagate that reads a into an . /// - /// The to read. - /// The . - public Task ReadAsync(IAsyncEnumerable value) + /// The type to read. + /// A delegate that when awaited reads the . + /// when is an instance of , othwerise . + public bool TryGetReader(Type type, out Func> reader) { - if (value == null) - { - throw new ArgumentNullException(nameof(value)); - } - - var type = value.GetType(); - if (!_asyncEnumerableConverters.TryGetValue(type, out var result)) + if (!_asyncEnumerableConverters.TryGetValue(type, out reader)) { var enumerableType = ClosedGenericMatcher.ExtractGenericInterface(type, typeof(IAsyncEnumerable<>)); - Debug.Assert(enumerableType != null); + if (enumerableType is null) + { + // Not an IAsyncEnumerable. Cache this result so we avoid reflection the next time we see this type. + reader = null; + _asyncEnumerableConverters.TryAdd(type, reader); + } + else + { + var enumeratedObjectType = enumerableType.GetGenericArguments()[0]; - var enumeratedObjectType = enumerableType.GetGenericArguments()[0]; + var converter = (Func>)Converter + .MakeGenericMethod(enumeratedObjectType) + .CreateDelegate(typeof(Func>), this); - var converter = (ReaderFunc)Converter - .MakeGenericMethod(enumeratedObjectType) - .CreateDelegate(typeof(ReaderFunc), this); - - _asyncEnumerableConverters.TryAdd(type, converter); - result = converter; + reader = converter; + _asyncEnumerableConverters.TryAdd(type, reader); + } } - return result(value); + return reader != null; } - private async Task ReadInternal(IAsyncEnumerable value) + private async Task ReadInternal(object value) { var asyncEnumerable = (IAsyncEnumerable)value; var result = new List(); diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ObjectResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ObjectResultExecutor.cs index 689604162a..0fad09d57e 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ObjectResultExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ObjectResultExecutor.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure /// public class ObjectResultExecutor : IActionResultExecutor { - private readonly AsyncEnumerableReader _asyncEnumerableReader; + private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; /// /// Creates a new . @@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure WriterFactory = writerFactory.CreateWriter; Logger = loggerFactory.CreateLogger(); var options = mvcOptions?.Value ?? throw new ArgumentNullException(nameof(mvcOptions)); - _asyncEnumerableReader = new AsyncEnumerableReader(options); + _asyncEnumerableReaderFactory = new AsyncEnumerableReader(options); } /// @@ -117,19 +118,19 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure var value = result.Value; - if (value is IAsyncEnumerable asyncEnumerable) + if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) { - return ExecuteAsyncEnumerable(context, result, asyncEnumerable); + return ExecuteAsyncEnumerable(context, result, value, reader); } return ExecuteAsyncCore(context, result, objectType, value); } - private async Task ExecuteAsyncEnumerable(ActionContext context, ObjectResult result, IAsyncEnumerable asyncEnumerable) + private async Task ExecuteAsyncEnumerable(ActionContext context, ObjectResult result, object asyncEnumerable, Func> reader) { Log.BufferingAsyncEnumerable(Logger, asyncEnumerable); - var enumerated = await _asyncEnumerableReader.ReadAsync(asyncEnumerable); + var enumerated = await reader(asyncEnumerable); await ExecuteAsyncCore(context, result, enumerated.GetType(), enumerated); } @@ -194,7 +195,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure "Buffering IAsyncEnumerable instance of type '{Type}'."); } - public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable asyncEnumerable) + public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable) => _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null); } } diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs index 176279550c..1a4960ba5d 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure private readonly JsonOptions _options; private readonly ILogger _logger; - private readonly AsyncEnumerableReader _asyncEnumerableReader; + private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; public SystemTextJsonResultExecutor( IOptions options, @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure { _options = options.Value; _logger = logger; - _asyncEnumerableReader = new AsyncEnumerableReader(mvcOptions.Value); + _asyncEnumerableReaderFactory = new AsyncEnumerableReader(mvcOptions.Value); } public async Task ExecuteAsync(ActionContext context, JsonResult result) @@ -76,10 +76,10 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure try { var value = result.Value; - if (value is IAsyncEnumerable asyncEnumerable) + if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) { - Log.BufferingAsyncEnumerable(_logger, asyncEnumerable); - value = await _asyncEnumerableReader.ReadAsync(asyncEnumerable); + Log.BufferingAsyncEnumerable(_logger, value); + value = await reader(value); } var type = value?.GetType() ?? typeof(object); @@ -154,7 +154,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure _jsonResultExecuting(logger, type, null); } - public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable asyncEnumerable) + public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable) => _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null); } } diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs index 27baf232c9..4a3a861ed4 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs @@ -5,46 +5,173 @@ using System; using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; -using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.AspNetCore.Mvc.Infrastructure { public class AsyncEnumerableReaderTest { - [Fact] - public async Task ReadAsync_ReadsIAsyncEnumerable() + [Theory] + [InlineData(typeof(Range))] + [InlineData(typeof(IEnumerable))] + [InlineData(typeof(List))] + public void TryGetReader_ReturnsFalse_IfTypeIsNotIAsyncEnumerable(Type type) { // Arrange var options = new MvcOptions(); - var reader = new AsyncEnumerableReader(options); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable = TestEnumerable(); // Act - var result = await reader.ReadAsync(TestEnumerable()); + var result = readerFactory.TryGetReader(type, out var reader); // Assert - var collection = Assert.IsAssignableFrom>(result); + Assert.False(result); + } + + [Fact] + public async Task TryGetReader_ReturnsReaderForIAsyncEnumerable() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable = TestEnumerable(); + + // Act + var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader); + + // Assert + Assert.True(result); + var readCollection = await reader(asyncEnumerable); + var collection = Assert.IsAssignableFrom>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } [Fact] - public async Task ReadAsync_ReadsIAsyncEnumerable_ImplementingMultipleAsyncEnumerableInterfaces() + public async Task TryGetReader_ReturnsReaderForIAsyncEnumerableOfValueType() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable = PrimitiveEnumerable(); + + // Act + var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader); + + // Assert + Assert.True(result); + var readCollection = await reader(asyncEnumerable); + var collection = Assert.IsAssignableFrom>(readCollection); + Assert.Equal(new[] { 0, 1, 2, }, collection); + } + + [Fact] + public void TryGetReader_ReturnsCachedDelegate() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable1 = TestEnumerable(); + var asyncEnumerable2 = TestEnumerable(); + + // Act + Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader1)); + Assert.True(readerFactory.TryGetReader(asyncEnumerable2.GetType(), out var reader2)); + + // Assert + Assert.Same(reader1, reader2); + } + + [Fact] + public void TryGetReader_ReturnsCachedDelegate_WhenTypeImplementsMultipleIAsyncEnumerableContracts() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable1 = new MultiAsyncEnumerable(); + var asyncEnumerable2 = new MultiAsyncEnumerable(); + + // Act + Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader1)); + Assert.True(readerFactory.TryGetReader(asyncEnumerable2.GetType(), out var reader2)); + + // Assert + Assert.Same(reader1, reader2); + } + + [Fact] + public async Task CachedDelegate_CanReadEnumerableInstanceMultipleTimes() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable1 = TestEnumerable(); + var asyncEnumerable2 = TestEnumerable(); + var expected = new[] { "0", "1", "2" }; + + // Act + Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); + + // Assert + Assert.Equal(expected, await reader(asyncEnumerable1)); + Assert.Equal(expected, await reader(asyncEnumerable2)); + } + + [Fact] + public async Task CachedDelegate_CanReadEnumerableInstanceMultipleTimes_ThatProduceDifferentResults() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable1 = TestEnumerable(); + var asyncEnumerable2 = TestEnumerable(4); + + // Act + Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); + + // Assert + Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1)); + Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2)); + } + + [Fact] + public void TryGetReader_ReturnsDifferentInstancesForDifferentEnumerables() + { + // Arrange + var options = new MvcOptions(); + var readerFactory = new AsyncEnumerableReader(options); + var enumerable1 = TestEnumerable(); + var enumerable2 = TestEnumerable2(); + + // Act + Assert.True(readerFactory.TryGetReader(enumerable1.GetType(), out var reader1)); + Assert.True(readerFactory.TryGetReader(enumerable2.GetType(), out var reader2)); + + // Assert + Assert.NotSame(reader1, reader2); + } + + [Fact] + public async Task Reader_ReadsIAsyncEnumerable_ImplementingMultipleAsyncEnumerableInterfaces() { // This test ensures the reader does not fail if you have a type that implements IAsyncEnumerable for multiple Ts // Arrange var options = new MvcOptions(); - var reader = new AsyncEnumerableReader(options); + var readerFactory = new AsyncEnumerableReader(options); + var asyncEnumerable = new MultiAsyncEnumerable(); // Act - var result = await reader.ReadAsync(new MultiAsyncEnumerable()); + var result = readerFactory.TryGetReader(asyncEnumerable.GetType(), out var reader); // Assert - var collection = Assert.IsAssignableFrom>(result); + Assert.True(result); + var readCollection = await reader(asyncEnumerable); + var collection = Assert.IsAssignableFrom>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } - [Fact] - public async Task ReadAsync_ThrowsIfBufferimitIsReached() + [Fact] + public async Task Reader_ThrowsIfBufferLimitIsReached() { // Arrange var enumerable = TestEnumerable(11); @@ -52,10 +179,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure "This limit is in place to prevent infinite streams of 'IAsyncEnumerable<>' from continuing indefinitely. If this is not a programming mistake, " + $"consider ways to reduce the collection size, or consider manually converting '{enumerable.GetType()}' into a list rather than increasing the limit."; var options = new MvcOptions { MaxIAsyncEnumerableBufferLimit = 10 }; - var reader = new AsyncEnumerableReader(options); + var readerFactory = new AsyncEnumerableReader(options); // Act - var ex = await Assert.ThrowsAsync(() => reader.ReadAsync(enumerable)); + Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); + var ex = await Assert.ThrowsAsync(() => reader(enumerable)); // Assert Assert.Equal(expected, ex.Message); @@ -70,6 +198,22 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure } } + public static async IAsyncEnumerable TestEnumerable2() + { + await Task.Yield(); + yield return "Hello"; + yield return "world"; + } + + public static async IAsyncEnumerable PrimitiveEnumerable(int count = 3) + { + await Task.Yield(); + for (var i = 0; i < count; i++) + { + yield return i; + } + } + public class MultiAsyncEnumerable : IAsyncEnumerable, IAsyncEnumerable { public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs index 7b74959cdc..8d0fa3ef84 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs @@ -311,6 +311,24 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.StartsWith("Property 'JsonResult.SerializerSettings' must be an instance of type", ex.Message); } + [Fact] + public async Task ExecuteAsync_WithNullValue() + { + // Arrange + var expected = Encoding.UTF8.GetBytes("null"); + + var context = GetActionContext(); + var result = new JsonResult(value: null); + var executor = CreateExecutor(); + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + Assert.Equal(expected, written); + } + [Fact] public async Task ExecuteAsync_SerializesAsyncEnumerables() { @@ -329,6 +347,24 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.Equal(expected, written); } + [Fact] + public async Task ExecuteAsync_SerializesAsyncEnumerablesOfPrimtives() + { + // Arrange + var expected = Encoding.UTF8.GetBytes(JsonSerializer.Serialize(new[] { 1, 2 })); + + var context = GetActionContext(); + var result = new JsonResult(TestAsyncPrimitiveEnumerable()); + var executor = CreateExecutor(); + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + Assert.Equal(expected, written); + } + protected IActionResultExecutor CreateExecutor() => CreateExecutor(NullLoggerFactory.Instance); protected abstract IActionResultExecutor CreateExecutor(ILoggerFactory loggerFactory); @@ -380,5 +416,12 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure yield return "Hello"; yield return "world"; } + + private async IAsyncEnumerable TestAsyncPrimitiveEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } } } diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/ObjectResultExecutorTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/ObjectResultExecutorTest.cs index 8ab2d76678..395362f8bf 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/ObjectResultExecutorTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/ObjectResultExecutorTest.cs @@ -361,6 +361,28 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure MediaTypeAssert.Equal(expectedContentType, responseContentType); } + [Fact] + public async Task ObjectResult_NullValue() + { + // Arrange + var executor = CreateExecutor(); + var result = new ObjectResult(value: null); + var formatter = new TestJsonOutputFormatter(); + result.Formatters.Add(formatter); + + var actionContext = new ActionContext() + { + HttpContext = GetHttpContext(), + }; + + // Act + await executor.ExecuteAsync(actionContext, result); + + // Assert + var formatterContext = formatter.LastOutputFormatterContext; + Assert.Null(formatterContext.Object); + } + [Fact] public async Task ObjectResult_ReadsAsyncEnumerables() { diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs index eac7d6400a..a01ef805ae 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs @@ -3,7 +3,6 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Mvc.Formatters; @@ -31,7 +30,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson private readonly MvcOptions _mvcOptions; private readonly MvcNewtonsoftJsonOptions _jsonOptions; private readonly IArrayPool _charPool; - private readonly AsyncEnumerableReader _asyncEnumerableReader; + private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; /// /// Creates a new . @@ -73,7 +72,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson _mvcOptions = mvcOptions?.Value ?? throw new ArgumentNullException(nameof(mvcOptions)); _jsonOptions = jsonOptions.Value; _charPool = new JsonArrayPool(charPool); - _asyncEnumerableReader = new AsyncEnumerableReader(_mvcOptions); + _asyncEnumerableReaderFactory = new AsyncEnumerableReader(_mvcOptions); } /// @@ -133,10 +132,10 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson var jsonSerializer = JsonSerializer.Create(jsonSerializerSettings); var value = result.Value; - if (result.Value is IAsyncEnumerable asyncEnumerable) + if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) { - Log.BufferingAsyncEnumerable(_logger, asyncEnumerable); - value = await _asyncEnumerableReader.ReadAsync(asyncEnumerable); + Log.BufferingAsyncEnumerable(_logger, value); + value = await reader(value); } jsonSerializer.Serialize(jsonWriter, value); @@ -201,7 +200,7 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson _jsonResultExecuting(logger, type, null); } - public static void BufferingAsyncEnumerable(ILogger logger, IAsyncEnumerable asyncEnumerable) + public static void BufferingAsyncEnumerable(ILogger logger, object asyncEnumerable) => _bufferingAsyncEnumerable(logger, asyncEnumerable.GetType().FullName, null); } }