From d52d7e3284bc5913aebaf410e78155af9d4ee9c0 Mon Sep 17 00:00:00 2001 From: Ryan Nowak Date: Fri, 2 Aug 2019 20:35:29 -0700 Subject: [PATCH] Harden StartCircuit (#12825) * Harden StartCircuit Fixes: #12057 Adds some upfront argument validation as well as error handling for circuit intialization failures. --- ...ft.AspNetCore.Components.netstandard2.0.cs | 4 ++ .../Components/src/LocationChangeException.cs | 23 +++++++ .../Components/src/NavigationManager.cs | 13 +++- .../Components/test/NavigationManagerTest.cs | 29 ++++++++- .../src/Circuits/DefaultCircuitFactory.cs | 2 + src/Components/Server/src/ComponentHub.cs | 65 ++++++++++++++----- .../ComponentHubReliabilityTest.cs | 60 ++++++++++++++++- .../test/testassets/Ignitor/BlazorClient.cs | 44 ++++++------- 8 files changed, 196 insertions(+), 44 deletions(-) create mode 100644 src/Components/Components/src/LocationChangeException.cs diff --git a/src/Components/Components/ref/Microsoft.AspNetCore.Components.netstandard2.0.cs b/src/Components/Components/ref/Microsoft.AspNetCore.Components.netstandard2.0.cs index c6d8c590a0..4f317be0ff 100644 --- a/src/Components/Components/ref/Microsoft.AspNetCore.Components.netstandard2.0.cs +++ b/src/Components/Components/ref/Microsoft.AspNetCore.Components.netstandard2.0.cs @@ -278,6 +278,10 @@ namespace Microsoft.AspNetCore.Components [Microsoft.AspNetCore.Components.ParameterAttribute] public Microsoft.AspNetCore.Components.RenderFragment Body { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } + public sealed partial class LocationChangeException : System.Exception + { + public LocationChangeException(string message, System.Exception innerException) { } + } [System.Runtime.InteropServices.StructLayoutAttribute(System.Runtime.InteropServices.LayoutKind.Sequential)] public readonly partial struct MarkupString { diff --git a/src/Components/Components/src/LocationChangeException.cs b/src/Components/Components/src/LocationChangeException.cs new file mode 100644 index 0000000000..13010eb5c1 --- /dev/null +++ b/src/Components/Components/src/LocationChangeException.cs @@ -0,0 +1,23 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Components +{ + /// + /// An exception thrown when throws an exception. + /// + public sealed class LocationChangeException : Exception + { + /// + /// Creates a new instance of . + /// + /// The exception message. + /// The inner exception. + public LocationChangeException(string message, Exception innerException) + : base(message, innerException) + { + } + } +} diff --git a/src/Components/Components/src/NavigationManager.cs b/src/Components/Components/src/NavigationManager.cs index f7729e27f6..d75077026f 100644 --- a/src/Components/Components/src/NavigationManager.cs +++ b/src/Components/Components/src/NavigationManager.cs @@ -2,7 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Runtime.InteropServices.ComTypes; using Microsoft.AspNetCore.Components.Routing; namespace Microsoft.AspNetCore.Components @@ -129,8 +128,9 @@ namespace Microsoft.AspNetCore.Components _isInitialized = true; - Uri = uri; + // Setting BaseUri before Uri so they get validated. BaseUri = baseUri; + Uri = uri; } /// @@ -201,7 +201,14 @@ namespace Microsoft.AspNetCore.Components /// protected void NotifyLocationChanged(bool isInterceptedLink) { - _locationChanged?.Invoke(this, new LocationChangedEventArgs(_uri, isInterceptedLink)); + try + { + _locationChanged?.Invoke(this, new LocationChangedEventArgs(_uri, isInterceptedLink)); + } + catch (Exception ex) + { + throw new LocationChangeException("An exception occurred while dispatching a location changed event.", ex); + } } private void AssertInitialized() diff --git a/src/Components/Components/test/NavigationManagerTest.cs b/src/Components/Components/test/NavigationManagerTest.cs index a2cc526905..9b857bb044 100644 --- a/src/Components/Components/test/NavigationManagerTest.cs +++ b/src/Components/Components/test/NavigationManagerTest.cs @@ -38,6 +38,24 @@ namespace Microsoft.AspNetCore.Components Assert.Equal(expectedResult, actualResult); } + [Theory] + [InlineData("scheme://host/", "otherscheme://host/")] + [InlineData("scheme://host/", "scheme://otherhost/")] + [InlineData("scheme://host/path/", "scheme://host/")] + public void Initialize_ThrowsForInvalidBaseRelativePaths(string baseUri, string absoluteUri) + { + var navigationManager = new TestNavigationManager(); + + var ex = Assert.Throws(() => + { + navigationManager.Initialize(baseUri, absoluteUri); + }); + + Assert.Equal( + $"The URI '{absoluteUri}' is not contained by the base URI '{baseUri}'.", + ex.Message); + } + [Theory] [InlineData("scheme://host/", "otherscheme://host/")] [InlineData("scheme://host/", "scheme://otherhost/")] @@ -76,9 +94,18 @@ namespace Microsoft.AspNetCore.Components private class TestNavigationManager : NavigationManager { + public TestNavigationManager() + { + } + public TestNavigationManager(string baseUri = null, string uri = null) { - Initialize(baseUri ?? "http://example.com/", uri ?? "http://example.com/welcome-page"); + Initialize(baseUri ?? "http://example.com/", uri ?? baseUri ?? "http://example.com/welcome-page"); + } + + public new void Initialize(string baseUri, string uri) + { + base.Initialize(baseUri, uri); } protected override void NavigateToCore(string uri, bool forceLoad) diff --git a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs index a0b9f8173e..d353d5dc2c 100644 --- a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs +++ b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs @@ -43,6 +43,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits string uri, ClaimsPrincipal user) { + // We do as much intialization as possible eagerly in this method, which makes the error handling + // story much simpler. If we throw from here, it's handled inside the initial hub method. var components = ResolveComponentMetadata(httpContext, client); var scope = _scopeFactory.CreateScope(); diff --git a/src/Components/Server/src/ComponentHub.cs b/src/Components/Server/src/ComponentHub.cs index 52f12fa90e..4da9a3391e 100644 --- a/src/Components/Server/src/ComponentHub.cs +++ b/src/Components/Server/src/ComponentHub.cs @@ -110,6 +110,18 @@ namespace Microsoft.AspNetCore.Components.Server return null; } + if (baseUri == null || + uri == null || + !Uri.IsWellFormedUriString(baseUri, UriKind.Absolute) || + !Uri.IsWellFormedUriString(uri, UriKind.Absolute)) + { + // We do some really minimal validation here to prevent obviously wrong data from getting in + // without duplicating too much logic. + Log.InvalidInputData(_logger); + _ = NotifyClientError(Clients.Caller, $"The uris provided are invalid."); + return null; + } + var circuitClient = new CircuitClientProxy(Clients.Caller, Context.ConnectionId); if (DefaultCircuitFactory.ResolveComponentMetadata(Context.GetHttpContext(), circuitClient).Count == 0) { @@ -122,26 +134,35 @@ namespace Microsoft.AspNetCore.Components.Server return null; } - var circuitHost = _circuitFactory.CreateCircuitHost( - Context.GetHttpContext(), - circuitClient, - baseUri, - uri, - Context.User); + try + { + var circuitHost = _circuitFactory.CreateCircuitHost( + Context.GetHttpContext(), + circuitClient, + baseUri, + uri, + Context.User); - circuitHost.UnhandledException += CircuitHost_UnhandledException; + circuitHost.UnhandledException += CircuitHost_UnhandledException; - // Fire-and-forget the initialization process, because we can't block the - // SignalR message loop (we'd get a deadlock if any of the initialization - // logic relied on receiving a subsequent message from SignalR), and it will - // take care of its own errors anyway. - _ = circuitHost.InitializeAsync(Context.ConnectionAborted); + // Fire-and-forget the initialization process, because we can't block the + // SignalR message loop (we'd get a deadlock if any of the initialization + // logic relied on receiving a subsequent message from SignalR), and it will + // take care of its own errors anyway. + _ = circuitHost.InitializeAsync(Context.ConnectionAborted); - _circuitRegistry.Register(circuitHost); - - CircuitHost = circuitHost; - - return circuitHost.CircuitId; + // It's safe to *publish* the circuit now because nothing will be able + // to run inside it until after InitializeAsync completes. + _circuitRegistry.Register(circuitHost); + CircuitHost = circuitHost; + return circuitHost.CircuitId; + } + catch (Exception ex) + { + Log.CircuitInitializationFailed(_logger, ex); + NotifyClientError(Clients.Caller, "The circuit failed to initialize."); + return null; + } } /// @@ -292,6 +313,12 @@ namespace Microsoft.AspNetCore.Components.Server private static readonly Action _circuitTerminatedGracefully = LoggerMessage.Define(LogLevel.Debug, new EventId(7, "CircuitTerminatedGracefully"), "Circuit '{CircuitId}' terminated gracefully"); + private static readonly Action _invalidInputData = + LoggerMessage.Define(LogLevel.Debug, new EventId(8, "InvalidInputData"), "Call to '{CallSite}' received invalid input data"); + + private static readonly Action _circuitInitializationFailed = + LoggerMessage.Define(LogLevel.Debug, new EventId(9, "CircuitInitializationFailed"), "Circuit initialization failed"); + public static void NoComponentsRegisteredInEndpoint(ILogger logger, string endpointDisplayName) { _noComponentsRegisteredInEndpoint(logger, endpointDisplayName, null); @@ -317,6 +344,10 @@ namespace Microsoft.AspNetCore.Components.Server public static void CircuitHostNotInitialized(ILogger logger, [CallerMemberName] string callSite = "") => _circuitHostNotInitialized(logger, callSite, null); public static void CircuitTerminatedGracefully(ILogger logger, string circuitId) => _circuitTerminatedGracefully(logger, circuitId, null); + + public static void InvalidInputData(ILogger logger, [CallerMemberName] string callSite = "") => _invalidInputData(logger, callSite, null); + + public static void CircuitInitializationFailed(ILogger logger, Exception exception) => _circuitInitializationFailed(logger, exception); } } } diff --git a/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs b/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs index d510e76c4b..deb353dd9c 100644 --- a/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs +++ b/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs @@ -61,7 +61,65 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests await Client.ExpectCircuitError(() => Client.HubConnection.SendAsync( "StartCircuit", baseUri, - baseUri.GetLeftPart(UriPartial.Authority))); + baseUri + "/home")); + + // Assert + var actualError = Assert.Single(Errors); + Assert.Matches(expectedError, actualError); + Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information); + } + + [Fact] + public async Task CannotStartCircuitWithNullData() + { + // Arrange + var expectedError = "The uris provided are invalid."; + var rootUri = _serverFixture.RootUri; + var uri = new Uri(rootUri, "/subdir"); + Assert.True(await Client.ConnectAsync(uri, prerendered: false, connectAutomatically: false), "Couldn't connect to the app"); + + // Act + await Client.ExpectCircuitError(() => Client.HubConnection.SendAsync("StartCircuit", null, null)); + + // Assert + var actualError = Assert.Single(Errors); + Assert.Matches(expectedError, actualError); + Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information); + } + + [Fact] + public async Task CannotStartCircuitWithInvalidUris() + { + // Arrange + var expectedError = "The uris provided are invalid."; + var rootUri = _serverFixture.RootUri; + var uri = new Uri(rootUri, "/subdir"); + Assert.True(await Client.ConnectAsync(uri, prerendered: false, connectAutomatically: false), "Couldn't connect to the app"); + + // Act + await Client.ExpectCircuitError(() => Client.HubConnection.SendAsync("StartCircuit", uri.AbsoluteUri, "/foo")); + + // Assert + var actualError = Assert.Single(Errors); + Assert.Matches(expectedError, actualError); + Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information); + } + + // This is a hand-chosen example of something that will cause an exception in creating the circuit host. + // We want to test this case so that we know what happens when creating the circuit host blows up. + [Fact] + public async Task StartCircuitCausesInitializationError() + { + // Arrange + var expectedError = "The circuit failed to initialize."; + var rootUri = _serverFixture.RootUri; + var uri = new Uri(rootUri, "/subdir"); + Assert.True(await Client.ConnectAsync(uri, prerendered: false, connectAutomatically: false), "Couldn't connect to the app"); + + // Act + // + // These are valid URIs by the BaseUri doesn't contain the Uri - so it fails to initialize. + await Client.ExpectCircuitError(() => Client.HubConnection.SendAsync("StartCircuit", uri, "http://example.com"), TimeSpan.FromHours(1)); // Assert var actualError = Assert.Single(Errors); diff --git a/src/Components/test/testassets/Ignitor/BlazorClient.cs b/src/Components/test/testassets/Ignitor/BlazorClient.cs index a71851a206..c11c2951dd 100644 --- a/src/Components/test/testassets/Ignitor/BlazorClient.cs +++ b/src/Components/test/testassets/Ignitor/BlazorClient.cs @@ -76,38 +76,38 @@ namespace Ignitor return NextBatchReceived.Completion.Task; } - public Task PrepareForNextJSInterop() + public Task PrepareForNextJSInterop(TimeSpan? timeout) { if (NextJSInteropReceived?.Completion != null) { throw new InvalidOperationException("Invalid state previous task not completed"); } - NextJSInteropReceived = new CancellableOperation(DefaultLatencyTimeout); + NextJSInteropReceived = new CancellableOperation(timeout); return NextJSInteropReceived.Completion.Task; } - public Task PrepareForNextDotNetInterop() + public Task PrepareForNextDotNetInterop(TimeSpan? timeout) { if (NextDotNetInteropCompletionReceived?.Completion != null) { throw new InvalidOperationException("Invalid state previous task not completed"); } - NextDotNetInteropCompletionReceived = new CancellableOperation(DefaultLatencyTimeout); + NextDotNetInteropCompletionReceived = new CancellableOperation(timeout); return NextDotNetInteropCompletionReceived.Completion.Task; } - public Task PrepareForNextCircuitError() + public Task PrepareForNextCircuitError(TimeSpan? timeout) { if (NextErrorReceived?.Completion != null) { throw new InvalidOperationException("Invalid state previous task not completed"); } - NextErrorReceived = new CancellableOperation(DefaultLatencyTimeout); + NextErrorReceived = new CancellableOperation(timeout); return NextErrorReceived.Completion.Task; } @@ -139,23 +139,23 @@ namespace Ignitor await task; } - public async Task ExpectJSInterop(Func action) + public async Task ExpectJSInterop(Func action, TimeSpan? timeout = null) { - var task = WaitForJSInterop(); + var task = WaitForJSInterop(timeout); await action(); await task; } - public async Task ExpectDotNetInterop(Func action) + public async Task ExpectDotNetInterop(Func action, TimeSpan? timeout = null) { - var task = WaitForDotNetInterop(); + var task = WaitForDotNetInterop(timeout); await action(); await task; } - public async Task ExpectCircuitError(Func action) + public async Task ExpectCircuitError(Func action, TimeSpan? timeout = null) { - var task = WaitForCircuitError(); + var task = WaitForCircuitError(timeout); await action(); await task; } @@ -175,42 +175,42 @@ namespace Ignitor return Task.CompletedTask; } - private async Task WaitForJSInterop() + private async Task WaitForJSInterop(TimeSpan? timeout = null) { if (ImplicitWait) { - if (DefaultLatencyTimeout == null) + if (DefaultLatencyTimeout == null && timeout == null) { throw new InvalidOperationException("Implicit wait without DefaultLatencyTimeout is not allowed."); } - await PrepareForNextJSInterop(); + await PrepareForNextJSInterop(timeout ?? DefaultLatencyTimeout); } } - private async Task WaitForDotNetInterop() + private async Task WaitForDotNetInterop(TimeSpan? timeout = null) { if (ImplicitWait) { - if (DefaultLatencyTimeout == null) + if (DefaultLatencyTimeout == null && timeout == null) { throw new InvalidOperationException("Implicit wait without DefaultLatencyTimeout is not allowed."); } - await PrepareForNextDotNetInterop(); + await PrepareForNextDotNetInterop(timeout ?? DefaultLatencyTimeout); } } - private async Task WaitForCircuitError() + private async Task WaitForCircuitError(TimeSpan? timeout = null) { if (ImplicitWait) { - if (DefaultLatencyTimeout == null) + if (DefaultLatencyTimeout == null && timeout == null) { throw new InvalidOperationException("Implicit wait without DefaultLatencyTimeout is not allowed."); } - await PrepareForNextCircuitError(); + await PrepareForNextCircuitError(timeout ?? DefaultLatencyTimeout); } } @@ -246,7 +246,7 @@ namespace Ignitor else { await ExpectRenderBatch( - async () => CircuitId = await HubConnection.InvokeAsync("StartCircuit", uri, new Uri(uri.GetLeftPart(UriPartial.Authority))), + async () => CircuitId = await HubConnection.InvokeAsync("StartCircuit", uri, uri), TimeSpan.FromSeconds(10)); return CircuitId != null; }