diff --git a/src/Components/Server/src/CircuitDisconnectMiddleware.cs b/src/Components/Server/src/CircuitDisconnectMiddleware.cs index 0792e009d3..d64c31e7da 100644 --- a/src/Components/Server/src/CircuitDisconnectMiddleware.cs +++ b/src/Components/Server/src/CircuitDisconnectMiddleware.cs @@ -39,65 +39,71 @@ namespace Microsoft.AspNetCore.Components.Server return; } - var (hasCircuitId, circuitId) = await TryGetCircuitIdAsync(context); - if (!hasCircuitId) + var circuitId = await GetCircuitIdAsync(context); + if (circuitId is null) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } - await TerminateCircuitGracefully(circuitId); + await TerminateCircuitGracefully(circuitId.Value); context.Response.StatusCode = StatusCodes.Status200OK; } - private async Task<(bool, string)> TryGetCircuitIdAsync(HttpContext context) + private async Task GetCircuitIdAsync(HttpContext context) { try { if (!context.Request.HasFormContentType) { - return (false, null); + return default; } var form = await context.Request.ReadFormAsync(); - if (!form.TryGetValue(CircuitIdKey, out var circuitId) || !CircuitIdFactory.ValidateCircuitId(circuitId)) + if (!form.TryGetValue(CircuitIdKey, out var text)) { - return (false, null); + return default; } - return (true, circuitId); + if (!CircuitIdFactory.TryParseCircuitId(text, out var circuitId)) + { + Log.InvalidCircuitId(Logger, text); + return default; + } + + return circuitId; } catch { - return (false, null); + return default; } } - private async Task TerminateCircuitGracefully(string circuitId) + private async Task TerminateCircuitGracefully(CircuitId circuitId) { - try - { - await Registry.TerminateAsync(circuitId); - Log.CircuitTerminatedGracefully(Logger, circuitId); - } - catch (Exception e) - { - Log.UnhandledExceptionInCircuit(Logger, circuitId, e); - } + // We don't expect TerminateAsync to throw. + Log.CircuitTerminatingGracefully(Logger, circuitId); + await Registry.TerminateAsync(circuitId); + Log.CircuitTerminatedGracefully(Logger, circuitId); } private class Log { - private static readonly Action _circuitTerminatedGracefully = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CircuitTerminatedGracefully"), "Circuit '{CircuitId}' terminated gracefully"); + private static readonly Action _circuitTerminatingGracefully = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CircuitTerminatingGracefully"), "Circuit with id '{CircuitId}' terminating gracefully."); - private static readonly Action _unhandledExceptionInCircuit = - LoggerMessage.Define(LogLevel.Warning, new EventId(2, "UnhandledExceptionInCircuit"), "Unhandled exception in circuit {CircuitId} while terminating gracefully."); + private static readonly Action _circuitTerminatedGracefully = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "CircuitTerminatedGracefully"), "Circuit with id '{CircuitId}' terminated gracefully."); - public static void CircuitTerminatedGracefully(ILogger logger, string circuitId) => _circuitTerminatedGracefully(logger, circuitId, null); + private static readonly Action _invalidCircuitId = + LoggerMessage.Define(LogLevel.Debug, new EventId(3, "InvalidCircuitId"), "CircuitDisconnectMiddleware recieved an invalid circuit id '{CircuitIdSecret}'."); - public static void UnhandledExceptionInCircuit(ILogger logger, string circuitId, Exception exception) => _unhandledExceptionInCircuit(logger, circuitId, exception); + public static void CircuitTerminatingGracefully(ILogger logger, CircuitId circuitId) => _circuitTerminatingGracefully(logger, circuitId, null); + + public static void CircuitTerminatedGracefully(ILogger logger, CircuitId circuitId) => _circuitTerminatedGracefully(logger, circuitId, null); + + public static void InvalidCircuitId(ILogger logger, string circuitSecret) => _invalidCircuitId(logger, circuitSecret, null); } } } diff --git a/src/Components/Server/src/Circuits/Circuit.cs b/src/Components/Server/src/Circuits/Circuit.cs index 536fb86f93..9d6620d83a 100644 --- a/src/Components/Server/src/Circuits/Circuit.cs +++ b/src/Components/Server/src/Circuits/Circuit.cs @@ -18,6 +18,6 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits /// /// Gets the identifier for the . /// - public string Id => _circuitHost.CircuitId; + public string Id => _circuitHost.CircuitId.Id; } } diff --git a/src/Components/Server/src/Circuits/CircuitHost.cs b/src/Components/Server/src/Circuits/CircuitHost.cs index a2c9ac869b..e664f32918 100644 --- a/src/Components/Server/src/Circuits/CircuitHost.cs +++ b/src/Components/Server/src/Circuits/CircuitHost.cs @@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public event UnhandledExceptionEventHandler UnhandledException; public CircuitHost( - string circuitId, + CircuitId circuitId, IServiceScope scope, CircuitOptions options, CircuitClientProxy client, @@ -48,7 +48,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits CircuitHandler[] circuitHandlers, ILogger logger) { - CircuitId = circuitId ?? throw new ArgumentNullException(nameof(circuitId)); + CircuitId = circuitId; + if (CircuitId.Secret is null) + { + // Prevent the use of a 'default' secret. + throw new ArgumentException(nameof(circuitId)); + } _scope = scope ?? throw new ArgumentNullException(nameof(scope)); _options = options ?? throw new ArgumentNullException(nameof(options)); @@ -70,7 +75,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public CircuitHandle Handle { get; } - public string CircuitId { get; } + public CircuitId CircuitId { get; } public Circuit Circuit { get; } @@ -194,7 +199,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // exceptions. private async Task OnCircuitOpenedAsync(CancellationToken cancellationToken) { - Log.CircuitOpened(_logger, Circuit.Id); + Log.CircuitOpened(_logger, CircuitId); Renderer.Dispatcher.AssertAccess(); @@ -223,7 +228,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public async Task OnConnectionUpAsync(CancellationToken cancellationToken) { - Log.ConnectionUp(_logger, Circuit.Id, Client.ConnectionId); + Log.ConnectionUp(_logger, CircuitId, Client.ConnectionId); Renderer.Dispatcher.AssertAccess(); @@ -252,7 +257,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public async Task OnConnectionDownAsync(CancellationToken cancellationToken) { - Log.ConnectionDown(_logger, Circuit.Id, Client.ConnectionId); + Log.ConnectionDown(_logger, CircuitId, Client.ConnectionId); Renderer.Dispatcher.AssertAccess(); @@ -281,7 +286,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits private async Task OnCircuitDownAsync(CancellationToken cancellationToken) { - Log.CircuitClosed(_logger, Circuit.Id); + Log.CircuitClosed(_logger, CircuitId); List exceptions = null; @@ -585,19 +590,19 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits private static readonly Action _intializationStarted; private static readonly Action _intializationSucceded; private static readonly Action _intializationFailed; - private static readonly Action _disposeStarted; - private static readonly Action _disposeSucceded; - private static readonly Action _disposeFailed; - private static readonly Action _onCircuitOpened; - private static readonly Action _onConnectionUp; - private static readonly Action _onConnectionDown; - private static readonly Action _onCircuitClosed; + private static readonly Action _disposeStarted; + private static readonly Action _disposeSucceded; + private static readonly Action _disposeFailed; + private static readonly Action _onCircuitOpened; + private static readonly Action _onConnectionUp; + private static readonly Action _onConnectionDown; + private static readonly Action _onCircuitClosed; private static readonly Action _circuitHandlerFailed; - private static readonly Action _circuitUnhandledException; - private static readonly Action _circuitTransmittingClientError; - private static readonly Action _circuitTransmittedClientErrorSuccess; - private static readonly Action _circuitTransmitErrorFailed; - private static readonly Action _unhandledExceptionClientDisconnected; + private static readonly Action _circuitUnhandledException; + private static readonly Action _circuitTransmittingClientError; + private static readonly Action _circuitTransmittedClientErrorSuccess; + private static readonly Action _circuitTransmitErrorFailed; + private static readonly Action _unhandledExceptionClientDisconnected; private static readonly Action _beginInvokeDotNetStatic; private static readonly Action _beginInvokeDotNetInstance; @@ -608,11 +613,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits private static readonly Action _endInvokeJSSucceeded; private static readonly Action _dispatchEventFailedToParseEventData; private static readonly Action _dispatchEventFailedToDispatchEvent; - private static readonly Action _locationChange; - private static readonly Action _locationChangeSucceeded; - private static readonly Action _locationChangeFailed; - private static readonly Action _locationChangeFailedInCircuit; - private static readonly Action _onRenderCompletedFailed; + private static readonly Action _locationChange; + private static readonly Action _locationChangeSucceeded; + private static readonly Action _locationChangeFailed; + private static readonly Action _locationChangeFailedInCircuit; + private static readonly Action _onRenderCompletedFailed; private static class EventIds { @@ -647,7 +652,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public static readonly EventId LocationChangeSucceded = new EventId(209, "LocationChangeSucceeded"); public static readonly EventId LocationChangeFailed = new EventId(210, "LocationChangeFailed"); public static readonly EventId LocationChangeFailedInCircuit = new EventId(211, "LocationChangeFailedInCircuit"); - public static readonly EventId OnRenderCompletedFailed = new EventId(212, " OnRenderCompletedFailed"); + public static readonly EventId OnRenderCompletedFailed = new EventId(212, "OnRenderCompletedFailed"); } static Log() @@ -667,37 +672,37 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits EventIds.InitializationFailed, "Circuit initialization failed."); - _disposeStarted = LoggerMessage.Define( + _disposeStarted = LoggerMessage.Define( LogLevel.Debug, EventIds.DisposeStarted, "Disposing circuit '{CircuitId}' started."); - _disposeSucceded = LoggerMessage.Define( + _disposeSucceded = LoggerMessage.Define( LogLevel.Debug, EventIds.DisposeSucceeded, "Disposing circuit '{CircuitId}' succeded."); - _disposeFailed = LoggerMessage.Define( + _disposeFailed = LoggerMessage.Define( LogLevel.Debug, EventIds.DisposeFailed, "Disposing circuit '{CircuitId}' failed."); - _onCircuitOpened = LoggerMessage.Define( + _onCircuitOpened = LoggerMessage.Define( LogLevel.Debug, EventIds.OnCircuitOpened, "Opening circuit with id '{CircuitId}'."); - _onConnectionUp = LoggerMessage.Define( + _onConnectionUp = LoggerMessage.Define( LogLevel.Debug, EventIds.OnConnectionUp, "Circuit id '{CircuitId}' connected using connection '{ConnectionId}'."); - _onConnectionDown = LoggerMessage.Define( + _onConnectionDown = LoggerMessage.Define( LogLevel.Debug, EventIds.OnConnectionDown, "Circuit id '{CircuitId}' disconnected from connection '{ConnectionId}'."); - _onCircuitClosed = LoggerMessage.Define( + _onCircuitClosed = LoggerMessage.Define( LogLevel.Debug, EventIds.OnCircuitClosed, "Closing circuit with id '{CircuitId}'."); @@ -707,27 +712,27 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits EventIds.CircuitHandlerFailed, "Unhandled error invoking circuit handler type {handlerType}.{handlerMethod}: {Message}"); - _circuitUnhandledException = LoggerMessage.Define( + _circuitUnhandledException = LoggerMessage.Define( LogLevel.Error, EventIds.CircuitUnhandledException, "Unhandled exception in circuit '{CircuitId}'."); - _circuitTransmittingClientError = LoggerMessage.Define( + _circuitTransmittingClientError = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitTransmittingClientError, "About to notify client of an error in circuit '{CircuitId}'."); - _circuitTransmittedClientErrorSuccess = LoggerMessage.Define( + _circuitTransmittedClientErrorSuccess = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitTransmittedClientErrorSuccess, "Successfully transmitted error to client in circuit '{CircuitId}'."); - _circuitTransmitErrorFailed = LoggerMessage.Define( + _circuitTransmitErrorFailed = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitTransmitErrorFailed, "Failed to transmit exception to client in circuit '{CircuitId}'."); - _unhandledExceptionClientDisconnected = LoggerMessage.Define( + _unhandledExceptionClientDisconnected = LoggerMessage.Define( LogLevel.Debug, EventIds.UnhandledExceptionClientDisconnected, "An exception ocurred on the circuit host '{CircuitId}' while the client is disconnected."); @@ -777,27 +782,27 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits EventIds.DispatchEventFailedToDispatchEvent, "There was an error dispatching the event '{EventHandlerId}' to the application."); - _locationChange = LoggerMessage.Define( + _locationChange = LoggerMessage.Define( LogLevel.Debug, EventIds.LocationChange, "Location changing to {URI} in circuit '{CircuitId}'."); - _locationChangeSucceeded = LoggerMessage.Define( + _locationChangeSucceeded = LoggerMessage.Define( LogLevel.Debug, EventIds.LocationChangeSucceded, "Location change to '{URI}' in circuit '{CircuitId}' succeded."); - _locationChangeFailed = LoggerMessage.Define( + _locationChangeFailed = LoggerMessage.Define( LogLevel.Debug, EventIds.LocationChangeFailed, "Location change to '{URI}' in circuit '{CircuitId}' failed."); - _locationChangeFailedInCircuit = LoggerMessage.Define( + _locationChangeFailedInCircuit = LoggerMessage.Define( LogLevel.Error, EventIds.LocationChangeFailed, "Location change to '{URI}' in circuit '{CircuitId}' failed."); - _onRenderCompletedFailed = LoggerMessage.Define( + _onRenderCompletedFailed = LoggerMessage.Define( LogLevel.Debug, EventIds.OnRenderCompletedFailed, "Failed to complete render batch '{RenderId}' in circuit host '{CircuitId}'."); @@ -806,13 +811,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public static void InitializationStarted(ILogger logger) => _intializationStarted(logger, null); public static void InitializationSucceeded(ILogger logger) => _intializationSucceded(logger, null); public static void InitializationFailed(ILogger logger, Exception exception) => _intializationFailed(logger, exception); - public static void DisposeStarted(ILogger logger, string circuitId) => _disposeStarted(logger, circuitId, null); - public static void DisposeSucceeded(ILogger logger, string circuitId) => _disposeSucceded(logger, circuitId, null); - public static void DisposeFailed(ILogger logger, string circuitId, Exception exception) => _disposeFailed(logger, circuitId, exception); - public static void CircuitOpened(ILogger logger, string circuitId) => _onCircuitOpened(logger, circuitId, null); - public static void ConnectionUp(ILogger logger, string circuitId, string connectionId) => _onConnectionUp(logger, circuitId, connectionId, null); - public static void ConnectionDown(ILogger logger, string circuitId, string connectionId) => _onConnectionDown(logger, circuitId, connectionId, null); - public static void CircuitClosed(ILogger logger, string circuitId) => _onCircuitClosed(logger, circuitId, null); + public static void DisposeStarted(ILogger logger, CircuitId circuitId) => _disposeStarted(logger, circuitId, null); + public static void DisposeSucceeded(ILogger logger, CircuitId circuitId) => _disposeSucceded(logger, circuitId, null); + public static void DisposeFailed(ILogger logger, CircuitId circuitId, Exception exception) => _disposeFailed(logger, circuitId, exception); + public static void CircuitOpened(ILogger logger, CircuitId circuitId) => _onCircuitOpened(logger, circuitId, null); + public static void ConnectionUp(ILogger logger, CircuitId circuitId, string connectionId) => _onConnectionUp(logger, circuitId, connectionId, null); + public static void ConnectionDown(ILogger logger, CircuitId circuitId, string connectionId) => _onConnectionDown(logger, circuitId, connectionId, null); + public static void CircuitClosed(ILogger logger, CircuitId circuitId) => _onCircuitClosed(logger, circuitId, null); public static void CircuitHandlerFailed(ILogger logger, CircuitHandler handler, string handlerMethod, Exception exception) { @@ -824,8 +829,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits exception); } - public static void CircuitUnhandledException(ILogger logger, string circuitId, Exception exception) => _circuitUnhandledException(logger, circuitId, exception); - public static void CircuitTransmitErrorFailed(ILogger logger, string circuitId, Exception exception) => _circuitTransmitErrorFailed(logger, circuitId, exception); + public static void CircuitUnhandledException(ILogger logger, CircuitId circuitId, Exception exception) => _circuitUnhandledException(logger, circuitId, exception); + public static void CircuitTransmitErrorFailed(ILogger logger, CircuitId circuitId, Exception exception) => _circuitTransmitErrorFailed(logger, circuitId, exception); public static void EndInvokeDispatchException(ILogger logger, Exception ex) => _endInvokeDispatchException(logger, ex); public static void EndInvokeJSFailed(ILogger logger, long asyncHandle, string arguments) => _endInvokeJSFailed(logger, asyncHandle, arguments, null); public static void EndInvokeJSSucceeded(ILogger logger, long asyncCall) => _endInvokeJSSucceeded(logger, asyncCall, null); @@ -856,14 +861,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits } } - public static void LocationChange(ILogger logger, string uri, string circuitId) => _locationChange(logger, uri, circuitId, null); - public static void LocationChangeSucceeded(ILogger logger, string uri, string circuitId) => _locationChangeSucceeded(logger, uri, circuitId, null); - public static void LocationChangeFailed(ILogger logger, string uri, string circuitId, Exception exception) => _locationChangeFailed(logger, uri, circuitId, exception); - public static void LocationChangeFailedInCircuit(ILogger logger, string uri, string circuitId, Exception exception) => _locationChangeFailedInCircuit(logger, uri, circuitId, exception); - public static void UnhandledExceptionClientDisconnected(ILogger logger, string circuitId, Exception exception) => _unhandledExceptionClientDisconnected(logger, circuitId, exception); - public static void CircuitTransmittingClientError(ILogger logger, string circuitId) => _circuitTransmittingClientError(logger, circuitId, null); - public static void CircuitTransmittedClientErrorSuccess(ILogger logger, string circuitId) => _circuitTransmittedClientErrorSuccess(logger, circuitId, null); - public static void OnRenderCompletedFailed(ILogger logger, long renderId, string circuitId, Exception e) => _onRenderCompletedFailed(logger, renderId, circuitId, e); + public static void LocationChange(ILogger logger, string uri, CircuitId circuitId) => _locationChange(logger, uri, circuitId, null); + public static void LocationChangeSucceeded(ILogger logger, string uri, CircuitId circuitId) => _locationChangeSucceeded(logger, uri, circuitId, null); + public static void LocationChangeFailed(ILogger logger, string uri, CircuitId circuitId, Exception exception) => _locationChangeFailed(logger, uri, circuitId, exception); + public static void LocationChangeFailedInCircuit(ILogger logger, string uri, CircuitId circuitId, Exception exception) => _locationChangeFailedInCircuit(logger, uri, circuitId, exception); + public static void UnhandledExceptionClientDisconnected(ILogger logger, CircuitId circuitId, Exception exception) => _unhandledExceptionClientDisconnected(logger, circuitId, exception); + public static void CircuitTransmittingClientError(ILogger logger, CircuitId circuitId) => _circuitTransmittingClientError(logger, circuitId, null); + public static void CircuitTransmittedClientErrorSuccess(ILogger logger, CircuitId circuitId) => _circuitTransmittedClientErrorSuccess(logger, circuitId, null); + public static void OnRenderCompletedFailed(ILogger logger, long renderId, CircuitId circuitId, Exception e) => _onRenderCompletedFailed(logger, renderId, circuitId, e); } } } diff --git a/src/Components/Server/src/Circuits/CircuitId.cs b/src/Components/Server/src/Circuits/CircuitId.cs new file mode 100644 index 0000000000..3d63d2fdb4 --- /dev/null +++ b/src/Components/Server/src/Circuits/CircuitId.cs @@ -0,0 +1,63 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Diagnostics.CodeAnalysis; +using System.Runtime.InteropServices; +using System.Security.Cryptography; + +namespace Microsoft.AspNetCore.Components.Server.Circuits +{ + // Consists of a secret (data protected payload) and a non-secret identifier + // for use in logs and user code. + // + // The contract of this is that the id is derived from the Secret. We use the secret + // for comparisons, but we use the id for display (and exposing to user code). As a result, + // we don't include the id in any comparisons done by this class. + // + // Intentionally not overriding ToString here so that this won't accidentally + // get logged. It's ok to log the secret at TRACE. + internal readonly struct CircuitId : IEquatable + { + public CircuitId(string secret, string id) + { + Secret = secret ?? throw new ArgumentNullException(nameof(secret)); + Id = id ?? throw new ArgumentNullException(nameof(id)); + } + + public string Id { get; } + + public string Secret { get; } + + public bool Equals([AllowNull] CircuitId other) + { + // We want to use a fixed time equality comparison for a *real* comparisons. + // The only use case for Secret being null is with a default struct value, + // which wouldn't be the result of untrusted input. + if (other.Secret == null) + { + return Secret == null; + } + + return + CryptographicOperations.FixedTimeEquals( + MemoryMarshal.AsBytes(Secret.AsSpan()), + MemoryMarshal.AsBytes(other.Secret.AsSpan())); + } + + public override bool Equals(object obj) + { + return obj is CircuitId other ? Equals(other) : false; + } + + public override int GetHashCode() + { + return HashCode.Combine(Secret); + } + + public override string ToString() + { + return Id; + } + } +} diff --git a/src/Components/Server/src/Circuits/CircuitIdFactory.cs b/src/Components/Server/src/Circuits/CircuitIdFactory.cs index 830ac8d51a..3d47161fea 100644 --- a/src/Components/Server/src/Circuits/CircuitIdFactory.cs +++ b/src/Components/Server/src/Circuits/CircuitIdFactory.cs @@ -2,10 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Data; using System.Security.Cryptography; using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.WebUtilities; -using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Components.Server.Circuits { @@ -13,7 +13,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // Generates strong cryptographic ids for circuits that are protected with authenticated encryption. internal class CircuitIdFactory { - private const string CircuitIdProtectorPurpose = "Microsoft.AspNetCore.Components.Server"; + private const string CircuitIdProtectorPurpose = "Microsoft.AspNetCore.Components.Server.CircuitIdFactory"; + + // We use 64 bytes, where the last 32 are the public version of the id. + // This way we can always recover the public id from the secret form. + private const int SecretLength = 64; + private const int IdLength = 32; private readonly RandomNumberGenerator _generator = RandomNumberGenerator.Create(); private readonly IDataProtector _protector; @@ -27,29 +32,58 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // we don't care about the underlying payload, other than its uniqueness and the fact that we // authenticate encrypt it using data protection. // For validation, the fact that we can unprotect the payload is guarantee enough. - public string CreateCircuitId() + public CircuitId CreateCircuitId() { - var buffer = new byte[32]; + var buffer = new byte[SecretLength]; _generator.GetBytes(buffer); - var payload = _protector.Protect(buffer); - return Base64UrlTextEncoder.Encode(payload); + var id = new byte[IdLength]; + Array.Copy( + sourceArray: buffer, + sourceIndex: SecretLength - IdLength, + destinationArray: id, + destinationIndex: 0, + length: IdLength); + + var secret = _protector.Protect(buffer); + return new CircuitId(Base64UrlTextEncoder.Encode(secret), Base64UrlTextEncoder.Encode(id)); } - public bool ValidateCircuitId(string circuitId) + public bool TryParseCircuitId(string text, out CircuitId circuitId) { + if (text is null) + { + circuitId = default; + return false; + } + try { - var protectedBytes = Base64UrlTextEncoder.Decode(circuitId); - _protector.Unprotect(protectedBytes); + var protectedBytes = Base64UrlTextEncoder.Decode(text); + var unprotectedBytes = _protector.Unprotect(protectedBytes); - // Its enough that we prove that we can unprotect the payload to validate the circuit id, - // as this demonstrates that it the id wasn't tampered with. + if (unprotectedBytes.Length != SecretLength) + { + // Wrong length + circuitId = default; + return false; + } + + var id = new byte[IdLength]; + Array.Copy( + sourceArray: unprotectedBytes, + sourceIndex: SecretLength - IdLength, + destinationArray: id, + destinationIndex: 0, + length: IdLength); + + circuitId = new CircuitId(text, Base64UrlTextEncoder.Encode(id)); return true; } catch (Exception) { // The payload format is not correct (either not base64urlencoded or not data protected) + circuitId = default; return false; } } diff --git a/src/Components/Server/src/Circuits/CircuitRegistry.cs b/src/Components/Server/src/Circuits/CircuitRegistry.cs index 3e3461a7dc..ab261a108b 100644 --- a/src/Components/Server/src/Circuits/CircuitRegistry.cs +++ b/src/Components/Server/src/Circuits/CircuitRegistry.cs @@ -47,12 +47,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public CircuitRegistry( IOptions options, ILogger logger, - CircuitIdFactory circuitIdFactory) + CircuitIdFactory CircuitHostFactory) { _options = options.Value; _logger = logger; - _circuitIdFactory = circuitIdFactory; - ConnectedCircuits = new ConcurrentDictionary(StringComparer.Ordinal); + _circuitIdFactory = CircuitHostFactory; + ConnectedCircuits = new ConcurrentDictionary(); DisconnectedCircuits = new MemoryCache(new MemoryCacheOptions { @@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits }; } - internal ConcurrentDictionary ConnectedCircuits { get; } + internal ConcurrentDictionary ConnectedCircuits { get; } internal MemoryCache DisconnectedCircuits { get; } @@ -155,7 +155,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits }; var entry = new DisconnectedCircuitEntry(circuitHost, cancellationTokenSource); - DisconnectedCircuits.Set(circuitHost.CircuitId, entry, entryOptions); + DisconnectedCircuits.Set(circuitHost.CircuitId.Secret, entry, entryOptions); } // ConnectAsync is called from the CircuitHub - but the error handling story is a little bit complicated. @@ -164,20 +164,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // // The solution is to handle exceptions here, and then return null to represent failure. // - // 1. If the circuit id is invalue return null - // 2. If the circuit is not found return null - // 3. If the circuit is found, but fails to connect, we need to dispose it here and return null - // 4. If everything goes well, return the circuit. - public virtual async Task ConnectAsync(string circuitId, IClientProxy clientProxy, string connectionId, CancellationToken cancellationToken) + // 1. If the circuit is not found return null + // 2. If the circuit is found, but fails to connect, we need to dispose it here and return null + // 3. If everything goes well, return the circuit. + public virtual async Task ConnectAsync(CircuitId circuitId, IClientProxy clientProxy, string connectionId, CancellationToken cancellationToken) { Log.CircuitConnectStarted(_logger, circuitId); - if (!_circuitIdFactory.ValidateCircuitId(circuitId)) - { - Log.InvalidCircuitId(_logger, circuitId); - return null; - } - CircuitHost circuitHost; bool previouslyConnected; @@ -193,8 +186,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits if (circuitHost == null) { - Log.FailedToReconnectToCircuit(_logger, circuitId); - // Failed to connect. Nothing to do here. + Log.FailedToFindCircuit(_logger, circuitId); + // Failed to find a matching circuit. Nothing to do here. return null; } @@ -220,12 +213,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits try { await circuitHandlerTask; - Log.ReconnectionSucceeded(_logger, circuitId); + Log.ReconnectionSucceeded(_logger, circuitHost.CircuitId); return circuitHost; } catch (Exception ex) { - Log.FailedToReconnectToCircuit(_logger, circuitId, ex); + Log.FailedToReconnectToCircuit(_logger, circuitHost.CircuitId, ex); await TerminateAsync(circuitId); // Return null on failure, because we need to clean up the circuit. @@ -233,11 +226,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits } } - protected virtual (CircuitHost circuitHost, bool previouslyConnected) ConnectCore(string circuitId, IClientProxy clientProxy, string connectionId) + protected virtual (CircuitHost circuitHost, bool previouslyConnected) ConnectCore(CircuitId circuitId, IClientProxy clientProxy, string connectionId) { if (ConnectedCircuits.TryGetValue(circuitId, out var connectedCircuitHost)) { - Log.ConnectingToActiveCircuit(_logger, circuitId, connectionId); + Log.ConnectingToActiveCircuit(_logger, connectedCircuitHost.CircuitId, connectionId); // The host is still active i.e. the server hasn't detected the client disconnect. // However the client reconnected establishing a new connection. @@ -245,15 +238,15 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits return (connectedCircuitHost, true); } - if (DisconnectedCircuits.TryGetValue(circuitId, out DisconnectedCircuitEntry disconnectedEntry)) + if (DisconnectedCircuits.TryGetValue(circuitId.Secret, out DisconnectedCircuitEntry disconnectedEntry)) { - Log.ConnectingToDisconnectedCircuit(_logger, circuitId, connectionId); + Log.ConnectingToDisconnectedCircuit(_logger, disconnectedEntry.CircuitHost.CircuitId, connectionId); // The host was in disconnected state. Transfer it to ConnectedCircuits so that it's no longer considered disconnected. // First discard the CancellationTokenSource so that the cache entry does not expire. DisposeTokenSource(disconnectedEntry); - DisconnectedCircuits.Remove(circuitId); + DisconnectedCircuits.Remove(circuitId.Secret); ConnectedCircuits.TryAdd(circuitId, disconnectedEntry.CircuitHost); disconnectedEntry.CircuitHost.Client.Transfer(clientProxy, connectionId); @@ -313,17 +306,18 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits } } - public ValueTask TerminateAsync(string circuitId) + // We don't expect this to throw. User code only runs inside DisposeAsync and that does its own error handling. + public ValueTask TerminateAsync(CircuitId circuitId) { CircuitHost circuitHost; DisconnectedCircuitEntry entry = default; lock (CircuitRegistryLock) { - if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId, out entry)) + if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId.Secret, out entry)) { circuitHost ??= entry.CircuitHost; - DisconnectedCircuits.Remove(circuitHost.CircuitId); - ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _); + DisconnectedCircuits.Remove(circuitId.Secret); + ConnectedCircuits.TryRemove(circuitId, out _); Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId); circuitHost.Client.SetDisconnected(); } @@ -372,36 +366,36 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits { private static readonly Action _exceptionDisposingCircuitHost; private static readonly Action _unhandledExceptionDisposingTokenSource; - private static readonly Action _circuitReconnectStarted; - private static readonly Action _invalidCircuitId; - private static readonly Action _connectingToActiveCircuit; - private static readonly Action _connectingToDisconnectedCircuit; - private static readonly Action _failedToReconnectToCircuit; - private static readonly Action _reconnectionSucceeded; - private static readonly Action _circuitDisconnectStarted; - private static readonly Action _circuitNotActive; - private static readonly Action _circuitConnectedToDifferentConnection; - private static readonly Action _circuitMarkedDisconnected; - private static readonly Action _circuitDisconnectedPermanently; - private static readonly Action _circuitEvicted; - private static readonly Action _circuitExceptionHandlerFailed; + private static readonly Action _circuitReconnectStarted; + private static readonly Action _failedToFindCircuit; + private static readonly Action _connectingToActiveCircuit; + private static readonly Action _connectingToDisconnectedCircuit; + private static readonly Action _failedToReconnectToCircuit; + private static readonly Action _reconnectionSucceeded; + private static readonly Action _circuitDisconnectStarted; + private static readonly Action _circuitNotActive; + private static readonly Action _circuitConnectedToDifferentConnection; + private static readonly Action _circuitMarkedDisconnected; + private static readonly Action _circuitDisconnectedPermanently; + private static readonly Action _circuitEvicted; + private static readonly Action _circuitExceptionHandlerFailed; private static class EventIds { public static readonly EventId ExceptionDisposingCircuit = new EventId(100, "ExceptionDisposingCircuit"); public static readonly EventId ExceptionDisposingTokenSource = new EventId(101, "ExceptionDisposingTokenSource"); public static readonly EventId AttemptingToReconnect = new EventId(102, "AttemptingToReconnect"); - public static readonly EventId InvalidCircuitId = new EventId(103, "InvalidCircuitId"); - public static readonly EventId ConnectingToActiveCircuit = new EventId(104, "ConnectingToActiveCircuit"); - public static readonly EventId ConnectingToDisconnectedCircuit = new EventId(105, "ConnectingToDisconnectedCircuit"); - public static readonly EventId FailedToReconnectToCircuit = new EventId(106, "FailedToReconnectToCircuit"); - public static readonly EventId CircuitDisconnectStarted = new EventId(107, "CircuitDisconnectStarted"); - public static readonly EventId CircuitNotActive = new EventId(108, "CircuitNotActive"); - public static readonly EventId CircuitConnectedToDifferentConnection = new EventId(109, "CircuitConnectedToDifferentConnection"); - public static readonly EventId CircuitMarkedDisconnected = new EventId(110, "CircuitMarkedDisconnected"); - public static readonly EventId CircuitEvicted = new EventId(111, "CircuitEvicted"); - public static readonly EventId CircuitDisconnectedPermanently = new EventId(112, "CircuitDisconnectedPermanently"); - public static readonly EventId CircuitExceptionHandlerFailed = new EventId(113, "CircuitExceptionHandlerFailed"); + public static readonly EventId FailedToFindCircuit = new EventId(104, "FailedToFindCircuit"); + public static readonly EventId ConnectingToActiveCircuit = new EventId(105, "ConnectingToActiveCircuit"); + public static readonly EventId ConnectingToDisconnectedCircuit = new EventId(106, "ConnectingToDisconnectedCircuit"); + public static readonly EventId FailedToReconnectToCircuit = new EventId(107, "FailedToReconnectToCircuit"); + public static readonly EventId CircuitDisconnectStarted = new EventId(108, "CircuitDisconnectStarted"); + public static readonly EventId CircuitNotActive = new EventId(109, "CircuitNotActive"); + public static readonly EventId CircuitConnectedToDifferentConnection = new EventId(110, "CircuitConnectedToDifferentConnection"); + public static readonly EventId CircuitMarkedDisconnected = new EventId(111, "CircuitMarkedDisconnected"); + public static readonly EventId CircuitEvicted = new EventId(112, "CircuitEvicted"); + public static readonly EventId CircuitDisconnectedPermanently = new EventId(113, "CircuitDisconnectedPermanently"); + public static readonly EventId CircuitExceptionHandlerFailed = new EventId(114, "CircuitExceptionHandlerFailed"); } static Log() @@ -416,67 +410,67 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits EventIds.ExceptionDisposingTokenSource, "Exception thrown when disposing token source: {Message}"); - _circuitReconnectStarted = LoggerMessage.Define( + _circuitReconnectStarted = LoggerMessage.Define( LogLevel.Debug, EventIds.AttemptingToReconnect, - "Attempting to reconnect to Circuit with id {CircuitId}."); + "Attempting to reconnect to Circuit with secret {CircuitHost}."); - _invalidCircuitId = LoggerMessage.Define( + _failedToFindCircuit = LoggerMessage.Define( LogLevel.Debug, - EventIds.InvalidCircuitId, - "Failed to validate circuit id {CircuitId}."); + EventIds.FailedToFindCircuit, + "Failed to find a matching circuit for circuit secret {CircuitHost}."); - _connectingToActiveCircuit = LoggerMessage.Define( + _connectingToActiveCircuit = LoggerMessage.Define( LogLevel.Debug, EventIds.ConnectingToActiveCircuit, "Transferring active circuit {CircuitId} to connection {ConnectionId}."); - _connectingToDisconnectedCircuit = LoggerMessage.Define( + _connectingToDisconnectedCircuit = LoggerMessage.Define( LogLevel.Debug, EventIds.ConnectingToDisconnectedCircuit, "Transfering disconnected circuit {CircuitId} to connection {ConnectionId}."); - _failedToReconnectToCircuit = LoggerMessage.Define( + _failedToReconnectToCircuit = LoggerMessage.Define( LogLevel.Debug, EventIds.FailedToReconnectToCircuit, "Failed to reconnect to a circuit with id {CircuitId}."); - _reconnectionSucceeded = LoggerMessage.Define( + _reconnectionSucceeded = LoggerMessage.Define( LogLevel.Debug, EventIds.FailedToReconnectToCircuit, "Reconnect to circuit with id {CircuitId} succeeded."); - _circuitDisconnectStarted = LoggerMessage.Define( + _circuitDisconnectStarted = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitDisconnectStarted, "Attempting to disconnect circuit with id {CircuitId} from connection {ConnectionId}."); - _circuitNotActive = LoggerMessage.Define( + _circuitNotActive = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitNotActive, "Failed to disconnect circuit with id {CircuitId}. The circuit is not active."); - _circuitConnectedToDifferentConnection = LoggerMessage.Define( + _circuitConnectedToDifferentConnection = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitConnectedToDifferentConnection, "Failed to disconnect circuit with id {CircuitId}. The circuit is connected to {ConnectionId}."); - _circuitMarkedDisconnected = LoggerMessage.Define( + _circuitMarkedDisconnected = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitMarkedDisconnected, "Circuit with id {CircuitId} is disconnected."); - _circuitDisconnectedPermanently = LoggerMessage.Define( + _circuitDisconnectedPermanently = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitDisconnectedPermanently, "Circuit with id {CircuitId} has been removed from the registry for permanent disconnection."); - _circuitEvicted = LoggerMessage.Define( + _circuitEvicted = LoggerMessage.Define( LogLevel.Debug, EventIds.CircuitEvicted, "Circuit with id {CircuitId} evicted due to {EvictionReason}."); - _circuitExceptionHandlerFailed = LoggerMessage.Define( + _circuitExceptionHandlerFailed = LoggerMessage.Define( LogLevel.Error, EventIds.CircuitExceptionHandlerFailed, "Exception handler for {CircuitId} failed."); @@ -488,43 +482,43 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public static void ExceptionDisposingTokenSource(ILogger logger, Exception exception) => _unhandledExceptionDisposingTokenSource(logger, exception.Message, exception); - public static void CircuitConnectStarted(ILogger logger, string circuitId) => + public static void CircuitConnectStarted(ILogger logger, CircuitId circuitId) => _circuitReconnectStarted(logger, circuitId, null); - public static void InvalidCircuitId(ILogger logger, string circuitId) => - _invalidCircuitId(logger, circuitId, null); + public static void FailedToFindCircuit(ILogger logger, CircuitId circuitId) => + _failedToFindCircuit(logger, circuitId, null); - public static void ConnectingToActiveCircuit(ILogger logger, string circuitId, string connectionId) => + public static void ConnectingToActiveCircuit(ILogger logger, CircuitId circuitId, string connectionId) => _connectingToActiveCircuit(logger, circuitId, connectionId, null); - public static void ConnectingToDisconnectedCircuit(ILogger logger, string circuitId, string connectionId) => + public static void ConnectingToDisconnectedCircuit(ILogger logger, CircuitId circuitId, string connectionId) => _connectingToDisconnectedCircuit(logger, circuitId, connectionId, null); - public static void FailedToReconnectToCircuit(ILogger logger, string circuitId, Exception exception = null) => + public static void FailedToReconnectToCircuit(ILogger logger, CircuitId circuitId, Exception exception = null) => _failedToReconnectToCircuit(logger, circuitId, exception); - public static void ReconnectionSucceeded(ILogger logger, string circuitId) => + public static void ReconnectionSucceeded(ILogger logger, CircuitId circuitId) => _reconnectionSucceeded(logger, circuitId, null); - public static void CircuitDisconnectStarted(ILogger logger, string circuitId, string connectionId) => + public static void CircuitDisconnectStarted(ILogger logger, CircuitId circuitId, string connectionId) => _circuitDisconnectStarted(logger, circuitId, connectionId, null); - public static void CircuitNotActive(ILogger logger, string circuitId) => + public static void CircuitNotActive(ILogger logger, CircuitId circuitId) => _circuitNotActive(logger, circuitId, null); - public static void CircuitConnectedToDifferentConnection(ILogger logger, string circuitId, string connectionId) => + public static void CircuitConnectedToDifferentConnection(ILogger logger, CircuitId circuitId, string connectionId) => _circuitConnectedToDifferentConnection(logger, circuitId, connectionId, null); - public static void CircuitMarkedDisconnected(ILogger logger, string circuitId) => + public static void CircuitMarkedDisconnected(ILogger logger, CircuitId circuitId) => _circuitMarkedDisconnected(logger, circuitId, null); - public static void CircuitDisconnectedPermanently(ILogger logger, string circuitId) => + public static void CircuitDisconnectedPermanently(ILogger logger, CircuitId circuitId) => _circuitDisconnectedPermanently(logger, circuitId, null); - public static void CircuitEvicted(ILogger logger, string circuitId, EvictionReason evictionReason) => + public static void CircuitEvicted(ILogger logger, CircuitId circuitId, EvictionReason evictionReason) => _circuitEvicted(logger, circuitId, evictionReason, null); - public static void CircuitExceptionHandlerFailed(ILogger logger, string circuitId, Exception exception) => + public static void CircuitExceptionHandlerFailed(ILogger logger, CircuitId circuitId, Exception exception) => _circuitExceptionHandlerFailed(logger, circuitId, exception); } } diff --git a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs index 59ac520ea1..83886bb3b8 100644 --- a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs +++ b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs @@ -111,11 +111,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits private static class Log { - private static readonly Action _createdConnectedCircuit = - LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CreatedConnectedCircuit"), "Created circuit {CircuitId} for connection {ConnectionId}"); + private static readonly Action _createdConnectedCircuit = + LoggerMessage.Define(LogLevel.Debug, new EventId(1, "CreatedConnectedCircuit"), "Created circuit {CircuitId} for connection {ConnectionId}"); - private static readonly Action _createdDisconnectedCircuit = - LoggerMessage.Define(LogLevel.Debug, new EventId(2, "CreatedDisconnectedCircuit"), "Created circuit {CircuitId} for disconnected client"); + private static readonly Action _createdDisconnectedCircuit = + LoggerMessage.Define(LogLevel.Debug, new EventId(2, "CreatedDisconnectedCircuit"), "Created circuit {CircuitId} for disconnected client"); internal static void CreatedCircuit(ILogger logger, CircuitHost circuitHost) { diff --git a/src/Components/Server/src/ComponentHub.cs b/src/Components/Server/src/ComponentHub.cs index 7ba95300de..a31982ea8f 100644 --- a/src/Components/Server/src/ComponentHub.cs +++ b/src/Components/Server/src/ComponentHub.cs @@ -38,17 +38,20 @@ namespace Microsoft.AspNetCore.Components.Server { private static readonly object CircuitKey = new object(); private readonly CircuitFactory _circuitFactory; + private readonly CircuitIdFactory _circuitIdFactory; private readonly CircuitRegistry _circuitRegistry; private readonly CircuitOptions _options; private readonly ILogger _logger; public ComponentHub( CircuitFactory circuitFactory, + CircuitIdFactory circuitIdFactory, CircuitRegistry circuitRegistry, ILogger logger, IOptions options) { _circuitFactory = circuitFactory; + _circuitIdFactory = circuitIdFactory; _circuitRegistry = circuitRegistry; _options = options.Value; _logger = logger; @@ -129,7 +132,12 @@ namespace Microsoft.AspNetCore.Components.Server // to run inside it until after InitializeAsync completes. _circuitRegistry.Register(circuitHost); SetCircuit(circuitHost); - return circuitHost.CircuitId; + + // Returning the secret here so the client can reconnect. + // + // Logging the secret and circuit ID here so we can associate them with just logs (if TRACE level is on). + Log.CreatedCircuit(_logger, circuitHost.CircuitId, circuitHost.CircuitId.Secret, Context.ConnectionId); + return circuitHost.CircuitId.Secret; } catch (Exception ex) { @@ -142,10 +150,22 @@ namespace Microsoft.AspNetCore.Components.Server } } - public async ValueTask ConnectCircuit(string circuitId) + public async ValueTask ConnectCircuit(string circuitIdSecret) { - // ConnectionAsync will not throw. - var circuitHost = await _circuitRegistry.ConnectAsync(circuitId, Clients.Caller, Context.ConnectionId, Context.ConnectionAborted); + // TryParseCircuitId will not throw. + if (!_circuitIdFactory.TryParseCircuitId(circuitIdSecret, out var circuitId)) + { + // Invalid id. + Log.InvalidCircuitId(_logger, circuitIdSecret); + return false; + } + + // ConnectAsync will not throw. + var circuitHost = await _circuitRegistry.ConnectAsync( + circuitId, + Clients.Caller, + Context.ConnectionId, + Context.ConnectionAborted); if (circuitHost != null) { SetCircuit(circuitHost); @@ -270,8 +290,8 @@ namespace Microsoft.AspNetCore.Components.Server private static readonly Action _unhandledExceptionInCircuit = LoggerMessage.Define(LogLevel.Warning, new EventId(3, "UnhandledExceptionInCircuit"), "Unhandled exception in circuit {CircuitId}"); - private static readonly Action _circuitAlreadyInitialized = - LoggerMessage.Define(LogLevel.Debug, new EventId(4, "CircuitAlreadyInitialized"), "The circuit host '{CircuitId}' has already been initialized"); + private static readonly Action _circuitAlreadyInitialized = + LoggerMessage.Define(LogLevel.Debug, new EventId(4, "CircuitAlreadyInitialized"), "The circuit host '{CircuitId}' has already been initialized"); private static readonly Action _circuitHostNotInitialized = LoggerMessage.Define(LogLevel.Debug, new EventId(5, "CircuitHostNotInitialized"), "Call to '{CallSite}' received before the circuit host initialization"); @@ -285,11 +305,17 @@ namespace Microsoft.AspNetCore.Components.Server private static readonly Action _circuitInitializationFailed = LoggerMessage.Define(LogLevel.Debug, new EventId(8, "CircuitInitializationFailed"), "Circuit initialization failed"); + private static readonly Action _createdCircuit = + LoggerMessage.Define(LogLevel.Debug, new EventId(8, "CreatedCircuit"), "Created circuit '{CircuitId}' with secret '{CircuitIdSecret}' for '{ConnectionId}'"); + + private static readonly Action _invalidCircuitId = + LoggerMessage.Define(LogLevel.Debug, new EventId(9, "InvalidCircuitId"), "ConnectAsync recieved an invalid circuit id '{CircuitIdSecret}'"); + public static void NoComponentsRegisteredInEndpoint(ILogger logger, string endpointDisplayName) => _noComponentsRegisteredInEndpoint(logger, endpointDisplayName, null); public static void ReceivedConfirmationForBatch(ILogger logger, long batchId) => _receivedConfirmationForBatch(logger, batchId, null); - public static void CircuitAlreadyInitialized(ILogger logger, string circuitId) => _circuitAlreadyInitialized(logger, circuitId, null); + public static void CircuitAlreadyInitialized(ILogger logger, CircuitId circuitId) => _circuitAlreadyInitialized(logger, circuitId, null); public static void CircuitHostNotInitialized(ILogger logger, [CallerMemberName] string callSite = "") => _circuitHostNotInitialized(logger, callSite, null); @@ -298,6 +324,28 @@ namespace Microsoft.AspNetCore.Components.Server public static void InvalidInputData(ILogger logger, [CallerMemberName] string callSite = "") => _invalidInputData(logger, callSite, null); public static void CircuitInitializationFailed(ILogger logger, Exception exception) => _circuitInitializationFailed(logger, exception); + + public static void CreatedCircuit(ILogger logger, CircuitId circuitId, string circuitSecret, string connectionId) + { + // Redact the secret unless tracing is on. + if (!logger.IsEnabled(LogLevel.Trace)) + { + circuitSecret = "(redacted)"; + } + + _createdCircuit(logger, circuitId, circuitSecret, connectionId, null); + } + + public static void InvalidCircuitId(ILogger logger, string circuitSecret) + { + // Redact the secret unless tracing is on. + if (!logger.IsEnabled(LogLevel.Trace)) + { + circuitSecret = "(redacted)"; + } + + _invalidCircuitId(logger, circuitSecret, null); + } } } } diff --git a/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs b/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs index 627c51b576..a63ece5b88 100644 --- a/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs +++ b/src/Components/Server/test/CircuitDisconnectMiddlewareTest.cs @@ -138,7 +138,7 @@ namespace Microsoft.AspNetCore.Components.Server { // Arrange var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory(); - var id = circuitIdFactory.CreateCircuitId(); + var circuitId = circuitIdFactory.CreateCircuitId(); var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), NullLogger.Instance, @@ -151,7 +151,7 @@ namespace Microsoft.AspNetCore.Components.Server (ctx) => Task.CompletedTask); using var memory = new MemoryStream(); - await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = id }).CopyToAsync(memory); + await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = circuitId.Secret, }).CopyToAsync(memory); memory.Seek(0, SeekOrigin.Begin); var context = new DefaultHttpContext(); @@ -171,8 +171,8 @@ namespace Microsoft.AspNetCore.Components.Server { // Arrange var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory(); - var id = circuitIdFactory.CreateCircuitId(); - var testCircuitHost = TestCircuitHost.Create(id); + var circuitId = circuitIdFactory.CreateCircuitId(); + var testCircuitHost = TestCircuitHost.Create(circuitId); var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), @@ -188,7 +188,7 @@ namespace Microsoft.AspNetCore.Components.Server (ctx) => Task.CompletedTask); using var memory = new MemoryStream(); - await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = id }).CopyToAsync(memory); + await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = circuitId.Secret, }).CopyToAsync(memory); memory.Seek(0, SeekOrigin.Begin); var context = new DefaultHttpContext(); @@ -208,8 +208,8 @@ namespace Microsoft.AspNetCore.Components.Server { // Arrange var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory(); - var id = circuitIdFactory.CreateCircuitId(); - var circuitHost = TestCircuitHost.Create(id); + var circuitId = circuitIdFactory.CreateCircuitId(); + var circuitHost = TestCircuitHost.Create(circuitId); var registry = new CircuitRegistry( Options.Create(new CircuitOptions()), @@ -226,7 +226,7 @@ namespace Microsoft.AspNetCore.Components.Server (ctx) => Task.CompletedTask); using var memory = new MemoryStream(); - await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = id }).CopyToAsync(memory); + await new FormUrlEncodedContent(new Dictionary { ["circuitId"] = circuitId.Secret }).CopyToAsync(memory); memory.Seek(0, SeekOrigin.Begin); var context = new DefaultHttpContext(); diff --git a/src/Components/Server/test/Circuits/CircuitHostTest.cs b/src/Components/Server/test/Circuits/CircuitHostTest.cs index dc4d5da01f..1d5ec1c772 100644 --- a/src/Components/Server/test/Circuits/CircuitHostTest.cs +++ b/src/Components/Server/test/Circuits/CircuitHostTest.cs @@ -26,9 +26,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var serviceScope = new Mock(); var remoteRenderer = GetRemoteRenderer(); var circuitHost = TestCircuitHost.Create( - Guid.NewGuid().ToString(), - serviceScope.Object, - remoteRenderer); + serviceScope: serviceScope.Object, + remoteRenderer: remoteRenderer); // Act await circuitHost.DisposeAsync(); @@ -52,9 +51,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var remoteRenderer = GetRemoteRenderer(); var circuitHost = TestCircuitHost.Create( - Guid.NewGuid().ToString(), - serviceScope.Object, - remoteRenderer); + serviceScope: serviceScope.Object, + remoteRenderer: remoteRenderer); // Act await circuitHost.DisposeAsync(); @@ -77,9 +75,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits .Throws(); var remoteRenderer = GetRemoteRenderer(); var circuitHost = TestCircuitHost.Create( - Guid.NewGuid().ToString(), - serviceScope.Object, - remoteRenderer, + serviceScope: serviceScope.Object, + remoteRenderer: remoteRenderer, handlers: new[] { handler.Object }); var throwOnDisposeComponent = new ThrowOnDisposeComponent(); @@ -101,9 +98,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var serviceScope = new Mock(); var remoteRenderer = GetRemoteRenderer(); var circuitHost = TestCircuitHost.Create( - Guid.NewGuid().ToString(), - serviceScope.Object, - remoteRenderer); + serviceScope: serviceScope.Object, + remoteRenderer: remoteRenderer); var component = new DispatcherComponent(circuitHost.Renderer.Dispatcher); circuitHost.Renderer.AssignRootComponentId(component); diff --git a/src/Components/Server/test/Circuits/CircuitIdFactoryTest.cs b/src/Components/Server/test/Circuits/CircuitIdFactoryTest.cs index 50b5f1701e..35bc454416 100644 --- a/src/Components/Server/test/Circuits/CircuitIdFactoryTest.cs +++ b/src/Components/Server/test/Circuits/CircuitIdFactoryTest.cs @@ -3,13 +3,12 @@ using System; using System.Linq; -using Microsoft.AspNetCore.DataProtection; using Microsoft.AspNetCore.WebUtilities; using Xunit; namespace Microsoft.AspNetCore.Components.Server.Circuits { - public class CircuitIdFactoryTest + public class circuitIdFactoryTest { [Fact] public void CreateCircuitId_Generates_NewRandomId() @@ -17,12 +16,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var factory = TestCircuitIdFactory.CreateTestFactory(); // Act - var id = factory.CreateCircuitId(); + var secret = factory.CreateCircuitId(); // Assert - Assert.NotNull(id); + Assert.NotNull(secret.Secret); // This is the magic data protection header that validates its protected - Assert.StartsWith("CfDJ", id); + Assert.StartsWith("CfDJ", secret.Secret); } [Fact] @@ -32,13 +31,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var factory = TestCircuitIdFactory.CreateTestFactory(); // Act - var ids = Enumerable.Range(0, 100).Select(i => factory.CreateCircuitId()).ToArray(); + var secrets = Enumerable.Range(0, 100).Select(i => factory.CreateCircuitId()).Select(s => s.Secret).ToArray(); // Assert - Assert.All(ids, id => Assert.NotNull(id)); - Assert.Equal(100, ids.Distinct(StringComparer.Ordinal).Count()); + Assert.All(secrets, secret => Assert.NotNull(secret)); + Assert.Equal(100, secrets.Distinct(StringComparer.Ordinal).Count()); } + // Note that this test also verifies that the ID can be reproduced from the secret. [Fact] public void CircuitIds_Roundtrip() { @@ -47,10 +47,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var id = factory.CreateCircuitId(); // Act - var isValid = factory.ValidateCircuitId(id); + var isValid = factory.TryParseCircuitId(id.Secret, out var parsed); // Assert Assert.True(isValid, "Failed to validate id"); + Assert.Equal(id, parsed); + Assert.Equal(id.Secret, parsed.Secret); + Assert.Equal(id.Id, parsed.Id); } [Fact] @@ -60,7 +63,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits var factory = TestCircuitIdFactory.CreateTestFactory(); // Act - var isValid = factory.ValidateCircuitId("$%@&=="); + var isValid = factory.TryParseCircuitId("$%@&==", out _); // Assert Assert.False(isValid, "Accepted an invalid payload"); @@ -71,16 +74,16 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits { // Arrange var factory = TestCircuitIdFactory.CreateTestFactory(); - var id = factory.CreateCircuitId(); - var protectedBytes = Base64UrlTextEncoder.Decode(id); + var secret = factory.CreateCircuitId(); + var protectedBytes = Base64UrlTextEncoder.Decode(secret.Secret); for (int i = protectedBytes.Length - 10; i < protectedBytes.Length; i++) { protectedBytes[i] = 0; } - var tamperedId = Base64UrlTextEncoder.Encode(protectedBytes); + var tampered = Base64UrlTextEncoder.Encode(protectedBytes); // Act - var isValid = factory.ValidateCircuitId(tamperedId); + var isValid = factory.TryParseCircuitId(tampered, out _); // Assert Assert.False(isValid, "Accepted a tampered payload"); diff --git a/src/Components/Server/test/Circuits/CircuitRegistryTest.cs b/src/Components/Server/test/Circuits/CircuitRegistryTest.cs index f57c27dbfa..245f29c114 100644 --- a/src/Components/Server/test/Circuits/CircuitRegistryTest.cs +++ b/src/Components/Server/test/Circuits/CircuitRegistryTest.cs @@ -163,14 +163,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // Arrange var registry = CreateRegistry(); var circuitHost = TestCircuitHost.Create(); - registry.DisconnectedCircuits.Set(circuitHost.CircuitId, circuitHost, new MemoryCacheEntryOptions { Size = 1 }); + registry.DisconnectedCircuits.Set(circuitHost.CircuitId.Secret, circuitHost, new MemoryCacheEntryOptions { Size = 1 }); // Act await registry.DisconnectAsync(circuitHost, circuitHost.Client.ConnectionId); // Assert Assert.Empty(registry.ConnectedCircuits.Values); - Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _)); + Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); } [Fact] @@ -267,7 +267,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits Assert.Same(client, circuitHost.Client.Client); Assert.Equal(newId, circuitHost.Client.ConnectionId); - Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _)); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); } [Fact] @@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits Assert.Same(client, circuitHost.Client.Client); Assert.Equal(newId, circuitHost.Client.ConnectionId); - Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _)); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _)); } [Fact] @@ -322,9 +322,9 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits // Act // Verify it's present in the dictionary. - Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _)); + Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _)); await Task.Run(() => tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10))); - Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _)); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _)); } [Fact] @@ -355,7 +355,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits Assert.True(registry.ConnectedCircuits.TryGetValue(circuitHost.CircuitId, out var cacheValue)); Assert.Same(circuitHost, cacheValue); // Nothing should be disconnected. - Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _)); + Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _)); } private class TestCircuitRegistry : CircuitRegistry @@ -370,7 +370,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits public Action OnAfterEntryEvicted { get; set; } - protected override (CircuitHost, bool) ConnectCore(string circuitId, IClientProxy clientProxy, string connectionId) + protected override (CircuitHost, bool) ConnectCore(CircuitId circuitId, IClientProxy clientProxy, string connectionId) { if (BeforeConnect != null) { diff --git a/src/Components/Server/test/Circuits/TestCircuitHost.cs b/src/Components/Server/test/Circuits/TestCircuitHost.cs index e0a2e0bd54..b6c52eaf9c 100644 --- a/src/Components/Server/test/Circuits/TestCircuitHost.cs +++ b/src/Components/Server/test/Circuits/TestCircuitHost.cs @@ -15,13 +15,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits { internal class TestCircuitHost : CircuitHost { - private TestCircuitHost(string circuitId, IServiceScope scope, CircuitOptions options, CircuitClientProxy client, RemoteRenderer renderer, IReadOnlyList descriptors, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger) + private TestCircuitHost(CircuitId circuitId, IServiceScope scope, CircuitOptions options, CircuitClientProxy client, RemoteRenderer renderer, IReadOnlyList descriptors, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger) : base(circuitId, scope, options, client, renderer, descriptors, jsRuntime, circuitHandlers, logger) { } public static CircuitHost Create( - string circuitId = null, + CircuitId? circuitId = null, IServiceScope serviceScope = null, RemoteRenderer remoteRenderer = null, CircuitHandler[] handlers = null, @@ -44,7 +44,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits handlers = handlers ?? Array.Empty(); return new TestCircuitHost( - circuitId ?? Guid.NewGuid().ToString(), + circuitId is null ? new CircuitId(Guid.NewGuid().ToString(), Guid.NewGuid().ToString()) : circuitId.Value, serviceScope, new CircuitOptions(), clientProxy, diff --git a/src/Components/test/E2ETest/ServerExecutionTests/CircuitGracefulTerminationTests.cs b/src/Components/test/E2ETest/ServerExecutionTests/CircuitGracefulTerminationTests.cs index 188fae57d0..d392289944 100644 --- a/src/Components/test/E2ETest/ServerExecutionTests/CircuitGracefulTerminationTests.cs +++ b/src/Components/test/E2ETest/ServerExecutionTests/CircuitGracefulTerminationTests.cs @@ -68,9 +68,6 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests { // Arrange & Act Browser.Close(); - // Set to null so that other tests in this class can create a new browser if necessary so - // that tests don't fail when running together. - await Task.WhenAny(Task.Delay(10000), GracefulDisconnectCompletionSource.Task); // Assert diff --git a/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs b/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs index 4721ca053f..536717ec41 100644 --- a/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs +++ b/src/Components/test/E2ETest/ServerExecutionTests/ComponentHubReliabilityTest.cs @@ -63,7 +63,7 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests private void LogMessages(WriteContext context) { - var log = new LogMessage(context.LogLevel, context.Message, context.Exception); + var log = new LogMessage(context.LogLevel, context.EventId, context.Message, context.Exception); Logs.Enqueue(log); Output.WriteLine(log.ToString()); } @@ -309,10 +309,11 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests var actualError = Assert.Single(Errors); Assert.Equal(expectedError, actualError); Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information); - Assert.Contains(Logs, l => (l.LogLevel, l.Message, l.Exception?.Message) == - (LogLevel.Debug, - $"Failed to complete render batch '1846' in circuit host '{Client.CircuitId}'.", - "Received an acknowledgement for batch with id '1846' when the last batch produced was '4'.")); + + var entry = Assert.Single(Logs, l => l.EventId.Name == "OnRenderCompletedFailed"); + Assert.Equal(LogLevel.Debug, entry.LogLevel); + Assert.Matches("Failed to complete render batch '1846' in circuit host '.*'\\.", entry.Message); + Assert.Equal("Received an acknowledgement for batch with id '1846' when the last batch produced was '4'.", entry.Exception.Message); } [Fact] @@ -384,10 +385,10 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests var actualError = Assert.Single(Errors); Assert.Equal(expectedError, actualError); Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information); - Assert.Contains(Logs, l => - { - return (l.LogLevel, l.Message) == (LogLevel.Debug, $"Location change to 'http://example.com' in circuit '{Client.CircuitId}' failed."); - }); + + var entry = Assert.Single(Logs, l => l.EventId.Name == "LocationChangeFailed"); + Assert.Equal(LogLevel.Debug, entry.LogLevel); + Assert.Matches("Location change to 'http://example.com' in circuit '.*' failed\\.", entry.Message); } [Fact] @@ -414,10 +415,10 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests // Assert var actualError = Assert.Single(Errors); Assert.Equal(expectedError, actualError); - Assert.Contains(Logs, l => - { - return (l.LogLevel, l.Message) == (LogLevel.Error, $"Location change to '{new Uri(_serverFixture.RootUri,"/test")}' in circuit '{Client.CircuitId}' failed."); - }); + + var entry = Assert.Single(Logs, l => l.EventId.Name == "LocationChangeFailed"); + Assert.Equal(LogLevel.Error, entry.LogLevel); + Assert.Matches($"Location change to '{new Uri(_serverFixture.RootUri, "/test")}' in circuit '.*' failed\\.", entry.Message); } [Theory] @@ -501,20 +502,22 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests [DebuggerDisplay("{LogLevel.ToString(),nq} - {Message ?? \"null\",nq} - {Exception?.Message,nq}")] private class LogMessage { - public LogMessage(LogLevel logLevel, string message, Exception exception) + public LogMessage(LogLevel logLevel, EventId eventId, string message, Exception exception) { LogLevel = logLevel; + EventId = eventId; Message = message; Exception = exception; } public LogLevel LogLevel { get; set; } + public EventId EventId { get; set; } public string Message { get; set; } public Exception Exception { get; set; } public override string ToString() { - return $"{LogLevel}: {Message}{(Exception != null ? Environment.NewLine : "")}{Exception}"; + return $"{LogLevel}: {EventId} {Message}{(Exception != null ? Environment.NewLine : "")}{Exception}"; } } diff --git a/src/Components/test/testassets/TestServer/appsettings.json b/src/Components/test/testassets/TestServer/appsettings.json index d49b185e5d..c21fbf38e5 100644 --- a/src/Components/test/testassets/TestServer/appsettings.json +++ b/src/Components/test/testassets/TestServer/appsettings.json @@ -2,7 +2,7 @@ "Logging": { "IncludeScopes": false, "LogLevel": { - "Microsoft.AspNetCore.Components": "Debug" + "Microsoft.AspNetCore.Components": "Trace" }, "Debug": { "LogLevel": {