// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections; using System.Collections.Generic; using Microsoft.AspNet.Hosting; using Microsoft.Framework.Internal; namespace Microsoft.AspNet.Mvc { /// public class TempDataDictionary : ITempDataDictionary { private Dictionary _data; private bool _loaded; private readonly ITempDataProvider _provider; private readonly IHttpContextAccessor _contextAccessor; private HashSet _initialKeys = new HashSet(StringComparer.OrdinalIgnoreCase); private HashSet _retainedKeys = new HashSet(StringComparer.OrdinalIgnoreCase); /// /// Initializes a new instance of the class. /// /// The that provides the HttpContext. /// The used to Load and Save data. public TempDataDictionary([NotNull] IHttpContextAccessor context, [NotNull] ITempDataProvider provider) { _provider = provider; _loaded = false; _contextAccessor = context; } public int Count { get { Load(); return _data.Count; } } public ICollection Keys { get { Load(); return _data.Keys; } } public ICollection Values { get { Load(); return _data.Values; } } bool ICollection>.IsReadOnly { get { Load(); return ((ICollection>)_data).IsReadOnly; } } public object this[string key] { get { Load(); object value; if (TryGetValue(key, out value)) { // Mark the key for deletion since it is read. _initialKeys.Remove(key); return value; } return null; } set { Load(); _data[key] = value; _initialKeys.Add(key); } } /// public void Keep() { Load(); _retainedKeys.Clear(); _retainedKeys.UnionWith(_data.Keys); } /// public void Keep(string key) { Load(); _retainedKeys.Add(key); } /// public void Load() { if (_loaded) { return; } var providerDictionary = _provider.LoadTempData(_contextAccessor.HttpContext); _data = (providerDictionary != null) ? new Dictionary(providerDictionary, StringComparer.OrdinalIgnoreCase) : new Dictionary(StringComparer.OrdinalIgnoreCase); _initialKeys = new HashSet(_data.Keys, StringComparer.OrdinalIgnoreCase); _retainedKeys.Clear(); _loaded = true; } /// public void Save() { if (!_loaded) { return; } // Because it is not possible to delete while enumerating, a copy of the keys must be taken. // Use the size of the dictionary as an upper bound to avoid creating more than one copy of the keys. var removeCount = 0; var keys = new string[_data.Count]; foreach (var entry in _data) { if (!_initialKeys.Contains(entry.Key) && !_retainedKeys.Contains(entry.Key)) { keys[removeCount] = entry.Key; removeCount++; } } for (var i = 0; i < removeCount; i++) { _data.Remove(keys[i]); } _provider.SaveTempData(_contextAccessor.HttpContext, _data); } /// public object Peek(string key) { Load(); object value; _data.TryGetValue(key, out value); return value; } public void Add(string key, object value) { Load(); _data.Add(key, value); _initialKeys.Add(key); } public void Clear() { Load(); _data.Clear(); _retainedKeys.Clear(); _initialKeys.Clear(); } public bool ContainsKey(string key) { Load(); return _data.ContainsKey(key); } public bool ContainsValue(object value) { Load(); return _data.ContainsValue(value); } public IEnumerator> GetEnumerator() { Load(); return new TempDataDictionaryEnumerator(this); } public bool Remove(string key) { Load(); _retainedKeys.Remove(key); _initialKeys.Remove(key); return _data.Remove(key); } public bool TryGetValue(string key, out object value) { Load(); // Mark the key for deletion since it is read. _initialKeys.Remove(key); return _data.TryGetValue(key, out value); } void ICollection>.CopyTo(KeyValuePair[] array, int index) { Load(); ((ICollection>)_data).CopyTo(array, index); } void ICollection>.Add(KeyValuePair keyValuePair) { Load(); _initialKeys.Add(keyValuePair.Key); ((ICollection>)_data).Add(keyValuePair); } bool ICollection>.Contains(KeyValuePair keyValuePair) { Load(); return ((ICollection>)_data).Contains(keyValuePair); } bool ICollection>.Remove(KeyValuePair keyValuePair) { Load(); _initialKeys.Remove(keyValuePair.Key); return ((ICollection>)_data).Remove(keyValuePair); } IEnumerator IEnumerable.GetEnumerator() { Load(); return new TempDataDictionaryEnumerator(this); } private sealed class TempDataDictionaryEnumerator : IEnumerator> { private IEnumerator> _enumerator; private TempDataDictionary _tempData; public TempDataDictionaryEnumerator(TempDataDictionary tempData) { _tempData = tempData; _enumerator = _tempData._data.GetEnumerator(); } public KeyValuePair Current { get { var kvp = _enumerator.Current; // Mark the key for deletion since it is read. _tempData._initialKeys.Remove(kvp.Key); return kvp; } } object IEnumerator.Current { get { return Current; } } public bool MoveNext() { return _enumerator.MoveNext(); } public void Reset() { _enumerator.Reset(); } void IDisposable.Dispose() { _enumerator.Dispose(); } } } }