diff --git a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs
index d29ba1e1b2..d3a332ef48 100644
--- a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs
+++ b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/DefaultCollectionValidationStrategy.cs
@@ -4,6 +4,7 @@
using System;
using System.Collections;
using System.Collections.Generic;
+using System.Reflection;
namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
{
@@ -36,6 +37,9 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
///
public class DefaultCollectionValidationStrategy : IValidationStrategy
{
+ private static readonly MethodInfo _getEnumerator = typeof(DefaultCollectionValidationStrategy)
+ .GetMethod(nameof(GetEnumerator), BindingFlags.Static | BindingFlags.NonPublic);
+
///
/// Gets an instance of .
///
@@ -51,14 +55,26 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
string key,
object model)
{
- return new Enumerator(metadata.ElementMetadata, key, (IEnumerable)model);
+ var enumerator = GetEnumeratorForElementType(metadata, model);
+ return new Enumerator(metadata.ElementMetadata, key, enumerator);
+ }
+
+ public static IEnumerator GetEnumeratorForElementType(ModelMetadata metadata, object model)
+ {
+ var getEnumeratorMethod = _getEnumerator.MakeGenericMethod(metadata.ElementType);
+ return (IEnumerator)getEnumeratorMethod.Invoke(null, new object[] { model });
+ }
+
+ // Called via reflection.
+ private static IEnumerator GetEnumerator(object model)
+ {
+ return (model as IEnumerable)?.GetEnumerator() ?? ((IEnumerable)model).GetEnumerator();
}
private class Enumerator : IEnumerator
{
private readonly string _key;
private readonly ModelMetadata _metadata;
- private readonly IEnumerable _model;
private readonly IEnumerator _enumerator;
private ValidationEntry _entry;
@@ -67,13 +83,11 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
public Enumerator(
ModelMetadata metadata,
string key,
- IEnumerable model)
+ IEnumerator enumerator)
{
_metadata = metadata;
_key = key;
- _model = model;
-
- _enumerator = _model.GetEnumerator();
+ _enumerator = enumerator;
_index = -1;
}
@@ -116,7 +130,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
public void Reset()
{
- throw new NotImplementedException();
+ _enumerator.Reset();
}
}
}
diff --git a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs
index 0222674e74..220a75a71a 100644
--- a/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs
+++ b/src/Microsoft.AspNet.Mvc.Core/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategy.cs
@@ -54,7 +54,8 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
string key,
object model)
{
- return new Enumerator(metadata.ElementMetadata, key, ElementKeys, (IEnumerable)model);
+ var enumerator = DefaultCollectionValidationStrategy.GetEnumeratorForElementType(metadata, model);
+ return new Enumerator(metadata.ElementMetadata, key, ElementKeys, enumerator);
}
private class Enumerator : IEnumerator
@@ -70,13 +71,13 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
ModelMetadata metadata,
string key,
IEnumerable elementKeys,
- IEnumerable model)
+ IEnumerator enumerator)
{
_metadata = metadata;
_key = key;
_keyEnumerator = elementKeys.GetEnumerator();
- _enumerator = model.GetEnumerator();
+ _enumerator = enumerator;
}
public ValidationEntry Current
diff --git a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs
index ff67a16152..ab388f0a3b 100644
--- a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs
+++ b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/DefaultCollectionValidationStrategyTest.cs
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Xunit;
@@ -84,6 +86,62 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
});
}
+ [Fact]
+ public void EnumerateElements_TwoEnumerableImplemenations()
+ {
+ // Arrange
+ var model = new TwiceEnumerable(new int[] { 2, 3, 5 });
+
+ var metadata = TestModelMetadataProvider.CreateDefaultProvider().GetMetadataForType(typeof(TwiceEnumerable));
+ var strategy = DefaultCollectionValidationStrategy.Instance;
+
+ // Act
+ var enumerator = strategy.GetChildren(metadata, "prefix", model);
+
+ // Assert
+ Assert.Collection(
+ BufferEntries(enumerator).OrderBy(e => e.Key),
+ e =>
+ {
+ Assert.Equal("prefix[0]", e.Key);
+ Assert.Equal(2, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ },
+ e =>
+ {
+ Assert.Equal("prefix[1]", e.Key);
+ Assert.Equal(3, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ },
+ e =>
+ {
+ Assert.Equal("prefix[2]", e.Key);
+ Assert.Equal(5, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ });
+ }
+
+ // 'int' is chosen by validation because it's declared on the more derived type.
+ private class TwiceEnumerable : List, IEnumerable
+ {
+ private readonly IEnumerable _enumerable;
+
+ public TwiceEnumerable(IEnumerable enumerable)
+ {
+ _enumerable = enumerable;
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return _enumerable.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ throw new InvalidOperationException();
+ }
+ }
+
private List BufferEntries(IEnumerator enumerator)
{
var entries = new List();
diff --git a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs
index 6dd7654c7a..4ea66e4ad5 100644
--- a/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs
+++ b/test/Microsoft.AspNet.Mvc.Core.Test/ModelBinding/Validation/ExplicitIndexCollectionValidationStrategyTest.cs
@@ -1,6 +1,8 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+using System;
+using System.Collections;
using System.Collections.Generic;
using System.Linq;
using Xunit;
@@ -85,6 +87,41 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
});
}
+ [Fact]
+ public void EnumerateElements_TwoEnumerableImplemenations()
+ {
+ // Arrange
+ var model = new TwiceEnumerable(new int[] { 2, 3, 5 });
+
+ var metadata = TestModelMetadataProvider.CreateDefaultProvider().GetMetadataForType(typeof(TwiceEnumerable));
+ var strategy = new ExplicitIndexCollectionValidationStrategy(new string[] { "zero", "one", "two" });
+
+ // Act
+ var enumerator = strategy.GetChildren(metadata, "prefix", model);
+
+ // Assert
+ Assert.Collection(
+ BufferEntries(enumerator).OrderBy(e => e.Key),
+ e =>
+ {
+ Assert.Equal("prefix[one]", e.Key);
+ Assert.Equal(3, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ },
+ e =>
+ {
+ Assert.Equal("prefix[two]", e.Key);
+ Assert.Equal(5, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ },
+ e =>
+ {
+ Assert.Equal("prefix[zero]", e.Key);
+ Assert.Equal(2, e.Model);
+ Assert.Same(metadata.ElementMetadata, e.Metadata);
+ });
+ }
+
[Fact]
public void EnumerateElements_RunOutOfIndices()
{
@@ -145,6 +182,27 @@ namespace Microsoft.AspNet.Mvc.ModelBinding.Validation
});
}
+ // 'int' is chosen by validation because it's declared on the more derived type.
+ private class TwiceEnumerable : List, IEnumerable
+ {
+ private readonly IEnumerable _enumerable;
+
+ public TwiceEnumerable(IEnumerable enumerable)
+ {
+ _enumerable = enumerable;
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ return _enumerable.GetEnumerator();
+ }
+
+ IEnumerator IEnumerable.GetEnumerator()
+ {
+ throw new InvalidOperationException();
+ }
+ }
+
private List BufferEntries(IEnumerator enumerator)
{
var entries = new List();