diff --git a/src/Microsoft.AspNet.Mvc.Abstractions/ModelBinding/ModelStateDictionary.cs b/src/Microsoft.AspNet.Mvc.Abstractions/ModelBinding/ModelStateDictionary.cs index 6ed9d61dce..49ed8d4e4b 100644 --- a/src/Microsoft.AspNet.Mvc.Abstractions/ModelBinding/ModelStateDictionary.cs +++ b/src/Microsoft.AspNet.Mvc.Abstractions/ModelBinding/ModelStateDictionary.cs @@ -4,9 +4,7 @@ using System; using System.Collections; using System.Collections.Generic; -using System.Linq; using Microsoft.AspNet.Mvc.Abstractions; -using Microsoft.Extensions.Internal; namespace Microsoft.AspNet.Mvc.ModelBinding { @@ -22,7 +20,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding /// public static readonly int DefaultMaxAllowedErrors = 200; - private readonly IDictionary _innerDictionary; + private readonly Dictionary _innerDictionary; private int _maxAllowedErrors; /// @@ -55,7 +53,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding throw new ArgumentNullException(nameof(dictionary)); } - _innerDictionary = new CopyOnWriteDictionary( + _innerDictionary = new Dictionary( dictionary, StringComparer.OrdinalIgnoreCase); @@ -124,7 +122,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding /// public bool IsReadOnly { - get { return _innerDictionary.IsReadOnly; } + get { return ((ICollection>)_innerDictionary).IsReadOnly; } } /// @@ -153,7 +151,11 @@ namespace Microsoft.AspNet.Mvc.ModelBinding /// public ModelValidationState ValidationState { - get { return GetValidity(_innerDictionary); } + get + { + var entries = FindKeysWithPrefix(string.Empty); + return GetValidity(entries, defaultState: ModelValidationState.Valid); + } } /// @@ -344,12 +346,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding } var entries = FindKeysWithPrefix(key); - if (!entries.Any()) - { - return ModelValidationState.Unvalidated; - } - - return GetValidity(entries); + return GetValidity(entries, defaultState: ModelValidationState.Unvalidated); } /// @@ -496,9 +493,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding { // If key is null or empty, clear all entries in the dictionary // else just clear the ones that have key as prefix - var entries = (string.IsNullOrEmpty(key)) ? - _innerDictionary : FindKeysWithPrefix(key); - + var entries = FindKeysWithPrefix(key ?? string.Empty); foreach (var entry in entries) { entry.Value.Errors.Clear(); @@ -523,11 +518,16 @@ namespace Microsoft.AspNet.Mvc.ModelBinding return modelState; } - private static ModelValidationState GetValidity(IEnumerable> entries) + private static ModelValidationState GetValidity(PrefixEnumerable entries, ModelValidationState defaultState) { + + var hasEntries = false; var validationState = ModelValidationState.Valid; + foreach (var entry in entries) { + hasEntries = true; + var entryState = entry.Value.ValidationState; if (entryState == ModelValidationState.Unvalidated) { @@ -539,7 +539,8 @@ namespace Microsoft.AspNet.Mvc.ModelBinding validationState = entryState; } } - return validationState; + + return hasEntries ? validationState : defaultState; } private void EnsureMaxErrorsReachedRecorded() @@ -591,7 +592,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding /// public bool Contains(KeyValuePair item) { - return _innerDictionary.Contains(item); + return ((ICollection>)_innerDictionary).Contains(item); } /// @@ -613,13 +614,13 @@ namespace Microsoft.AspNet.Mvc.ModelBinding throw new ArgumentNullException(nameof(array)); } - _innerDictionary.CopyTo(array, arrayIndex); + ((ICollection>)_innerDictionary).CopyTo(array, arrayIndex); } /// public bool Remove(KeyValuePair item) { - return _innerDictionary.Remove(item); + return ((ICollection>)_innerDictionary).Remove(item); } /// @@ -656,70 +657,204 @@ namespace Microsoft.AspNet.Mvc.ModelBinding return GetEnumerator(); } - public IEnumerable> FindKeysWithPrefix(string prefix) + public static bool StartsWithPrefix(string prefix, string key) { if (prefix == null) { throw new ArgumentNullException(nameof(prefix)); } - ModelState exactMatchValue; - if (_innerDictionary.TryGetValue(prefix, out exactMatchValue)) + if (key == null) { - yield return new KeyValuePair(prefix, exactMatchValue); + throw new ArgumentNullException(nameof(key)); } - foreach (var entry in _innerDictionary) + if (StringComparer.OrdinalIgnoreCase.Equals(key, prefix)) { - var key = entry.Key; + return true; + } - if (key.Length <= prefix.Length) + if (key.Length <= prefix.Length) + { + return false; + } + + if (!key.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) + { + if (key.StartsWith("[", StringComparison.OrdinalIgnoreCase)) { - continue; - } + var subKey = key.Substring(key.IndexOf('.') + 1); - if (!key.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) - { - - if (key.StartsWith("[", StringComparison.OrdinalIgnoreCase)) + if (!subKey.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) { - var subKey = key.Substring(key.IndexOf('.') + 1); - - if (!subKey.StartsWith(prefix, StringComparison.OrdinalIgnoreCase)) - { - continue; - } - - if (string.Equals(prefix, subKey, StringComparison.Ordinal)) - { - yield return entry; - continue; - } - - key = subKey; + return false; } - else + + if (string.Equals(prefix, subKey, StringComparison.OrdinalIgnoreCase)) { - continue; + return true; } - } - // Everything is prefixed by the empty string - if (prefix.Length == 0) - { - yield return entry; + key = subKey; } else { - var charAfterPrefix = key[prefix.Length]; - switch (charAfterPrefix) + return false; + } + } + + // Everything is prefixed by the empty string + if (prefix.Length == 0) + { + return true; + } + else + { + var charAfterPrefix = key[prefix.Length]; + switch (charAfterPrefix) + { + case '[': + case '.': + return true; + } + } + + return false; + } + + public PrefixEnumerable FindKeysWithPrefix(string prefix) + { + if (prefix == null) + { + throw new ArgumentNullException(nameof(prefix)); + } + + return new PrefixEnumerable(this, prefix); + } + + public struct PrefixEnumerable : IEnumerable> + { + private readonly ModelStateDictionary _dictionary; + private readonly string _prefix; + + public PrefixEnumerable(ModelStateDictionary dictionary, string prefix) + { + if (dictionary == null) + { + throw new ArgumentNullException(nameof(dictionary)); + } + + if (prefix == null) + { + throw new ArgumentNullException(nameof(prefix)); + } + + _dictionary = dictionary; + _prefix = prefix; + } + + public PrefixEnumerator GetEnumerator() + { + return _dictionary == null ? new PrefixEnumerator() : new PrefixEnumerator(_dictionary, _prefix); + } + + IEnumerator> IEnumerable>.GetEnumerator() + { + return GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } + + public struct PrefixEnumerator : IEnumerator> + { + private readonly ModelStateDictionary _dictionary; + private readonly string _prefix; + + private bool _exactMatchUsed; + private Dictionary.Enumerator _enumerator; + + public PrefixEnumerator(ModelStateDictionary dictionary, string prefix) + { + if (dictionary == null) + { + throw new ArgumentNullException(nameof(dictionary)); + } + + if (prefix == null) + { + throw new ArgumentNullException(nameof(prefix)); + } + + _dictionary = dictionary; + _prefix = prefix; + + _exactMatchUsed = false; + _enumerator = default(Dictionary.Enumerator); + Current = default(KeyValuePair); + } + + public KeyValuePair Current { get; private set; } + + object IEnumerator.Current + { + get + { + return Current; + } + } + + public void Dispose() + { + } + + public bool MoveNext() + { + if (_dictionary == null) + { + return false; + } + + // ModelStateDictionary has a behavior where the first 'match' returned from iterating + // prefixes is the exact match for the prefix (if present). Only after looking for an + // exact match do we fall back to iteration to find 'starts-with' matches. + if (!_exactMatchUsed) + { + _exactMatchUsed = true; + _enumerator = _dictionary._innerDictionary.GetEnumerator(); + + ModelState entry; + if (_dictionary.TryGetValue(_prefix, out entry)) { - case '[': - case '.': - yield return entry; - break; + Current = new KeyValuePair(_prefix, entry); + return true; } } + + while (_enumerator.MoveNext()) + { + if (string.Equals(_prefix, _enumerator.Current.Key, StringComparison.OrdinalIgnoreCase)) + { + // Skip this one. We've already handle the 'exact match' case. + } + else if (StartsWithPrefix(_prefix, _enumerator.Current.Key)) + { + Current = _enumerator.Current; + return true; + } + } + + return false; + } + + public void Reset() + { + _exactMatchUsed = false; + _enumerator = default(Dictionary.Enumerator); + Current = default(KeyValuePair); } } } diff --git a/src/Microsoft.AspNet.Mvc.ViewFeatures/ViewFeatures/ViewDataDictionary.cs b/src/Microsoft.AspNet.Mvc.ViewFeatures/ViewFeatures/ViewDataDictionary.cs index 31bfaea269..b2059cee17 100644 --- a/src/Microsoft.AspNet.Mvc.ViewFeatures/ViewFeatures/ViewDataDictionary.cs +++ b/src/Microsoft.AspNet.Mvc.ViewFeatures/ViewFeatures/ViewDataDictionary.cs @@ -227,7 +227,7 @@ namespace Microsoft.AspNet.Mvc.ViewFeatures /// protected ViewDataDictionary(ViewDataDictionary source, object model, Type declaredModelType) : this(source._metadataProvider, - new ModelStateDictionary(source.ModelState), + source.ModelState, declaredModelType, data: new CopyOnWriteDictionary(source, StringComparer.OrdinalIgnoreCase), templateInfo: new TemplateInfo(source.TemplateInfo)) diff --git a/test/Microsoft.AspNet.Mvc.Abstractions.Test/ModelBinding/ModelStateDictionaryTest.cs b/test/Microsoft.AspNet.Mvc.Abstractions.Test/ModelBinding/ModelStateDictionaryTest.cs index 500d5bd61f..e376a1f9f2 100644 --- a/test/Microsoft.AspNet.Mvc.Abstractions.Test/ModelBinding/ModelStateDictionaryTest.cs +++ b/test/Microsoft.AspNet.Mvc.Abstractions.Test/ModelBinding/ModelStateDictionaryTest.cs @@ -2,8 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Globalization; -using Microsoft.Extensions.Internal; +using System.Collections.Generic; using Xunit; namespace Microsoft.AspNet.Mvc.ModelBinding @@ -156,7 +155,7 @@ namespace Microsoft.AspNet.Mvc.ModelBinding Assert.Equal(0, target.ErrorCount); Assert.Equal(1, target.Count); Assert.Same(modelState, target["key"]); - Assert.IsType>(target.InnerDictionary); + Assert.IsType>(target.InnerDictionary); } [Fact]