Create CircuitSecret

Fixes: #13012

This change introduces a circuit id 'secret' as the concept that's used when
doing handshaking between the client and the server, and makes CircuitId
(visible to user-code) a separate concept (not a secret, can't be used
to open a connection).

The scope of this grew once I realized that we probably shouldn't be
logging Circuit Secret in so many places, we should be logging CircuitId
as the piece of data we use for correlation, and try to keep the secret
out of logs except where really necessary (and with trace level).

I ended up creating a new type to represent the combination of the
circuit id and secret and prevent
accidental misuse, and then chased down all of the build errors.

As an extra detail, the circuit id is part of the data-protected payload
that's used as the secret. This way we can always get the id back from
the secret without any external storage.
This commit is contained in:
Ryan Nowak 2019-08-14 18:32:55 -07:00
parent 7f054152db
commit 3b51b55176
16 changed files with 403 additions and 254 deletions

View File

@ -39,65 +39,71 @@ namespace Microsoft.AspNetCore.Components.Server
return;
}
var (hasCircuitId, circuitId) = await TryGetCircuitIdAsync(context);
if (!hasCircuitId)
var circuitId = await GetCircuitIdAsync(context);
if (circuitId is null)
{
context.Response.StatusCode = StatusCodes.Status400BadRequest;
return;
}
await TerminateCircuitGracefully(circuitId);
await TerminateCircuitGracefully(circuitId.Value);
context.Response.StatusCode = StatusCodes.Status200OK;
}
private async Task<(bool, string)> TryGetCircuitIdAsync(HttpContext context)
private async Task<CircuitId?> GetCircuitIdAsync(HttpContext context)
{
try
{
if (!context.Request.HasFormContentType)
{
return (false, null);
return default;
}
var form = await context.Request.ReadFormAsync();
if (!form.TryGetValue(CircuitIdKey, out var circuitId) || !CircuitIdFactory.ValidateCircuitId(circuitId))
if (!form.TryGetValue(CircuitIdKey, out var text))
{
return (false, null);
return default;
}
return (true, circuitId);
if (!CircuitIdFactory.TryParseCircuitId(text, out var circuitId))
{
Log.InvalidCircuitId(Logger, text);
return default;
}
return circuitId;
}
catch
{
return (false, null);
return default;
}
}
private async Task TerminateCircuitGracefully(string circuitId)
private async Task TerminateCircuitGracefully(CircuitId circuitId)
{
try
{
await Registry.TerminateAsync(circuitId);
Log.CircuitTerminatedGracefully(Logger, circuitId);
}
catch (Exception e)
{
Log.UnhandledExceptionInCircuit(Logger, circuitId, e);
}
// We don't expect TerminateAsync to throw.
Log.CircuitTerminatingGracefully(Logger, circuitId);
await Registry.TerminateAsync(circuitId);
Log.CircuitTerminatedGracefully(Logger, circuitId);
}
private class Log
{
private static readonly Action<ILogger, string, Exception> _circuitTerminatedGracefully =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(1, "CircuitTerminatedGracefully"), "Circuit '{CircuitId}' terminated gracefully");
private static readonly Action<ILogger, CircuitId, Exception> _circuitTerminatingGracefully =
LoggerMessage.Define<CircuitId>(LogLevel.Debug, new EventId(1, "CircuitTerminatingGracefully"), "Circuit with id '{CircuitId}' terminating gracefully.");
private static readonly Action<ILogger, string, Exception> _unhandledExceptionInCircuit =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(2, "UnhandledExceptionInCircuit"), "Unhandled exception in circuit {CircuitId} while terminating gracefully.");
private static readonly Action<ILogger, CircuitId, Exception> _circuitTerminatedGracefully =
LoggerMessage.Define<CircuitId>(LogLevel.Debug, new EventId(2, "CircuitTerminatedGracefully"), "Circuit with id '{CircuitId}' terminated gracefully.");
public static void CircuitTerminatedGracefully(ILogger logger, string circuitId) => _circuitTerminatedGracefully(logger, circuitId, null);
private static readonly Action<ILogger, string, Exception> _invalidCircuitId =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(3, "InvalidCircuitId"), "CircuitDisconnectMiddleware recieved an invalid circuit id '{CircuitIdSecret}'.");
public static void UnhandledExceptionInCircuit(ILogger logger, string circuitId, Exception exception) => _unhandledExceptionInCircuit(logger, circuitId, exception);
public static void CircuitTerminatingGracefully(ILogger logger, CircuitId circuitId) => _circuitTerminatingGracefully(logger, circuitId, null);
public static void CircuitTerminatedGracefully(ILogger logger, CircuitId circuitId) => _circuitTerminatedGracefully(logger, circuitId, null);
public static void InvalidCircuitId(ILogger logger, string circuitSecret) => _invalidCircuitId(logger, circuitSecret, null);
}
}
}

View File

@ -18,6 +18,6 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
/// <summary>
/// Gets the identifier for the <see cref="Circuit"/>.
/// </summary>
public string Id => _circuitHost.CircuitId;
public string Id => _circuitHost.CircuitId.Id;
}
}

View File

@ -38,7 +38,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public event UnhandledExceptionEventHandler UnhandledException;
public CircuitHost(
string circuitId,
CircuitId circuitId,
IServiceScope scope,
CircuitOptions options,
CircuitClientProxy client,
@ -48,7 +48,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
CircuitHandler[] circuitHandlers,
ILogger logger)
{
CircuitId = circuitId ?? throw new ArgumentNullException(nameof(circuitId));
CircuitId = circuitId;
if (CircuitId.Secret is null)
{
// Prevent the use of a 'default' secret.
throw new ArgumentException(nameof(circuitId));
}
_scope = scope ?? throw new ArgumentNullException(nameof(scope));
_options = options ?? throw new ArgumentNullException(nameof(options));
@ -70,7 +75,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public CircuitHandle Handle { get; }
public string CircuitId { get; }
public CircuitId CircuitId { get; }
public Circuit Circuit { get; }
@ -194,7 +199,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
// exceptions.
private async Task OnCircuitOpenedAsync(CancellationToken cancellationToken)
{
Log.CircuitOpened(_logger, Circuit.Id);
Log.CircuitOpened(_logger, CircuitId);
Renderer.Dispatcher.AssertAccess();
@ -223,7 +228,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task OnConnectionUpAsync(CancellationToken cancellationToken)
{
Log.ConnectionUp(_logger, Circuit.Id, Client.ConnectionId);
Log.ConnectionUp(_logger, CircuitId, Client.ConnectionId);
Renderer.Dispatcher.AssertAccess();
@ -252,7 +257,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public async Task OnConnectionDownAsync(CancellationToken cancellationToken)
{
Log.ConnectionDown(_logger, Circuit.Id, Client.ConnectionId);
Log.ConnectionDown(_logger, CircuitId, Client.ConnectionId);
Renderer.Dispatcher.AssertAccess();
@ -281,7 +286,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private async Task OnCircuitDownAsync(CancellationToken cancellationToken)
{
Log.CircuitClosed(_logger, Circuit.Id);
Log.CircuitClosed(_logger, CircuitId);
List<Exception> exceptions = null;
@ -585,19 +590,19 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private static readonly Action<ILogger, Exception> _intializationStarted;
private static readonly Action<ILogger, Exception> _intializationSucceded;
private static readonly Action<ILogger, Exception> _intializationFailed;
private static readonly Action<ILogger, string, Exception> _disposeStarted;
private static readonly Action<ILogger, string, Exception> _disposeSucceded;
private static readonly Action<ILogger, string, Exception> _disposeFailed;
private static readonly Action<ILogger, string, Exception> _onCircuitOpened;
private static readonly Action<ILogger, string, string, Exception> _onConnectionUp;
private static readonly Action<ILogger, string, string, Exception> _onConnectionDown;
private static readonly Action<ILogger, string, Exception> _onCircuitClosed;
private static readonly Action<ILogger, CircuitId, Exception> _disposeStarted;
private static readonly Action<ILogger, CircuitId, Exception> _disposeSucceded;
private static readonly Action<ILogger, CircuitId, Exception> _disposeFailed;
private static readonly Action<ILogger, CircuitId, Exception> _onCircuitOpened;
private static readonly Action<ILogger, CircuitId, string, Exception> _onConnectionUp;
private static readonly Action<ILogger, CircuitId, string, Exception> _onConnectionDown;
private static readonly Action<ILogger, CircuitId, Exception> _onCircuitClosed;
private static readonly Action<ILogger, Type, string, string, Exception> _circuitHandlerFailed;
private static readonly Action<ILogger, string, Exception> _circuitUnhandledException;
private static readonly Action<ILogger, string, Exception> _circuitTransmittingClientError;
private static readonly Action<ILogger, string, Exception> _circuitTransmittedClientErrorSuccess;
private static readonly Action<ILogger, string, Exception> _circuitTransmitErrorFailed;
private static readonly Action<ILogger, string, Exception> _unhandledExceptionClientDisconnected;
private static readonly Action<ILogger, CircuitId, Exception> _circuitUnhandledException;
private static readonly Action<ILogger, CircuitId, Exception> _circuitTransmittingClientError;
private static readonly Action<ILogger, CircuitId, Exception> _circuitTransmittedClientErrorSuccess;
private static readonly Action<ILogger, CircuitId, Exception> _circuitTransmitErrorFailed;
private static readonly Action<ILogger, CircuitId, Exception> _unhandledExceptionClientDisconnected;
private static readonly Action<ILogger, string, string, string, Exception> _beginInvokeDotNetStatic;
private static readonly Action<ILogger, string, long, string, Exception> _beginInvokeDotNetInstance;
@ -608,11 +613,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private static readonly Action<ILogger, long, Exception> _endInvokeJSSucceeded;
private static readonly Action<ILogger, Exception> _dispatchEventFailedToParseEventData;
private static readonly Action<ILogger, string, Exception> _dispatchEventFailedToDispatchEvent;
private static readonly Action<ILogger, string, string, Exception> _locationChange;
private static readonly Action<ILogger, string, string, Exception> _locationChangeSucceeded;
private static readonly Action<ILogger, string, string, Exception> _locationChangeFailed;
private static readonly Action<ILogger, string, string, Exception> _locationChangeFailedInCircuit;
private static readonly Action<ILogger, long, string, Exception> _onRenderCompletedFailed;
private static readonly Action<ILogger, string, CircuitId, Exception> _locationChange;
private static readonly Action<ILogger, string, CircuitId, Exception> _locationChangeSucceeded;
private static readonly Action<ILogger, string, CircuitId, Exception> _locationChangeFailed;
private static readonly Action<ILogger, string, CircuitId, Exception> _locationChangeFailedInCircuit;
private static readonly Action<ILogger, long, CircuitId, Exception> _onRenderCompletedFailed;
private static class EventIds
{
@ -647,7 +652,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public static readonly EventId LocationChangeSucceded = new EventId(209, "LocationChangeSucceeded");
public static readonly EventId LocationChangeFailed = new EventId(210, "LocationChangeFailed");
public static readonly EventId LocationChangeFailedInCircuit = new EventId(211, "LocationChangeFailedInCircuit");
public static readonly EventId OnRenderCompletedFailed = new EventId(212, " OnRenderCompletedFailed");
public static readonly EventId OnRenderCompletedFailed = new EventId(212, "OnRenderCompletedFailed");
}
static Log()
@ -667,37 +672,37 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
EventIds.InitializationFailed,
"Circuit initialization failed.");
_disposeStarted = LoggerMessage.Define<string>(
_disposeStarted = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.DisposeStarted,
"Disposing circuit '{CircuitId}' started.");
_disposeSucceded = LoggerMessage.Define<string>(
_disposeSucceded = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.DisposeSucceeded,
"Disposing circuit '{CircuitId}' succeded.");
_disposeFailed = LoggerMessage.Define<string>(
_disposeFailed = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.DisposeFailed,
"Disposing circuit '{CircuitId}' failed.");
_onCircuitOpened = LoggerMessage.Define<string>(
_onCircuitOpened = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.OnCircuitOpened,
"Opening circuit with id '{CircuitId}'.");
_onConnectionUp = LoggerMessage.Define<string, string>(
_onConnectionUp = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.OnConnectionUp,
"Circuit id '{CircuitId}' connected using connection '{ConnectionId}'.");
_onConnectionDown = LoggerMessage.Define<string, string>(
_onConnectionDown = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.OnConnectionDown,
"Circuit id '{CircuitId}' disconnected from connection '{ConnectionId}'.");
_onCircuitClosed = LoggerMessage.Define<string>(
_onCircuitClosed = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.OnCircuitClosed,
"Closing circuit with id '{CircuitId}'.");
@ -707,27 +712,27 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
EventIds.CircuitHandlerFailed,
"Unhandled error invoking circuit handler type {handlerType}.{handlerMethod}: {Message}");
_circuitUnhandledException = LoggerMessage.Define<string>(
_circuitUnhandledException = LoggerMessage.Define<CircuitId>(
LogLevel.Error,
EventIds.CircuitUnhandledException,
"Unhandled exception in circuit '{CircuitId}'.");
_circuitTransmittingClientError = LoggerMessage.Define<string>(
_circuitTransmittingClientError = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitTransmittingClientError,
"About to notify client of an error in circuit '{CircuitId}'.");
_circuitTransmittedClientErrorSuccess = LoggerMessage.Define<string>(
_circuitTransmittedClientErrorSuccess = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitTransmittedClientErrorSuccess,
"Successfully transmitted error to client in circuit '{CircuitId}'.");
_circuitTransmitErrorFailed = LoggerMessage.Define<string>(
_circuitTransmitErrorFailed = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitTransmitErrorFailed,
"Failed to transmit exception to client in circuit '{CircuitId}'.");
_unhandledExceptionClientDisconnected = LoggerMessage.Define<string>(
_unhandledExceptionClientDisconnected = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.UnhandledExceptionClientDisconnected,
"An exception ocurred on the circuit host '{CircuitId}' while the client is disconnected.");
@ -777,27 +782,27 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
EventIds.DispatchEventFailedToDispatchEvent,
"There was an error dispatching the event '{EventHandlerId}' to the application.");
_locationChange = LoggerMessage.Define<string, string>(
_locationChange = LoggerMessage.Define<string, CircuitId>(
LogLevel.Debug,
EventIds.LocationChange,
"Location changing to {URI} in circuit '{CircuitId}'.");
_locationChangeSucceeded = LoggerMessage.Define<string, string>(
_locationChangeSucceeded = LoggerMessage.Define<string, CircuitId>(
LogLevel.Debug,
EventIds.LocationChangeSucceded,
"Location change to '{URI}' in circuit '{CircuitId}' succeded.");
_locationChangeFailed = LoggerMessage.Define<string, string>(
_locationChangeFailed = LoggerMessage.Define<string, CircuitId>(
LogLevel.Debug,
EventIds.LocationChangeFailed,
"Location change to '{URI}' in circuit '{CircuitId}' failed.");
_locationChangeFailedInCircuit = LoggerMessage.Define<string, string>(
_locationChangeFailedInCircuit = LoggerMessage.Define<string, CircuitId>(
LogLevel.Error,
EventIds.LocationChangeFailed,
"Location change to '{URI}' in circuit '{CircuitId}' failed.");
_onRenderCompletedFailed = LoggerMessage.Define<long, string>(
_onRenderCompletedFailed = LoggerMessage.Define<long, CircuitId>(
LogLevel.Debug,
EventIds.OnRenderCompletedFailed,
"Failed to complete render batch '{RenderId}' in circuit host '{CircuitId}'.");
@ -806,13 +811,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public static void InitializationStarted(ILogger logger) => _intializationStarted(logger, null);
public static void InitializationSucceeded(ILogger logger) => _intializationSucceded(logger, null);
public static void InitializationFailed(ILogger logger, Exception exception) => _intializationFailed(logger, exception);
public static void DisposeStarted(ILogger logger, string circuitId) => _disposeStarted(logger, circuitId, null);
public static void DisposeSucceeded(ILogger logger, string circuitId) => _disposeSucceded(logger, circuitId, null);
public static void DisposeFailed(ILogger logger, string circuitId, Exception exception) => _disposeFailed(logger, circuitId, exception);
public static void CircuitOpened(ILogger logger, string circuitId) => _onCircuitOpened(logger, circuitId, null);
public static void ConnectionUp(ILogger logger, string circuitId, string connectionId) => _onConnectionUp(logger, circuitId, connectionId, null);
public static void ConnectionDown(ILogger logger, string circuitId, string connectionId) => _onConnectionDown(logger, circuitId, connectionId, null);
public static void CircuitClosed(ILogger logger, string circuitId) => _onCircuitClosed(logger, circuitId, null);
public static void DisposeStarted(ILogger logger, CircuitId circuitId) => _disposeStarted(logger, circuitId, null);
public static void DisposeSucceeded(ILogger logger, CircuitId circuitId) => _disposeSucceded(logger, circuitId, null);
public static void DisposeFailed(ILogger logger, CircuitId circuitId, Exception exception) => _disposeFailed(logger, circuitId, exception);
public static void CircuitOpened(ILogger logger, CircuitId circuitId) => _onCircuitOpened(logger, circuitId, null);
public static void ConnectionUp(ILogger logger, CircuitId circuitId, string connectionId) => _onConnectionUp(logger, circuitId, connectionId, null);
public static void ConnectionDown(ILogger logger, CircuitId circuitId, string connectionId) => _onConnectionDown(logger, circuitId, connectionId, null);
public static void CircuitClosed(ILogger logger, CircuitId circuitId) => _onCircuitClosed(logger, circuitId, null);
public static void CircuitHandlerFailed(ILogger logger, CircuitHandler handler, string handlerMethod, Exception exception)
{
@ -824,8 +829,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
exception);
}
public static void CircuitUnhandledException(ILogger logger, string circuitId, Exception exception) => _circuitUnhandledException(logger, circuitId, exception);
public static void CircuitTransmitErrorFailed(ILogger logger, string circuitId, Exception exception) => _circuitTransmitErrorFailed(logger, circuitId, exception);
public static void CircuitUnhandledException(ILogger logger, CircuitId circuitId, Exception exception) => _circuitUnhandledException(logger, circuitId, exception);
public static void CircuitTransmitErrorFailed(ILogger logger, CircuitId circuitId, Exception exception) => _circuitTransmitErrorFailed(logger, circuitId, exception);
public static void EndInvokeDispatchException(ILogger logger, Exception ex) => _endInvokeDispatchException(logger, ex);
public static void EndInvokeJSFailed(ILogger logger, long asyncHandle, string arguments) => _endInvokeJSFailed(logger, asyncHandle, arguments, null);
public static void EndInvokeJSSucceeded(ILogger logger, long asyncCall) => _endInvokeJSSucceeded(logger, asyncCall, null);
@ -856,14 +861,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
public static void LocationChange(ILogger logger, string uri, string circuitId) => _locationChange(logger, uri, circuitId, null);
public static void LocationChangeSucceeded(ILogger logger, string uri, string circuitId) => _locationChangeSucceeded(logger, uri, circuitId, null);
public static void LocationChangeFailed(ILogger logger, string uri, string circuitId, Exception exception) => _locationChangeFailed(logger, uri, circuitId, exception);
public static void LocationChangeFailedInCircuit(ILogger logger, string uri, string circuitId, Exception exception) => _locationChangeFailedInCircuit(logger, uri, circuitId, exception);
public static void UnhandledExceptionClientDisconnected(ILogger logger, string circuitId, Exception exception) => _unhandledExceptionClientDisconnected(logger, circuitId, exception);
public static void CircuitTransmittingClientError(ILogger logger, string circuitId) => _circuitTransmittingClientError(logger, circuitId, null);
public static void CircuitTransmittedClientErrorSuccess(ILogger logger, string circuitId) => _circuitTransmittedClientErrorSuccess(logger, circuitId, null);
public static void OnRenderCompletedFailed(ILogger logger, long renderId, string circuitId, Exception e) => _onRenderCompletedFailed(logger, renderId, circuitId, e);
public static void LocationChange(ILogger logger, string uri, CircuitId circuitId) => _locationChange(logger, uri, circuitId, null);
public static void LocationChangeSucceeded(ILogger logger, string uri, CircuitId circuitId) => _locationChangeSucceeded(logger, uri, circuitId, null);
public static void LocationChangeFailed(ILogger logger, string uri, CircuitId circuitId, Exception exception) => _locationChangeFailed(logger, uri, circuitId, exception);
public static void LocationChangeFailedInCircuit(ILogger logger, string uri, CircuitId circuitId, Exception exception) => _locationChangeFailedInCircuit(logger, uri, circuitId, exception);
public static void UnhandledExceptionClientDisconnected(ILogger logger, CircuitId circuitId, Exception exception) => _unhandledExceptionClientDisconnected(logger, circuitId, exception);
public static void CircuitTransmittingClientError(ILogger logger, CircuitId circuitId) => _circuitTransmittingClientError(logger, circuitId, null);
public static void CircuitTransmittedClientErrorSuccess(ILogger logger, CircuitId circuitId) => _circuitTransmittedClientErrorSuccess(logger, circuitId, null);
public static void OnRenderCompletedFailed(ILogger logger, long renderId, CircuitId circuitId, Exception e) => _onRenderCompletedFailed(logger, renderId, circuitId, e);
}
}
}

View File

@ -0,0 +1,63 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
// Consists of a secret (data protected payload) and a non-secret identifier
// for use in logs and user code.
//
// The contract of this is that the id is derived from the Secret. We use the secret
// for comparisons, but we use the id for display (and exposing to user code). As a result,
// we don't include the id in any comparisons done by this class.
//
// Intentionally not overriding ToString here so that this won't accidentally
// get logged. It's ok to log the secret at TRACE.
internal readonly struct CircuitId : IEquatable<CircuitId>
{
public CircuitId(string secret, string id)
{
Secret = secret ?? throw new ArgumentNullException(nameof(secret));
Id = id ?? throw new ArgumentNullException(nameof(id));
}
public string Id { get; }
public string Secret { get; }
public bool Equals([AllowNull] CircuitId other)
{
// We want to use a fixed time equality comparison for a *real* comparisons.
// The only use case for Secret being null is with a default struct value,
// which wouldn't be the result of untrusted input.
if (other.Secret == null)
{
return Secret == null;
}
return
CryptographicOperations.FixedTimeEquals(
MemoryMarshal.AsBytes(Secret.AsSpan()),
MemoryMarshal.AsBytes(other.Secret.AsSpan()));
}
public override bool Equals(object obj)
{
return obj is CircuitId other ? Equals(other) : false;
}
public override int GetHashCode()
{
return HashCode.Combine(Secret);
}
public override string ToString()
{
return Id;
}
}
}

View File

@ -2,10 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Data;
using System.Security.Cryptography;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.AspNetCore.WebUtilities;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
@ -13,7 +13,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
// Generates strong cryptographic ids for circuits that are protected with authenticated encryption.
internal class CircuitIdFactory
{
private const string CircuitIdProtectorPurpose = "Microsoft.AspNetCore.Components.Server";
private const string CircuitIdProtectorPurpose = "Microsoft.AspNetCore.Components.Server.CircuitIdFactory";
// We use 64 bytes, where the last 32 are the public version of the id.
// This way we can always recover the public id from the secret form.
private const int SecretLength = 64;
private const int IdLength = 32;
private readonly RandomNumberGenerator _generator = RandomNumberGenerator.Create();
private readonly IDataProtector _protector;
@ -27,29 +32,58 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
// we don't care about the underlying payload, other than its uniqueness and the fact that we
// authenticate encrypt it using data protection.
// For validation, the fact that we can unprotect the payload is guarantee enough.
public string CreateCircuitId()
public CircuitId CreateCircuitId()
{
var buffer = new byte[32];
var buffer = new byte[SecretLength];
_generator.GetBytes(buffer);
var payload = _protector.Protect(buffer);
return Base64UrlTextEncoder.Encode(payload);
var id = new byte[IdLength];
Array.Copy(
sourceArray: buffer,
sourceIndex: SecretLength - IdLength,
destinationArray: id,
destinationIndex: 0,
length: IdLength);
var secret = _protector.Protect(buffer);
return new CircuitId(Base64UrlTextEncoder.Encode(secret), Base64UrlTextEncoder.Encode(id));
}
public bool ValidateCircuitId(string circuitId)
public bool TryParseCircuitId(string text, out CircuitId circuitId)
{
if (text is null)
{
circuitId = default;
return false;
}
try
{
var protectedBytes = Base64UrlTextEncoder.Decode(circuitId);
_protector.Unprotect(protectedBytes);
var protectedBytes = Base64UrlTextEncoder.Decode(text);
var unprotectedBytes = _protector.Unprotect(protectedBytes);
// Its enough that we prove that we can unprotect the payload to validate the circuit id,
// as this demonstrates that it the id wasn't tampered with.
if (unprotectedBytes.Length != SecretLength)
{
// Wrong length
circuitId = default;
return false;
}
var id = new byte[IdLength];
Array.Copy(
sourceArray: unprotectedBytes,
sourceIndex: SecretLength - IdLength,
destinationArray: id,
destinationIndex: 0,
length: IdLength);
circuitId = new CircuitId(text, Base64UrlTextEncoder.Encode(id));
return true;
}
catch (Exception)
{
// The payload format is not correct (either not base64urlencoded or not data protected)
circuitId = default;
return false;
}
}

View File

@ -47,12 +47,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public CircuitRegistry(
IOptions<CircuitOptions> options,
ILogger<CircuitRegistry> logger,
CircuitIdFactory circuitIdFactory)
CircuitIdFactory CircuitHostFactory)
{
_options = options.Value;
_logger = logger;
_circuitIdFactory = circuitIdFactory;
ConnectedCircuits = new ConcurrentDictionary<string, CircuitHost>(StringComparer.Ordinal);
_circuitIdFactory = CircuitHostFactory;
ConnectedCircuits = new ConcurrentDictionary<CircuitId, CircuitHost>();
DisconnectedCircuits = new MemoryCache(new MemoryCacheOptions
{
@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
};
}
internal ConcurrentDictionary<string, CircuitHost> ConnectedCircuits { get; }
internal ConcurrentDictionary<CircuitId, CircuitHost> ConnectedCircuits { get; }
internal MemoryCache DisconnectedCircuits { get; }
@ -155,7 +155,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
};
var entry = new DisconnectedCircuitEntry(circuitHost, cancellationTokenSource);
DisconnectedCircuits.Set(circuitHost.CircuitId, entry, entryOptions);
DisconnectedCircuits.Set(circuitHost.CircuitId.Secret, entry, entryOptions);
}
// ConnectAsync is called from the CircuitHub - but the error handling story is a little bit complicated.
@ -164,20 +164,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
//
// The solution is to handle exceptions here, and then return null to represent failure.
//
// 1. If the circuit id is invalue return null
// 2. If the circuit is not found return null
// 3. If the circuit is found, but fails to connect, we need to dispose it here and return null
// 4. If everything goes well, return the circuit.
public virtual async Task<CircuitHost> ConnectAsync(string circuitId, IClientProxy clientProxy, string connectionId, CancellationToken cancellationToken)
// 1. If the circuit is not found return null
// 2. If the circuit is found, but fails to connect, we need to dispose it here and return null
// 3. If everything goes well, return the circuit.
public virtual async Task<CircuitHost> ConnectAsync(CircuitId circuitId, IClientProxy clientProxy, string connectionId, CancellationToken cancellationToken)
{
Log.CircuitConnectStarted(_logger, circuitId);
if (!_circuitIdFactory.ValidateCircuitId(circuitId))
{
Log.InvalidCircuitId(_logger, circuitId);
return null;
}
CircuitHost circuitHost;
bool previouslyConnected;
@ -193,8 +186,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
if (circuitHost == null)
{
Log.FailedToReconnectToCircuit(_logger, circuitId);
// Failed to connect. Nothing to do here.
Log.FailedToFindCircuit(_logger, circuitId);
// Failed to find a matching circuit. Nothing to do here.
return null;
}
@ -220,12 +213,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
try
{
await circuitHandlerTask;
Log.ReconnectionSucceeded(_logger, circuitId);
Log.ReconnectionSucceeded(_logger, circuitHost.CircuitId);
return circuitHost;
}
catch (Exception ex)
{
Log.FailedToReconnectToCircuit(_logger, circuitId, ex);
Log.FailedToReconnectToCircuit(_logger, circuitHost.CircuitId, ex);
await TerminateAsync(circuitId);
// Return null on failure, because we need to clean up the circuit.
@ -233,11 +226,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
protected virtual (CircuitHost circuitHost, bool previouslyConnected) ConnectCore(string circuitId, IClientProxy clientProxy, string connectionId)
protected virtual (CircuitHost circuitHost, bool previouslyConnected) ConnectCore(CircuitId circuitId, IClientProxy clientProxy, string connectionId)
{
if (ConnectedCircuits.TryGetValue(circuitId, out var connectedCircuitHost))
{
Log.ConnectingToActiveCircuit(_logger, circuitId, connectionId);
Log.ConnectingToActiveCircuit(_logger, connectedCircuitHost.CircuitId, connectionId);
// The host is still active i.e. the server hasn't detected the client disconnect.
// However the client reconnected establishing a new connection.
@ -245,15 +238,15 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
return (connectedCircuitHost, true);
}
if (DisconnectedCircuits.TryGetValue(circuitId, out DisconnectedCircuitEntry disconnectedEntry))
if (DisconnectedCircuits.TryGetValue(circuitId.Secret, out DisconnectedCircuitEntry disconnectedEntry))
{
Log.ConnectingToDisconnectedCircuit(_logger, circuitId, connectionId);
Log.ConnectingToDisconnectedCircuit(_logger, disconnectedEntry.CircuitHost.CircuitId, connectionId);
// The host was in disconnected state. Transfer it to ConnectedCircuits so that it's no longer considered disconnected.
// First discard the CancellationTokenSource so that the cache entry does not expire.
DisposeTokenSource(disconnectedEntry);
DisconnectedCircuits.Remove(circuitId);
DisconnectedCircuits.Remove(circuitId.Secret);
ConnectedCircuits.TryAdd(circuitId, disconnectedEntry.CircuitHost);
disconnectedEntry.CircuitHost.Client.Transfer(clientProxy, connectionId);
@ -313,17 +306,18 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
}
}
public ValueTask TerminateAsync(string circuitId)
// We don't expect this to throw. User code only runs inside DisposeAsync and that does its own error handling.
public ValueTask TerminateAsync(CircuitId circuitId)
{
CircuitHost circuitHost;
DisconnectedCircuitEntry entry = default;
lock (CircuitRegistryLock)
{
if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId, out entry))
if (ConnectedCircuits.TryGetValue(circuitId, out circuitHost) || DisconnectedCircuits.TryGetValue(circuitId.Secret, out entry))
{
circuitHost ??= entry.CircuitHost;
DisconnectedCircuits.Remove(circuitHost.CircuitId);
ConnectedCircuits.TryRemove(circuitHost.CircuitId, out _);
DisconnectedCircuits.Remove(circuitId.Secret);
ConnectedCircuits.TryRemove(circuitId, out _);
Log.CircuitDisconnectedPermanently(_logger, circuitHost.CircuitId);
circuitHost.Client.SetDisconnected();
}
@ -372,36 +366,36 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
{
private static readonly Action<ILogger, string, Exception> _exceptionDisposingCircuitHost;
private static readonly Action<ILogger, string, Exception> _unhandledExceptionDisposingTokenSource;
private static readonly Action<ILogger, string, Exception> _circuitReconnectStarted;
private static readonly Action<ILogger, string, Exception> _invalidCircuitId;
private static readonly Action<ILogger, string, string, Exception> _connectingToActiveCircuit;
private static readonly Action<ILogger, string, string, Exception> _connectingToDisconnectedCircuit;
private static readonly Action<ILogger, string, Exception> _failedToReconnectToCircuit;
private static readonly Action<ILogger, string, Exception> _reconnectionSucceeded;
private static readonly Action<ILogger, string, string, Exception> _circuitDisconnectStarted;
private static readonly Action<ILogger, string, Exception> _circuitNotActive;
private static readonly Action<ILogger, string, string, Exception> _circuitConnectedToDifferentConnection;
private static readonly Action<ILogger, string, Exception> _circuitMarkedDisconnected;
private static readonly Action<ILogger, string, Exception> _circuitDisconnectedPermanently;
private static readonly Action<ILogger, string, EvictionReason, Exception> _circuitEvicted;
private static readonly Action<ILogger, string, Exception> _circuitExceptionHandlerFailed;
private static readonly Action<ILogger, CircuitId, Exception> _circuitReconnectStarted;
private static readonly Action<ILogger, CircuitId, Exception> _failedToFindCircuit;
private static readonly Action<ILogger, CircuitId, string, Exception> _connectingToActiveCircuit;
private static readonly Action<ILogger, CircuitId, string, Exception> _connectingToDisconnectedCircuit;
private static readonly Action<ILogger, CircuitId, Exception> _failedToReconnectToCircuit;
private static readonly Action<ILogger, CircuitId, Exception> _reconnectionSucceeded;
private static readonly Action<ILogger, CircuitId, string, Exception> _circuitDisconnectStarted;
private static readonly Action<ILogger, CircuitId, Exception> _circuitNotActive;
private static readonly Action<ILogger, CircuitId, string, Exception> _circuitConnectedToDifferentConnection;
private static readonly Action<ILogger, CircuitId, Exception> _circuitMarkedDisconnected;
private static readonly Action<ILogger, CircuitId, Exception> _circuitDisconnectedPermanently;
private static readonly Action<ILogger, CircuitId, EvictionReason, Exception> _circuitEvicted;
private static readonly Action<ILogger, CircuitId, Exception> _circuitExceptionHandlerFailed;
private static class EventIds
{
public static readonly EventId ExceptionDisposingCircuit = new EventId(100, "ExceptionDisposingCircuit");
public static readonly EventId ExceptionDisposingTokenSource = new EventId(101, "ExceptionDisposingTokenSource");
public static readonly EventId AttemptingToReconnect = new EventId(102, "AttemptingToReconnect");
public static readonly EventId InvalidCircuitId = new EventId(103, "InvalidCircuitId");
public static readonly EventId ConnectingToActiveCircuit = new EventId(104, "ConnectingToActiveCircuit");
public static readonly EventId ConnectingToDisconnectedCircuit = new EventId(105, "ConnectingToDisconnectedCircuit");
public static readonly EventId FailedToReconnectToCircuit = new EventId(106, "FailedToReconnectToCircuit");
public static readonly EventId CircuitDisconnectStarted = new EventId(107, "CircuitDisconnectStarted");
public static readonly EventId CircuitNotActive = new EventId(108, "CircuitNotActive");
public static readonly EventId CircuitConnectedToDifferentConnection = new EventId(109, "CircuitConnectedToDifferentConnection");
public static readonly EventId CircuitMarkedDisconnected = new EventId(110, "CircuitMarkedDisconnected");
public static readonly EventId CircuitEvicted = new EventId(111, "CircuitEvicted");
public static readonly EventId CircuitDisconnectedPermanently = new EventId(112, "CircuitDisconnectedPermanently");
public static readonly EventId CircuitExceptionHandlerFailed = new EventId(113, "CircuitExceptionHandlerFailed");
public static readonly EventId FailedToFindCircuit = new EventId(104, "FailedToFindCircuit");
public static readonly EventId ConnectingToActiveCircuit = new EventId(105, "ConnectingToActiveCircuit");
public static readonly EventId ConnectingToDisconnectedCircuit = new EventId(106, "ConnectingToDisconnectedCircuit");
public static readonly EventId FailedToReconnectToCircuit = new EventId(107, "FailedToReconnectToCircuit");
public static readonly EventId CircuitDisconnectStarted = new EventId(108, "CircuitDisconnectStarted");
public static readonly EventId CircuitNotActive = new EventId(109, "CircuitNotActive");
public static readonly EventId CircuitConnectedToDifferentConnection = new EventId(110, "CircuitConnectedToDifferentConnection");
public static readonly EventId CircuitMarkedDisconnected = new EventId(111, "CircuitMarkedDisconnected");
public static readonly EventId CircuitEvicted = new EventId(112, "CircuitEvicted");
public static readonly EventId CircuitDisconnectedPermanently = new EventId(113, "CircuitDisconnectedPermanently");
public static readonly EventId CircuitExceptionHandlerFailed = new EventId(114, "CircuitExceptionHandlerFailed");
}
static Log()
@ -416,67 +410,67 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
EventIds.ExceptionDisposingTokenSource,
"Exception thrown when disposing token source: {Message}");
_circuitReconnectStarted = LoggerMessage.Define<string>(
_circuitReconnectStarted = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.AttemptingToReconnect,
"Attempting to reconnect to Circuit with id {CircuitId}.");
"Attempting to reconnect to Circuit with secret {CircuitHost}.");
_invalidCircuitId = LoggerMessage.Define<string>(
_failedToFindCircuit = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.InvalidCircuitId,
"Failed to validate circuit id {CircuitId}.");
EventIds.FailedToFindCircuit,
"Failed to find a matching circuit for circuit secret {CircuitHost}.");
_connectingToActiveCircuit = LoggerMessage.Define<string, string>(
_connectingToActiveCircuit = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.ConnectingToActiveCircuit,
"Transferring active circuit {CircuitId} to connection {ConnectionId}.");
_connectingToDisconnectedCircuit = LoggerMessage.Define<string, string>(
_connectingToDisconnectedCircuit = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.ConnectingToDisconnectedCircuit,
"Transfering disconnected circuit {CircuitId} to connection {ConnectionId}.");
_failedToReconnectToCircuit = LoggerMessage.Define<string>(
_failedToReconnectToCircuit = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.FailedToReconnectToCircuit,
"Failed to reconnect to a circuit with id {CircuitId}.");
_reconnectionSucceeded = LoggerMessage.Define<string>(
_reconnectionSucceeded = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.FailedToReconnectToCircuit,
"Reconnect to circuit with id {CircuitId} succeeded.");
_circuitDisconnectStarted = LoggerMessage.Define<string, string>(
_circuitDisconnectStarted = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.CircuitDisconnectStarted,
"Attempting to disconnect circuit with id {CircuitId} from connection {ConnectionId}.");
_circuitNotActive = LoggerMessage.Define<string>(
_circuitNotActive = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitNotActive,
"Failed to disconnect circuit with id {CircuitId}. The circuit is not active.");
_circuitConnectedToDifferentConnection = LoggerMessage.Define<string, string>(
_circuitConnectedToDifferentConnection = LoggerMessage.Define<CircuitId, string>(
LogLevel.Debug,
EventIds.CircuitConnectedToDifferentConnection,
"Failed to disconnect circuit with id {CircuitId}. The circuit is connected to {ConnectionId}.");
_circuitMarkedDisconnected = LoggerMessage.Define<string>(
_circuitMarkedDisconnected = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitMarkedDisconnected,
"Circuit with id {CircuitId} is disconnected.");
_circuitDisconnectedPermanently = LoggerMessage.Define<string>(
_circuitDisconnectedPermanently = LoggerMessage.Define<CircuitId>(
LogLevel.Debug,
EventIds.CircuitDisconnectedPermanently,
"Circuit with id {CircuitId} has been removed from the registry for permanent disconnection.");
_circuitEvicted = LoggerMessage.Define<string, EvictionReason>(
_circuitEvicted = LoggerMessage.Define<CircuitId, EvictionReason>(
LogLevel.Debug,
EventIds.CircuitEvicted,
"Circuit with id {CircuitId} evicted due to {EvictionReason}.");
_circuitExceptionHandlerFailed = LoggerMessage.Define<string>(
_circuitExceptionHandlerFailed = LoggerMessage.Define<CircuitId>(
LogLevel.Error,
EventIds.CircuitExceptionHandlerFailed,
"Exception handler for {CircuitId} failed.");
@ -488,43 +482,43 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public static void ExceptionDisposingTokenSource(ILogger logger, Exception exception) =>
_unhandledExceptionDisposingTokenSource(logger, exception.Message, exception);
public static void CircuitConnectStarted(ILogger logger, string circuitId) =>
public static void CircuitConnectStarted(ILogger logger, CircuitId circuitId) =>
_circuitReconnectStarted(logger, circuitId, null);
public static void InvalidCircuitId(ILogger logger, string circuitId) =>
_invalidCircuitId(logger, circuitId, null);
public static void FailedToFindCircuit(ILogger logger, CircuitId circuitId) =>
_failedToFindCircuit(logger, circuitId, null);
public static void ConnectingToActiveCircuit(ILogger logger, string circuitId, string connectionId) =>
public static void ConnectingToActiveCircuit(ILogger logger, CircuitId circuitId, string connectionId) =>
_connectingToActiveCircuit(logger, circuitId, connectionId, null);
public static void ConnectingToDisconnectedCircuit(ILogger logger, string circuitId, string connectionId) =>
public static void ConnectingToDisconnectedCircuit(ILogger logger, CircuitId circuitId, string connectionId) =>
_connectingToDisconnectedCircuit(logger, circuitId, connectionId, null);
public static void FailedToReconnectToCircuit(ILogger logger, string circuitId, Exception exception = null) =>
public static void FailedToReconnectToCircuit(ILogger logger, CircuitId circuitId, Exception exception = null) =>
_failedToReconnectToCircuit(logger, circuitId, exception);
public static void ReconnectionSucceeded(ILogger logger, string circuitId) =>
public static void ReconnectionSucceeded(ILogger logger, CircuitId circuitId) =>
_reconnectionSucceeded(logger, circuitId, null);
public static void CircuitDisconnectStarted(ILogger logger, string circuitId, string connectionId) =>
public static void CircuitDisconnectStarted(ILogger logger, CircuitId circuitId, string connectionId) =>
_circuitDisconnectStarted(logger, circuitId, connectionId, null);
public static void CircuitNotActive(ILogger logger, string circuitId) =>
public static void CircuitNotActive(ILogger logger, CircuitId circuitId) =>
_circuitNotActive(logger, circuitId, null);
public static void CircuitConnectedToDifferentConnection(ILogger logger, string circuitId, string connectionId) =>
public static void CircuitConnectedToDifferentConnection(ILogger logger, CircuitId circuitId, string connectionId) =>
_circuitConnectedToDifferentConnection(logger, circuitId, connectionId, null);
public static void CircuitMarkedDisconnected(ILogger logger, string circuitId) =>
public static void CircuitMarkedDisconnected(ILogger logger, CircuitId circuitId) =>
_circuitMarkedDisconnected(logger, circuitId, null);
public static void CircuitDisconnectedPermanently(ILogger logger, string circuitId) =>
public static void CircuitDisconnectedPermanently(ILogger logger, CircuitId circuitId) =>
_circuitDisconnectedPermanently(logger, circuitId, null);
public static void CircuitEvicted(ILogger logger, string circuitId, EvictionReason evictionReason) =>
public static void CircuitEvicted(ILogger logger, CircuitId circuitId, EvictionReason evictionReason) =>
_circuitEvicted(logger, circuitId, evictionReason, null);
public static void CircuitExceptionHandlerFailed(ILogger logger, string circuitId, Exception exception) =>
public static void CircuitExceptionHandlerFailed(ILogger logger, CircuitId circuitId, Exception exception) =>
_circuitExceptionHandlerFailed(logger, circuitId, exception);
}
}

View File

@ -111,11 +111,11 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
private static class Log
{
private static readonly Action<ILogger, string, string, Exception> _createdConnectedCircuit =
LoggerMessage.Define<string, string>(LogLevel.Debug, new EventId(1, "CreatedConnectedCircuit"), "Created circuit {CircuitId} for connection {ConnectionId}");
private static readonly Action<ILogger, CircuitId, string, Exception> _createdConnectedCircuit =
LoggerMessage.Define<CircuitId, string>(LogLevel.Debug, new EventId(1, "CreatedConnectedCircuit"), "Created circuit {CircuitId} for connection {ConnectionId}");
private static readonly Action<ILogger, string, Exception> _createdDisconnectedCircuit =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(2, "CreatedDisconnectedCircuit"), "Created circuit {CircuitId} for disconnected client");
private static readonly Action<ILogger, CircuitId, Exception> _createdDisconnectedCircuit =
LoggerMessage.Define<CircuitId>(LogLevel.Debug, new EventId(2, "CreatedDisconnectedCircuit"), "Created circuit {CircuitId} for disconnected client");
internal static void CreatedCircuit(ILogger logger, CircuitHost circuitHost)
{

View File

@ -38,17 +38,20 @@ namespace Microsoft.AspNetCore.Components.Server
{
private static readonly object CircuitKey = new object();
private readonly CircuitFactory _circuitFactory;
private readonly CircuitIdFactory _circuitIdFactory;
private readonly CircuitRegistry _circuitRegistry;
private readonly CircuitOptions _options;
private readonly ILogger _logger;
public ComponentHub(
CircuitFactory circuitFactory,
CircuitIdFactory circuitIdFactory,
CircuitRegistry circuitRegistry,
ILogger<ComponentHub> logger,
IOptions<CircuitOptions> options)
{
_circuitFactory = circuitFactory;
_circuitIdFactory = circuitIdFactory;
_circuitRegistry = circuitRegistry;
_options = options.Value;
_logger = logger;
@ -129,7 +132,12 @@ namespace Microsoft.AspNetCore.Components.Server
// to run inside it until after InitializeAsync completes.
_circuitRegistry.Register(circuitHost);
SetCircuit(circuitHost);
return circuitHost.CircuitId;
// Returning the secret here so the client can reconnect.
//
// Logging the secret and circuit ID here so we can associate them with just logs (if TRACE level is on).
Log.CreatedCircuit(_logger, circuitHost.CircuitId, circuitHost.CircuitId.Secret, Context.ConnectionId);
return circuitHost.CircuitId.Secret;
}
catch (Exception ex)
{
@ -142,10 +150,22 @@ namespace Microsoft.AspNetCore.Components.Server
}
}
public async ValueTask<bool> ConnectCircuit(string circuitId)
public async ValueTask<bool> ConnectCircuit(string circuitIdSecret)
{
// ConnectionAsync will not throw.
var circuitHost = await _circuitRegistry.ConnectAsync(circuitId, Clients.Caller, Context.ConnectionId, Context.ConnectionAborted);
// TryParseCircuitId will not throw.
if (!_circuitIdFactory.TryParseCircuitId(circuitIdSecret, out var circuitId))
{
// Invalid id.
Log.InvalidCircuitId(_logger, circuitIdSecret);
return false;
}
// ConnectAsync will not throw.
var circuitHost = await _circuitRegistry.ConnectAsync(
circuitId,
Clients.Caller,
Context.ConnectionId,
Context.ConnectionAborted);
if (circuitHost != null)
{
SetCircuit(circuitHost);
@ -270,8 +290,8 @@ namespace Microsoft.AspNetCore.Components.Server
private static readonly Action<ILogger, string, Exception> _unhandledExceptionInCircuit =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(3, "UnhandledExceptionInCircuit"), "Unhandled exception in circuit {CircuitId}");
private static readonly Action<ILogger, string, Exception> _circuitAlreadyInitialized =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(4, "CircuitAlreadyInitialized"), "The circuit host '{CircuitId}' has already been initialized");
private static readonly Action<ILogger, CircuitId, Exception> _circuitAlreadyInitialized =
LoggerMessage.Define<CircuitId>(LogLevel.Debug, new EventId(4, "CircuitAlreadyInitialized"), "The circuit host '{CircuitId}' has already been initialized");
private static readonly Action<ILogger, string, Exception> _circuitHostNotInitialized =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(5, "CircuitHostNotInitialized"), "Call to '{CallSite}' received before the circuit host initialization");
@ -285,11 +305,17 @@ namespace Microsoft.AspNetCore.Components.Server
private static readonly Action<ILogger, Exception> _circuitInitializationFailed =
LoggerMessage.Define(LogLevel.Debug, new EventId(8, "CircuitInitializationFailed"), "Circuit initialization failed");
private static readonly Action<ILogger, CircuitId, string, string, Exception> _createdCircuit =
LoggerMessage.Define<CircuitId, string, string>(LogLevel.Debug, new EventId(8, "CreatedCircuit"), "Created circuit '{CircuitId}' with secret '{CircuitIdSecret}' for '{ConnectionId}'");
private static readonly Action<ILogger, string, Exception> _invalidCircuitId =
LoggerMessage.Define<string>(LogLevel.Debug, new EventId(9, "InvalidCircuitId"), "ConnectAsync recieved an invalid circuit id '{CircuitIdSecret}'");
public static void NoComponentsRegisteredInEndpoint(ILogger logger, string endpointDisplayName) => _noComponentsRegisteredInEndpoint(logger, endpointDisplayName, null);
public static void ReceivedConfirmationForBatch(ILogger logger, long batchId) => _receivedConfirmationForBatch(logger, batchId, null);
public static void CircuitAlreadyInitialized(ILogger logger, string circuitId) => _circuitAlreadyInitialized(logger, circuitId, null);
public static void CircuitAlreadyInitialized(ILogger logger, CircuitId circuitId) => _circuitAlreadyInitialized(logger, circuitId, null);
public static void CircuitHostNotInitialized(ILogger logger, [CallerMemberName] string callSite = "") => _circuitHostNotInitialized(logger, callSite, null);
@ -298,6 +324,28 @@ namespace Microsoft.AspNetCore.Components.Server
public static void InvalidInputData(ILogger logger, [CallerMemberName] string callSite = "") => _invalidInputData(logger, callSite, null);
public static void CircuitInitializationFailed(ILogger logger, Exception exception) => _circuitInitializationFailed(logger, exception);
public static void CreatedCircuit(ILogger logger, CircuitId circuitId, string circuitSecret, string connectionId)
{
// Redact the secret unless tracing is on.
if (!logger.IsEnabled(LogLevel.Trace))
{
circuitSecret = "(redacted)";
}
_createdCircuit(logger, circuitId, circuitSecret, connectionId, null);
}
public static void InvalidCircuitId(ILogger logger, string circuitSecret)
{
// Redact the secret unless tracing is on.
if (!logger.IsEnabled(LogLevel.Trace))
{
circuitSecret = "(redacted)";
}
_invalidCircuitId(logger, circuitSecret, null);
}
}
}
}

View File

@ -138,7 +138,7 @@ namespace Microsoft.AspNetCore.Components.Server
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var circuitId = circuitIdFactory.CreateCircuitId();
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
NullLogger<CircuitRegistry>.Instance,
@ -151,7 +151,7 @@ namespace Microsoft.AspNetCore.Components.Server
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = circuitId.Secret, }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
@ -171,8 +171,8 @@ namespace Microsoft.AspNetCore.Components.Server
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var testCircuitHost = TestCircuitHost.Create(id);
var circuitId = circuitIdFactory.CreateCircuitId();
var testCircuitHost = TestCircuitHost.Create(circuitId);
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
@ -188,7 +188,7 @@ namespace Microsoft.AspNetCore.Components.Server
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = circuitId.Secret, }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
@ -208,8 +208,8 @@ namespace Microsoft.AspNetCore.Components.Server
{
// Arrange
var circuitIdFactory = TestCircuitIdFactory.CreateTestFactory();
var id = circuitIdFactory.CreateCircuitId();
var circuitHost = TestCircuitHost.Create(id);
var circuitId = circuitIdFactory.CreateCircuitId();
var circuitHost = TestCircuitHost.Create(circuitId);
var registry = new CircuitRegistry(
Options.Create(new CircuitOptions()),
@ -226,7 +226,7 @@ namespace Microsoft.AspNetCore.Components.Server
(ctx) => Task.CompletedTask);
using var memory = new MemoryStream();
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = id }).CopyToAsync(memory);
await new FormUrlEncodedContent(new Dictionary<string, string> { ["circuitId"] = circuitId.Secret }).CopyToAsync(memory);
memory.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();

View File

@ -26,9 +26,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var serviceScope = new Mock<IServiceScope>();
var remoteRenderer = GetRemoteRenderer();
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer);
serviceScope: serviceScope.Object,
remoteRenderer: remoteRenderer);
// Act
await circuitHost.DisposeAsync();
@ -52,9 +51,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var remoteRenderer = GetRemoteRenderer();
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer);
serviceScope: serviceScope.Object,
remoteRenderer: remoteRenderer);
// Act
await circuitHost.DisposeAsync();
@ -77,9 +75,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
.Throws<InvalidTimeZoneException>();
var remoteRenderer = GetRemoteRenderer();
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer,
serviceScope: serviceScope.Object,
remoteRenderer: remoteRenderer,
handlers: new[] { handler.Object });
var throwOnDisposeComponent = new ThrowOnDisposeComponent();
@ -101,9 +98,8 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var serviceScope = new Mock<IServiceScope>();
var remoteRenderer = GetRemoteRenderer();
var circuitHost = TestCircuitHost.Create(
Guid.NewGuid().ToString(),
serviceScope.Object,
remoteRenderer);
serviceScope: serviceScope.Object,
remoteRenderer: remoteRenderer);
var component = new DispatcherComponent(circuitHost.Renderer.Dispatcher);
circuitHost.Renderer.AssignRootComponentId(component);

View File

@ -3,13 +3,12 @@
using System;
using System.Linq;
using Microsoft.AspNetCore.DataProtection;
using Microsoft.AspNetCore.WebUtilities;
using Xunit;
namespace Microsoft.AspNetCore.Components.Server.Circuits
{
public class CircuitIdFactoryTest
public class circuitIdFactoryTest
{
[Fact]
public void CreateCircuitId_Generates_NewRandomId()
@ -17,12 +16,12 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var id = factory.CreateCircuitId();
var secret = factory.CreateCircuitId();
// Assert
Assert.NotNull(id);
Assert.NotNull(secret.Secret);
// This is the magic data protection header that validates its protected
Assert.StartsWith("CfDJ", id);
Assert.StartsWith("CfDJ", secret.Secret);
}
[Fact]
@ -32,13 +31,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var ids = Enumerable.Range(0, 100).Select(i => factory.CreateCircuitId()).ToArray();
var secrets = Enumerable.Range(0, 100).Select(i => factory.CreateCircuitId()).Select(s => s.Secret).ToArray();
// Assert
Assert.All(ids, id => Assert.NotNull(id));
Assert.Equal(100, ids.Distinct(StringComparer.Ordinal).Count());
Assert.All(secrets, secret => Assert.NotNull(secret));
Assert.Equal(100, secrets.Distinct(StringComparer.Ordinal).Count());
}
// Note that this test also verifies that the ID can be reproduced from the secret.
[Fact]
public void CircuitIds_Roundtrip()
{
@ -47,10 +47,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var id = factory.CreateCircuitId();
// Act
var isValid = factory.ValidateCircuitId(id);
var isValid = factory.TryParseCircuitId(id.Secret, out var parsed);
// Assert
Assert.True(isValid, "Failed to validate id");
Assert.Equal(id, parsed);
Assert.Equal(id.Secret, parsed.Secret);
Assert.Equal(id.Id, parsed.Id);
}
[Fact]
@ -60,7 +63,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
var factory = TestCircuitIdFactory.CreateTestFactory();
// Act
var isValid = factory.ValidateCircuitId("$%@&==");
var isValid = factory.TryParseCircuitId("$%@&==", out _);
// Assert
Assert.False(isValid, "Accepted an invalid payload");
@ -71,16 +74,16 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
{
// Arrange
var factory = TestCircuitIdFactory.CreateTestFactory();
var id = factory.CreateCircuitId();
var protectedBytes = Base64UrlTextEncoder.Decode(id);
var secret = factory.CreateCircuitId();
var protectedBytes = Base64UrlTextEncoder.Decode(secret.Secret);
for (int i = protectedBytes.Length - 10; i < protectedBytes.Length; i++)
{
protectedBytes[i] = 0;
}
var tamperedId = Base64UrlTextEncoder.Encode(protectedBytes);
var tampered = Base64UrlTextEncoder.Encode(protectedBytes);
// Act
var isValid = factory.ValidateCircuitId(tamperedId);
var isValid = factory.TryParseCircuitId(tampered, out _);
// Assert
Assert.False(isValid, "Accepted a tampered payload");

View File

@ -163,14 +163,14 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
// Arrange
var registry = CreateRegistry();
var circuitHost = TestCircuitHost.Create();
registry.DisconnectedCircuits.Set(circuitHost.CircuitId, circuitHost, new MemoryCacheEntryOptions { Size = 1 });
registry.DisconnectedCircuits.Set(circuitHost.CircuitId.Secret, circuitHost, new MemoryCacheEntryOptions { Size = 1 });
// Act
await registry.DisconnectAsync(circuitHost, circuitHost.Client.ConnectionId);
// Assert
Assert.Empty(registry.ConnectedCircuits.Values);
Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _));
Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _));
}
[Fact]
@ -267,7 +267,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
Assert.Same(client, circuitHost.Client.Client);
Assert.Equal(newId, circuitHost.Client.ConnectionId);
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _));
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _));
}
[Fact]
@ -297,7 +297,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
Assert.Same(client, circuitHost.Client.Client);
Assert.Equal(newId, circuitHost.Client.ConnectionId);
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out _));
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out _));
}
[Fact]
@ -322,9 +322,9 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
// Act
// Verify it's present in the dictionary.
Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _));
Assert.True(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _));
await Task.Run(() => tcs.Task.TimeoutAfter(TimeSpan.FromSeconds(10)));
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _));
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _));
}
[Fact]
@ -355,7 +355,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
Assert.True(registry.ConnectedCircuits.TryGetValue(circuitHost.CircuitId, out var cacheValue));
Assert.Same(circuitHost, cacheValue);
// Nothing should be disconnected.
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId, out var _));
Assert.False(registry.DisconnectedCircuits.TryGetValue(circuitHost.CircuitId.Secret, out var _));
}
private class TestCircuitRegistry : CircuitRegistry
@ -370,7 +370,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
public Action OnAfterEntryEvicted { get; set; }
protected override (CircuitHost, bool) ConnectCore(string circuitId, IClientProxy clientProxy, string connectionId)
protected override (CircuitHost, bool) ConnectCore(CircuitId circuitId, IClientProxy clientProxy, string connectionId)
{
if (BeforeConnect != null)
{

View File

@ -15,13 +15,13 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
{
internal class TestCircuitHost : CircuitHost
{
private TestCircuitHost(string circuitId, IServiceScope scope, CircuitOptions options, CircuitClientProxy client, RemoteRenderer renderer, IReadOnlyList<ComponentDescriptor> descriptors, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger)
private TestCircuitHost(CircuitId circuitId, IServiceScope scope, CircuitOptions options, CircuitClientProxy client, RemoteRenderer renderer, IReadOnlyList<ComponentDescriptor> descriptors, RemoteJSRuntime jsRuntime, CircuitHandler[] circuitHandlers, ILogger logger)
: base(circuitId, scope, options, client, renderer, descriptors, jsRuntime, circuitHandlers, logger)
{
}
public static CircuitHost Create(
string circuitId = null,
CircuitId? circuitId = null,
IServiceScope serviceScope = null,
RemoteRenderer remoteRenderer = null,
CircuitHandler[] handlers = null,
@ -44,7 +44,7 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits
handlers = handlers ?? Array.Empty<CircuitHandler>();
return new TestCircuitHost(
circuitId ?? Guid.NewGuid().ToString(),
circuitId is null ? new CircuitId(Guid.NewGuid().ToString(), Guid.NewGuid().ToString()) : circuitId.Value,
serviceScope,
new CircuitOptions(),
clientProxy,

View File

@ -68,9 +68,6 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
{
// Arrange & Act
Browser.Close();
// Set to null so that other tests in this class can create a new browser if necessary so
// that tests don't fail when running together.
await Task.WhenAny(Task.Delay(10000), GracefulDisconnectCompletionSource.Task);
// Assert

View File

@ -63,7 +63,7 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
private void LogMessages(WriteContext context)
{
var log = new LogMessage(context.LogLevel, context.Message, context.Exception);
var log = new LogMessage(context.LogLevel, context.EventId, context.Message, context.Exception);
Logs.Enqueue(log);
Output.WriteLine(log.ToString());
}
@ -309,10 +309,11 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
var actualError = Assert.Single(Errors);
Assert.Equal(expectedError, actualError);
Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information);
Assert.Contains(Logs, l => (l.LogLevel, l.Message, l.Exception?.Message) ==
(LogLevel.Debug,
$"Failed to complete render batch '1846' in circuit host '{Client.CircuitId}'.",
"Received an acknowledgement for batch with id '1846' when the last batch produced was '4'."));
var entry = Assert.Single(Logs, l => l.EventId.Name == "OnRenderCompletedFailed");
Assert.Equal(LogLevel.Debug, entry.LogLevel);
Assert.Matches("Failed to complete render batch '1846' in circuit host '.*'\\.", entry.Message);
Assert.Equal("Received an acknowledgement for batch with id '1846' when the last batch produced was '4'.", entry.Exception.Message);
}
[Fact]
@ -384,10 +385,10 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
var actualError = Assert.Single(Errors);
Assert.Equal(expectedError, actualError);
Assert.DoesNotContain(Logs, l => l.LogLevel > LogLevel.Information);
Assert.Contains(Logs, l =>
{
return (l.LogLevel, l.Message) == (LogLevel.Debug, $"Location change to 'http://example.com' in circuit '{Client.CircuitId}' failed.");
});
var entry = Assert.Single(Logs, l => l.EventId.Name == "LocationChangeFailed");
Assert.Equal(LogLevel.Debug, entry.LogLevel);
Assert.Matches("Location change to 'http://example.com' in circuit '.*' failed\\.", entry.Message);
}
[Fact]
@ -414,10 +415,10 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
// Assert
var actualError = Assert.Single(Errors);
Assert.Equal(expectedError, actualError);
Assert.Contains(Logs, l =>
{
return (l.LogLevel, l.Message) == (LogLevel.Error, $"Location change to '{new Uri(_serverFixture.RootUri,"/test")}' in circuit '{Client.CircuitId}' failed.");
});
var entry = Assert.Single(Logs, l => l.EventId.Name == "LocationChangeFailed");
Assert.Equal(LogLevel.Error, entry.LogLevel);
Assert.Matches($"Location change to '{new Uri(_serverFixture.RootUri, "/test")}' in circuit '.*' failed\\.", entry.Message);
}
[Theory]
@ -501,20 +502,22 @@ namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests
[DebuggerDisplay("{LogLevel.ToString(),nq} - {Message ?? \"null\",nq} - {Exception?.Message,nq}")]
private class LogMessage
{
public LogMessage(LogLevel logLevel, string message, Exception exception)
public LogMessage(LogLevel logLevel, EventId eventId, string message, Exception exception)
{
LogLevel = logLevel;
EventId = eventId;
Message = message;
Exception = exception;
}
public LogLevel LogLevel { get; set; }
public EventId EventId { get; set; }
public string Message { get; set; }
public Exception Exception { get; set; }
public override string ToString()
{
return $"{LogLevel}: {Message}{(Exception != null ? Environment.NewLine : "")}{Exception}";
return $"{LogLevel}: {EventId} {Message}{(Exception != null ? Environment.NewLine : "")}{Exception}";
}
}

View File

@ -2,7 +2,7 @@
"Logging": {
"IncludeScopes": false,
"LogLevel": {
"Microsoft.AspNetCore.Components": "Debug"
"Microsoft.AspNetCore.Components": "Trace"
},
"Debug": {
"LogLevel": {