diff --git a/src/Microsoft.AspNetCore.Authentication.Facebook/FacebookHandler.cs b/src/Microsoft.AspNetCore.Authentication.Facebook/FacebookHandler.cs index 0df42597dd..4deec4df7d 100644 --- a/src/Microsoft.AspNetCore.Authentication.Facebook/FacebookHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.Facebook/FacebookHandler.cs @@ -38,8 +38,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook var response = await Backchannel.GetAsync(endpoint, Context.RequestAborted); if (!response.IsSuccessStatusCode) { - var errorMessage = $"Failed to retrived Facebook user information ({response.StatusCode}) Please check if the authentication information is correct and the corresponding Google API is enabled."; - throw new InvalidOperationException(errorMessage); + throw new HttpRequestException($"Failed to retrived Facebook user information ({response.StatusCode}) Please check if the authentication information is correct and the corresponding Facebook API is enabled."); } var payload = JObject.Parse(await response.Content.ReadAsStringAsync()); diff --git a/src/Microsoft.AspNetCore.Authentication.Google/GoogleHandler.cs b/src/Microsoft.AspNetCore.Authentication.Google/GoogleHandler.cs index 4a81744b05..b9a24f58c2 100644 --- a/src/Microsoft.AspNetCore.Authentication.Google/GoogleHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.Google/GoogleHandler.cs @@ -34,8 +34,7 @@ namespace Microsoft.AspNetCore.Authentication.Google var response = await Backchannel.SendAsync(request, Context.RequestAborted); if (!response.IsSuccessStatusCode) { - var errorMessage = $"Failed to retrived Google user information ({response.StatusCode}) Please check if the authentication information is correct and the corresponding Google API is enabled."; - throw new InvalidOperationException(errorMessage); + throw new HttpRequestException($"An error occurred when retrieving user information ({response.StatusCode}). Please check if the authentication information is correct and the corresponding Google API is enabled."); } var payload = JObject.Parse(await response.Content.ReadAsStringAsync()); diff --git a/src/Microsoft.AspNetCore.Authentication.MicrosoftAccount/MicrosoftAccountHandler.cs b/src/Microsoft.AspNetCore.Authentication.MicrosoftAccount/MicrosoftAccountHandler.cs index 18f5df9d4a..ad4ceb81f7 100644 --- a/src/Microsoft.AspNetCore.Authentication.MicrosoftAccount/MicrosoftAccountHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.MicrosoftAccount/MicrosoftAccountHandler.cs @@ -28,8 +28,7 @@ namespace Microsoft.AspNetCore.Authentication.MicrosoftAccount var response = await Backchannel.SendAsync(request, Context.RequestAborted); if (!response.IsSuccessStatusCode) { - var errorMessage = $"Failed to retrived Microsoft user information ({response.StatusCode}) Please check if the authentication information is correct and the corresponding Google API is enabled."; - throw new InvalidOperationException(errorMessage); + throw new HttpRequestException($"Failed to retrived Microsoft user information ({response.StatusCode}) Please check if the authentication information is correct and the corresponding Microsoft API is enabled."); } var payload = JObject.Parse(await response.Content.ReadAsStringAsync()); diff --git a/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs b/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs index 29d46367db..353b2b5847 100644 --- a/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication.OAuth/OAuthHandler.cs @@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Http.Extensions; using Microsoft.AspNetCore.Http.Features.Authentication; using Microsoft.Extensions.Primitives; using Newtonsoft.Json.Linq; +using System.Threading; namespace Microsoft.AspNetCore.Authentication.OAuth { @@ -119,14 +120,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth properties.StoreTokens(authTokens); } - try + var ticket = await CreateTicketAsync(identity, properties, tokens); + if (ticket != null) { - var ticket = await CreateTicketAsync(identity, properties, tokens); return AuthenticateResult.Success(ticket); } - catch (Exception ex) + else { - return AuthenticateResult.Fail(ex); + return AuthenticateResult.Fail("Failed to retrieve user information from remote server."); } } diff --git a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs b/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs index 5891a005d3..f152ff1cfb 100644 --- a/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs +++ b/src/Microsoft.AspNetCore.Authentication/RemoteAuthenticationHandler.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features.Authentication; using Microsoft.AspNetCore.Http.Authentication; using Microsoft.Extensions.Logging; +using System.Net.Http; namespace Microsoft.AspNetCore.Authentication { @@ -31,26 +32,49 @@ namespace Microsoft.AspNetCore.Authentication protected virtual async Task HandleRemoteCallbackAsync() { - var authResult = await HandleRemoteAuthenticateAsync(); - if (authResult != null && authResult.Skipped) + AuthenticateResult authResult = null; + Exception exception = null; + + try { - return false; + authResult = await HandleRemoteAuthenticateAsync(); + if (authResult != null && authResult.Skipped == true) + { + return false; + } + else if (authResult == null) + { + exception = new InvalidOperationException("Invalide return state, unable to redirect."); + } + else if (!authResult.Succeeded) + { + exception = authResult?.Failure ?? + new InvalidOperationException("Invalide return state, unable to redirect."); + } } - if (authResult == null || !authResult.Succeeded) + catch (Exception ex) { - var errorContext = new FailureContext(Context, authResult?.Failure ?? new Exception("Invalid return state, unable to redirect.")); - Logger.RemoteAuthenticationError(errorContext.Failure.Message); + exception = ex; + } + + if (exception != null) + { + Logger.RemoteAuthenticationError(exception.Message); + var errorContext = new FailureContext(Context, exception); await Options.Events.RemoteFailure(errorContext); + if (errorContext.HandledResponse) { return true; } - if (errorContext.Skipped) + else if (errorContext.Skipped) { return false; } - - throw new AggregateException("Unhandled remote failure.", errorContext.Failure); + else + { + throw new AggregateException("Unhandled remote failure.", exception); + } } // We have a ticket if we get here