Don't mutate strings in Kestrel (#17556)

* Removed mutating of string-contents in BCryptHandle

* Revert "Removed mutating of string-contents in BCryptHandle"

This reverts commit 5ae80c2834471baf34d1e5a05a42e3cce1ff02d7.

This is a .NET STandard 2.0 project, so no span is available by default. I think it's not worth it to add a reference to System.Memory-package just for this change.

* Better perf for StringUtilities.TryGetAsciiString

* Removed mutating of created string from HttpUtilities

* Use static readonly span-actions as this gives a boost due to not having a null check for the compiler generated cached delegate

* Debug Asserts

* PR Feedback
This commit is contained in:
Günther Foidl 2020-02-10 23:21:08 +01:00 committed by GitHub
parent 568e73e69a
commit aa7804c247
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 256 additions and 98 deletions

View File

@ -2,8 +2,10 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
@ -25,6 +27,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen
private static readonly UTF8EncodingSealed HeaderValueEncoding = new UTF8EncodingSealed();
private static readonly SpanAction<char, IntPtr> _getHeaderName = GetHeaderName;
private static readonly SpanAction<char, IntPtr> _getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static void SetKnownMethod(ulong mask, ulong knownMethodUlong, HttpMethod knownMethod, int length)
@ -81,6 +85,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
}
// The same as GetAsciiStringNonNullCharacters but throws BadRequest
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe string GetHeaderName(this ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
@ -88,25 +93,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
return string.Empty;
}
var asciiString = new string('\0', span.Length);
fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, new IntPtr(source), _getHeaderName);
}
}
fixed (char* output = asciiString)
fixed (byte* buffer = span)
private static unsafe void GetHeaderName(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
BadHttpRequestException.Throw(RequestRejectionReason.InvalidCharactersInHeaderName);
}
}
return asciiString;
}
public static string GetAsciiStringNonNullCharacters(this Span<byte> span)
=> GetAsciiStringNonNullCharacters((ReadOnlySpan<byte>)span);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<byte> span)
{
if (span.IsEmpty)
@ -114,19 +123,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
return string.Empty;
}
var asciiString = new string('\0', span.Length);
fixed (byte* source = &MemoryMarshal.GetReference(span))
{
return string.Create(span.Length, new IntPtr(source), _getAsciiStringNonNullCharacters);
}
}
fixed (char* output = asciiString)
fixed (byte* buffer = span)
private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
throw new InvalidOperationException();
}
}
return asciiString;
}
public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
@ -139,14 +152,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
return string.Empty;
}
var resultString = new string('\0', span.Length);
fixed (char* output = resultString)
fixed (byte* buffer = span)
fixed (byte* source = &MemoryMarshal.GetReference(span))
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
var resultString = string.Create(span.Length, new IntPtr(source), s_getAsciiOrUtf8StringNonNullCharacters);
// If resultString is marked, perform UTF-8 encoding
if (resultString[0] == '\0')
{
// null characters are considered invalid
if (span.IndexOf((byte)0) != -1)
@ -156,15 +167,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
try
{
resultString = HeaderValueEncoding.GetString(buffer, span.Length);
resultString = HeaderValueEncoding.GetString(span);
}
catch (DecoderFallbackException)
{
throw new InvalidOperationException();
}
}
return resultString;
}
}
private static readonly SpanAction<char, IntPtr> s_getAsciiOrUtf8StringNonNullCharacters = GetAsciiOrUTF8StringNonNullCharacters;
private static unsafe void GetAsciiOrUTF8StringNonNullCharacters(Span<char> buffer, IntPtr state)
{
fixed (char* output = &MemoryMarshal.GetReference(buffer))
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
{
// Mark resultString for UTF-8 encoding
output[0] = '\0';
}
}
return resultString;
}
public static string GetAsciiStringEscaped(this Span<byte> span, int maxChars)
@ -283,7 +311,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
{
method = HttpMethod.Head;
}
else if(firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal))
else if (firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal))
{
method = HttpMethod.Post;
}
@ -294,7 +322,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
{
method = HttpMethod.Trace;
}
else if(firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal))
else if (firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal))
{
method = HttpMethod.Patch;
}

View File

@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.Http2Cat
internal class Http2Utilities : IHttpHeadersHandler
{
public static ReadOnlySpan<byte> ClientPreface => new byte[24] { (byte)'P', (byte)'R', (byte)'I', (byte)' ', (byte)'*', (byte)' ', (byte)'H', (byte)'T', (byte)'T', (byte)'P', (byte)'/', (byte)'2', (byte)'.', (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n', (byte)'S', (byte)'M', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
public static readonly int MaxRequestHeaderFieldSize = 16 * 1024;
public const int MaxRequestHeaderFieldSize = 16 * 1024;
public static readonly string FourKHeaderValue = new string('a', 4096);
private static readonly Encoding HeaderValueEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);

View File

@ -2,103 +2,100 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
using System.Text;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
{
internal class StringUtilities
internal static class StringUtilities
{
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
{
// Calculate end position
Debug.Assert(input != null);
Debug.Assert(output != null);
var end = input + count;
// Start as valid
var isValid = true;
do
Debug.Assert((long)end >= Vector256<sbyte>.Count);
if (Sse2.IsSupported)
{
// If Vector not-accelerated or remaining less than vector size
if (!Vector.IsHardwareAccelerated || input > end - Vector<sbyte>.Count)
if (Avx2.IsSupported && input <= end - Vector256<sbyte>.Count)
{
if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination
Vector256<sbyte> zero = Vector256<sbyte>.Zero;
do
{
// 64-bit: Loop longs by default
while (input <= end - sizeof(long))
var vector = Avx.LoadVector256(input).AsSByte();
if (!CheckBytesInAsciiRange(vector, zero))
{
isValid &= CheckBytesInAsciiRange(((long*)input)[0]);
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
output[4] = (char)input[4];
output[5] = (char)input[5];
output[6] = (char)input[6];
output[7] = (char)input[7];
input += sizeof(long);
output += sizeof(long);
return false;
}
if (input <= end - sizeof(int))
{
isValid &= CheckBytesInAsciiRange(((int*)input)[0]);
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
var tmp0 = Avx2.UnpackLow(vector, zero);
var tmp1 = Avx2.UnpackHigh(vector, zero);
input += sizeof(int);
output += sizeof(int);
}
}
else
// Bring into the right order
var out0 = Avx2.Permute2x128(tmp0, tmp1, 0x20);
var out1 = Avx2.Permute2x128(tmp0, tmp1, 0x31);
Avx.Store((ushort*)output, out0.AsUInt16());
Avx.Store((ushort*)output + Vector256<ushort>.Count, out1.AsUInt16());
input += Vector256<sbyte>.Count;
output += Vector256<sbyte>.Count;
} while (input <= end - Vector256<sbyte>.Count);
if (input == end)
{
// 32-bit: Loop ints by default
while (input <= end - sizeof(int))
{
isValid &= CheckBytesInAsciiRange(((int*)input)[0]);
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
input += sizeof(int);
output += sizeof(int);
}
return true;
}
if (input <= end - sizeof(short))
{
isValid &= CheckBytesInAsciiRange(((short*)input)[0]);
output[0] = (char)input[0];
output[1] = (char)input[1];
input += sizeof(short);
output += sizeof(short);
}
if (input < end)
{
isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]);
output[0] = (char)input[0];
}
return isValid;
}
// do/while as entry condition already checked
do
if (input <= end - Vector128<sbyte>.Count)
{
Vector128<sbyte> zero = Vector128<sbyte>.Zero;
do
{
var vector = Sse2.LoadVector128(input).AsSByte();
if (!CheckBytesInAsciiRange(vector, zero))
{
return false;
}
var c0 = Sse2.UnpackLow(vector, zero).AsUInt16();
var c1 = Sse2.UnpackHigh(vector, zero).AsUInt16();
Sse2.Store((ushort*)output, c0);
Sse2.Store((ushort*)output + Vector128<ushort>.Count, c1);
input += Vector128<sbyte>.Count;
output += Vector128<sbyte>.Count;
} while (input <= end - Vector128<sbyte>.Count);
if (input == end)
{
return true;
}
}
}
else if (Vector.IsHardwareAccelerated)
{
while (input <= end - Vector<sbyte>.Count)
{
var vector = Unsafe.AsRef<Vector<sbyte>>(input);
isValid &= CheckBytesInAsciiRange(vector);
if (!CheckBytesInAsciiRange(vector))
{
return false;
}
Vector.Widen(
vector,
out Unsafe.AsRef<Vector<short>>(output),
@ -106,13 +103,127 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
input += Vector<sbyte>.Count;
output += Vector<sbyte>.Count;
} while (input <= end - Vector<sbyte>.Count);
}
// Vector path done, loop back to do non-Vector
// If is a exact multiple of vector size, bail now
} while (input < end);
if (input == end)
{
return true;
}
}
return isValid;
if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination
{
// 64-bit: Loop longs by default
while (input <= end - sizeof(long))
{
var value = *(long*)input;
if (!CheckBytesInAsciiRange(value))
{
return false;
}
if (Bmi2.X64.IsSupported)
{
// BMI2 will work regardless of the processor's endianness.
((ulong*)output)[0] = Bmi2.X64.ParallelBitDeposit((ulong)value, 0x00FF00FF_00FF00FFul);
((ulong*)output)[1] = Bmi2.X64.ParallelBitDeposit((ulong)(value >> 32), 0x00FF00FF_00FF00FFul);
}
else
{
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
output[4] = (char)input[4];
output[5] = (char)input[5];
output[6] = (char)input[6];
output[7] = (char)input[7];
}
input += sizeof(long);
output += sizeof(long);
}
if (input <= end - sizeof(int))
{
var value = *(int*)input;
if (!CheckBytesInAsciiRange(value))
{
return false;
}
if (Bmi2.IsSupported)
{
// BMI2 will work regardless of the processor's endianness.
((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu);
((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu);
}
else
{
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
}
input += sizeof(int);
output += sizeof(int);
}
}
else
{
// 32-bit: Loop ints by default
while (input <= end - sizeof(int))
{
var value = *(int*)input;
if (!CheckBytesInAsciiRange(value))
{
return false;
}
if (Bmi2.IsSupported)
{
// BMI2 will work regardless of the processor's endianness.
((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu);
((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu);
}
else
{
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
}
input += sizeof(int);
output += sizeof(int);
}
}
if (input <= end - sizeof(short))
{
if (!CheckBytesInAsciiRange(((short*)input)[0]))
{
return false;
}
output[0] = (char)input[0];
output[1] = (char)input[1];
input += sizeof(short);
output += sizeof(short);
}
if (input < end)
{
if (!CheckBytesInAsciiRange(((sbyte*)input)[0]))
{
return false;
}
output[0] = (char)input[0];
}
return true;
}
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
@ -365,7 +476,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetByteCount(value);
return !value.Contains('\0');
}
catch (DecoderFallbackException) {
catch (DecoderFallbackException)
{
return false;
}
}
@ -418,6 +530,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
return Vector.GreaterThanAll(check, Vector<sbyte>.Zero);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool CheckBytesInAsciiRange(Vector256<sbyte> check, Vector256<sbyte> zero)
{
Debug.Assert(Avx2.IsSupported);
var mask = Avx2.CompareGreaterThan(check, zero);
return (uint)Avx2.MoveMask(mask) == 0xFFFF_FFFF;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool CheckBytesInAsciiRange(Vector128<sbyte> check, Vector128<sbyte> zero)
{
Debug.Assert(Sse2.IsSupported);
var mask = Sse2.CompareGreaterThan(check, zero);
return Sse2.MoveMask(mask) == 0xFFFF;
}
// Validate: bytes != 0 && bytes <= 127
// Subtract 1 from all bytes to move 0 to high bits
// bitwise or with self to catch all > 127 bytes