From b31cd35f92dd884f4e2400e9cfdcb4d4618066cd Mon Sep 17 00:00:00 2001 From: Steve Sanderson Date: Thu, 4 Jul 2019 11:16:10 +0100 Subject: [PATCH] Ensure DI scope is disposed if prerendering fails --- .../Server/src/Circuits/CircuitPrerenderer.cs | 45 ++++++++++-------- .../test/Circuits/CircuitPrerendererTest.cs | 46 +++++++++++++++++++ 2 files changed, 71 insertions(+), 20 deletions(-) diff --git a/src/Components/Server/src/Circuits/CircuitPrerenderer.cs b/src/Components/Server/src/Circuits/CircuitPrerenderer.cs index e433fd1748..1d3ab61802 100644 --- a/src/Components/Server/src/Circuits/CircuitPrerenderer.cs +++ b/src/Components/Server/src/Circuits/CircuitPrerenderer.cs @@ -8,14 +8,13 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Components.Rendering; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Extensions; -using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Components.Server.Circuits { internal class CircuitPrerenderer : IComponentPrerenderer { private static object CircuitHostKey = new object(); - private static object NavigationStatusKey = new object(); + private static object CancellationStatusKey = new object(); private readonly CircuitFactory _circuitFactory; private readonly CircuitRegistry _registry; @@ -29,15 +28,15 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public async Task PrerenderComponentAsync(ComponentPrerenderingContext prerenderingContext) { var context = prerenderingContext.Context; - var navigationStatus = GetOrCreateNavigationStatus(context); - if (navigationStatus.Navigated) + var cancellationStatus = GetOrCreateCancellationStatus(context); + if (cancellationStatus.Canceled) { // Avoid creating a circuit host if other component earlier in the pipeline already triggered - // a navigation request. Instead rendre nothing + // cancelation (e.g., by navigating or throwing). Instead render nothing. return new ComponentPrerenderResult(Array.Empty()); } - var circuitHost = GetOrCreateCircuitHost(context, navigationStatus); - ComponentRenderedText renderResult = default; + var circuitHost = GetOrCreateCircuitHost(context, cancellationStatus); + ComponentRenderedText renderResult; try { renderResult = await circuitHost.PrerenderComponentAsync( @@ -48,7 +47,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits { // Cleanup the state as we won't need it any longer. // Signal callbacks that we don't have to register the circuit. - await CleanupCircuitState(context, navigationStatus, circuitHost); + await CleanupCircuitState(context, cancellationStatus, circuitHost); // Navigation was attempted during prerendering. if (prerenderingContext.Context.Response.HasStarted) @@ -64,6 +63,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits context.Response.Redirect(navigationException.Location); return new ComponentPrerenderResult(Array.Empty()); } + catch + { + // If prerendering any component fails, cancel prerendering entirely and dispose the DI scope + await CleanupCircuitState(context, cancellationStatus, circuitHost); + throw; + } circuitHost.Descriptors.Add(new ComponentDescriptor { @@ -81,28 +86,28 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits return new ComponentPrerenderResult(result); } - private CircuitNavigationStatus GetOrCreateNavigationStatus(HttpContext context) + private PrerenderingCancellationStatus GetOrCreateCancellationStatus(HttpContext context) { - if (context.Items.TryGetValue(NavigationStatusKey, out var existingHost)) + if (context.Items.TryGetValue(CancellationStatusKey, out var existingValue)) { - return (CircuitNavigationStatus)existingHost; + return (PrerenderingCancellationStatus)existingValue; } else { - var navigationStatus = new CircuitNavigationStatus(); - context.Items[NavigationStatusKey] = navigationStatus; - return navigationStatus; + var cancellationStatus = new PrerenderingCancellationStatus(); + context.Items[CancellationStatusKey] = cancellationStatus; + return cancellationStatus; } } - private static async Task CleanupCircuitState(HttpContext context, CircuitNavigationStatus navigationStatus, CircuitHost circuitHost) + private static async Task CleanupCircuitState(HttpContext context, PrerenderingCancellationStatus cancellationStatus, CircuitHost circuitHost) { - navigationStatus.Navigated = true; + cancellationStatus.Canceled = true; context.Items.Remove(CircuitHostKey); await circuitHost.DisposeAsync(); } - private CircuitHost GetOrCreateCircuitHost(HttpContext context, CircuitNavigationStatus navigationStatus) + private CircuitHost GetOrCreateCircuitHost(HttpContext context, PrerenderingCancellationStatus cancellationStatus) { if (context.Items.TryGetValue(CircuitHostKey, out var existingHost)) { @@ -120,7 +125,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits context.Response.OnCompleted(() => { result.UnhandledException -= CircuitHost_UnhandledException; - if (!navigationStatus.Navigated) + if (!cancellationStatus.Canceled) { _registry.RegisterDisconnectedCircuit(result); } @@ -164,9 +169,9 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits return result; } - private class CircuitNavigationStatus + private class PrerenderingCancellationStatus { - public bool Navigated { get; set; } + public bool Canceled { get; set; } } } } diff --git a/src/Components/Server/test/Circuits/CircuitPrerendererTest.cs b/src/Components/Server/test/Circuits/CircuitPrerendererTest.cs index b96541455c..210bdb80d0 100644 --- a/src/Components/Server/test/Circuits/CircuitPrerendererTest.cs +++ b/src/Components/Server/test/Circuits/CircuitPrerendererTest.cs @@ -111,6 +111,32 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits }), GetUnwrappedContent(result)); } + [Fact] + public async Task DisposesCircuitScopeEvenIfPrerenderingThrows() + { + // Arrange + var circuitFactory = new MockServiceScopeCircuitFactory(); + var circuitRegistry = new CircuitRegistry( + Options.Create(new CircuitOptions()), + Mock.Of>(), + TestCircuitIdFactory.CreateTestFactory()); + var httpContext = new DefaultHttpContext(); + var prerenderer = new CircuitPrerenderer(circuitFactory, circuitRegistry); + var prerenderingContext = new ComponentPrerenderingContext + { + ComponentType = typeof(ThrowExceptionComponent), + Parameters = ParameterCollection.Empty, + Context = httpContext + }; + + // Act + await Assert.ThrowsAsync(async () => + await prerenderer.PrerenderComponentAsync(prerenderingContext)); + + // Assert + circuitFactory.MockServiceScope.Verify(scope => scope.Dispose(), Times.Once()); + } + class TestCircuitFactory : CircuitFactory { public override CircuitHost CreateCircuitHost(HttpContext httpContext, CircuitClientProxy client, string uriAbsolute, string baseUriAbsolute) @@ -127,6 +153,17 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits } } + class MockServiceScopeCircuitFactory : CircuitFactory + { + public Mock MockServiceScope { get; } + = new Mock(); + + public override CircuitHost CreateCircuitHost(HttpContext httpContext, CircuitClientProxy client, string uriAbsolute, string baseUriAbsolute) + { + return TestCircuitHost.Create(Guid.NewGuid().ToString(), MockServiceScope.Object); + } + } + class UriDisplayComponent : IComponent { private RenderHandle _renderHandle; @@ -151,5 +188,14 @@ namespace Microsoft.AspNetCore.Components.Server.Tests.Circuits return Task.CompletedTask; } } + + class ThrowExceptionComponent : IComponent + { + public void Configure(RenderHandle renderHandle) + => throw new InvalidTimeZoneException(); + + public Task SetParametersAsync(ParameterCollection parameters) + => Task.CompletedTask; + } } }