diff --git a/src/Microsoft.AspNet.Security.DataProtection/CryptoUtil.cs b/src/Microsoft.AspNet.Security.DataProtection/CryptoUtil.cs
index 52e556fbcf..823e7aa213 100644
--- a/src/Microsoft.AspNet.Security.DataProtection/CryptoUtil.cs
+++ b/src/Microsoft.AspNet.Security.DataProtection/CryptoUtil.cs
@@ -6,6 +6,7 @@ using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
+using System.Text;
#if !ASPNETCORE50
using System.Runtime.ConstrainedExecution;
@@ -15,6 +16,9 @@ namespace Microsoft.AspNet.Security.DataProtection
{
internal unsafe static class CryptoUtil
{
+ // UTF8 encoding that fails on invalid chars
+ public static readonly UTF8Encoding SecureUtf8Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);
+
// This isn't a typical Debug.Assert; the check is always performed, even in retail builds.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Assert(bool condition, string message)
diff --git a/src/Microsoft.AspNet.Security.DataProtection/DataProtectionExtensions.cs b/src/Microsoft.AspNet.Security.DataProtection/DataProtectionExtensions.cs
new file mode 100644
index 0000000000..cf7c9fa7cd
--- /dev/null
+++ b/src/Microsoft.AspNet.Security.DataProtection/DataProtectionExtensions.cs
@@ -0,0 +1,59 @@
+// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Security.Cryptography;
+
+namespace Microsoft.AspNet.Security.DataProtection
+{
+ ///
+ /// Helpful extension methods for data protection APIs.
+ ///
+ public static class DataProtectionExtensions
+ {
+ ///
+ /// Cryptographically protects a piece of plaintext data.
+ ///
+ /// The data protector to use for this operation.
+ /// The plaintext data to protect.
+ /// The protected form of the plaintext data.
+ public static string Protect([NotNull] this IDataProtector protector, [NotNull] string unprotectedData)
+ {
+ try
+ {
+ byte[] unprotectedDataAsBytes = CryptoUtil.SecureUtf8Encoding.GetBytes(unprotectedData);
+ byte[] protectedDataAsBytes = protector.Protect(unprotectedDataAsBytes);
+ return WebEncoders.Base64UrlEncode(protectedDataAsBytes);
+ }
+ catch (Exception ex) if (!(ex is CryptographicException))
+ {
+ // Homogenize exceptions to CryptographicException
+ throw Error.CryptCommon_GenericError(ex);
+ }
+ }
+
+ ///
+ /// Cryptographically unprotects a piece of protected data.
+ ///
+ /// The data protector to use for this operation.
+ /// The protected data to unprotect.
+ /// The plaintext form of the protected data.
+ ///
+ /// This method will throw CryptographicException if the input is invalid or malformed.
+ ///
+ public static string Unprotect([NotNull] this IDataProtector protector, [NotNull] string protectedData)
+ {
+ try
+ {
+ byte[] protectedDataAsBytes = WebEncoders.Base64UrlDecode(protectedData);
+ byte[] unprotectedDataAsBytes = protector.Unprotect(protectedDataAsBytes);
+ return CryptoUtil.SecureUtf8Encoding.GetString(unprotectedDataAsBytes);
+ }
+ catch (Exception ex) if (!(ex is CryptographicException))
+ {
+ // Homogenize exceptions to CryptographicException
+ throw Error.CryptCommon_GenericError(ex);
+ }
+ }
+ }
+}
diff --git a/src/Microsoft.AspNet.Security.DataProtection/Dpapi/DpapiDataProtector.cs b/src/Microsoft.AspNet.Security.DataProtection/Dpapi/DpapiDataProtector.cs
index 0bc4cb073d..172a1289cc 100644
--- a/src/Microsoft.AspNet.Security.DataProtection/Dpapi/DpapiDataProtector.cs
+++ b/src/Microsoft.AspNet.Security.DataProtection/Dpapi/DpapiDataProtector.cs
@@ -4,7 +4,6 @@
using System;
using System.IO;
using System.Security.Cryptography;
-using System.Text;
namespace Microsoft.AspNet.Security.DataProtection.Dpapi
{
@@ -12,8 +11,6 @@ namespace Microsoft.AspNet.Security.DataProtection.Dpapi
// or for Windows machines where we can't depend on the user profile.
internal sealed class DpapiDataProtector : IDataProtector
{
- private static readonly UTF8Encoding _secureUtf8Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);
-
private readonly byte[] _combinedPurposes;
private readonly DataProtectionScope _scope;
private readonly IProtectedData _shim;
@@ -31,7 +28,7 @@ namespace Microsoft.AspNet.Security.DataProtection.Dpapi
using (var memoryStream = new MemoryStream())
{
memoryStream.Write(_combinedPurposes, 0, _combinedPurposes.Length);
- using (var writer = new BinaryWriter(memoryStream, _secureUtf8Encoding, leaveOpen: true))
+ using (var writer = new BinaryWriter(memoryStream, CryptoUtil.SecureUtf8Encoding, leaveOpen: true))
{
writer.Write(purpose);
}
diff --git a/src/Microsoft.AspNet.Security.DataProtection/KeyManagement/KeyRingBasedDataProtector.cs b/src/Microsoft.AspNet.Security.DataProtection/KeyManagement/KeyRingBasedDataProtector.cs
index 3b87e17147..cd5a78e5c1 100644
--- a/src/Microsoft.AspNet.Security.DataProtection/KeyManagement/KeyRingBasedDataProtector.cs
+++ b/src/Microsoft.AspNet.Security.DataProtection/KeyManagement/KeyRingBasedDataProtector.cs
@@ -5,7 +5,6 @@ using System;
using System.Diagnostics;
using System.IO;
using System.Security.Cryptography;
-using System.Text;
using System.Threading;
using Microsoft.AspNet.Security.DataProtection.AuthenticatedEncryption;
@@ -278,10 +277,9 @@ namespace Microsoft.AspNet.Security.DataProtection.KeyManagement
private sealed class PurposeBinaryWriter : BinaryWriter
{
// Strings should never contain invalid UTF16 chars, so we'll use a secure encoding.
- private static readonly UTF8Encoding _secureEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);
private static readonly byte[] _guidBuffer = new byte[sizeof(Guid)];
- public PurposeBinaryWriter(MemoryStream stream) : base(stream, _secureEncoding, leaveOpen: true) { }
+ public PurposeBinaryWriter(MemoryStream stream) : base(stream, CryptoUtil.SecureUtf8Encoding, leaveOpen: true) { }
public new void Write7BitEncodedInt(int value)
{
diff --git a/src/Microsoft.AspNet.Security.DataProtection/WebEncoders.cs b/src/Microsoft.AspNet.Security.DataProtection/WebEncoders.cs
new file mode 100644
index 0000000000..36db7b520a
--- /dev/null
+++ b/src/Microsoft.AspNet.Security.DataProtection/WebEncoders.cs
@@ -0,0 +1,133 @@
+// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Diagnostics;
+
+namespace Microsoft.AspNet.Security.DataProtection
+{
+ // Internal copy of HttpAbstractions functionality.
+ internal static class WebEncoders
+ {
+ ///
+ /// Decodes a base64url-encoded string.
+ ///
+ /// The base64url-encoded input to decode.
+ /// The base64url-decoded form of the input.
+ ///
+ /// The input must not contain any whitespace or padding characters.
+ /// Throws FormatException if the input is malformed.
+ ///
+ public static byte[] Base64UrlDecode([NotNull] string input)
+ {
+ // Assumption: input is base64url encoded without padding and contains no whitespace.
+
+ // First, we need to add the padding characters back.
+ int numPaddingCharsToAdd = GetNumBase64PaddingCharsToAddForDecode(input.Length);
+ char[] completeBase64Array = new char[checked(input.Length + numPaddingCharsToAdd)];
+ Debug.Assert(completeBase64Array.Length % 4 == 0, "Invariant: Array length must be a multiple of 4.");
+ input.CopyTo(0, completeBase64Array, 0, input.Length);
+ for (int i = 1; i <= numPaddingCharsToAdd; i++)
+ {
+ completeBase64Array[completeBase64Array.Length - i] = '=';
+ }
+
+ // Next, fix up '-' -> '+' and '_' -> '/'
+ for (int i = 0; i < completeBase64Array.Length; i++)
+ {
+ char c = completeBase64Array[i];
+ if (c == '-')
+ {
+ completeBase64Array[i] = '+';
+ }
+ else if (c == '_')
+ {
+ completeBase64Array[i] = '/';
+ }
+ }
+
+ // Finally, decode.
+ // If the caller provided invalid base64 chars, they'll be caught here.
+ return Convert.FromBase64CharArray(completeBase64Array, 0, completeBase64Array.Length);
+ }
+
+ ///
+ /// Encodes an input using base64url encoding.
+ ///
+ /// The binary input to encode.
+ /// The base64url-encoded form of the input.
+ public static string Base64UrlEncode([NotNull] byte[] input)
+ {
+ // Special-case empty input
+ if (input.Length == 0)
+ {
+ return String.Empty;
+ }
+
+ // We're going to use base64url encoding with no padding characters.
+ // See RFC 4648, Sec. 5.
+ char[] buffer = new char[GetNumBase64CharsRequiredForInput(input.Length)];
+ int numBase64Chars = Convert.ToBase64CharArray(input, 0, input.Length, buffer, 0);
+
+ // Fix up '+' -> '-' and '/' -> '_'
+ for (int i = 0; i < numBase64Chars; i++)
+ {
+ char ch = buffer[i];
+ if (ch == '+')
+ {
+ buffer[i] = '-';
+ }
+ else if (ch == '/')
+ {
+ buffer[i] = '_';
+ }
+ else if (ch == '=')
+ {
+ // We've reached a padding character: truncate the string from this point
+ return new String(buffer, 0, i);
+ }
+ }
+
+ // If we got this far, the buffer didn't contain any padding chars, so turn
+ // it directly into a string.
+ return new String(buffer, 0, numBase64Chars);
+ }
+
+ private static int GetNumBase64CharsRequiredForInput(int inputLength)
+ {
+ int numWholeOrPartialInputBlocks = checked(inputLength + 2) / 3;
+ return checked(numWholeOrPartialInputBlocks * 4);
+ }
+
+ private static int GetNumBase64PaddingCharsInString(string str)
+ {
+ // Assumption: input contains a well-formed base64 string with no whitespace.
+
+ // base64 guaranteed have 0 - 2 padding characters.
+ if (str[str.Length - 1] == '=')
+ {
+ if (str[str.Length - 2] == '=')
+ {
+ return 2;
+ }
+ return 1;
+ }
+ return 0;
+ }
+
+ private static int GetNumBase64PaddingCharsToAddForDecode(int inputLength)
+ {
+ switch (inputLength % 4)
+ {
+ case 0:
+ return 0;
+ case 2:
+ return 2;
+ case 3:
+ return 1;
+ default:
+ throw new FormatException("TODO: Malformed input.");
+ }
+ }
+ }
+}
diff --git a/test/Microsoft.AspNet.Security.DataProtection.Test/DataProtectionExtensionsTests.cs b/test/Microsoft.AspNet.Security.DataProtection.Test/DataProtectionExtensionsTests.cs
new file mode 100644
index 0000000000..53fc15656c
--- /dev/null
+++ b/test/Microsoft.AspNet.Security.DataProtection.Test/DataProtectionExtensionsTests.cs
@@ -0,0 +1,85 @@
+// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved.
+// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
+
+using System;
+using System.Security.Cryptography;
+using System.Text;
+using Moq;
+using Xunit;
+
+namespace Microsoft.AspNet.Security.DataProtection.Test
+{
+ public class DataProtectionExtensionsTests
+ {
+ [Fact]
+ public void Protect_InvalidUtf_Failure()
+ {
+ // Arrange
+ Mock mockProtector = new Mock();
+
+ // Act & assert
+ var ex = Assert.Throws(() =>
+ {
+ DataProtectionExtensions.Protect(mockProtector.Object, "Hello\ud800");
+ });
+ Assert.IsAssignableFrom(typeof(EncoderFallbackException), ex.InnerException);
+ }
+
+ [Fact]
+ public void Protect_Success()
+ {
+ // Arrange
+ Mock mockProtector = new Mock();
+ mockProtector.Setup(p => p.Protect(new byte[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f })).Returns(new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05 });
+
+ // Act
+ string retVal = DataProtectionExtensions.Protect(mockProtector.Object, "Hello");
+
+ // Assert
+ Assert.Equal("AQIDBAU", retVal);
+ }
+
+ [Fact]
+ public void Unprotect_InvalidBase64BeforeDecryption_Failure()
+ {
+ // Arrange
+ Mock mockProtector = new Mock();
+
+ // Act & assert
+ var ex = Assert.Throws(() =>
+ {
+ DataProtectionExtensions.Unprotect(mockProtector.Object, "A");
+ });
+ Assert.IsAssignableFrom(typeof(FormatException), ex.InnerException);
+ }
+
+ [Fact]
+ public void Unprotect_InvalidUtfAfterDecryption_Failure()
+ {
+ // Arrange
+ Mock mockProtector = new Mock();
+ mockProtector.Setup(p => p.Unprotect(new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05 })).Returns(new byte[] { 0xff });
+
+ // Act & assert
+ var ex = Assert.Throws(() =>
+ {
+ DataProtectionExtensions.Unprotect(mockProtector.Object, "AQIDBAU");
+ });
+ Assert.IsAssignableFrom(typeof(DecoderFallbackException), ex.InnerException);
+ }
+
+ [Fact]
+ public void Unprotect_Success()
+ {
+ // Arrange
+ Mock mockProtector = new Mock();
+ mockProtector.Setup(p => p.Unprotect(new byte[] { 0x01, 0x02, 0x03, 0x04, 0x05 })).Returns(new byte[] { 0x48, 0x65, 0x6c, 0x6c, 0x6f });
+
+ // Act
+ string retVal = DataProtectionExtensions.Unprotect(mockProtector.Object, "AQIDBAU");
+
+ // Assert
+ Assert.Equal("Hello", retVal);
+ }
+ }
+}