// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Linq; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Components.Authorization; using Microsoft.AspNetCore.Components.Server; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging.Abstractions; using Xunit; namespace Microsoft.AspNetCore.Components { public class RevalidatingServerAuthenticationStateProviderTest { [Fact] public void AcceptsAndReturnsAuthStateFromHost() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider(TimeSpan.MaxValue); // Act/Assert: Host can supply a value var hostAuthStateTask = (new TaskCompletionSource()).Task; provider.SetAuthenticationState(hostAuthStateTask); Assert.Same(hostAuthStateTask, provider.GetAuthenticationStateAsync()); // Act/Assert: Host can supply a changed value var hostAuthStateTask2 = (new TaskCompletionSource()).Task; provider.SetAuthenticationState(hostAuthStateTask2); Assert.Same(hostAuthStateTask2, provider.GetAuthenticationStateAsync()); } [Fact] public async Task IfValidateAuthenticationStateAsyncReturnsTrue_ContinuesRevalidating() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.NextValidationResult = Task.FromResult(true); var didNotifyAuthenticationStateChanged = false; provider.AuthenticationStateChanged += _ => { didNotifyAuthenticationStateChanged = true; }; // Act for (var i = 0; i < 10; i++) { await provider.NextValidateAuthenticationStateAsyncCall; } // Assert Assert.Equal(10, provider.RevalidationCallLog.Count); Assert.False(didNotifyAuthenticationStateChanged); Assert.Equal("test user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name); } [Fact] public async Task IfValidateAuthenticationStateAsyncReturnsFalse_ForcesSignOut() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.NextValidationResult = Task.FromResult(false); var newAuthStateNotificationTcs = new TaskCompletionSource>(); provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask); // Act var newAuthStateTask = await newAuthStateNotificationTcs.Task; var newAuthState = await newAuthStateTask; // Assert Assert.False(newAuthState.User.Identity.IsAuthenticated); // Assert: no longer revalidates await Task.Delay(200); Assert.Single(provider.RevalidationCallLog); } [Fact] public async Task IfValidateAuthenticationStateAsyncThrows_ForcesSignOut() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.NextValidationResult = Task.FromException(new InvalidTimeZoneException()); var newAuthStateNotificationTcs = new TaskCompletionSource>(); provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask); // Act var newAuthStateTask = await newAuthStateNotificationTcs.Task; var newAuthState = await newAuthStateTask; // Assert Assert.False(newAuthState.User.Identity.IsAuthenticated); // Assert: no longer revalidates await Task.Delay(200); Assert.Single(provider.RevalidationCallLog); } [Fact] public async Task IfHostSuppliesNewAuthenticationState_RestartsRevalidationLoop() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.NextValidationResult = Task.FromResult(true); await provider.NextValidateAuthenticationStateAsyncCall; Assert.Collection(provider.RevalidationCallLog, call => Assert.Equal("test user", call.AuthenticationState.User.Identity.Name)); // Act/Assert 1: Can become signed out // Doesn't revalidate unauthenticated states provider.SetAuthenticationState(CreateAuthenticationStateTask(null)); await Task.Delay(200); Assert.Empty(provider.RevalidationCallLog.Skip(1)); // Act/Assert 2: Can become a different user; resumes revalidation provider.SetAuthenticationState(CreateAuthenticationStateTask("different user")); await provider.NextValidateAuthenticationStateAsyncCall; Assert.Collection(provider.RevalidationCallLog.Skip(1), call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name)); } [Fact] public async Task StopsRevalidatingAfterDisposal() { // Arrange using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.NextValidationResult = Task.FromResult(true); // Act ((IDisposable)provider).Dispose(); await Task.Delay(200); // Assert Assert.Empty(provider.RevalidationCallLog); } [Fact] public async Task SuppliesCancellationTokenThatSignalsWhenRevalidationLoopIsBeingDiscarded() { // Arrange var validationTcs = new TaskCompletionSource(); var authenticationStateChangedCount = 0; using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.NextValidationResult = validationTcs.Task; provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.AuthenticationStateChanged += _ => { authenticationStateChangedCount++; }; // Act/Assert 1: token isn't cancelled initially await provider.NextValidateAuthenticationStateAsyncCall; var firstRevalidationCall = provider.RevalidationCallLog.Single(); Assert.False(firstRevalidationCall.CancellationToken.IsCancellationRequested); Assert.Equal(0, authenticationStateChangedCount); // Have the task throw a TCE to show this doesn't get treated as a failure firstRevalidationCall.CancellationToken.Register(() => validationTcs.TrySetCanceled(firstRevalidationCall.CancellationToken)); // Act/Assert 2: token is cancelled when the loop is superseded provider.NextValidationResult = Task.FromResult(true); provider.SetAuthenticationState(CreateAuthenticationStateTask("different user")); Assert.True(firstRevalidationCall.CancellationToken.IsCancellationRequested); // Since we asked for that operation to be cancelled, we don't treat it as a failure and // don't force a logout Assert.Equal(1, authenticationStateChangedCount); Assert.Equal("different user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name); // Subsequent revalidation can complete successfully await provider.NextValidateAuthenticationStateAsyncCall; Assert.Collection(provider.RevalidationCallLog.Skip(1), call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name)); } [Fact] public async Task IfValidateAuthenticationStateAsyncReturnsUnrelatedCancelledTask_TreatAsFailure() { // Arrange var validationTcs = new TaskCompletionSource(); var incrementExecuted = new TaskCompletionSource(); var authenticationStateChangedCount = 0; using var provider = new TestRevalidatingServerAuthenticationStateProvider( TimeSpan.FromMilliseconds(50)); provider.NextValidationResult = validationTcs.Task; provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); provider.AuthenticationStateChanged += _ => { authenticationStateChangedCount++; incrementExecuted.TrySetResult(true); }; // Be waiting for the first ValidateAuthenticationStateAsync to complete await provider.NextValidateAuthenticationStateAsyncCall; var firstRevalidationCall = provider.RevalidationCallLog.Single(); Assert.Equal(0, authenticationStateChangedCount); // Act: ValidateAuthenticationStateAsync returns canceled task, but the cancellation // is unrelated to the CT we supplied validationTcs.TrySetCanceled(new CancellationTokenSource().Token); // Assert: Since we didn't ask for that operation to be canceled, this is treated as // a failure to validate, so we force a logout await incrementExecuted.Task.TimeoutAfter(TimeSpan.FromSeconds(5)); Assert.Equal(1, authenticationStateChangedCount); var newAuthState = await provider.GetAuthenticationStateAsync(); Assert.False(newAuthState.User.Identity.IsAuthenticated); Assert.Null(newAuthState.User.Identity.Name); } static Task CreateAuthenticationStateTask(string username) { var identity = !string.IsNullOrEmpty(username) ? new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, username) }, "testauth") : new ClaimsIdentity(); var authenticationState = new AuthenticationState(new ClaimsPrincipal(identity)); return Task.FromResult(authenticationState); } class TestRevalidatingServerAuthenticationStateProvider : RevalidatingServerAuthenticationStateProvider { private readonly TimeSpan _revalidationInterval; private TaskCompletionSource _nextValidateAuthenticationStateAsyncCallSource = new TaskCompletionSource(); public TestRevalidatingServerAuthenticationStateProvider(TimeSpan revalidationInterval) : base(NullLoggerFactory.Instance) { _revalidationInterval = revalidationInterval; } public Task NextValidationResult { get; set; } public Task NextValidateAuthenticationStateAsyncCall => _nextValidateAuthenticationStateAsyncCallSource.Task; public List<(AuthenticationState AuthenticationState, CancellationToken CancellationToken)> RevalidationCallLog { get; } = new List<(AuthenticationState, CancellationToken)>(); protected override TimeSpan RevalidationInterval => _revalidationInterval; protected override Task ValidateAuthenticationStateAsync(AuthenticationState authenticationState, CancellationToken cancellationToken) { RevalidationCallLog.Add((authenticationState, cancellationToken)); var result = NextValidationResult; var prevCts = _nextValidateAuthenticationStateAsyncCallSource; _nextValidateAuthenticationStateAsyncCallSource = new TaskCompletionSource(); prevCts.SetResult(true); return result; } } } }