diff --git a/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs b/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs index ce7b6a1d26..5440426852 100644 --- a/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs +++ b/src/Components/Server/ref/Microsoft.AspNetCore.Components.Server.netcoreapp3.0.cs @@ -35,6 +35,20 @@ namespace Microsoft.AspNetCore.Components.Server public System.TimeSpan JSInteropDefaultCallTimeout { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } public int MaxBufferedUnacknowledgedRenderBatches { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } } } + public abstract partial class RevalidatingServerAuthenticationStateProvider : Microsoft.AspNetCore.Components.Server.ServerAuthenticationStateProvider, System.IDisposable + { + public RevalidatingServerAuthenticationStateProvider(Microsoft.Extensions.Logging.ILoggerFactory loggerFactory) { } + protected abstract System.TimeSpan RevalidationInterval { get; } + protected virtual void Dispose(bool disposing) { } + void System.IDisposable.Dispose() { } + protected abstract System.Threading.Tasks.Task ValidateAuthenticationStateAsync(Microsoft.AspNetCore.Components.AuthenticationState authenticationState, System.Threading.CancellationToken cancellationToken); + } + public partial class ServerAuthenticationStateProvider : Microsoft.AspNetCore.Components.AuthenticationStateProvider, Microsoft.AspNetCore.Components.IHostEnvironmentAuthenticationStateProvider + { + public ServerAuthenticationStateProvider() { } + public override System.Threading.Tasks.Task GetAuthenticationStateAsync() { throw null; } + public void SetAuthenticationState(System.Threading.Tasks.Task authenticationStateTask) { } + } } namespace Microsoft.AspNetCore.Components.Server.Circuits { diff --git a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs index df16c45c4e..f3463301b1 100644 --- a/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs +++ b/src/Components/Server/src/Circuits/DefaultCircuitFactory.cs @@ -58,13 +58,6 @@ namespace Microsoft.AspNetCore.Components.Server.Circuits jsRuntime.Initialize(client); componentContext.Initialize(client); - var authenticationStateProvider = scope.ServiceProvider.GetService() as IHostEnvironmentAuthenticationStateProvider; - if (authenticationStateProvider != null) - { - var authenticationState = new AuthenticationState(httpContext.User); // TODO: Get this from the hub connection context instead - authenticationStateProvider.SetAuthenticationState(Task.FromResult(authenticationState)); - } - var navigationManager = (RemoteNavigationManager)scope.ServiceProvider.GetRequiredService(); var navigationInterception = (RemoteNavigationInterception)scope.ServiceProvider.GetRequiredService(); if (client.Connected) diff --git a/src/Components/Server/src/Circuits/RevalidatingServerAuthenticationStateProvider.cs b/src/Components/Server/src/Circuits/RevalidatingServerAuthenticationStateProvider.cs new file mode 100644 index 0000000000..2895d6df94 --- /dev/null +++ b/src/Components/Server/src/Circuits/RevalidatingServerAuthenticationStateProvider.cs @@ -0,0 +1,119 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; + +namespace Microsoft.AspNetCore.Components.Server +{ + /// + /// A base class for services that receive an + /// authentication state from the host environment, and revalidate it at regular intervals. + /// + public abstract class RevalidatingServerAuthenticationStateProvider + : ServerAuthenticationStateProvider, IDisposable + { + private readonly ILogger _logger; + private CancellationTokenSource _loopCancellationTokenSource = new CancellationTokenSource(); + + /// + /// Constructs an instance of . + /// + /// A logger factory. + public RevalidatingServerAuthenticationStateProvider(ILoggerFactory loggerFactory) + { + if (loggerFactory is null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + + _logger = loggerFactory.CreateLogger(); + + // Whenever we receive notification of a new authentication state, cancel any + // existing revalidation loop and start a new one + AuthenticationStateChanged += authenticationStateTask => + { + _loopCancellationTokenSource?.Cancel(); + _loopCancellationTokenSource = new CancellationTokenSource(); + _ = RevalidationLoop(authenticationStateTask, _loopCancellationTokenSource.Token); + }; + } + + /// + /// Gets the interval between revalidation attempts. + /// + protected abstract TimeSpan RevalidationInterval { get; } + + /// + /// Determines whether the authentication state is still valid. + /// + /// The current . + /// A to observe while performing the operation. + /// A that resolves as true if the is still valid, or false if it is not. + protected abstract Task ValidateAuthenticationStateAsync(AuthenticationState authenticationState, CancellationToken cancellationToken); + + private async Task RevalidationLoop(Task authenticationStateTask, CancellationToken cancellationToken) + { + try + { + var authenticationState = await authenticationStateTask; + if (authenticationState.User.Identity.IsAuthenticated) + { + while (!cancellationToken.IsCancellationRequested) + { + bool isValid; + + try + { + await Task.Delay(RevalidationInterval, cancellationToken); + isValid = await ValidateAuthenticationStateAsync(authenticationState, cancellationToken); + } + catch (TaskCanceledException tce) + { + // If it was our cancellation token, then this revalidation loop gracefully completes + // Otherwise, treat it like any other failure + if (tce.CancellationToken == cancellationToken) + { + break; + } + + throw; + } + + if (!isValid) + { + ForceSignOut(); + break; + } + } + } + } + catch (Exception ex) + { + _logger.LogError(ex, "An error occurred while revalidating authentication state"); + ForceSignOut(); + } + } + + private void ForceSignOut() + { + var anonymousUser = new ClaimsPrincipal(new ClaimsIdentity()); + var anonymousState = new AuthenticationState(anonymousUser); + SetAuthenticationState(Task.FromResult(anonymousState)); + } + + void IDisposable.Dispose() + { + _loopCancellationTokenSource?.Cancel(); + Dispose(disposing: true); + } + + /// + protected virtual void Dispose(bool disposing) + { + } + } +} diff --git a/src/Components/Server/src/Circuits/ServerAuthenticationStateProvider.cs b/src/Components/Server/src/Circuits/ServerAuthenticationStateProvider.cs index bd1fecfa68..2708b902e9 100644 --- a/src/Components/Server/src/Circuits/ServerAuthenticationStateProvider.cs +++ b/src/Components/Server/src/Circuits/ServerAuthenticationStateProvider.cs @@ -4,19 +4,21 @@ using System; using System.Threading.Tasks; -namespace Microsoft.AspNetCore.Components.Server.Circuits +namespace Microsoft.AspNetCore.Components.Server { /// /// An intended for use in server-side Blazor. /// - internal class ServerAuthenticationStateProvider : AuthenticationStateProvider, IHostEnvironmentAuthenticationStateProvider + public class ServerAuthenticationStateProvider : AuthenticationStateProvider, IHostEnvironmentAuthenticationStateProvider { private Task _authenticationStateTask; + /// public override Task GetAuthenticationStateAsync() => _authenticationStateTask ?? throw new InvalidOperationException($"{nameof(GetAuthenticationStateAsync)} was called before {nameof(SetAuthenticationState)}."); + /// public void SetAuthenticationState(Task authenticationStateTask) { _authenticationStateTask = authenticationStateTask ?? throw new ArgumentNullException(nameof(authenticationStateTask)); diff --git a/src/Components/Server/test/Circuits/RevalidatingServerAuthenticationStateProvider.cs b/src/Components/Server/test/Circuits/RevalidatingServerAuthenticationStateProvider.cs new file mode 100644 index 0000000000..9f791c41eb --- /dev/null +++ b/src/Components/Server/test/Circuits/RevalidatingServerAuthenticationStateProvider.cs @@ -0,0 +1,256 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Components.Server; +using Microsoft.Extensions.Logging.Abstractions; +using Xunit; + +namespace Microsoft.AspNetCore.Components +{ + public class RevalidatingServerAuthenticationStateProviderTest + { + [Fact] + public void AcceptsAndReturnsAuthStateFromHost() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider(TimeSpan.MaxValue); + + // Act/Assert: Host can supply a value + var hostAuthStateTask = (new TaskCompletionSource()).Task; + provider.SetAuthenticationState(hostAuthStateTask); + Assert.Same(hostAuthStateTask, provider.GetAuthenticationStateAsync()); + + // Act/Assert: Host can supply a changed value + var hostAuthStateTask2 = (new TaskCompletionSource()).Task; + provider.SetAuthenticationState(hostAuthStateTask2); + Assert.Same(hostAuthStateTask2, provider.GetAuthenticationStateAsync()); + } + + [Fact] + public async Task IfValidateAuthenticationStateAsyncReturnsTrue_ContinuesRevalidating() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.NextValidationResult = Task.FromResult(true); + var didNotifyAuthenticationStateChanged = false; + provider.AuthenticationStateChanged += _ => { didNotifyAuthenticationStateChanged = true; }; + + // Act + for (var i = 0; i < 10; i++) + { + await provider.NextValidateAuthenticationStateAsyncCall; + } + + // Assert + Assert.Equal(10, provider.RevalidationCallLog.Count); + Assert.False(didNotifyAuthenticationStateChanged); + Assert.Equal("test user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name); + } + + [Fact] + public async Task IfValidateAuthenticationStateAsyncReturnsFalse_ForcesSignOut() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.NextValidationResult = Task.FromResult(false); + + var newAuthStateNotificationTcs = new TaskCompletionSource>(); + provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask); + + // Act + var newAuthStateTask = await newAuthStateNotificationTcs.Task; + var newAuthState = await newAuthStateTask; + + // Assert + Assert.False(newAuthState.User.Identity.IsAuthenticated); + + // Assert: no longer revalidates + await Task.Delay(200); + Assert.Single(provider.RevalidationCallLog); + } + + [Fact] + public async Task IfValidateAuthenticationStateAsyncThrows_ForcesSignOut() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.NextValidationResult = Task.FromException(new InvalidTimeZoneException()); + + var newAuthStateNotificationTcs = new TaskCompletionSource>(); + provider.AuthenticationStateChanged += newStateTask => newAuthStateNotificationTcs.SetResult(newStateTask); + + // Act + var newAuthStateTask = await newAuthStateNotificationTcs.Task; + var newAuthState = await newAuthStateTask; + + // Assert + Assert.False(newAuthState.User.Identity.IsAuthenticated); + + // Assert: no longer revalidates + await Task.Delay(200); + Assert.Single(provider.RevalidationCallLog); + } + + [Fact] + public async Task IfHostSuppliesNewAuthenticationState_RestartsRevalidationLoop() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.NextValidationResult = Task.FromResult(true); + await provider.NextValidateAuthenticationStateAsyncCall; + Assert.Collection(provider.RevalidationCallLog, + call => Assert.Equal("test user", call.AuthenticationState.User.Identity.Name)); + + // Act/Assert 1: Can become signed out + // Doesn't revalidate unauthenticated states + provider.SetAuthenticationState(CreateAuthenticationStateTask(null)); + await Task.Delay(200); + Assert.Empty(provider.RevalidationCallLog.Skip(1)); + + // Act/Assert 2: Can become a different user; resumes revalidation + provider.SetAuthenticationState(CreateAuthenticationStateTask("different user")); + await provider.NextValidateAuthenticationStateAsyncCall; + Assert.Collection(provider.RevalidationCallLog.Skip(1), + call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name)); + } + + [Fact] + public async Task StopsRevalidatingAfterDisposal() + { + // Arrange + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.NextValidationResult = Task.FromResult(true); + + // Act + ((IDisposable)provider).Dispose(); + await Task.Delay(200); + + // Assert + Assert.Empty(provider.RevalidationCallLog); + } + + [Fact] + public async Task SuppliesCancellationTokenThatSignalsWhenRevalidationLoopIsBeingDiscarded() + { + // Arrange + var validationTcs = new TaskCompletionSource(); + var authenticationStateChangedCount = 0; + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.NextValidationResult = validationTcs.Task; + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.AuthenticationStateChanged += _ => { authenticationStateChangedCount++; }; + + // Act/Assert 1: token isn't cancelled initially + await provider.NextValidateAuthenticationStateAsyncCall; + var firstRevalidationCall = provider.RevalidationCallLog.Single(); + Assert.False(firstRevalidationCall.CancellationToken.IsCancellationRequested); + Assert.Equal(0, authenticationStateChangedCount); + + // Have the task throw a TCE to show this doesn't get treated as a failure + firstRevalidationCall.CancellationToken.Register(() => validationTcs.TrySetCanceled(firstRevalidationCall.CancellationToken)); + + // Act/Assert 2: token is cancelled when the loop is superseded + provider.NextValidationResult = Task.FromResult(true); + provider.SetAuthenticationState(CreateAuthenticationStateTask("different user")); + Assert.True(firstRevalidationCall.CancellationToken.IsCancellationRequested); + + // Since we asked for that operation to be cancelled, we don't treat it as a failure and + // don't force a logout + Assert.Equal(1, authenticationStateChangedCount); + Assert.Equal("different user", (await provider.GetAuthenticationStateAsync()).User.Identity.Name); + + // Subsequent revalidation can complete successfully + await provider.NextValidateAuthenticationStateAsyncCall; + Assert.Collection(provider.RevalidationCallLog.Skip(1), + call => Assert.Equal("different user", call.AuthenticationState.User.Identity.Name)); + } + + [Fact] + public async Task IfValidateAuthenticationStateAsyncReturnsUnrelatedCancelledTask_TreatAsFailure() + { + // Arrange + var validationTcs = new TaskCompletionSource(); + var authenticationStateChangedCount = 0; + using var provider = new TestRevalidatingServerAuthenticationStateProvider( + TimeSpan.FromMilliseconds(50)); + provider.NextValidationResult = validationTcs.Task; + provider.SetAuthenticationState(CreateAuthenticationStateTask("test user")); + provider.AuthenticationStateChanged += _ => { authenticationStateChangedCount++; }; + + // Be waiting for the first ValidateAuthenticationStateAsync to complete + await provider.NextValidateAuthenticationStateAsyncCall; + var firstRevalidationCall = provider.RevalidationCallLog.Single(); + Assert.Equal(0, authenticationStateChangedCount); + + // Act: ValidateAuthenticationStateAsync returns cancelled task, but the cancellation + // is unrelated to the CT we supplied + validationTcs.TrySetCanceled(new CancellationTokenSource().Token); + + // Assert: Since we didn't ask for that operation to be cancelled, this is treated as + // a failure to validate, so we force a logout + Assert.Equal(1, authenticationStateChangedCount); + var newAuthState = await provider.GetAuthenticationStateAsync(); + Assert.False(newAuthState.User.Identity.IsAuthenticated); + Assert.Null(newAuthState.User.Identity.Name); + } + + static Task CreateAuthenticationStateTask(string username) + { + var identity = !string.IsNullOrEmpty(username) + ? new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, username) }, "testauth") + : new ClaimsIdentity(); + var authenticationState = new AuthenticationState(new ClaimsPrincipal(identity)); + return Task.FromResult(authenticationState); + } + + class TestRevalidatingServerAuthenticationStateProvider : RevalidatingServerAuthenticationStateProvider + { + private readonly TimeSpan _revalidationInterval; + private TaskCompletionSource _nextValidateAuthenticationStateAsyncCallSource + = new TaskCompletionSource(); + + public TestRevalidatingServerAuthenticationStateProvider(TimeSpan revalidationInterval) + : base(NullLoggerFactory.Instance) + { + _revalidationInterval = revalidationInterval; + } + + public Task NextValidationResult { get; set; } + + public Task NextValidateAuthenticationStateAsyncCall + => _nextValidateAuthenticationStateAsyncCallSource.Task; + + public List<(AuthenticationState AuthenticationState, CancellationToken CancellationToken)> RevalidationCallLog { get; } + = new List<(AuthenticationState, CancellationToken)>(); + + protected override TimeSpan RevalidationInterval => _revalidationInterval; + + protected override Task ValidateAuthenticationStateAsync(AuthenticationState authenticationState, CancellationToken cancellationToken) + { + RevalidationCallLog.Add((authenticationState, cancellationToken)); + var result = NextValidationResult; + var prevCts = _nextValidateAuthenticationStateAsyncCallSource; + _nextValidateAuthenticationStateAsyncCallSource = new TaskCompletionSource(); + prevCts.SetResult(true); + return result; + } + } + } +} diff --git a/src/Components/test/E2ETest/ServerExecutionTests/InteropReliabilityTests.cs b/src/Components/test/E2ETest/ServerExecutionTests/InteropReliabilityTests.cs index cf1be22179..b4ccc6123e 100644 --- a/src/Components/test/E2ETest/ServerExecutionTests/InteropReliabilityTests.cs +++ b/src/Components/test/E2ETest/ServerExecutionTests/InteropReliabilityTests.cs @@ -10,6 +10,8 @@ using Ignitor; using Microsoft.AspNetCore.Components.E2ETest.Infrastructure.ServerFixtures; using Microsoft.AspNetCore.Components.Web; using Microsoft.AspNetCore.SignalR.Client; +using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; @@ -17,6 +19,7 @@ using Xunit; namespace Microsoft.AspNetCore.Components.E2ETest.ServerExecutionTests { + [Flaky("https://github.com/aspnet/AspNetCore/issues/12940", FlakyOn.All)] public class InteropReliabilityTests : IClassFixture { private static readonly TimeSpan DefaultLatencyTimeout = TimeSpan.FromMilliseconds(500); diff --git a/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingAuthenticationStateProvider.cs b/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingAuthenticationStateProvider.cs deleted file mode 100644 index 3895c488a9..0000000000 --- a/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingAuthenticationStateProvider.cs +++ /dev/null @@ -1,98 +0,0 @@ -using System; -using System.Security.Claims; -using System.Threading; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Components; -using Microsoft.AspNetCore.Identity; -using Microsoft.Extensions.DependencyInjection; -using Microsoft.Extensions.Logging; - -namespace BlazorServerWeb_CSharp.Areas.Identity -{ - /// - /// An service that revalidates the - /// authentication state at regular intervals. If a signed-in user's security - /// stamp changes, this revalidation mechanism will sign the user out. - /// - /// The type encapsulating a user. - public class RevalidatingAuthenticationStateProvider - : AuthenticationStateProvider, IDisposable where TUser : class - { - private readonly static TimeSpan RevalidationInterval = TimeSpan.FromMinutes(30); - - private readonly CancellationTokenSource _loopCancellationTokenSource = new CancellationTokenSource(); - private readonly IServiceScopeFactory _scopeFactory; - private readonly ILogger _logger; - private Task _currentAuthenticationStateTask; - - public RevalidatingAuthenticationStateProvider( - IServiceScopeFactory scopeFactory, - SignInManager circuitScopeSignInManager, - ILogger> logger) - { - var initialUser = circuitScopeSignInManager.Context.User; - _currentAuthenticationStateTask = Task.FromResult(new AuthenticationState(initialUser)); - _scopeFactory = scopeFactory; - _logger = logger; - - if (initialUser.Identity.IsAuthenticated) - { - _ = RevalidationLoop(); - } - } - - public override Task GetAuthenticationStateAsync() - => _currentAuthenticationStateTask; - - private async Task RevalidationLoop() - { - var cancellationToken = _loopCancellationTokenSource.Token; - - while (!cancellationToken.IsCancellationRequested) - { - try - { - await Task.Delay(RevalidationInterval, cancellationToken); - } - catch (TaskCanceledException) - { - break; - } - - var isValid = await CheckIfAuthenticationStateIsValidAsync(); - if (!isValid) - { - // Force sign-out. Also stop the revalidation loop, because the user can - // only sign back in by starting a new connection. - var anonymousUser = new ClaimsPrincipal(new ClaimsIdentity()); - _currentAuthenticationStateTask = Task.FromResult(new AuthenticationState(anonymousUser)); - NotifyAuthenticationStateChanged(_currentAuthenticationStateTask); - _loopCancellationTokenSource.Cancel(); - } - } - } - - private async Task CheckIfAuthenticationStateIsValidAsync() - { - try - { - // Get the sign-in manager from a new scope to ensure it fetches fresh data - using (var scope = _scopeFactory.CreateScope()) - { - var signInManager = scope.ServiceProvider.GetRequiredService>(); - var authenticationState = await _currentAuthenticationStateTask; - var validatedUser = await signInManager.ValidateSecurityStampAsync(authenticationState.User); - return validatedUser != null; - } - } - catch (Exception ex) - { - _logger.LogError(ex, "An error occurred while revalidating authentication state"); - return false; - } - } - - void IDisposable.Dispose() - => _loopCancellationTokenSource.Cancel(); - } -} diff --git a/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingIdentityAuthenticationStateProvider.cs b/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingIdentityAuthenticationStateProvider.cs new file mode 100644 index 0000000000..dd4fb3e33e --- /dev/null +++ b/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Areas/Identity/RevalidatingIdentityAuthenticationStateProvider.cs @@ -0,0 +1,74 @@ +using System; +using System.Security.Claims; +using System.Threading; +using System.Threading.Tasks; +using Microsoft.AspNetCore.Components; +using Microsoft.AspNetCore.Components.Server; +using Microsoft.AspNetCore.Identity; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; + +namespace BlazorServerWeb_CSharp.Areas.Identity +{ + public class RevalidatingIdentityAuthenticationStateProvider + : RevalidatingServerAuthenticationStateProvider where TUser : class + { + private readonly IServiceScopeFactory _scopeFactory; + private readonly IdentityOptions _options; + + public RevalidatingIdentityAuthenticationStateProvider( + ILoggerFactory loggerFactory, + IServiceScopeFactory scopeFactory, + IOptions optionsAccessor) + : base(loggerFactory) + { + _scopeFactory = scopeFactory; + _options = optionsAccessor.Value; + } + + protected override TimeSpan RevalidationInterval => TimeSpan.FromMinutes(30); + + protected override async Task ValidateAuthenticationStateAsync( + AuthenticationState authenticationState, CancellationToken cancellationToken) + { + // Get the user manager from a new scope to ensure it fetches fresh data + var scope = _scopeFactory.CreateScope(); + try + { + var userManager = scope.ServiceProvider.GetRequiredService>(); + return await ValidateSecurityStampAsync(userManager, authenticationState.User); + } + finally + { + if (scope is IAsyncDisposable asyncDisposable) + { + await asyncDisposable.DisposeAsync(); + } + else + { + scope.Dispose(); + } + } + } + + private async Task ValidateSecurityStampAsync(UserManager userManager, ClaimsPrincipal principal) + { + var user = await userManager.GetUserAsync(principal); + if (user == null) + { + return false; + } + else if (!userManager.SupportsUserSecurityStamp) + { + return true; + } + else + { + var principalStamp = principal.FindFirstValue(_options.ClaimsIdentity.SecurityStampClaimType); + var userStamp = await userManager.GetSecurityStampAsync(user); + return principalStamp == userStamp; + } + } + } +} diff --git a/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Startup.cs b/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Startup.cs index 7840c30734..7a6742000f 100644 --- a/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Startup.cs +++ b/src/ProjectTemplates/Web.ProjectTemplates/content/BlazorServerWeb-CSharp/Startup.cs @@ -127,7 +127,7 @@ namespace BlazorServerWeb_CSharp services.AddRazorPages(); services.AddServerSideBlazor(); #if (IndividualLocalAuth) - services.AddScoped>(); + services.AddScoped>(); #endif services.AddSingleton(); } diff --git a/src/ProjectTemplates/test/template-baselines.json b/src/ProjectTemplates/test/template-baselines.json index 413df6c8e8..48b625bc37 100644 --- a/src/ProjectTemplates/test/template-baselines.json +++ b/src/ProjectTemplates/test/template-baselines.json @@ -905,7 +905,7 @@ "_Imports.razor", "Areas/Identity/Pages/Account/LogOut.cshtml", "Areas/Identity/Pages/Shared/_LoginPartial.cshtml", - "Areas/Identity/RevalidatingAuthenticationStateProvider.cs", + "Areas/Identity/RevalidatingIdentityAuthenticationStateProvider.cs", "Data/ApplicationDbContext.cs", "Data/WeatherForecast.cs", "Data/WeatherForecastService.cs",