diff --git a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/SaveTempDataFilter.cs b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/SaveTempDataFilter.cs index f430ba7b08..912c7c7131 100644 --- a/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/SaveTempDataFilter.cs +++ b/src/Microsoft.AspNetCore.Mvc.ViewFeatures/Internal/SaveTempDataFilter.cs @@ -29,6 +29,36 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal /// public void OnResourceExecuting(ResourceExecutingContext context) { + if (!context.HttpContext.Response.HasStarted) + { + context.HttpContext.Response.OnStarting((state) => + { + var saveTempDataContext = (SaveTempDataContext)state; + + // If temp data was already saved, skip trying to save again as the calls here would potentially fail + // because the session feature might not be available at this point. + // Example: An action returns NoContentResult and since NoContentResult does not write anything to + // the body of the response, this delegate would get executed way late in the pipeline at which point + // the session feature would have been removed. + object obj; + if (saveTempDataContext.HttpContext.Items.TryGetValue(TempDataSavedKey, out obj)) + { + return TaskCache.CompletedTask; + } + + SaveTempData( + result: null, + factory: saveTempDataContext.TempDataDictionaryFactory, + httpContext: saveTempDataContext.HttpContext); + + return TaskCache.CompletedTask; + }, + state: new SaveTempDataContext() + { + HttpContext = context.HttpContext, + TempDataDictionaryFactory = _factory + }); + } } /// @@ -39,34 +69,6 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal /// public void OnResultExecuting(ResultExecutingContext context) { - context.HttpContext.Response.OnStarting((state) => - { - var saveTempDataContext = (SaveTempDataContext)state; - - // If temp data was already saved, skip trying to save again as the calls here would potentially fail - // because the session feature might not be available at this point. - // Example: An action returns NoContentResult and since NoContentResult does not write anything to - // the body of the response, this delegate would get executed way late in the pipeline at which point - // the session feature would have been removed. - object obj; - if (saveTempDataContext.HttpContext.Items.TryGetValue(TempDataSavedKey, out obj)) - { - return TaskCache.CompletedTask; - } - - SaveTempData( - saveTempDataContext.ActionResult, - saveTempDataContext.TempDataDictionaryFactory, - saveTempDataContext.HttpContext); - - return TaskCache.CompletedTask; - }, - state: new SaveTempDataContext() - { - HttpContext = context.HttpContext, - ActionResult = context.Result, - TempDataDictionaryFactory = _factory - }); } /// @@ -78,7 +80,11 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal if (!context.HttpContext.Response.HasStarted) { SaveTempData(context.Result, _factory, context.HttpContext); - context.HttpContext.Items.Add(TempDataSavedKey, true); + // If SaveTempDataFilter got added twice this might already be in there. + if (!context.HttpContext.Items.ContainsKey(TempDataSavedKey)) + { + context.HttpContext.Items.Add(TempDataSavedKey, true); + } } } @@ -94,7 +100,6 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal private class SaveTempDataContext { public HttpContext HttpContext { get; set; } - public IActionResult ActionResult { get; set; } public ITempDataDictionaryFactory TempDataDictionaryFactory { get; set; } } } diff --git a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/TempDataTestBase.cs b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/TempDataTestBase.cs index 0f88652357..ab9c284483 100644 --- a/test/Microsoft.AspNetCore.Mvc.FunctionalTests/TempDataTestBase.cs +++ b/test/Microsoft.AspNetCore.Mvc.FunctionalTests/TempDataTestBase.cs @@ -162,6 +162,20 @@ namespace Microsoft.AspNetCore.Mvc.FunctionalTests Assert.Equal($"Foo 10 3 10/10/2010 00:00:00 {testGuid.ToString()}", body); } + [Fact] + public async Task ResponseWrite_DoesNotCrashSaveTempDataFilter() + { + // Arrange + var nameValueCollection = new List> + { + new KeyValuePair("Name", "Jordan"), + }; + var content = new FormUrlEncodedContent(nameValueCollection); + + // Act, checking it didn't throw + var response = await Client.GetAsync("/TempData/SetTempDataResponseWrite"); + } + [Fact] public async Task SetInActionResultExecution_AvailableForNextRequest() { diff --git a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/SaveTempDataFilterTest.cs b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/SaveTempDataFilterTest.cs index 00f7fd061d..fe7d6f4b9e 100644 --- a/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/SaveTempDataFilterTest.cs +++ b/test/Microsoft.AspNetCore.Mvc.ViewFeatures.Test/Internal/SaveTempDataFilterTest.cs @@ -2,11 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Text; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Mvc.Abstractions; using Microsoft.AspNetCore.Mvc.Filters; +using Microsoft.AspNetCore.Mvc.ModelBinding; using Microsoft.AspNetCore.Routing; using Moq; using Xunit; @@ -28,23 +30,46 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal } [Fact] - public void OnResultExecuting_RegistersOnStartingCallback() + public async Task OnResultExecuting_DoesntThrowIfResponseStarted() + { + // Arrange + var responseFeature = new TestResponseFeature(hasStarted: true); + var httpContext = GetHttpContext(responseFeature); + var tempDataFactory = new Mock(MockBehavior.Loose); + tempDataFactory + .Setup(f => f.GetTempData(It.IsAny())) + .Verifiable(); + var filter = new SaveTempDataFilter(tempDataFactory.Object); + var context = GetResultExecutingContext(httpContext); + filter.OnResultExecuting(context); + + // Act + // Checking it doesn't throw + await responseFeature.FireOnSendingHeadersAsync(); + } + + [Fact] + public void OnResourceExecuting_RegistersOnStartingCallback() { // Arrange var responseFeature = new Mock(MockBehavior.Strict); responseFeature - .Setup(rf => rf.OnStarting(It.IsAny>(), It.IsAny())) + .Setup(rf => rf.OnStarting(It.IsAny>(), It.IsAny())) .Verifiable(); + responseFeature + .SetupGet(rf => rf.HasStarted) + .Returns(false); + var tempDataFactory = new Mock(MockBehavior.Strict); tempDataFactory .Setup(f => f.GetTempData(It.IsAny())) .Verifiable(); var filter = new SaveTempDataFilter(tempDataFactory.Object); var httpContext = GetHttpContext(responseFeature.Object); - var context = GetResultExecutingContext(httpContext); + var context = GetResourceExecutingContext(httpContext); // Act - filter.OnResultExecuting(context); + filter.OnResourceExecuting(context); // Assert responseFeature.Verify(); @@ -52,7 +77,28 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal } [Fact] - public async Task OnResultExecuting_DoesNotSaveTempData_WhenTempDataAlreadySaved() + public void OnResultExecuted_CanBeCalledTwice() + { + // Arrange + var responseFeature = new TestResponseFeature(); + var httpContext = GetHttpContext(responseFeature); + var tempData = GetTempDataDictionary(); + var tempDataFactory = new Mock(MockBehavior.Strict); + tempDataFactory + .Setup(f => f.GetTempData(It.IsAny())) + .Returns(tempData.Object) + .Verifiable(); + var filter = new SaveTempDataFilter(tempDataFactory.Object); + var context = GetResultExecutedContext(httpContext); + + // Act (No Assert) + filter.OnResultExecuted(context); + // Shouldn't have thrown + filter.OnResultExecuted(context); + } + + [Fact] + public async Task OnResourceExecuting_DoesNotSaveTempData_WhenTempDataAlreadySaved() { // Arrange var responseFeature = new TestResponseFeature(); @@ -63,8 +109,8 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal .Setup(f => f.GetTempData(It.IsAny())) .Verifiable(); var filter = new SaveTempDataFilter(tempDataFactory.Object); - var context = GetResultExecutingContext(httpContext); - filter.OnResultExecuting(context); // registers callback + var context = GetResourceExecutingContext(httpContext); + filter.OnResourceExecuting(context); // registers callback // Act await responseFeature.FireOnSendingHeadersAsync(); @@ -81,10 +127,12 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal var tempDataDictionary = GetTempDataDictionary(); var filter = GetFilter(tempDataDictionary.Object); var responseFeature = new TestResponseFeature(); - var actionContext = GetActionContext(GetHttpContext(responseFeature)); - var context = GetResultExecutingContext(actionContext, result); - filter.OnResultExecuting(context); // registers callback + var httpContext = GetHttpContext(responseFeature); + var resourceContext = GetResourceExecutingContext(httpContext); + var resultContext = GetResultExecutedContext(httpContext, result); + filter.OnResourceExecuting(resourceContext); // registers callback + filter.OnResultExecuted(resultContext); // Act await responseFeature.FireOnSendingHeadersAsync(); @@ -93,15 +141,17 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal } [Fact] - public async Task OnResultExecuting_KeepsTempData_ForIKeepTempDataResult() + public async Task OnResourceExecuting_KeepsTempData_ForIKeepTempDataResult() { // Arrange var tempDataDictionary = GetTempDataDictionary(); var filter = GetFilter(tempDataDictionary.Object); var responseFeature = new TestResponseFeature(); - var actionContext = GetActionContext(GetHttpContext(responseFeature)); - var context = GetResultExecutingContext(actionContext, new TestKeepTempDataActionResult()); - filter.OnResultExecuting(context); // registers callback + var httpContext = GetHttpContext(responseFeature); + var resourceContext = GetResourceExecutingContext(httpContext); + var resultContext = GetResultExecutedContext(httpContext, new TestKeepTempDataActionResult()); + filter.OnResourceExecuting(resourceContext); // registers callback + filter.OnResultExecuted(resultContext); // Act await responseFeature.FireOnSendingHeadersAsync(); @@ -118,9 +168,9 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal var tempDataDictionary = GetTempDataDictionary(); var filter = GetFilter(tempDataDictionary.Object); var responseFeature = new TestResponseFeature(); - var actionContext = GetActionContext(GetHttpContext(responseFeature)); - var context = GetResultExecutingContext(actionContext, new TestActionResult()); - filter.OnResultExecuting(context); // registers callback + var actionContext = GetHttpContext(responseFeature); + var context = GetResourceExecutingContext(actionContext); + filter.OnResourceExecuting(context); // registers callback // Act await responseFeature.FireOnSendingHeadersAsync(); @@ -224,6 +274,21 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal return tempDataDictionary; } + private ResourceExecutingContext GetResourceExecutingContext(HttpContext httpContext) + { + if (httpContext == null) + { + httpContext = GetHttpContext(); + } + var actionResult = new TestActionResult(); + + var actionContext = new ActionContext(httpContext, new RouteData(), new ActionDescriptor()); + var filters = new IFilterMetadata[] { }; + var valueProviderFactories = new IValueProviderFactory[] { }; + + return new ResourceExecutingContext(actionContext, filters, valueProviderFactories); + } + private ResultExecutedContext GetResultExecutedContext(HttpContext httpContext = null, IActionResult actionResult = null) { if (httpContext == null) @@ -331,6 +396,11 @@ namespace Microsoft.AspNetCore.Mvc.ViewFeatures.Internal public override void OnStarting(Func callback, object state) { + if (_hasStarted) + { + throw new ArgumentException(); + } + var prior = _responseStartingAsync; _responseStartingAsync = async () => { diff --git a/test/WebSites/BasicWebSite/Controllers/TempDataController.cs b/test/WebSites/BasicWebSite/Controllers/TempDataController.cs index b1383b00d8..1ad0d7a732 100644 --- a/test/WebSites/BasicWebSite/Controllers/TempDataController.cs +++ b/test/WebSites/BasicWebSite/Controllers/TempDataController.cs @@ -3,6 +3,8 @@ using System; using System.Collections.Generic; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc; namespace BasicWebSite.Controllers @@ -59,6 +61,13 @@ namespace BasicWebSite.Controllers return RedirectToAction("GetTempDataMultiple"); } + public async Task SetTempDataResponseWrite() + { + TempData["value1"] = "steve"; + + await Response.WriteAsync("Steve!"); + } + public string GetTempDataMultiple() { var value1 = TempData["key1"].ToString();