// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Linq; using System.Reflection; using Microsoft.AspNet.Http; using Microsoft.AspNet.Mvc.Core; using Microsoft.Framework.Internal; using Newtonsoft.Json; using Newtonsoft.Json.Bson; using Newtonsoft.Json.Linq; namespace Microsoft.AspNet.Mvc { /// /// Provides session-state data to the current object. /// public class SessionStateTempDataProvider : ITempDataProvider { private const string TempDataSessionStateKey = "__ControllerTempData"; private readonly JsonSerializer _jsonSerializer = JsonSerializer.Create( new JsonSerializerSettings() { TypeNameHandling = TypeNameHandling.None }); private static readonly MethodInfo _convertArrayMethodInfo = typeof(SessionStateTempDataProvider).GetMethod( nameof(ConvertArray), BindingFlags.Static | BindingFlags.NonPublic); private static readonly MethodInfo _convertDictMethodInfo = typeof(SessionStateTempDataProvider).GetMethod( nameof(ConvertDictionary), BindingFlags.Static | BindingFlags.NonPublic); private static readonly ConcurrentDictionary> _arrayConverters = new ConcurrentDictionary>(); private static readonly ConcurrentDictionary> _dictionaryConverters = new ConcurrentDictionary>(); private static readonly Dictionary _tokenTypeLookup = new Dictionary { { JTokenType.String, typeof(string) }, { JTokenType.Integer, typeof(int) }, { JTokenType.Boolean, typeof(bool) }, { JTokenType.Float, typeof(float) }, { JTokenType.Guid, typeof(Guid) }, { JTokenType.Date, typeof(DateTime) }, { JTokenType.TimeSpan, typeof(TimeSpan) }, { JTokenType.Uri, typeof(Uri) }, }; /// public virtual IDictionary LoadTempData([NotNull] HttpContext context) { if (!IsSessionEnabled(context)) { // Session middleware is not enabled. No-op return null; } var tempDataDictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); var session = context.Session; byte[] value; if (session != null && session.TryGetValue(TempDataSessionStateKey, out value)) { using (var memoryStream = new MemoryStream(value)) using (var writer = new BsonReader(memoryStream)) { tempDataDictionary = _jsonSerializer.Deserialize>(writer); } var convertedDictionary = new Dictionary(tempDataDictionary, StringComparer.OrdinalIgnoreCase); foreach (var item in tempDataDictionary) { var jArrayValue = item.Value as JArray; if (jArrayValue != null && jArrayValue.Count > 0) { var arrayType = jArrayValue[0].Type; Type returnType; if (_tokenTypeLookup.TryGetValue(arrayType, out returnType)) { var arrayConverter = _arrayConverters.GetOrAdd(returnType, type => { return (Func)_convertArrayMethodInfo.MakeGenericMethod(type).CreateDelegate(typeof(Func)); }); var result = arrayConverter(jArrayValue); convertedDictionary[item.Key] = result; } else { var message = Resources.FormatTempData_CannotDeserializeToken(nameof(JToken), arrayType); throw new InvalidOperationException(message); } } else { var jObjectValue = item.Value as JObject; if (jObjectValue == null) { continue; } else if (!jObjectValue.HasValues) { convertedDictionary[item.Key] = null; continue; } var jTokenType = jObjectValue.Properties().First().Value.Type; Type valueType; if (_tokenTypeLookup.TryGetValue(jTokenType, out valueType)) { var dictionaryConverter = _dictionaryConverters.GetOrAdd(valueType, type => { return (Func)_convertDictMethodInfo.MakeGenericMethod(type).CreateDelegate(typeof(Func)); }); var result = dictionaryConverter(jObjectValue); convertedDictionary[item.Key] = result; } else { var message = Resources.FormatTempData_CannotDeserializeToken(nameof(JToken), jTokenType); throw new InvalidOperationException(message); } } } tempDataDictionary = convertedDictionary; // If we got it from Session, remove it so that no other request gets it session.Remove(TempDataSessionStateKey); } else { // Since we call Save() after the response has been sent, we need to initialize an empty session // so that it is established before the headers are sent. session[TempDataSessionStateKey] = new byte[] { }; } return tempDataDictionary; } /// public virtual void SaveTempData([NotNull] HttpContext context, IDictionary values) { var hasValues = (values != null && values.Count > 0); if (hasValues) { foreach (var item in values.Values) { // We want to allow only simple types to be serialized in session. EnsureObjectCanBeSerialized(item); } // Accessing Session property will throw if the session middleware is not enabled. var session = context.Session; using (var memoryStream = new MemoryStream()) using (var writer = new BsonWriter(memoryStream)) { _jsonSerializer.Serialize(writer, values); session[TempDataSessionStateKey] = memoryStream.ToArray(); } } else if (IsSessionEnabled(context)) { var session = context.Session; session.Remove(TempDataSessionStateKey); } } private static bool IsSessionEnabled(HttpContext context) { return context.GetFeature() != null; } internal void EnsureObjectCanBeSerialized(object item) { var itemType = item.GetType(); Type actualType = null; if (itemType.IsArray) { itemType = itemType.GetElementType(); } else if (itemType.GetTypeInfo().IsGenericType) { if (itemType.ExtractGenericInterface(typeof(IList<>)) != null) { var genericTypeArguments = itemType.GetGenericArguments(); Debug.Assert(genericTypeArguments.Length == 1, "IList has one generic argument"); actualType = genericTypeArguments[0]; } else if (itemType.ExtractGenericInterface(typeof(IDictionary<,>)) != null) { var genericTypeArguments = itemType.GetGenericArguments(); Debug.Assert(genericTypeArguments.Length == 2, "IDictionary has two generic arguments"); // Throw if the key type of the dictionary is not string. if (genericTypeArguments[0] != typeof(string)) { var message = Resources.FormatTempData_CannotSerializeDictionary( typeof(SessionStateTempDataProvider).FullName, genericTypeArguments[0]); throw new InvalidOperationException(message); } else { actualType = genericTypeArguments[1]; } } } actualType = actualType ?? itemType; if (!TypeHelper.IsSimpleType(actualType)) { var underlyingType = Nullable.GetUnderlyingType(actualType) ?? actualType; var message = Resources.FormatTempData_CannotSerializeToSession( typeof(SessionStateTempDataProvider).FullName, underlyingType); throw new InvalidOperationException(message); } } private static IList ConvertArray(JArray array) { return array.Values().ToArray(); } private static IDictionary ConvertDictionary(JObject jObject) { var convertedDictionary = new Dictionary(StringComparer.Ordinal); foreach (var item in jObject) { convertedDictionary.Add(item.Key, jObject.Value(item.Key)); } return convertedDictionary; } } }