Don't mutate strings in Kestrel (#17556)
* Removed mutating of string-contents in BCryptHandle * Revert "Removed mutating of string-contents in BCryptHandle" This reverts commit 5ae80c2834471baf34d1e5a05a42e3cce1ff02d7. This is a .NET STandard 2.0 project, so no span is available by default. I think it's not worth it to add a reference to System.Memory-package just for this change. * Better perf for StringUtilities.TryGetAsciiString * Removed mutating of created string from HttpUtilities * Use static readonly span-actions as this gives a boost due to not having a null check for the compiler generated cached delegate * Debug Asserts * PR Feedback
This commit is contained in:
parent
568e73e69a
commit
aa7804c247
|
|
@ -2,8 +2,10 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers;
|
||||
using System.Diagnostics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Text;
|
||||
using Microsoft.AspNetCore.Http;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
|
||||
|
|
@ -25,6 +27,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
private const ulong _http11VersionLong = 3543824036068086856; // GetAsciiStringAsLong("HTTP/1.1"); const results in better codegen
|
||||
|
||||
private static readonly UTF8EncodingSealed HeaderValueEncoding = new UTF8EncodingSealed();
|
||||
private static readonly SpanAction<char, IntPtr> _getHeaderName = GetHeaderName;
|
||||
private static readonly SpanAction<char, IntPtr> _getAsciiStringNonNullCharacters = GetAsciiStringNonNullCharacters;
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static void SetKnownMethod(ulong mask, ulong knownMethodUlong, HttpMethod knownMethod, int length)
|
||||
|
|
@ -81,6 +85,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
}
|
||||
|
||||
// The same as GetAsciiStringNonNullCharacters but throws BadRequest
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static unsafe string GetHeaderName(this ReadOnlySpan<byte> span)
|
||||
{
|
||||
if (span.IsEmpty)
|
||||
|
|
@ -88,25 +93,29 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
return string.Empty;
|
||||
}
|
||||
|
||||
var asciiString = new string('\0', span.Length);
|
||||
fixed (byte* source = &MemoryMarshal.GetReference(span))
|
||||
{
|
||||
return string.Create(span.Length, new IntPtr(source), _getHeaderName);
|
||||
}
|
||||
}
|
||||
|
||||
fixed (char* output = asciiString)
|
||||
fixed (byte* buffer = span)
|
||||
private static unsafe void GetHeaderName(Span<char> buffer, IntPtr state)
|
||||
{
|
||||
fixed (char* output = &MemoryMarshal.GetReference(buffer))
|
||||
{
|
||||
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
|
||||
// in the string
|
||||
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
|
||||
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
|
||||
{
|
||||
BadHttpRequestException.Throw(RequestRejectionReason.InvalidCharactersInHeaderName);
|
||||
}
|
||||
}
|
||||
|
||||
return asciiString;
|
||||
}
|
||||
|
||||
public static string GetAsciiStringNonNullCharacters(this Span<byte> span)
|
||||
=> GetAsciiStringNonNullCharacters((ReadOnlySpan<byte>)span);
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
public static unsafe string GetAsciiStringNonNullCharacters(this ReadOnlySpan<byte> span)
|
||||
{
|
||||
if (span.IsEmpty)
|
||||
|
|
@ -114,19 +123,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
return string.Empty;
|
||||
}
|
||||
|
||||
var asciiString = new string('\0', span.Length);
|
||||
fixed (byte* source = &MemoryMarshal.GetReference(span))
|
||||
{
|
||||
return string.Create(span.Length, new IntPtr(source), _getAsciiStringNonNullCharacters);
|
||||
}
|
||||
}
|
||||
|
||||
fixed (char* output = asciiString)
|
||||
fixed (byte* buffer = span)
|
||||
private static unsafe void GetAsciiStringNonNullCharacters(Span<char> buffer, IntPtr state)
|
||||
{
|
||||
fixed (char* output = &MemoryMarshal.GetReference(buffer))
|
||||
{
|
||||
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
|
||||
// in the string
|
||||
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
|
||||
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
|
||||
{
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
}
|
||||
return asciiString;
|
||||
}
|
||||
|
||||
public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
|
||||
|
|
@ -139,14 +152,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
return string.Empty;
|
||||
}
|
||||
|
||||
var resultString = new string('\0', span.Length);
|
||||
|
||||
fixed (char* output = resultString)
|
||||
fixed (byte* buffer = span)
|
||||
fixed (byte* source = &MemoryMarshal.GetReference(span))
|
||||
{
|
||||
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
|
||||
// in the string
|
||||
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
|
||||
var resultString = string.Create(span.Length, new IntPtr(source), s_getAsciiOrUtf8StringNonNullCharacters);
|
||||
|
||||
// If resultString is marked, perform UTF-8 encoding
|
||||
if (resultString[0] == '\0')
|
||||
{
|
||||
// null characters are considered invalid
|
||||
if (span.IndexOf((byte)0) != -1)
|
||||
|
|
@ -156,15 +167,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
|
||||
try
|
||||
{
|
||||
resultString = HeaderValueEncoding.GetString(buffer, span.Length);
|
||||
resultString = HeaderValueEncoding.GetString(span);
|
||||
}
|
||||
catch (DecoderFallbackException)
|
||||
{
|
||||
throw new InvalidOperationException();
|
||||
}
|
||||
}
|
||||
|
||||
return resultString;
|
||||
}
|
||||
}
|
||||
|
||||
private static readonly SpanAction<char, IntPtr> s_getAsciiOrUtf8StringNonNullCharacters = GetAsciiOrUTF8StringNonNullCharacters;
|
||||
|
||||
private static unsafe void GetAsciiOrUTF8StringNonNullCharacters(Span<char> buffer, IntPtr state)
|
||||
{
|
||||
fixed (char* output = &MemoryMarshal.GetReference(buffer))
|
||||
{
|
||||
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
|
||||
// in the string
|
||||
if (!StringUtilities.TryGetAsciiString((byte*)state.ToPointer(), output, buffer.Length))
|
||||
{
|
||||
// Mark resultString for UTF-8 encoding
|
||||
output[0] = '\0';
|
||||
}
|
||||
}
|
||||
return resultString;
|
||||
}
|
||||
|
||||
public static string GetAsciiStringEscaped(this Span<byte> span, int maxChars)
|
||||
|
|
@ -283,7 +311,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
{
|
||||
method = HttpMethod.Head;
|
||||
}
|
||||
else if(firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal))
|
||||
else if (firstChar == 'P' && string.Equals(value, HttpMethods.Post, StringComparison.Ordinal))
|
||||
{
|
||||
method = HttpMethod.Post;
|
||||
}
|
||||
|
|
@ -294,7 +322,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
{
|
||||
method = HttpMethod.Trace;
|
||||
}
|
||||
else if(firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal))
|
||||
else if (firstChar == 'P' && string.Equals(value, HttpMethods.Patch, StringComparison.Ordinal))
|
||||
{
|
||||
method = HttpMethod.Patch;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.Http2Cat
|
|||
internal class Http2Utilities : IHttpHeadersHandler
|
||||
{
|
||||
public static ReadOnlySpan<byte> ClientPreface => new byte[24] { (byte)'P', (byte)'R', (byte)'I', (byte)' ', (byte)'*', (byte)' ', (byte)'H', (byte)'T', (byte)'T', (byte)'P', (byte)'/', (byte)'2', (byte)'.', (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n', (byte)'S', (byte)'M', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' };
|
||||
public static readonly int MaxRequestHeaderFieldSize = 16 * 1024;
|
||||
public const int MaxRequestHeaderFieldSize = 16 * 1024;
|
||||
public static readonly string FourKHeaderValue = new string('a', 4096);
|
||||
private static readonly Encoding HeaderValueEncoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true);
|
||||
|
||||
|
|
|
|||
|
|
@ -2,103 +2,100 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Buffers.Binary;
|
||||
using System.Diagnostics;
|
||||
using System.Numerics;
|
||||
using System.Runtime.CompilerServices;
|
||||
using System.Runtime.InteropServices;
|
||||
using System.Runtime.Intrinsics;
|
||||
using System.Runtime.Intrinsics.X86;
|
||||
using System.Text;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
||||
{
|
||||
internal class StringUtilities
|
||||
internal static class StringUtilities
|
||||
{
|
||||
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
|
||||
public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
|
||||
{
|
||||
// Calculate end position
|
||||
Debug.Assert(input != null);
|
||||
Debug.Assert(output != null);
|
||||
|
||||
var end = input + count;
|
||||
// Start as valid
|
||||
var isValid = true;
|
||||
|
||||
do
|
||||
Debug.Assert((long)end >= Vector256<sbyte>.Count);
|
||||
|
||||
if (Sse2.IsSupported)
|
||||
{
|
||||
// If Vector not-accelerated or remaining less than vector size
|
||||
if (!Vector.IsHardwareAccelerated || input > end - Vector<sbyte>.Count)
|
||||
if (Avx2.IsSupported && input <= end - Vector256<sbyte>.Count)
|
||||
{
|
||||
if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination
|
||||
Vector256<sbyte> zero = Vector256<sbyte>.Zero;
|
||||
|
||||
do
|
||||
{
|
||||
// 64-bit: Loop longs by default
|
||||
while (input <= end - sizeof(long))
|
||||
var vector = Avx.LoadVector256(input).AsSByte();
|
||||
if (!CheckBytesInAsciiRange(vector, zero))
|
||||
{
|
||||
isValid &= CheckBytesInAsciiRange(((long*)input)[0]);
|
||||
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
output[4] = (char)input[4];
|
||||
output[5] = (char)input[5];
|
||||
output[6] = (char)input[6];
|
||||
output[7] = (char)input[7];
|
||||
|
||||
input += sizeof(long);
|
||||
output += sizeof(long);
|
||||
return false;
|
||||
}
|
||||
if (input <= end - sizeof(int))
|
||||
{
|
||||
isValid &= CheckBytesInAsciiRange(((int*)input)[0]);
|
||||
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
var tmp0 = Avx2.UnpackLow(vector, zero);
|
||||
var tmp1 = Avx2.UnpackHigh(vector, zero);
|
||||
|
||||
input += sizeof(int);
|
||||
output += sizeof(int);
|
||||
}
|
||||
}
|
||||
else
|
||||
// Bring into the right order
|
||||
var out0 = Avx2.Permute2x128(tmp0, tmp1, 0x20);
|
||||
var out1 = Avx2.Permute2x128(tmp0, tmp1, 0x31);
|
||||
|
||||
Avx.Store((ushort*)output, out0.AsUInt16());
|
||||
Avx.Store((ushort*)output + Vector256<ushort>.Count, out1.AsUInt16());
|
||||
|
||||
input += Vector256<sbyte>.Count;
|
||||
output += Vector256<sbyte>.Count;
|
||||
} while (input <= end - Vector256<sbyte>.Count);
|
||||
|
||||
if (input == end)
|
||||
{
|
||||
// 32-bit: Loop ints by default
|
||||
while (input <= end - sizeof(int))
|
||||
{
|
||||
isValid &= CheckBytesInAsciiRange(((int*)input)[0]);
|
||||
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
|
||||
input += sizeof(int);
|
||||
output += sizeof(int);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (input <= end - sizeof(short))
|
||||
{
|
||||
isValid &= CheckBytesInAsciiRange(((short*)input)[0]);
|
||||
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
|
||||
input += sizeof(short);
|
||||
output += sizeof(short);
|
||||
}
|
||||
if (input < end)
|
||||
{
|
||||
isValid &= CheckBytesInAsciiRange(((sbyte*)input)[0]);
|
||||
output[0] = (char)input[0];
|
||||
}
|
||||
|
||||
return isValid;
|
||||
}
|
||||
|
||||
// do/while as entry condition already checked
|
||||
do
|
||||
if (input <= end - Vector128<sbyte>.Count)
|
||||
{
|
||||
Vector128<sbyte> zero = Vector128<sbyte>.Zero;
|
||||
|
||||
do
|
||||
{
|
||||
var vector = Sse2.LoadVector128(input).AsSByte();
|
||||
if (!CheckBytesInAsciiRange(vector, zero))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
var c0 = Sse2.UnpackLow(vector, zero).AsUInt16();
|
||||
var c1 = Sse2.UnpackHigh(vector, zero).AsUInt16();
|
||||
|
||||
Sse2.Store((ushort*)output, c0);
|
||||
Sse2.Store((ushort*)output + Vector128<ushort>.Count, c1);
|
||||
|
||||
input += Vector128<sbyte>.Count;
|
||||
output += Vector128<sbyte>.Count;
|
||||
} while (input <= end - Vector128<sbyte>.Count);
|
||||
|
||||
if (input == end)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if (Vector.IsHardwareAccelerated)
|
||||
{
|
||||
while (input <= end - Vector<sbyte>.Count)
|
||||
{
|
||||
var vector = Unsafe.AsRef<Vector<sbyte>>(input);
|
||||
isValid &= CheckBytesInAsciiRange(vector);
|
||||
if (!CheckBytesInAsciiRange(vector))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
Vector.Widen(
|
||||
vector,
|
||||
out Unsafe.AsRef<Vector<short>>(output),
|
||||
|
|
@ -106,13 +103,127 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
|
||||
input += Vector<sbyte>.Count;
|
||||
output += Vector<sbyte>.Count;
|
||||
} while (input <= end - Vector<sbyte>.Count);
|
||||
}
|
||||
|
||||
// Vector path done, loop back to do non-Vector
|
||||
// If is a exact multiple of vector size, bail now
|
||||
} while (input < end);
|
||||
if (input == end)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return isValid;
|
||||
if (Environment.Is64BitProcess) // Use Intrinsic switch for branch elimination
|
||||
{
|
||||
// 64-bit: Loop longs by default
|
||||
while (input <= end - sizeof(long))
|
||||
{
|
||||
var value = *(long*)input;
|
||||
if (!CheckBytesInAsciiRange(value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Bmi2.X64.IsSupported)
|
||||
{
|
||||
// BMI2 will work regardless of the processor's endianness.
|
||||
((ulong*)output)[0] = Bmi2.X64.ParallelBitDeposit((ulong)value, 0x00FF00FF_00FF00FFul);
|
||||
((ulong*)output)[1] = Bmi2.X64.ParallelBitDeposit((ulong)(value >> 32), 0x00FF00FF_00FF00FFul);
|
||||
}
|
||||
else
|
||||
{
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
output[4] = (char)input[4];
|
||||
output[5] = (char)input[5];
|
||||
output[6] = (char)input[6];
|
||||
output[7] = (char)input[7];
|
||||
}
|
||||
|
||||
input += sizeof(long);
|
||||
output += sizeof(long);
|
||||
}
|
||||
|
||||
if (input <= end - sizeof(int))
|
||||
{
|
||||
var value = *(int*)input;
|
||||
if (!CheckBytesInAsciiRange(value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Bmi2.IsSupported)
|
||||
{
|
||||
// BMI2 will work regardless of the processor's endianness.
|
||||
((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu);
|
||||
((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu);
|
||||
}
|
||||
else
|
||||
{
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
}
|
||||
|
||||
input += sizeof(int);
|
||||
output += sizeof(int);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// 32-bit: Loop ints by default
|
||||
while (input <= end - sizeof(int))
|
||||
{
|
||||
var value = *(int*)input;
|
||||
if (!CheckBytesInAsciiRange(value))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Bmi2.IsSupported)
|
||||
{
|
||||
// BMI2 will work regardless of the processor's endianness.
|
||||
((uint*)output)[0] = Bmi2.ParallelBitDeposit((uint)value, 0x00FF00FFu);
|
||||
((uint*)output)[1] = Bmi2.ParallelBitDeposit((uint)(value >> 16), 0x00FF00FFu);
|
||||
}
|
||||
else
|
||||
{
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
output[2] = (char)input[2];
|
||||
output[3] = (char)input[3];
|
||||
}
|
||||
|
||||
input += sizeof(int);
|
||||
output += sizeof(int);
|
||||
}
|
||||
}
|
||||
|
||||
if (input <= end - sizeof(short))
|
||||
{
|
||||
if (!CheckBytesInAsciiRange(((short*)input)[0]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
output[0] = (char)input[0];
|
||||
output[1] = (char)input[1];
|
||||
|
||||
input += sizeof(short);
|
||||
output += sizeof(short);
|
||||
}
|
||||
|
||||
if (input < end)
|
||||
{
|
||||
if (!CheckBytesInAsciiRange(((sbyte*)input)[0]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
output[0] = (char)input[0];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
|
||||
|
|
@ -365,7 +476,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true).GetByteCount(value);
|
||||
return !value.Contains('\0');
|
||||
}
|
||||
catch (DecoderFallbackException) {
|
||||
catch (DecoderFallbackException)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -418,6 +530,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure
|
|||
return Vector.GreaterThanAll(check, Vector<sbyte>.Zero);
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static bool CheckBytesInAsciiRange(Vector256<sbyte> check, Vector256<sbyte> zero)
|
||||
{
|
||||
Debug.Assert(Avx2.IsSupported);
|
||||
|
||||
var mask = Avx2.CompareGreaterThan(check, zero);
|
||||
return (uint)Avx2.MoveMask(mask) == 0xFFFF_FFFF;
|
||||
}
|
||||
|
||||
[MethodImpl(MethodImplOptions.AggressiveInlining)]
|
||||
private static bool CheckBytesInAsciiRange(Vector128<sbyte> check, Vector128<sbyte> zero)
|
||||
{
|
||||
Debug.Assert(Sse2.IsSupported);
|
||||
|
||||
var mask = Sse2.CompareGreaterThan(check, zero);
|
||||
return Sse2.MoveMask(mask) == 0xFFFF;
|
||||
}
|
||||
|
||||
// Validate: bytes != 0 && bytes <= 127
|
||||
// Subtract 1 from all bytes to move 0 to high bits
|
||||
// bitwise or with self to catch all > 127 bytes
|
||||
|
|
|
|||
Loading…
Reference in New Issue