// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Security.Cryptography; using System.Text.Encodings.Web; using System.Threading.Tasks; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Authentication { public abstract class RemoteAuthenticationHandler : AuthenticationHandler, IAuthenticationRequestHandler where TOptions : RemoteAuthenticationOptions, new() { private const string CorrelationProperty = ".xsrf"; private const string CorrelationMarker = "N"; private const string AuthSchemeKey = ".AuthScheme"; private static readonly RandomNumberGenerator CryptoRandom = RandomNumberGenerator.Create(); protected string SignInScheme => Options.SignInScheme; /// /// The handler calls methods on the events which give the application control at certain points where processing is occurring. /// If it is not provided a default instance is supplied which does nothing when the methods are called. /// protected new RemoteAuthenticationEvents Events { get { return (RemoteAuthenticationEvents)base.Events; } set { base.Events = value; } } protected RemoteAuthenticationHandler(IOptionsSnapshot options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock) : base(options, logger, encoder, clock) { } protected override Task CreateEventsAsync() => Task.FromResult(new RemoteAuthenticationEvents()); public virtual Task ShouldHandleRequestAsync() => Task.FromResult(Options.CallbackPath == Request.Path); public virtual async Task HandleRequestAsync() { if (!await ShouldHandleRequestAsync()) { return false; } AuthenticationTicket ticket = null; Exception exception = null; try { var authResult = await HandleRemoteAuthenticateAsync(); if (authResult == null) { exception = new InvalidOperationException("Invalid return state, unable to redirect."); } else if (authResult.Handled) { return true; } else if (authResult.Skipped || authResult.None) { return false; } else if (!authResult.Succeeded) { exception = authResult.Failure ?? new InvalidOperationException("Invalid return state, unable to redirect."); } ticket = authResult.Ticket; } catch (Exception ex) { exception = ex; } if (exception != null) { Logger.RemoteAuthenticationError(exception.Message); var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception); await Events.RemoteFailure(errorContext); if (errorContext.Result != null) { if (errorContext.Result.Handled) { return true; } else if (errorContext.Result.Skipped) { return false; } } throw exception; } // We have a ticket if we get here var ticketContext = new TicketReceivedContext(Context, Scheme, Options, ticket) { ReturnUri = ticket.Properties.RedirectUri }; // REVIEW: is this safe or good? ticket.Properties.RedirectUri = null; // Mark which provider produced this identity so we can cross-check later in HandleAuthenticateAsync ticketContext.Properties.Items[AuthSchemeKey] = Scheme.Name; await Events.TicketReceived(ticketContext); if (ticketContext.Result != null) { if (ticketContext.Result.Handled) { Logger.SigninHandled(); return true; } else if (ticketContext.Result.Skipped) { Logger.SigninSkipped(); return false; } } await Context.SignInAsync(SignInScheme, ticketContext.Principal, ticketContext.Properties); // Default redirect path is the base path if (string.IsNullOrEmpty(ticketContext.ReturnUri)) { ticketContext.ReturnUri = "/"; } Response.Redirect(ticketContext.ReturnUri); return true; } /// /// Authenticate the user identity with the identity provider. /// /// The method process the request on the endpoint defined by CallbackPath. /// protected abstract Task HandleRemoteAuthenticateAsync(); protected override async Task HandleAuthenticateAsync() { var result = await Context.AuthenticateAsync(SignInScheme); if (result != null) { if (result.Failure != null) { return result; } // The SignInScheme may be shared with multiple providers, make sure this provider issued the identity. string authenticatedScheme; var ticket = result.Ticket; if (ticket != null && ticket.Principal != null && ticket.Properties != null && ticket.Properties.Items.TryGetValue(AuthSchemeKey, out authenticatedScheme) && string.Equals(Scheme.Name, authenticatedScheme, StringComparison.Ordinal)) { return AuthenticateResult.Success(new AuthenticationTicket(ticket.Principal, ticket.Properties, Scheme.Name)); } return AuthenticateResult.Fail("Not authenticated"); } return AuthenticateResult.Fail("Remote authentication does not directly support AuthenticateAsync"); } protected override Task HandleForbiddenAsync(AuthenticationProperties properties) => Context.ForbidAsync(SignInScheme); protected virtual void GenerateCorrelationId(AuthenticationProperties properties) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } var bytes = new byte[32]; CryptoRandom.GetBytes(bytes); var correlationId = Base64UrlTextEncoder.Encode(bytes); var cookieOptions = Options.CorrelationCookie.Build(Context, Clock.UtcNow); properties.Items[CorrelationProperty] = correlationId; var cookieName = Options.CorrelationCookie.Name + Scheme.Name + "." + correlationId; Response.Cookies.Append(cookieName, CorrelationMarker, cookieOptions); } protected virtual bool ValidateCorrelationId(AuthenticationProperties properties) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } if (!properties.Items.TryGetValue(CorrelationProperty, out string correlationId)) { Logger.CorrelationPropertyNotFound(Options.CorrelationCookie.Name); return false; } properties.Items.Remove(CorrelationProperty); var cookieName = Options.CorrelationCookie.Name + Scheme.Name + "." + correlationId; var correlationCookie = Request.Cookies[cookieName]; if (string.IsNullOrEmpty(correlationCookie)) { Logger.CorrelationCookieNotFound(cookieName); return false; } var cookieOptions = Options.CorrelationCookie.Build(Context, Clock.UtcNow); Response.Cookies.Delete(cookieName, cookieOptions); if (!string.Equals(correlationCookie, CorrelationMarker, StringComparison.Ordinal)) { Logger.UnexpectedCorrelationCookieValue(cookieName, correlationCookie); return false; } return true; } } }