#1188 Add AuthenticationProperties to HandleRequestResult and RemoteFailureContext

This commit is contained in:
Chris R 2017-07-06 12:36:34 -07:00 committed by Chris Ross (ASP.NET)
parent 5abcfe7e3d
commit 144ee21696
11 changed files with 346 additions and 178 deletions

View File

@ -67,6 +67,10 @@ namespace SocialSample
o.Fields.Add("name");
o.Fields.Add("email");
o.SaveTokens = true;
o.Events = new OAuthEvents()
{
OnRemoteFailure = HandleOnRemoteFailure
};
})
// You must first create an app with Google and add its ID and Secret to your user-secrets.
// https://console.developers.google.com/project
@ -81,6 +85,10 @@ namespace SocialSample
o.Scope.Add("profile");
o.Scope.Add("email");
o.SaveTokens = true;
o.Events = new OAuthEvents()
{
OnRemoteFailure = HandleOnRemoteFailure
};
})
// You must first create an app with Google and add its ID and Secret to your user-secrets.
// https://console.developers.google.com/project
@ -93,12 +101,7 @@ namespace SocialSample
o.SaveTokens = true;
o.Events = new OAuthEvents()
{
OnRemoteFailure = ctx =>
{
ctx.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(ctx.Failure.Message));
ctx.HandleResponse();
return Task.FromResult(0);
}
OnRemoteFailure = HandleOnRemoteFailure
};
o.ClaimActions.MapJsonSubKey("urn:google:image", "image", "url");
o.ClaimActions.Remove(ClaimTypes.GivenName);
@ -116,12 +119,7 @@ namespace SocialSample
o.ClaimActions.MapJsonKey("urn:twitter:profilepicture", "profile_image_url", ClaimTypes.Uri);
o.Events = new TwitterEvents()
{
OnRemoteFailure = ctx =>
{
ctx.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(ctx.Failure.Message));
ctx.HandleResponse();
return Task.FromResult(0);
}
OnRemoteFailure = HandleOnRemoteFailure
};
})
/* Azure AD app model v2 has restrictions that prevent the use of plain HTTP for redirect URLs.
@ -139,6 +137,10 @@ namespace SocialSample
o.TokenEndpoint = MicrosoftAccountDefaults.TokenEndpoint;
o.Scope.Add("https://graph.microsoft.com/user.read");
o.SaveTokens = true;
o.Events = new OAuthEvents()
{
OnRemoteFailure = HandleOnRemoteFailure
};
})
// You must first create an app with Microsoft Account and add its ID and Secret to your user-secrets.
// https://azure.microsoft.com/en-us/documentation/articles/active-directory-v2-app-registration/
@ -148,6 +150,10 @@ namespace SocialSample
o.ClientSecret = Configuration["microsoftaccount:clientsecret"];
o.SaveTokens = true;
o.Scope.Add("offline_access");
o.Events = new OAuthEvents()
{
OnRemoteFailure = HandleOnRemoteFailure
};
})
// You must first create an app with GitHub and add its ID and Secret to your user-secrets.
// https://github.com/settings/applications/
@ -159,6 +165,10 @@ namespace SocialSample
o.AuthorizationEndpoint = "https://github.com/login/oauth/authorize";
o.TokenEndpoint = "https://github.com/login/oauth/access_token";
o.SaveTokens = true;
o.Events = new OAuthEvents()
{
OnRemoteFailure = HandleOnRemoteFailure
};
})
// You must first create an app with GitHub and add its ID and Secret to your user-secrets.
// https://github.com/settings/applications/
@ -180,6 +190,7 @@ namespace SocialSample
o.ClaimActions.MapJsonKey("urn:github:url", "url");
o.Events = new OAuthEvents
{
OnRemoteFailure = HandleOnRemoteFailure,
OnCreatingTicket = async context =>
{
// Get the GitHub user
@ -198,6 +209,30 @@ namespace SocialSample
});
}
private async Task HandleOnRemoteFailure(RemoteFailureContext context)
{
context.Response.StatusCode = 500;
context.Response.ContentType = "text/html";
await context.Response.WriteAsync("<html><body>");
await context.Response.WriteAsync("A remote failure has occurred: " + UrlEncoder.Default.Encode(context.Failure.Message) + "<br>");
if (context.Properties != null)
{
await context.Response.WriteAsync("Properties:<br>");
foreach (var pair in context.Properties.Items)
{
await context.Response.WriteAsync($"-{ UrlEncoder.Default.Encode(pair.Key)}={ UrlEncoder.Default.Encode(pair.Value)}<br>");
}
}
await context.Response.WriteAsync("<a href=\"/\">Home</a>");
await context.Response.WriteAsync("</body></html>");
// context.Response.Redirect("/error?FailureMessage=" + UrlEncoder.Default.Encode(context.Failure.Message));
context.HandleResponse();
}
public void Configure(IApplicationBuilder app)
{
app.UseDeveloperExceptionPage();

View File

@ -44,9 +44,22 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync()
{
AuthenticationProperties properties = null;
var query = Request.Query;
var state = query["state"];
var properties = Options.StateDataFormat.Unprotect(state);
if (properties == null)
{
return HandleRequestResult.Fail("The oauth state was missing or invalid.");
}
// OAuth2 10.12 CSRF
if (!ValidateCorrelationId(properties))
{
return HandleRequestResult.Fail("Correlation failed.", properties);
}
var error = query["error"];
if (!StringValues.IsNullOrEmpty(error))
{
@ -63,39 +76,26 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
failureMessage.Append(";Uri=").Append(errorUri);
}
return HandleRequestResult.Fail(failureMessage.ToString());
return HandleRequestResult.Fail(failureMessage.ToString(), properties);
}
var code = query["code"];
var state = query["state"];
properties = Options.StateDataFormat.Unprotect(state);
if (properties == null)
{
return HandleRequestResult.Fail("The oauth state was missing or invalid.");
}
// OAuth2 10.12 CSRF
if (!ValidateCorrelationId(properties))
{
return HandleRequestResult.Fail("Correlation failed.");
}
if (StringValues.IsNullOrEmpty(code))
{
return HandleRequestResult.Fail("Code was not found.");
return HandleRequestResult.Fail("Code was not found.", properties);
}
var tokens = await ExchangeCodeAsync(code, BuildRedirectUri(Options.CallbackPath));
if (tokens.Error != null)
{
return HandleRequestResult.Fail(tokens.Error);
return HandleRequestResult.Fail(tokens.Error, properties);
}
if (string.IsNullOrEmpty(tokens.AccessToken))
{
return HandleRequestResult.Fail("Failed to retrieve access token.");
return HandleRequestResult.Fail("Failed to retrieve access token.", properties);
}
var identity = new ClaimsIdentity(ClaimsIssuer);
@ -141,7 +141,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
}
else
{
return HandleRequestResult.Fail("Failed to retrieve user information from remote server.");
return HandleRequestResult.Fail("Failed to retrieve user information from remote server.", properties);
}
}

View File

@ -491,13 +491,10 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
return HandleRequestResult.Fail("No message.");
}
AuthenticationProperties properties = null;
try
{
AuthenticationProperties properties = null;
if (!string.IsNullOrEmpty(authorizationResponse.State))
{
properties = Options.StateDataFormat.Unprotect(authorizationResponse.State);
}
properties = ReadPropertiesAndClearState(authorizationResponse);
var messageReceivedContext = await RunMessageReceivedEventAsync(authorizationResponse, properties);
if (messageReceivedContext.Result != null)
@ -521,8 +518,7 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
return HandleRequestResult.Fail(Resources.MessageStateIsNullOrEmpty);
}
// if state exists and we failed to 'unprotect' this is not a message we should process.
properties = Options.StateDataFormat.Unprotect(authorizationResponse.State);
properties = ReadPropertiesAndClearState(authorizationResponse);
}
if (properties == null)
@ -533,21 +529,20 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
// Not for us?
return HandleRequestResult.SkipHandler();
}
// if state exists and we failed to 'unprotect' this is not a message we should process.
return HandleRequestResult.Fail(Resources.MessageStateIsInvalid);
}
properties.Items.TryGetValue(OpenIdConnectDefaults.UserstatePropertiesKey, out string userstate);
authorizationResponse.State = userstate;
if (!ValidateCorrelationId(properties))
{
return HandleRequestResult.Fail("Correlation failed.");
return HandleRequestResult.Fail("Correlation failed.", properties);
}
// if any of the error fields are set, throw error null
if (!string.IsNullOrEmpty(authorizationResponse.Error))
{
return HandleRequestResult.Fail(CreateOpenIdConnectProtocolException(authorizationResponse, response: null));
return HandleRequestResult.Fail(CreateOpenIdConnectProtocolException(authorizationResponse, response: null), properties);
}
if (_configuration == null && Options.ConfigurationManager != null)
@ -635,8 +630,7 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
// At least a cursory validation is required on the new IdToken, even if we've already validated the one from the authorization response.
// And we'll want to validate the new JWT in ValidateTokenResponse.
JwtSecurityToken tokenEndpointJwt;
var tokenEndpointUser = ValidateToken(tokenEndpointResponse.IdToken, properties, validationParameters, out tokenEndpointJwt);
var tokenEndpointUser = ValidateToken(tokenEndpointResponse.IdToken, properties, validationParameters, out var tokenEndpointJwt);
// Avoid reading & deleting the nonce cookie, running the event, etc, if it was already done as part of the authorization response validation.
if (user == null)
@ -722,10 +716,27 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
return authenticationFailedContext.Result;
}
return HandleRequestResult.Fail(exception);
return HandleRequestResult.Fail(exception, properties);
}
}
private AuthenticationProperties ReadPropertiesAndClearState(OpenIdConnectMessage message)
{
AuthenticationProperties properties = null;
if (!string.IsNullOrEmpty(message.State))
{
properties = Options.StateDataFormat.Unprotect(message.State);
if (properties != null)
{
// If properties can be decoded from state, clear the message state.
properties.Items.TryGetValue(OpenIdConnectDefaults.UserstatePropertiesKey, out var userstate);
message.State = userstate;
}
}
return properties;
}
private void PopulateSessionProperties(OpenIdConnectMessage message, AuthenticationProperties properties)
{
if (!string.IsNullOrEmpty(message.SessionState))
@ -830,7 +841,7 @@ namespace Microsoft.AspNetCore.Authentication.OpenIdConnect
}
else
{
return HandleRequestResult.Fail("Unknown response type: " + contentType.MediaType);
return HandleRequestResult.Fail("Unknown response type: " + contentType.MediaType, properties);
}
var userInformationReceivedContext = await RunUserInformationReceivedEventAsync(principal, properties, message, user);

View File

@ -46,7 +46,6 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
protected override async Task<HandleRequestResult> HandleRemoteAuthenticateAsync()
{
AuthenticationProperties properties = null;
var query = Request.Query;
var protectedRequestToken = Request.Cookies[Options.StateCookie.Name];
@ -57,25 +56,25 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
return HandleRequestResult.Fail("Invalid state cookie.");
}
properties = requestToken.Properties;
var properties = requestToken.Properties;
// REVIEW: see which of these are really errors
var returnedToken = query["oauth_token"];
if (StringValues.IsNullOrEmpty(returnedToken))
{
return HandleRequestResult.Fail("Missing oauth_token");
return HandleRequestResult.Fail("Missing oauth_token", properties);
}
if (!string.Equals(returnedToken, requestToken.Token, StringComparison.Ordinal))
{
return HandleRequestResult.Fail("Unmatched token");
return HandleRequestResult.Fail("Unmatched token", properties);
}
var oauthVerifier = query["oauth_verifier"];
if (StringValues.IsNullOrEmpty(oauthVerifier))
{
return HandleRequestResult.Fail("Missing or blank oauth_verifier");
return HandleRequestResult.Fail("Missing or blank oauth_verifier", properties);
}
var cookieOptions = Options.StateCookie.Build(Context, Clock.UtcNow);

View File

@ -25,5 +25,10 @@ namespace Microsoft.AspNetCore.Authentication
/// User friendly error message for the error.
/// </summary>
public Exception Failure { get; set; }
/// <summary>
/// Additional state values for the authentication session.
/// </summary>
public AuthenticationProperties Properties { get; set; }
}
}

View File

@ -46,15 +46,33 @@ namespace Microsoft.AspNetCore.Authentication
return new HandleRequestResult() { Failure = failure };
}
/// <summary>
/// Indicates that there was a failure during authentication.
/// </summary>
/// <param name="failure">The failure exception.</param>
/// <param name="properties">Additional state values for the authentication session.</param>
/// <returns>The result.</returns>
public static new HandleRequestResult Fail(Exception failure, AuthenticationProperties properties)
{
return new HandleRequestResult() { Failure = failure, Properties = properties };
}
/// <summary>
/// Indicates that there was a failure during authentication.
/// </summary>
/// <param name="failureMessage">The failure message.</param>
/// <returns>The result.</returns>
public static new HandleRequestResult Fail(string failureMessage)
{
return new HandleRequestResult() { Failure = new Exception(failureMessage) };
}
=> Fail(new Exception(failureMessage));
/// <summary>
/// Indicates that there was a failure during authentication.
/// </summary>
/// <param name="failureMessage">The failure message.</param>
/// <param name="properties">Additional state values for the authentication session.</param>
/// <returns>The result.</returns>
public static new HandleRequestResult Fail(string failureMessage, AuthenticationProperties properties)
=> Fail(new Exception(failureMessage), properties);
/// <summary>
/// Discontinue all processing for this request and return to the client.

View File

@ -49,6 +49,7 @@ namespace Microsoft.AspNetCore.Authentication
AuthenticationTicket ticket = null;
Exception exception = null;
AuthenticationProperties properties = null;
try
{
var authResult = await HandleRemoteAuthenticateAsync();
@ -66,8 +67,8 @@ namespace Microsoft.AspNetCore.Authentication
}
else if (!authResult.Succeeded)
{
exception = authResult.Failure ??
new InvalidOperationException("Invalid return state, unable to redirect.");
exception = authResult.Failure ?? new InvalidOperationException("Invalid return state, unable to redirect.");
properties = authResult.Properties;
}
ticket = authResult?.Ticket;
@ -80,7 +81,10 @@ namespace Microsoft.AspNetCore.Authentication
if (exception != null)
{
Logger.RemoteAuthenticationError(exception.Message);
var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception);
var errorContext = new RemoteFailureContext(Context, Scheme, Options, exception)
{
Properties = properties
};
await Events.RemoteFailure(errorContext);
if (errorContext.Result != null)
@ -95,11 +99,14 @@ namespace Microsoft.AspNetCore.Authentication
}
else if (errorContext.Result.Failure != null)
{
throw new InvalidOperationException("An error was returned from the RemoteFailure event.", errorContext.Result.Failure);
throw new Exception("An error was returned from the RemoteFailure event.", errorContext.Result.Failure);
}
}
throw exception;
if (errorContext.Failure != null)
{
throw new Exception("An error was encountered while handling the remote login.", errorContext.Failure);
}
}
// We have a ticket if we get here
@ -107,7 +114,7 @@ namespace Microsoft.AspNetCore.Authentication
{
ReturnUri = ticket.Properties.RedirectUri
};
// REVIEW: is this safe or good?
ticket.Properties.RedirectUri = null;
// Mark which provider produced this identity so we can cross-check later in HandleAuthenticateAsync

View File

@ -253,6 +253,7 @@ namespace Microsoft.AspNetCore.Authentication.Google
{
o.ClientId = "Test Id";
o.ClientSecret = "Test Secret";
o.StateDataFormat = new TestStateDataFormat();
o.Events = redirect ? new OAuthEvents()
{
OnRemoteFailure = ctx =>
@ -263,7 +264,8 @@ namespace Microsoft.AspNetCore.Authentication.Google
}
} : new OAuthEvents();
});
var sendTask = server.SendAsync("https://example.com/signin-google?error=OMG&error_description=SoBad&error_uri=foobar");
var sendTask = server.SendAsync("https://example.com/signin-google?error=OMG&error_description=SoBad&error_uri=foobar&state=protected_state",
".AspNetCore.Correlation.Google.corrilationId=N");
if (redirect)
{
var transaction = await sendTask;
@ -1075,5 +1077,37 @@ namespace Microsoft.AspNetCore.Authentication.Google
});
return new TestServer(builder);
}
private class TestStateDataFormat : ISecureDataFormat<AuthenticationProperties>
{
private AuthenticationProperties Data { get; set; }
public string Protect(AuthenticationProperties data)
{
return "protected_state";
}
public string Protect(AuthenticationProperties data, string purpose)
{
throw new NotImplementedException();
}
public AuthenticationProperties Unprotect(string protectedText)
{
Assert.Equal("protected_state", protectedText);
var properties = new AuthenticationProperties(new Dictionary<string, string>()
{
{ ".xsrf", "corrilationId" },
{ "testkey", "testvalue" }
});
properties.RedirectUri = "http://testhost/redirect";
return properties;
}
public AuthenticationProperties Unprotect(string protectedText, string purpose)
{
throw new NotImplementedException();
}
}
}
}

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Net;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authentication.Cookies;
@ -10,6 +11,7 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;
using Xunit;
namespace Microsoft.AspNetCore.Authentication.OAuth
@ -20,20 +22,13 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
public async Task VerifySignInSchemeCannotBeSetToSelf()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.SignInScheme = "weeblie";
o.ClientId = "whatever";
o.ClientSecret = "whatever";
}),
context =>
{
// REVIEW: Gross.
context.ChallengeAsync("weeblie").GetAwaiter().GetResult();
return true;
});
var error = await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("https://example.com/challenge"));
}));
var error = await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("https://example.com/"));
Assert.Contains("cannot be set to itself", error.Message);
}
@ -54,7 +49,6 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
public async Task ThrowsIfClientIdMissing()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.SignInScheme = "whatever";
@ -62,22 +56,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
o.ClientSecret = "whatever";
o.TokenEndpoint = "/";
o.AuthorizationEndpoint = "/";
}),
context =>
{
// REVIEW: Gross.
Assert.Throws<ArgumentException>("ClientId", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult());
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}));
await Assert.ThrowsAsync<ArgumentException>("ClientId", () => server.SendAsync("http://example.com/"));
}
[Fact]
public async Task ThrowsIfClientSecretMissing()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.SignInScheme = "whatever";
@ -85,22 +71,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
o.CallbackPath = "/";
o.TokenEndpoint = "/";
o.AuthorizationEndpoint = "/";
}),
context =>
{
// REVIEW: Gross.
Assert.Throws<ArgumentException>("ClientSecret", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult());
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}));
await Assert.ThrowsAsync<ArgumentException>("ClientSecret", () => server.SendAsync("http://example.com/"));
}
[Fact]
public async Task ThrowsIfCallbackPathMissing()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.ClientId = "Whatever;";
@ -108,22 +86,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
o.TokenEndpoint = "/";
o.AuthorizationEndpoint = "/";
o.SignInScheme = "eh";
}),
context =>
{
// REVIEW: Gross.
Assert.Throws<ArgumentException>("CallbackPath", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult());
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}));
await Assert.ThrowsAsync<ArgumentException>("CallbackPath", () => server.SendAsync("http://example.com/"));
}
[Fact]
public async Task ThrowsIfTokenEndpointMissing()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.ClientId = "Whatever;";
@ -131,22 +101,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
o.CallbackPath = "/";
o.AuthorizationEndpoint = "/";
o.SignInScheme = "eh";
}),
context =>
{
// REVIEW: Gross.
Assert.Throws<ArgumentException>("TokenEndpoint", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult());
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}));
await Assert.ThrowsAsync<ArgumentException>("TokenEndpoint", () => server.SendAsync("http://example.com/"));
}
[Fact]
public async Task ThrowsIfAuthorizationEndpointMissing()
{
var server = CreateServer(
app => { },
services => services.AddAuthentication().AddOAuth("weeblie", o =>
{
o.ClientId = "Whatever;";
@ -154,22 +116,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
o.CallbackPath = "/";
o.TokenEndpoint = "/";
o.SignInScheme = "eh";
}),
context =>
{
// REVIEW: Gross.
Assert.Throws<ArgumentException>("AuthorizationEndpoint", () => context.ChallengeAsync("weeblie").GetAwaiter().GetResult());
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}));
await Assert.ThrowsAsync<ArgumentException>("AuthorizationEndpoint", () => server.SendAsync("http://example.com/"));
}
[Fact]
public async Task RedirectToIdentityProvider_SetsCorrelationIdCookiePath_ToCallBackPath()
{
var server = CreateServer(
app => { },
s => s.AddAuthentication().AddOAuth(
"Weblie",
opt =>
@ -181,9 +135,9 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
opt.TokenEndpoint = "https://example.com/provider/token";
opt.CallbackPath = "/oauth-callback";
}),
ctx =>
async ctx =>
{
ctx.ChallengeAsync("Weblie").ConfigureAwait(false).GetAwaiter().GetResult();
await ctx.ChallengeAsync("Weblie");
return true;
});
@ -201,7 +155,6 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
public async Task RedirectToAuthorizeEndpoint_CorrelationIdCookieOptions_CanBeOverriden()
{
var server = CreateServer(
app => { },
s => s.AddAuthentication().AddOAuth(
"Weblie",
opt =>
@ -214,9 +167,9 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
opt.CallbackPath = "/oauth-callback";
opt.CorrelationCookie.Path = "/";
}),
ctx =>
async ctx =>
{
ctx.ChallengeAsync("Weblie").ConfigureAwait(false).GetAwaiter().GetResult();
await ctx.ChallengeAsync("Weblie");
return true;
});
@ -230,15 +183,50 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
Assert.Contains("path=/", correlation);
}
private static TestServer CreateServer(Action<IApplicationBuilder> configure, Action<IServiceCollection> configureServices, Func<HttpContext, bool> handler)
[Fact]
public async Task RemoteAuthenticationFailed_OAuthError_IncludesProperties()
{
var server = CreateServer(
s => s.AddAuthentication().AddOAuth(
"Weblie",
opt =>
{
opt.ClientId = "Test Id";
opt.ClientSecret = "secret";
opt.SignInScheme = CookieAuthenticationDefaults.AuthenticationScheme;
opt.AuthorizationEndpoint = "https://example.com/provider/login";
opt.TokenEndpoint = "https://example.com/provider/token";
opt.CallbackPath = "/oauth-callback";
opt.StateDataFormat = new TestStateDataFormat();
opt.Events = new OAuthEvents()
{
OnRemoteFailure = context =>
{
Assert.Contains("declined", context.Failure.Message);
Assert.Equal("testvalue", context.Properties.Items["testkey"]);
context.Response.StatusCode = StatusCodes.Status406NotAcceptable;
context.HandleResponse();
return Task.CompletedTask;
}
};
}));
var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=declined&state=protected_state",
".AspNetCore.Correlation.Weblie.corrilationId=N");
Assert.Equal(HttpStatusCode.NotAcceptable, transaction.Response.StatusCode);
Assert.Null(transaction.Response.Headers.Location);
}
private static TestServer CreateServer(Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler = null)
{
var builder = new WebHostBuilder()
.Configure(app =>
{
configure?.Invoke(app);
app.UseAuthentication();
app.Use(async (context, next) =>
{
if (handler == null || !handler(context))
if (handler == null || ! await handler(context))
{
await next();
}
@ -247,5 +235,37 @@ namespace Microsoft.AspNetCore.Authentication.OAuth
.ConfigureServices(configureServices);
return new TestServer(builder);
}
private class TestStateDataFormat : ISecureDataFormat<AuthenticationProperties>
{
private AuthenticationProperties Data { get; set; }
public string Protect(AuthenticationProperties data)
{
return "protected_state";
}
public string Protect(AuthenticationProperties data, string purpose)
{
throw new NotImplementedException();
}
public AuthenticationProperties Unprotect(string protectedText)
{
Assert.Equal("protected_state", protectedText);
var properties = new AuthenticationProperties(new Dictionary<string, string>()
{
{ ".xsrf", "corrilationId" },
{ "testkey", "testvalue" }
});
properties.RedirectUri = "http://testhost/redirect";
return properties;
}
public AuthenticationProperties Unprotect(string protectedText, string purpose)
{
throw new NotImplementedException();
}
}
}
}

View File

@ -95,7 +95,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(remoteFailure);
@ -191,7 +191,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(tokenValidated);
@ -348,7 +348,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(tokenValidated);
@ -532,7 +532,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(tokenValidated);
@ -731,7 +731,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "state=protected_state&code=my_code");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(codeReceived);
@ -943,7 +943,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(tokenValidated);
@ -1186,7 +1186,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
return PostAsync(server, "signin-oidc", "id_token=my_id_token&state=protected_state&code=my_code");
});
Assert.Equal("Authentication was aborted from user code.", exception.Message);
Assert.Equal("Authentication was aborted from user code.", exception.InnerException.Message);
Assert.True(messageReceived);
Assert.True(tokenValidated);
@ -1450,6 +1450,7 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
{
remoteFailure = true;
Assert.Equal("TestException", context.Failure.Message);
Assert.Equal("testvalue", context.Properties.Items["testkey"]);
context.HandleResponse();
context.Response.StatusCode = StatusCodes.Status202Accepted;
return Task.FromResult(0);
@ -1877,7 +1878,8 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect
var properties = new AuthenticationProperties(new Dictionary<string, string>()
{
{ ".xsrf", "corrilationId" },
{ OpenIdConnectDefaults.RedirectUriForCodePropertiesKey, "redirect_uri" }
{ OpenIdConnectDefaults.RedirectUriForCodePropertiesKey, "redirect_uri" },
{ "testkey", "testvalue" }
});
properties.RedirectUri = "http://testhost/redirect";
return properties;

View File

@ -1,6 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Security.Claims;
@ -11,6 +12,7 @@ using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Net.Http.Headers;
using Xunit;
namespace Microsoft.AspNetCore.Authentication.Twitter
@ -60,26 +62,12 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
};
o.BackchannelHttpHandler = new TestHttpMessageHandler
{
Sender = req =>
{
if (req.RequestUri.AbsoluteUri == "https://api.twitter.com/oauth/request_token")
{
return new HttpResponseMessage(HttpStatusCode.OK)
{
Content =
new StringContent("oauth_callback_confirmed=true&oauth_token=test_oauth_token&oauth_token_secret=test_oauth_token_secret",
Encoding.UTF8,
"application/x-www-form-urlencoded")
};
}
return null;
}
Sender = BackchannelRequestToken
};
},
context =>
async context =>
{
// REVIEW: Gross
context.ChallengeAsync("Twitter").GetAwaiter().GetResult();
await context.ChallengeAsync("Twitter");
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
@ -168,7 +156,6 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode);
}
[Fact]
public async Task ChallengeWillTriggerRedirection()
{
@ -178,35 +165,70 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
o.ConsumerSecret = "Test Consumer Secret";
o.BackchannelHttpHandler = new TestHttpMessageHandler
{
Sender = req =>
{
if (req.RequestUri.AbsoluteUri == "https://api.twitter.com/oauth/request_token")
{
return new HttpResponseMessage(HttpStatusCode.OK)
{
Content =
new StringContent("oauth_callback_confirmed=true&oauth_token=test_oauth_token&oauth_token_secret=test_oauth_token_secret",
Encoding.UTF8,
"application/x-www-form-urlencoded")
};
}
return null;
}
Sender = BackchannelRequestToken
};
},
context =>
{
// REVIEW: gross
context.ChallengeAsync("Twitter").GetAwaiter().GetResult();
return true;
});
async context =>
{
await context.ChallengeAsync("Twitter");
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode);
var location = transaction.Response.Headers.Location.AbsoluteUri;
Assert.Contains("https://api.twitter.com/oauth/authenticate?oauth_token=", location);
}
private static TestServer CreateServer(Action<TwitterOptions> options, Func<HttpContext, bool> handler = null)
[Fact]
public async Task BadCallbackCallsRemoteAuthFailedWithState()
{
var server = CreateServer(o =>
{
o.ConsumerKey = "Test Consumer Key";
o.ConsumerSecret = "Test Consumer Secret";
o.BackchannelHttpHandler = new TestHttpMessageHandler
{
Sender = BackchannelRequestToken
};
o.Events = new TwitterEvents()
{
OnRemoteFailure = context =>
{
Assert.NotNull(context.Failure);
Assert.NotNull(context.Properties);
Assert.Equal("testvalue", context.Properties.Items["testkey"]);
context.Response.StatusCode = StatusCodes.Status406NotAcceptable;
context.HandleResponse();
return Task.CompletedTask;
}
};
},
async context =>
{
var properties = new AuthenticationProperties();
properties.Items["testkey"] = "testvalue";
await context.ChallengeAsync("Twitter", properties);
return true;
});
var transaction = await server.SendAsync("http://example.com/challenge");
Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode);
var location = transaction.Response.Headers.Location.AbsoluteUri;
Assert.Contains("https://api.twitter.com/oauth/authenticate?oauth_token=", location);
Assert.True(transaction.Response.Headers.TryGetValues(HeaderNames.SetCookie, out var setCookie));
Assert.True(SetCookieHeaderValue.TryParseList(setCookie.ToList(), out var setCookieValues));
Assert.Single(setCookieValues);
var setCookieValue = setCookieValues.Single();
var cookie = new CookieHeaderValue(setCookieValue.Name, setCookieValue.Value);
var request = new HttpRequestMessage(HttpMethod.Get, "/signin-twitter");
request.Headers.Add(HeaderNames.Cookie, cookie.ToString());
var client = server.CreateClient();
var response = await client.SendAsync(request);
Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode);
}
private static TestServer CreateServer(Action<TwitterOptions> options, Func<HttpContext, Task<bool>> handler = null)
{
var builder = new WebHostBuilder()
.Configure(app =>
@ -228,7 +250,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
{
await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Twitter"));
}
else if (handler == null || !handler(context))
else if (handler == null || ! await handler(context))
{
await next();
}
@ -247,5 +269,20 @@ namespace Microsoft.AspNetCore.Authentication.Twitter
});
return new TestServer(builder);
}
private HttpResponseMessage BackchannelRequestToken(HttpRequestMessage req)
{
if (req.RequestUri.AbsoluteUri == "https://api.twitter.com/oauth/request_token")
{
return new HttpResponseMessage(HttpStatusCode.OK)
{
Content =
new StringContent("oauth_callback_confirmed=true&oauth_token=test_oauth_token&oauth_token_secret=test_oauth_token_secret",
Encoding.UTF8,
"application/x-www-form-urlencoded")
};
}
throw new NotImplementedException(req.RequestUri.AbsoluteUri);
}
}
}