aspnetcore/src/Microsoft.AspNetCore.Authen.../OAuthHandler.cs

228 lines
8.7 KiB
C#

// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Globalization;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Security.Claims;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Authentication;
using Microsoft.AspNetCore.Http.Extensions;
using Microsoft.AspNetCore.Http.Features.Authentication;
using Microsoft.Extensions.Primitives;
using Newtonsoft.Json.Linq;
using System.Threading;
namespace Microsoft.AspNetCore.Authentication.OAuth
{
public class OAuthHandler<TOptions> : RemoteAuthenticationHandler<TOptions> where TOptions : OAuthOptions
{
public OAuthHandler(HttpClient backchannel)
{
Backchannel = backchannel;
}
protected HttpClient Backchannel { get; private set; }
protected override async Task<AuthenticateResult> HandleRemoteAuthenticateAsync()
{
AuthenticationProperties properties = null;
var query = Request.Query;
var error = query["error"];
if (!StringValues.IsNullOrEmpty(error))
{
var failureMessage = new StringBuilder();
failureMessage.Append(error);
var errorDescription = query["error_description"];
if (!StringValues.IsNullOrEmpty(errorDescription))
{
failureMessage.Append(";Description=").Append(errorDescription);
}
var errorUri = query["error_uri"];
if (!StringValues.IsNullOrEmpty(errorUri))
{
failureMessage.Append(";Uri=").Append(errorUri);
}
return AuthenticateResult.Fail(failureMessage.ToString());
}
var code = query["code"];
var state = query["state"];
properties = Options.StateDataFormat.Unprotect(state);
if (properties == null)
{
return AuthenticateResult.Fail("The oauth state was missing or invalid.");
}
// OAuth2 10.12 CSRF
if (!ValidateCorrelationId(properties))
{
return AuthenticateResult.Fail("Correlation failed.");
}
if (StringValues.IsNullOrEmpty(code))
{
return AuthenticateResult.Fail("Code was not found.");
}
var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath));
if (tokens.Error != null)
{
return AuthenticateResult.Fail(tokens.Error);
}
if (string.IsNullOrEmpty(tokens.AccessToken))
{
return AuthenticateResult.Fail("Failed to retrieve access token.");
}
var identity = new ClaimsIdentity(Options.ClaimsIssuer);
if (Options.SaveTokens)
{
var authTokens = new List<AuthenticationToken>();
authTokens.Add(new AuthenticationToken { Name = "access_token", Value = tokens.AccessToken });
if (!string.IsNullOrEmpty(tokens.RefreshToken))
{
authTokens.Add(new AuthenticationToken { Name = "refresh_token", Value = tokens.RefreshToken });
}
if (!string.IsNullOrEmpty(tokens.TokenType))
{
authTokens.Add(new AuthenticationToken { Name = "token_type", Value = tokens.TokenType });
}
if (!string.IsNullOrEmpty(tokens.ExpiresIn))
{
int value;
if (int.TryParse(tokens.ExpiresIn, NumberStyles.Integer, CultureInfo.InvariantCulture, out value))
{
// https://www.w3.org/TR/xmlschema-2/#dateTime
// https://msdn.microsoft.com/en-us/library/az4se3k1(v=vs.110).aspx
var expiresAt = Options.SystemClock.UtcNow + TimeSpan.FromSeconds(value);
authTokens.Add(new AuthenticationToken
{
Name = "expires_at",
Value = expiresAt.ToString("o", CultureInfo.InvariantCulture)
});
}
}
properties.StoreTokens(authTokens);
}
var ticket = await CreateTicketAsync(identity, properties, tokens);
if (ticket != null)
{
return AuthenticateResult.Success(ticket);
}
else
{
return AuthenticateResult.Fail("Failed to retrieve user information from remote server.");
}
}
protected virtual async Task<OAuthTokenResponse> ExchangeCodeAsync(string code, string redirectUri)
{
var tokenRequestParameters = new Dictionary<string, string>()
{
{ "client_id", Options.ClientId },
{ "redirect_uri", redirectUri },
{ "client_secret", Options.ClientSecret },
{ "code", code },
{ "grant_type", "authorization_code" },
};
var requestContent = new FormUrlEncodedContent(tokenRequestParameters);
var requestMessage = new HttpRequestMessage(HttpMethod.Post, Options.TokenEndpoint);
requestMessage.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
requestMessage.Content = requestContent;
var response = await Backchannel.SendAsync(requestMessage, Context.RequestAborted);
if (response.IsSuccessStatusCode)
{
var payload = JObject.Parse(await response.Content.ReadAsStringAsync());
return OAuthTokenResponse.Success(payload);
}
else
{
var error = "OAuth token endpoint failure: " + await Display(response);
return OAuthTokenResponse.Failed(new Exception(error));
}
}
private static async Task<string> Display(HttpResponseMessage response)
{
var output = new StringBuilder();
output.Append("Status: " + response.StatusCode + ";");
output.Append("Headers: " + response.Headers.ToString() + ";");
output.Append("Body: " + await response.Content.ReadAsStringAsync() + ";");
return output.ToString();
}
protected virtual async Task<AuthenticationTicket> CreateTicketAsync(ClaimsIdentity identity, AuthenticationProperties properties, OAuthTokenResponse tokens)
{
var ticket = new AuthenticationTicket(new ClaimsPrincipal(identity), properties, Options.AuthenticationScheme);
var context = new OAuthCreatingTicketContext(ticket, Context, Options, Backchannel, tokens);
await Options.Events.CreatingTicket(context);
return context.Ticket;
}
protected override async Task<bool> HandleUnauthorizedAsync(ChallengeContext context)
{
if (context == null)
{
throw new ArgumentNullException(nameof(context));
}
var properties = new AuthenticationProperties(context.Properties);
if (string.IsNullOrEmpty(properties.RedirectUri))
{
properties.RedirectUri = CurrentUri;
}
// OAuth2 10.12 CSRF
GenerateCorrelationId(properties);
var authorizationEndpoint = BuildChallengeUrl(properties, BuildRedirectUri(Options.CallbackPath));
var redirectContext = new OAuthRedirectToAuthorizationContext(
Context, Options,
properties, authorizationEndpoint);
await Options.Events.RedirectToAuthorizationEndpoint(redirectContext);
return true;
}
protected virtual string BuildChallengeUrl(AuthenticationProperties properties, string redirectUri)
{
var scope = FormatScope();
var state = Options.StateDataFormat.Protect(properties);
var queryBuilder = new QueryBuilder()
{
{ "client_id", Options.ClientId },
{ "scope", scope },
{ "response_type", "code" },
{ "redirect_uri", redirectUri },
{ "state", state },
};
return Options.AuthorizationEndpoint + queryBuilder.ToString();
}
protected virtual string FormatScope()
{
// OAuth2 3.3 space separated
return string.Join(" ", Options.Scope);
}
}
}