265 lines
12 KiB
C#
265 lines
12 KiB
C#
// Copyright (c) .NET Foundation. All rights reserved.
|
|
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Linq;
|
|
using System.Security.Claims;
|
|
using System.Threading;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNetCore.Components.Authorization;
|
|
using Microsoft.AspNetCore.Components.Server;
|
|
using Microsoft.AspNetCore.Testing;
|
|
using Microsoft.Extensions.Logging.Abstractions;
|
|
using Xunit;
|
|
|
|
namespace Microsoft.AspNetCore.Components
|
|
{
|
|
public class RevalidatingServerAuthenticationStateProviderTest
|
|
{
|
|
[Fact]
|
|
public void AcceptsAndReturnsAuthStateFromHost()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(TimeSpan.MaxValue);
|
|
|
|
// Act/Assert: Host can supply a value
|
|
var hostAuthStateTask = (new TaskCompletionSource<AuthenticationState>()).Task;
|
|
provider.SetAuthenticationState(hostAuthStateTask);
|
|
Assert.Same(hostAuthStateTask, provider.GetAuthenticationStateAsync());
|
|
|
|
// Act/Assert: Host can supply a changed value
|
|
var hostAuthStateTask2 = (new TaskCompletionSource<AuthenticationState>()).Task;
|
|
provider.SetAuthenticationState(hostAuthStateTask2);
|
|
Assert.Same(hostAuthStateTask2, provider.GetAuthenticationStateAsync());
|
|
}
|
|
|
|
[Fact]
|
|
public async Task IfValidateAuthenticationStateAsyncReturnsTrue_ContinuesRevalidating()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.NextValidationResult = Task.FromResult(true);
|
|
var didNotifyAuthenticationStateChanged = false;
|
|
provider.AuthenticationStateChanged += _ => { didNotifyAuthenticationStateChanged = true; };
|
|
|
|
// Act
|
|
for (var i = 0; i < 10; i++)
|
|
{
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
}
|
|
|
|
// Assert
|
|
Assert.Equal(10, provider.RevalidationCallLog.Count);
|
|
Assert.False(didNotifyAuthenticationStateChanged);
|
|
Assert.Equal("test user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task IfValidateAuthenticationStateAsyncReturnsFalse_ForcesSignOut()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.NextValidationResult = Task.FromResult(false);
|
|
|
|
var newAuthStateNotificationTcs = new TaskCompletionSource<Task<AuthenticationState>>();
|
|
provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask);
|
|
|
|
// Act
|
|
var newAuthStateTask = await newAuthStateNotificationTcs.Task;
|
|
var newAuthState = await newAuthStateTask;
|
|
|
|
// Assert
|
|
Assert.False(newAuthState.User.Identity.IsAuthenticated);
|
|
|
|
// Assert: no longer revalidates
|
|
await Task.Delay(200);
|
|
Assert.Single(provider.RevalidationCallLog);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task IfValidateAuthenticationStateAsyncThrows_ForcesSignOut()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.NextValidationResult = Task.FromException<bool>(new InvalidTimeZoneException());
|
|
|
|
var newAuthStateNotificationTcs = new TaskCompletionSource<Task<AuthenticationState>>();
|
|
provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask);
|
|
|
|
// Act
|
|
var newAuthStateTask = await newAuthStateNotificationTcs.Task;
|
|
var newAuthState = await newAuthStateTask;
|
|
|
|
// Assert
|
|
Assert.False(newAuthState.User.Identity.IsAuthenticated);
|
|
|
|
// Assert: no longer revalidates
|
|
await Task.Delay(200);
|
|
Assert.Single(provider.RevalidationCallLog);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task IfHostSuppliesNewAuthenticationState_RestartsRevalidationLoop()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.NextValidationResult = Task.FromResult(true);
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
Assert.Collection(provider.RevalidationCallLog,
|
|
call => Assert.Equal("test user", call.AuthenticationState.User.Identity.Name));
|
|
|
|
// Act/Assert 1: Can become signed out
|
|
// Doesn't revalidate unauthenticated states
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask(null));
|
|
await Task.Delay(200);
|
|
Assert.Empty(provider.RevalidationCallLog.Skip(1));
|
|
|
|
// Act/Assert 2: Can become a different user; resumes revalidation
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("different user"));
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
Assert.Collection(provider.RevalidationCallLog.Skip(1),
|
|
call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name));
|
|
}
|
|
|
|
[Fact]
|
|
public async Task StopsRevalidatingAfterDisposal()
|
|
{
|
|
// Arrange
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.NextValidationResult = Task.FromResult(true);
|
|
|
|
// Act
|
|
((IDisposable)provider).Dispose();
|
|
await Task.Delay(200);
|
|
|
|
// Assert
|
|
Assert.Empty(provider.RevalidationCallLog);
|
|
}
|
|
|
|
[Fact]
|
|
public async Task SuppliesCancellationTokenThatSignalsWhenRevalidationLoopIsBeingDiscarded()
|
|
{
|
|
// Arrange
|
|
var validationTcs = new TaskCompletionSource<bool>();
|
|
var authenticationStateChangedCount = 0;
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.NextValidationResult = validationTcs.Task;
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.AuthenticationStateChanged += _ => { authenticationStateChangedCount++; };
|
|
|
|
// Act/Assert 1: token isn't cancelled initially
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
var firstRevalidationCall = provider.RevalidationCallLog.Single();
|
|
Assert.False(firstRevalidationCall.CancellationToken.IsCancellationRequested);
|
|
Assert.Equal(0, authenticationStateChangedCount);
|
|
|
|
// Have the task throw a TCE to show this doesn't get treated as a failure
|
|
firstRevalidationCall.CancellationToken.Register(() => validationTcs.TrySetCanceled(firstRevalidationCall.CancellationToken));
|
|
|
|
// Act/Assert 2: token is cancelled when the loop is superseded
|
|
provider.NextValidationResult = Task.FromResult(true);
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("different user"));
|
|
Assert.True(firstRevalidationCall.CancellationToken.IsCancellationRequested);
|
|
|
|
// Since we asked for that operation to be cancelled, we don't treat it as a failure and
|
|
// don't force a logout
|
|
Assert.Equal(1, authenticationStateChangedCount);
|
|
Assert.Equal("different user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name);
|
|
|
|
// Subsequent revalidation can complete successfully
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
Assert.Collection(provider.RevalidationCallLog.Skip(1),
|
|
call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name));
|
|
}
|
|
|
|
[Fact]
|
|
public async Task IfValidateAuthenticationStateAsyncReturnsUnrelatedCancelledTask_TreatAsFailure()
|
|
{
|
|
// Arrange
|
|
var validationTcs = new TaskCompletionSource<bool>();
|
|
var incrementExecuted = new TaskCompletionSource<bool>();
|
|
var authenticationStateChangedCount = 0;
|
|
using var provider = new TestRevalidatingServerAuthenticationStateProvider(
|
|
TimeSpan.FromMilliseconds(50));
|
|
provider.NextValidationResult = validationTcs.Task;
|
|
provider.SetAuthenticationState(CreateAuthenticationStateTask("test user"));
|
|
provider.AuthenticationStateChanged += _ =>
|
|
{
|
|
authenticationStateChangedCount++;
|
|
incrementExecuted.TrySetResult(true);
|
|
};
|
|
|
|
// Be waiting for the first ValidateAuthenticationStateAsync to complete
|
|
await provider.NextValidateAuthenticationStateAsyncCall;
|
|
var firstRevalidationCall = provider.RevalidationCallLog.Single();
|
|
Assert.Equal(0, authenticationStateChangedCount);
|
|
|
|
// Act: ValidateAuthenticationStateAsync returns canceled task, but the cancellation
|
|
// is unrelated to the CT we supplied
|
|
validationTcs.TrySetCanceled(new CancellationTokenSource().Token);
|
|
|
|
// Assert: Since we didn't ask for that operation to be canceled, this is treated as
|
|
// a failure to validate, so we force a logout
|
|
await incrementExecuted.Task.TimeoutAfter(TimeSpan.FromSeconds(5));
|
|
Assert.Equal(1, authenticationStateChangedCount);
|
|
var newAuthState = await provider.GetAuthenticationStateAsync();
|
|
Assert.False(newAuthState.User.Identity.IsAuthenticated);
|
|
Assert.Null(newAuthState.User.Identity.Name);
|
|
}
|
|
|
|
static Task<AuthenticationState> CreateAuthenticationStateTask(string username)
|
|
{
|
|
var identity = !string.IsNullOrEmpty(username)
|
|
? new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, username) }, "testauth")
|
|
: new ClaimsIdentity();
|
|
var authenticationState = new AuthenticationState(new ClaimsPrincipal(identity));
|
|
return Task.FromResult(authenticationState);
|
|
}
|
|
|
|
class TestRevalidatingServerAuthenticationStateProvider : RevalidatingServerAuthenticationStateProvider
|
|
{
|
|
private readonly TimeSpan _revalidationInterval;
|
|
private TaskCompletionSource<object> _nextValidateAuthenticationStateAsyncCallSource
|
|
= new TaskCompletionSource<object>();
|
|
|
|
public TestRevalidatingServerAuthenticationStateProvider(TimeSpan revalidationInterval)
|
|
: base(NullLoggerFactory.Instance)
|
|
{
|
|
_revalidationInterval = revalidationInterval;
|
|
}
|
|
|
|
public Task<bool> NextValidationResult { get; set; }
|
|
|
|
public Task NextValidateAuthenticationStateAsyncCall
|
|
=> _nextValidateAuthenticationStateAsyncCallSource.Task;
|
|
|
|
public List<(AuthenticationState AuthenticationState, CancellationToken CancellationToken)> RevalidationCallLog { get; }
|
|
= new List<(AuthenticationState, CancellationToken)>();
|
|
|
|
protected override TimeSpan RevalidationInterval => _revalidationInterval;
|
|
|
|
protected override Task<bool> ValidateAuthenticationStateAsync(AuthenticationState authenticationState, CancellationToken cancellationToken)
|
|
{
|
|
RevalidationCallLog.Add((authenticationState, cancellationToken));
|
|
var result = NextValidationResult;
|
|
var prevCts = _nextValidateAuthenticationStateAsyncCallSource;
|
|
_nextValidateAuthenticationStateAsyncCallSource = new TaskCompletionSource<object>();
|
|
prevCts.SetResult(true);
|
|
return result;
|
|
}
|
|
}
|
|
}
|
|
}
|