diff --git a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs index d29ba1e1b2..d3a332ef48 100644 --- a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs +++ b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs @@ -4,6 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; +using System.Reflection; namespace Microsoft.AspNet.Mvc.ModelBinding.Validation { @@ -36,6 +37,9 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation /// public class DefaultCollectionValidationStrategy : IValidationStrategy { + private static readonly MethodInfo _getEnumerator = typeof(DefaultCollectionValidationStrategy) + .GetMethod(nameof(GetEnumerator), BindingFlags.Static | BindingFlags.NonPublic); + /// /// Gets an instance of . /// @@ -51,14 +55,26 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation string key, object model) { - return new Enumerator(metadata.ElementMetadata, key, (IEnumerable)model); + var enumerator = GetEnumeratorForElementType(metadata, model); + return new Enumerator(metadata.ElementMetadata, key, enumerator); + } + + public static IEnumerator GetEnumeratorForElementType(ModelMetadata metadata, object model) + { + var getEnumeratorMethod = _getEnumerator.MakeGenericMethod(metadata.ElementType); + return (IEnumerator)getEnumeratorMethod.Invoke(null, new object[] { model }); + } + + // Called via reflection. + private static IEnumerator GetEnumerator(object model) + { + return (model as IEnumerable)?.GetEnumerator() ?? ((IEnumerable)model).GetEnumerator(); } private class Enumerator : IEnumerator { private readonly string _key; private readonly ModelMetadata _metadata; - private readonly IEnumerable _model; private readonly IEnumerator _enumerator; private ValidationEntry _entry; @@ -67,13 +83,11 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation public Enumerator( ModelMetadata metadata, string key, - IEnumerable model) + IEnumerator enumerator) { _metadata = metadata; _key = key; - _model = model; - - _enumerator = _model.GetEnumerator(); + _enumerator = enumerator; _index = -1; } @@ -116,7 +130,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation public void Reset() { - throw new NotImplementedException(); + _enumerator.Reset(); } } } diff --git a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs index 0222674e74..220a75a71a 100644 --- a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs +++ b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs @@ -54,7 +54,8 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation string key, object model) { - return new Enumerator(metadata.ElementMetadata, key, ElementKeys, (IEnumerable)model); + var enumerator = DefaultCollectionValidationStrategy.GetEnumeratorForElementType(metadata, model); + return new Enumerator(metadata.ElementMetadata, key, ElementKeys, enumerator); } private class Enumerator : IEnumerator @@ -70,13 +71,13 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation ModelMetadata metadata, string key, IEnumerable elementKeys, - IEnumerable model) + IEnumerator enumerator) { _metadata = metadata; _key = key; _keyEnumerator = elementKeys.GetEnumerator(); - _enumerator = model.GetEnumerator(); + _enumerator = enumerator; } public ValidationEntry Current diff --git a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs index ff67a16152..ab388f0a3b 100644 --- a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs +++ b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using Xunit; @@ -84,6 +86,62 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation }); } + [Fact] + public void EnumerateElements_TwoEnumerableImplemenations() + { + // Arrange + var model = new TwiceEnumerable(new int[] { 2, 3, 5 }); + + var metadata = TestModelMetadataProvider.CreateDefaultProvider().GetMetadataForType(typeof(TwiceEnumerable)); + var strategy = DefaultCollectionValidationStrategy.Instance; + + // Act + var enumerator = strategy.GetChildren(metadata, "prefix", model); + + // Assert + Assert.Collection( + BufferEntries(enumerator).OrderBy(e => e.Key), + e => + { + Assert.Equal("prefix[0]", e.Key); + Assert.Equal(2, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }, + e => + { + Assert.Equal("prefix[1]", e.Key); + Assert.Equal(3, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }, + e => + { + Assert.Equal("prefix[2]", e.Key); + Assert.Equal(5, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }); + } + + // 'int' is chosen by validation because it's declared on the more derived type. + private class TwiceEnumerable : List, IEnumerable + { + private readonly IEnumerable _enumerable; + + public TwiceEnumerable(IEnumerable enumerable) + { + _enumerable = enumerable; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _enumerable.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new InvalidOperationException(); + } + } + private List BufferEntries(IEnumerator enumerator) { var entries = new List(); diff --git a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs index 6dd7654c7a..4ea66e4ad5 100644 --- a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs +++ b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections; using System.Collections.Generic; using System.Linq; using Xunit; @@ -85,6 +87,41 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation }); } + [Fact] + public void EnumerateElements_TwoEnumerableImplemenations() + { + // Arrange + var model = new TwiceEnumerable(new int[] { 2, 3, 5 }); + + var metadata = TestModelMetadataProvider.CreateDefaultProvider().GetMetadataForType(typeof(TwiceEnumerable)); + var strategy = new ExplicitIndexCollectionValidationStrategy(new string[] { "zero", "one", "two" }); + + // Act + var enumerator = strategy.GetChildren(metadata, "prefix", model); + + // Assert + Assert.Collection( + BufferEntries(enumerator).OrderBy(e => e.Key), + e => + { + Assert.Equal("prefix[one]", e.Key); + Assert.Equal(3, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }, + e => + { + Assert.Equal("prefix[two]", e.Key); + Assert.Equal(5, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }, + e => + { + Assert.Equal("prefix[zero]", e.Key); + Assert.Equal(2, e.Model); + Assert.Same(metadata.ElementMetadata, e.Metadata); + }); + } + [Fact] public void EnumerateElements_RunOutOfIndices() { @@ -145,6 +182,27 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation }); } + // 'int' is chosen by validation because it's declared on the more derived type. + private class TwiceEnumerable : List, IEnumerable + { + private readonly IEnumerable _enumerable; + + public TwiceEnumerable(IEnumerable enumerable) + { + _enumerable = enumerable; + } + + IEnumerator IEnumerable.GetEnumerator() + { + return _enumerable.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + throw new InvalidOperationException(); + } + } + private List BufferEntries(IEnumerator enumerator) { var entries = new List();