diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/SessionStateTempDataProvider.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/SessionStateTempDataProvider.cs index 32ad4e75b8..3146779fea 100644 --- a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/SessionStateTempDataProvider.cs +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/ViewFeatures/SessionStateTempDataProvider.cs @@ -21,7 +21,9 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures /// public class SessionStateTempDataProvider : ITempDataProvider { - private const string TempDataSessionStateKey = "__ControllerTempData"; + // Internal for testing + internal const string TempDataSessionStateKey = "__ControllerTempData"; + private readonly JsonSerializer _jsonSerializer = JsonSerializer.Create( new JsonSerializerSettings() { @@ -61,15 +63,22 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures // Accessing Session property will throw if the session middleware is not enabled. var session = context.Session; - var tempDataDictionary = new Dictionary(StringComparer.OrdinalIgnoreCase); + Dictionary tempDataDictionary = null; byte[] value; if (session.TryGetValue(TempDataSessionStateKey, out value)) { + // If we got it from Session, remove it so that no other request gets it + session.Remove(TempDataSessionStateKey); + using (var memoryStream = new MemoryStream(value)) using (var writer = new BsonReader(memoryStream)) { tempDataDictionary = _jsonSerializer.Deserialize>(writer); + if (tempDataDictionary == null) + { + return new Dictionary(StringComparer.OrdinalIgnoreCase); + } } var convertedDictionary = new Dictionary( @@ -142,9 +151,6 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures } tempDataDictionary = convertedDictionary; - - // If we got it from Session, remove it so that no other request gets it - session.Remove(TempDataSessionStateKey); } else { @@ -153,7 +159,7 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures session.Set(TempDataSessionStateKey, new byte[] { }); } - return tempDataDictionary; + return tempDataDictionary ?? new Dictionary(StringComparer.OrdinalIgnoreCase); } /// @@ -172,8 +178,11 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures { foreach (var item in values.Values) { - // We want to allow only simple types to be serialized in session. - EnsureObjectCanBeSerialized(item); + if (item != null) + { + // We want to allow only simple types to be serialized in session. + EnsureObjectCanBeSerialized(item); + } } using (var memoryStream = new MemoryStream()) diff --git a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/SessionStateTempDataProviderTest.cs b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/SessionStateTempDataProviderTest.cs index 34c5520d34..ba8f46a4d6 100644 --- a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/SessionStateTempDataProviderTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/ViewFeatures/SessionStateTempDataProviderTest.cs @@ -54,6 +54,21 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures Assert.Empty(tempDataDictionary); } + [Fact] + public void Load_ReturnsEmptyDictionary_WhenSessionDataIsEmpty() + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + var httpContext = GetHttpContext(); + httpContext.Session.Set(SessionStateTempDataProvider.TempDataSessionStateKey, new byte[] { }); + + // Act + var tempDataDictionary = testProvider.LoadTempData(httpContext); + + // Assert + Assert.Empty(tempDataDictionary); + } + public static TheoryData InvalidTypes { get @@ -369,6 +384,26 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures Assert.Null(emptyDictionary); } + [Fact] + public void SaveAndLoad_NullValue_RoundTripsSuccessfully() + { + // Arrange + var testProvider = new SessionStateTempDataProvider(); + var input = new Dictionary + { + { "NullKey", null } + }; + var context = GetHttpContext(); + + // Act + testProvider.SaveTempData(context, input); + var TempData = testProvider.LoadTempData(context); + + // Assert + Assert.True(TempData.ContainsKey("NullKey")); + Assert.Null(TempData["NullKey"]); + } + private class TestItem { public int DummyInt { get; set; }