aspnetcore/src/Microsoft.AspNetCore.Authen.../RemoteAuthenticationHandler.cs

234 lines
8.6 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Security.Cryptography;
using System.Text.Encodings.Web;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
namespace Microsoft.AspNetCore.Authentication
{
public abstract class RemoteAuthenticationHandler<TOptions> : AuthenticationHandler<TOptions>, IAuthenticationRequestHandler
where TOptions : RemoteAuthenticationOptions, new()
{
private const string CorrelationProperty = ".xsrf";
private const string CorrelationMarker = "N";
private const string AuthSchemeKey = ".AuthScheme";
private static readonly RandomNumberGenerator CryptoRandom = RandomNumberGenerator.Create();
protected string SignInScheme => Options.SignInScheme;
/// <summary>
/// The handler calls methods on the events which give the application control at certain points where processing is occurring.
/// If it is not provided a default instance is supplied which does nothing when the methods are called.
/// </summary>
protected new RemoteAuthenticationEvents Events
{
get { return (RemoteAuthenticationEvents)base.Events; }
set { base.Events = value; }
}
protected RemoteAuthenticationHandler(IOptionsSnapshot<TOptions> options, ILoggerFactory logger, UrlEncoder encoder, ISystemClock clock)
: base(options, logger, encoder, clock) { }
protected override Task<object> CreateEventsAsync()
=> Task.FromResult<object>(new RemoteAuthenticationEvents());
public virtual Task<bool> ShouldHandleRequestAsync()
=> Task.FromResult(Options.CallbackPath == Request.Path);
public virtual async Task<bool> HandleRequestAsync()
{
if (!await ShouldHandleRequestAsync())
{
return false;
}
AuthenticationTicket ticket = null;
Exception exception = null;
try
{
var authResult = await HandleRemoteAuthenticateAsync();
if (authResult == null)
{
exception = new InvalidOperationException("Invalid return state, unable to redirect.");
}
else if (authResult.Handled)
{
return true;
}
else if (authResult.Skipped || authResult.None)
{
return false;
}
else if (!authResult.Succeeded)
{
exception = authResult.Failure ??
new InvalidOperationException("Invalid return state, unable to redirect.");
}
ticket = authResult.Ticket;
}
catch (Exception ex)
{
exception = ex;
}
if (exception != null)
{
Logger.RemoteAuthenticationError(exception.Message);
var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception);
await Events.RemoteFailure(errorContext);
if (errorContext.Result != null)
{
if (errorContext.Result.Handled)
{
return true;
}
else if (errorContext.Result.Skipped)
{
return false;
}
}
throw exception;
}
// We have a ticket if we get here
var ticketContext = new TicketReceivedContext(Context, Scheme, Options, ticket)
{
ReturnUri = ticket.Properties.RedirectUri
};
// REVIEW: is this safe or good?
ticket.Properties.RedirectUri = null;
// Mark which provider produced this identity so we can cross-check later in HandleAuthenticateAsync
ticketContext.Properties.Items[AuthSchemeKey] = Scheme.Name;
await Events.TicketReceived(ticketContext);
if (ticketContext.Result != null)
{
if (ticketContext.Result.Handled)
{
Logger.SigninHandled();
return true;
}
else if (ticketContext.Result.Skipped)
{
Logger.SigninSkipped();
return false;
}
}
await Context.SignInAsync(SignInScheme, ticketContext.Principal, ticketContext.Properties);
// Default redirect path is the base path
if (string.IsNullOrEmpty(ticketContext.ReturnUri))
{
ticketContext.ReturnUri = "/";
}
Response.Redirect(ticketContext.ReturnUri);
return true;
}
/// <summary>
/// Authenticate the user identity with the identity provider.
///
/// The method process the request on the endpoint defined by CallbackPath.
/// </summary>
protected abstract Task<HandleRequestResult> HandleRemoteAuthenticateAsync();
protected override async Task<AuthenticateResult> HandleAuthenticateAsync()
{
var result = await Context.AuthenticateAsync(SignInScheme);
if (result != null)
{
if (result.Failure != null)
{
return result;
}
// The SignInScheme may be shared with multiple providers, make sure this provider issued the identity.
string authenticatedScheme;
var ticket = result.Ticket;
if (ticket != null && ticket.Principal != null && ticket.Properties != null
&& ticket.Properties.Items.TryGetValue(AuthSchemeKey, out authenticatedScheme)
&& string.Equals(Scheme.Name, authenticatedScheme, StringComparison.Ordinal))
{
return AuthenticateResult.Success(new AuthenticationTicket(ticket.Principal,
ticket.Properties, Scheme.Name));
}
return AuthenticateResult.Fail("Not authenticated");
}
return AuthenticateResult.Fail("Remote authentication does not directly support AuthenticateAsync");
}
protected override Task HandleForbiddenAsync(AuthenticationProperties properties)
=> Context.ForbidAsync(SignInScheme);
protected virtual void GenerateCorrelationId(AuthenticationProperties properties)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
var bytes = new byte[32];
CryptoRandom.GetBytes(bytes);
var correlationId = Base64UrlTextEncoder.Encode(bytes);
var cookieOptions = Options.CorrelationCookie.Build(Context, Clock.UtcNow);
properties.Items[CorrelationProperty] = correlationId;
var cookieName = Options.CorrelationCookie.Name + Scheme.Name + "." + correlationId;
Response.Cookies.Append(cookieName, CorrelationMarker, cookieOptions);
}
protected virtual bool ValidateCorrelationId(AuthenticationProperties properties)
{
if (properties == null)
{
throw new ArgumentNullException(nameof(properties));
}
if (!properties.Items.TryGetValue(CorrelationProperty, out string correlationId))
{
Logger.CorrelationPropertyNotFound(Options.CorrelationCookie.Name);
return false;
}
properties.Items.Remove(CorrelationProperty);
var cookieName = Options.CorrelationCookie.Name + Scheme.Name + "." + correlationId;
var correlationCookie = Request.Cookies[cookieName];
if (string.IsNullOrEmpty(correlationCookie))
{
Logger.CorrelationCookieNotFound(cookieName);
return false;
}
var cookieOptions = Options.CorrelationCookie.Build(Context, Clock.UtcNow);
Response.Cookies.Delete(cookieName, cookieOptions);
if (!string.Equals(correlationCookie, CorrelationMarker, StringComparison.Ordinal))
{
Logger.UnexpectedCorrelationCookieValue(cookieName, correlationCookie);
return false;
}
return true;
}
}
}