diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs index 580280e11d..28510c2d5b 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/ControllerActionInvoker.cs @@ -282,7 +282,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure } } - private async Task InvokeNextActionFilterAsync() + private Task InvokeNextActionFilterAsync() { try { @@ -292,7 +292,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure var isCompleted = false; while (!isCompleted) { - await Next(ref next, ref scope, ref state, ref isCompleted); + var lastTask = Next(ref next, ref scope, ref state, ref isCompleted); + if (!lastTask.IsCompletedSuccessfully) + { + return Awaited(this, lastTask, next, scope, state, isCompleted); + } } } catch (Exception exception) @@ -304,14 +308,59 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure } Debug.Assert(_actionExecutedContext != null); + return Task.CompletedTask; + + static async Task Awaited(ControllerActionInvoker invoker, Task lastTask, State next, Scope scope, object state, bool isCompleted) + { + try + { + await lastTask; + + while (!isCompleted) + { + await invoker.Next(ref next, ref scope, ref state, ref isCompleted); + } + } + catch (Exception exception) + { + invoker._actionExecutedContext = new ActionExecutedContext(invoker._controllerContext, invoker._filters, invoker._instance) + { + ExceptionDispatchInfo = ExceptionDispatchInfo.Capture(exception), + }; + } + + Debug.Assert(invoker._actionExecutedContext != null); + } } - private async Task InvokeNextActionFilterAwaitedAsync() + private Task InvokeNextActionFilterAwaitedAsync() { Debug.Assert(_actionExecutingContext != null); if (_actionExecutingContext.Result != null) { // If we get here, it means that an async filter set a result AND called next(). This is forbidden. + return Throw(); + } + + var task = InvokeNextActionFilterAsync(); + if (!task.IsCompletedSuccessfully) + { + return Awaited(this, task); + } + + Debug.Assert(_actionExecutedContext != null); + return Task.FromResult(_actionExecutedContext); + + static async Task Awaited(ControllerActionInvoker invoker, Task task) + { + await task; + + Debug.Assert(invoker._actionExecutedContext != null); + return invoker._actionExecutedContext; + } +#pragma warning disable CS1998 + static async Task Throw() + { var message = Resources.FormatAsyncActionFilter_InvalidShortCircuit( typeof(IAsyncActionFilter).Name, nameof(ActionExecutingContext.Result), @@ -320,69 +369,119 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure throw new InvalidOperationException(message); } - - await InvokeNextActionFilterAsync(); - - Debug.Assert(_actionExecutedContext != null); - return _actionExecutedContext; +#pragma warning restore CS1998 } - private async Task InvokeActionMethodAsync() + private Task InvokeActionMethodAsync() { - var controllerContext = _controllerContext; - var objectMethodExecutor = _cacheEntry.ObjectMethodExecutor; - var controller = _instance; - var arguments = _arguments; - var actionMethodExecutor = _cacheEntry.ActionMethodExecutor; - var orderedArguments = PrepareArguments(arguments, objectMethodExecutor); - - var diagnosticListener = _diagnosticListener; - var logger = _logger; - - IActionResult result = null; - try + if (_diagnosticListener.IsEnabled() || _logger.IsEnabled(LogLevel.Trace)) { - diagnosticListener.BeforeActionMethod( - controllerContext, - arguments, - controller); - logger.ActionMethodExecuting(controllerContext, orderedArguments); - var stopwatch = ValueStopwatch.StartNew(); - var actionResultValueTask = actionMethodExecutor.Execute(_mapper, objectMethodExecutor, controller, orderedArguments); - if (actionResultValueTask.IsCompletedSuccessfully) - { - result = actionResultValueTask.Result; - } - else - { - result = await actionResultValueTask; - } - - _result = result; - logger.ActionMethodExecuted(controllerContext, result, stopwatch.GetElapsedTime()); + return Logged(this); } - finally + + var objectMethodExecutor = _cacheEntry.ObjectMethodExecutor; + var actionMethodExecutor = _cacheEntry.ActionMethodExecutor; + var orderedArguments = PrepareArguments(_arguments, objectMethodExecutor); + + var actionResultValueTask = actionMethodExecutor.Execute(_mapper, objectMethodExecutor, _instance, orderedArguments); + if (actionResultValueTask.IsCompletedSuccessfully) { - diagnosticListener.AfterActionMethod( - controllerContext, - arguments, - controllerContext, - result); + _result = actionResultValueTask.Result; + } + else + { + return Awaited(this, actionResultValueTask); + } + + return Task.CompletedTask; + + static async Task Awaited(ControllerActionInvoker invoker, ValueTask actionResultValueTask) + { + invoker._result = await actionResultValueTask; + } + + static async Task Logged(ControllerActionInvoker invoker) + { + var controllerContext = invoker._controllerContext; + var objectMethodExecutor = invoker._cacheEntry.ObjectMethodExecutor; + var controller = invoker._instance; + var arguments = invoker._arguments; + var actionMethodExecutor = invoker._cacheEntry.ActionMethodExecutor; + var orderedArguments = PrepareArguments(arguments, objectMethodExecutor); + + var diagnosticListener = invoker._diagnosticListener; + var logger = invoker._logger; + + IActionResult result = null; + try + { + diagnosticListener.BeforeActionMethod( + controllerContext, + arguments, + controller); + logger.ActionMethodExecuting(controllerContext, orderedArguments); + var stopwatch = ValueStopwatch.StartNew(); + var actionResultValueTask = actionMethodExecutor.Execute(invoker._mapper, objectMethodExecutor, controller, orderedArguments); + if (actionResultValueTask.IsCompletedSuccessfully) + { + result = actionResultValueTask.Result; + } + else + { + result = await actionResultValueTask; + } + + invoker._result = result; + logger.ActionMethodExecuted(controllerContext, result, stopwatch.GetElapsedTime()); + } + finally + { + diagnosticListener.AfterActionMethod( + controllerContext, + arguments, + controllerContext, + result); + } } } /// for details on what the /// variables in this method represent. - protected override async Task InvokeInnerFilterAsync() + protected override Task InvokeInnerFilterAsync() { - var next = State.ActionBegin; - var scope = Scope.Invoker; - var state = (object)null; - var isCompleted = false; - - while (!isCompleted) + try { - await Next(ref next, ref scope, ref state, ref isCompleted); + var next = State.ActionBegin; + var scope = Scope.Invoker; + var state = (object)null; + var isCompleted = false; + + while (!isCompleted) + { + var lastTask = Next(ref next, ref scope, ref state, ref isCompleted); + if (!lastTask.IsCompletedSuccessfully) + { + return Awaited(this, lastTask, next, scope, state, isCompleted); + } + } + + return Task.CompletedTask; + } + catch (Exception ex) + { + // Wrap non task-wrapped exceptions in a Task, + // as this isn't done automatically since the method is not async. + return Task.FromException(ex); + } + + static async Task Awaited(ControllerActionInvoker invoker, Task lastTask, State next, Scope scope, object state, bool isCompleted) + { + await lastTask; + + while (!isCompleted) + { + await invoker.Next(ref next, ref scope, ref state, ref isCompleted); + } } } diff --git a/src/Mvc/Mvc.ViewFeatures/ref/Microsoft.AspNetCore.Mvc.ViewFeatures.netcoreapp3.0.cs b/src/Mvc/Mvc.ViewFeatures/ref/Microsoft.AspNetCore.Mvc.ViewFeatures.netcoreapp3.0.cs index 8e358a9505..5e7561c95c 100644 --- a/src/Mvc/Mvc.ViewFeatures/ref/Microsoft.AspNetCore.Mvc.ViewFeatures.netcoreapp3.0.cs +++ b/src/Mvc/Mvc.ViewFeatures/ref/Microsoft.AspNetCore.Mvc.ViewFeatures.netcoreapp3.0.cs @@ -29,7 +29,6 @@ namespace Microsoft.AspNetCore.Mvc [Microsoft.AspNetCore.Mvc.NonActionAttribute] public virtual void OnActionExecuting(Microsoft.AspNetCore.Mvc.Filters.ActionExecutingContext context) { } [Microsoft.AspNetCore.Mvc.NonActionAttribute] - [System.Diagnostics.DebuggerStepThroughAttribute] public virtual System.Threading.Tasks.Task OnActionExecutionAsync(Microsoft.AspNetCore.Mvc.Filters.ActionExecutingContext context, Microsoft.AspNetCore.Mvc.Filters.ActionExecutionDelegate next) { throw null; } [Microsoft.AspNetCore.Mvc.NonActionAttribute] public virtual Microsoft.AspNetCore.Mvc.PartialViewResult PartialView() { throw null; } diff --git a/src/Mvc/Mvc.ViewFeatures/src/Controller.cs b/src/Mvc/Mvc.ViewFeatures/src/Controller.cs index 6d2a841d7c..e90117df62 100644 --- a/src/Mvc/Mvc.ViewFeatures/src/Controller.cs +++ b/src/Mvc/Mvc.ViewFeatures/src/Controller.cs @@ -340,7 +340,7 @@ namespace Microsoft.AspNetCore.Mvc /// of to continue execution of the action. /// A instance. [NonAction] - public virtual async Task OnActionExecutionAsync( + public virtual Task OnActionExecutionAsync( ActionExecutingContext context, ActionExecutionDelegate next) { @@ -357,7 +357,20 @@ namespace Microsoft.AspNetCore.Mvc OnActionExecuting(context); if (context.Result == null) { - OnActionExecuted(await next()); + var task = next(); + if (!task.IsCompletedSuccessfully) + { + return Awaited(this, task); + } + + OnActionExecuted(task.Result); + } + + return Task.CompletedTask; + + static async Task Awaited(Controller controller, Task task) + { + controller.OnActionExecuted(await task); } }