// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Net.Http; using System.Net.Http.Headers; using System.Security.Claims; using System.Security.Cryptography; using System.Text; using System.Threading.Tasks; using Microsoft.AspNet.Http; using Microsoft.AspNet.Http.Authentication; using Microsoft.AspNet.Http.Extensions; using Microsoft.AspNet.Http.Features.Authentication; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; using Newtonsoft.Json.Linq; namespace Microsoft.AspNet.Authentication.OAuth { public class OAuthHandler : RemoteAuthenticationHandler where TOptions : OAuthOptions { private static readonly RandomNumberGenerator CryptoRandom = RandomNumberGenerator.Create(); public OAuthHandler(HttpClient backchannel) { Backchannel = backchannel; } protected HttpClient Backchannel { get; private set; } protected override async Task HandleRemoteAuthenticateAsync() { AuthenticationProperties properties = null; var query = Request.Query; var error = query["error"]; if (!StringValues.IsNullOrEmpty(error)) { var failureMessage = new StringBuilder(); failureMessage.Append(error); var errorDescription = query["error_description"]; if (!StringValues.IsNullOrEmpty(errorDescription)) { failureMessage.Append(";Description=").Append(errorDescription); } var errorUri = query["error_uri"]; if (!StringValues.IsNullOrEmpty(errorUri)) { failureMessage.Append(";Uri=").Append(errorUri); } return AuthenticateResult.Fail(failureMessage.ToString()); } var code = query["code"]; var state = query["state"]; properties = Options.StateDataFormat.Unprotect(state); if (properties == null) { return AuthenticateResult.Fail("The oauth state was missing or invalid."); } // OAuth2 10.12 CSRF if (!ValidateCorrelationId(properties)) { return AuthenticateResult.Fail("Correlation failed."); } if (StringValues.IsNullOrEmpty(code)) { return AuthenticateResult.Fail("Code was not found."); } var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath)); if (tokens.Error != null) { return AuthenticateResult.Fail(tokens.Error); } if (string.IsNullOrEmpty(tokens.AccessToken)) { return AuthenticateResult.Fail("Failed to retrieve access token."); } var identity = new ClaimsIdentity(Options.ClaimsIssuer); if (Options.SaveTokensAsClaims) { identity.AddClaim(new Claim("access_token", tokens.AccessToken, ClaimValueTypes.String, Options.ClaimsIssuer)); if (!string.IsNullOrEmpty(tokens.RefreshToken)) { identity.AddClaim(new Claim("refresh_token", tokens.RefreshToken, ClaimValueTypes.String, Options.ClaimsIssuer)); } if (!string.IsNullOrEmpty(tokens.TokenType)) { identity.AddClaim(new Claim("token_type", tokens.TokenType, ClaimValueTypes.String, Options.ClaimsIssuer)); } if (!string.IsNullOrEmpty(tokens.ExpiresIn)) { identity.AddClaim(new Claim("expires_in", tokens.ExpiresIn, ClaimValueTypes.String, Options.ClaimsIssuer)); } } return AuthenticateResult.Success(await CreateTicketAsync(identity, properties, tokens)); } protected virtual async Task ExchangeCodeAsync(string code, string redirectUri) { var tokenRequestParameters = new Dictionary() { { "client_id", Options.ClientId }, { "redirect_uri", redirectUri }, { "client_secret", Options.ClientSecret }, { "code", code }, { "grant_type", "authorization_code" }, }; var requestContent = new FormUrlEncodedContent(tokenRequestParameters); var requestMessage = new HttpRequestMessage(HttpMethod.Post, Options.TokenEndpoint); requestMessage.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); requestMessage.Content = requestContent; var response = await Backchannel.SendAsync(requestMessage, Context.RequestAborted); if (response.IsSuccessStatusCode) { var payload = JObject.Parse(await response.Content.ReadAsStringAsync()); return OAuthTokenResponse.Success(payload); } else { var error = "OAuth token endpoint failure: " + await Display(response); return OAuthTokenResponse.Failed(new Exception(error)); } } private static async Task Display(HttpResponseMessage response) { var output = new StringBuilder(); output.Append("Status: " + response.StatusCode + ";"); output.Append("Headers: " + response.Headers.ToString() + ";"); output.Append("Body: " + await response.Content.ReadAsStringAsync() + ";"); return output.ToString(); } protected virtual async Task CreateTicketAsync(ClaimsIdentity identity, AuthenticationProperties properties, OAuthTokenResponse tokens) { var context = new OAuthCreatingTicketContext(new ClaimsPrincipal(identity), properties, Context, Options, Backchannel, tokens); await Options.Events.CreatingTicket(context); if (context.Principal?.Identity == null) { return null; } return new AuthenticationTicket(context.Principal, context.Properties, Options.AuthenticationScheme); } protected override async Task HandleUnauthorizedAsync(ChallengeContext context) { if (context == null) { throw new ArgumentNullException(nameof(context)); } var properties = new AuthenticationProperties(context.Properties); if (string.IsNullOrEmpty(properties.RedirectUri)) { properties.RedirectUri = CurrentUri; } // OAuth2 10.12 CSRF GenerateCorrelationId(properties); var authorizationEndpoint = BuildChallengeUrl(properties, BuildRedirectUri(Options.CallbackPath)); var redirectContext = new OAuthRedirectToAuthorizationContext( Context, Options, properties, authorizationEndpoint); await Options.Events.RedirectToAuthorizationEndpoint(redirectContext); return true; } protected virtual string BuildChallengeUrl(AuthenticationProperties properties, string redirectUri) { var scope = FormatScope(); var state = Options.StateDataFormat.Protect(properties); var queryBuilder = new QueryBuilder() { { "client_id", Options.ClientId }, { "scope", scope }, { "response_type", "code" }, { "redirect_uri", redirectUri }, { "state", state }, }; return Options.AuthorizationEndpoint + queryBuilder.ToString(); } protected virtual string FormatScope() { // OAuth2 3.3 space separated return string.Join(" ", Options.Scope); } protected void GenerateCorrelationId(AuthenticationProperties properties) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } var correlationKey = Constants.CorrelationPrefix + Options.AuthenticationScheme; var nonceBytes = new byte[32]; CryptoRandom.GetBytes(nonceBytes); var correlationId = Base64UrlTextEncoder.Encode(nonceBytes); var cookieOptions = new CookieOptions { HttpOnly = true, Secure = Request.IsHttps }; properties.Items[correlationKey] = correlationId; Response.Cookies.Append(correlationKey, correlationId, cookieOptions); } protected bool ValidateCorrelationId(AuthenticationProperties properties) { if (properties == null) { throw new ArgumentNullException(nameof(properties)); } var correlationKey = Constants.CorrelationPrefix + Options.AuthenticationScheme; var correlationCookie = Request.Cookies[correlationKey]; if (string.IsNullOrEmpty(correlationCookie)) { Logger.LogWarning("{0} cookie not found.", correlationKey); return false; } var cookieOptions = new CookieOptions { HttpOnly = true, Secure = Request.IsHttps }; Response.Cookies.Delete(correlationKey, cookieOptions); string correlationExtra; if (!properties.Items.TryGetValue( correlationKey, out correlationExtra)) { Logger.LogWarning("{0} state property not found.", correlationKey); return false; } properties.Items.Remove(correlationKey); if (!string.Equals(correlationCookie, correlationExtra, StringComparison.Ordinal)) { Logger.LogWarning("{0} correlation cookie and state property mismatch.", correlationKey); return false; } return true; } } }