From d475d41f7117f3820d422829d6a02952a91f5754 Mon Sep 17 00:00:00 2001 From: Stephen Halter Date: Wed, 26 Oct 2016 17:44:44 -0700 Subject: [PATCH] Fix boundary cases in MemoryPoolIterator.(Try)PeekLong() - Fix edge case where the iterator is at the very end of a block. - Fix edge case where one bits where improperly filled in on a right shift. - Don't use -1 to represent failure. Use bool and an out parameter instead. --- .../Infrastructure/MemoryPoolIterator.cs | 41 ++- .../MemoryPoolIteratorExtensions.cs | 56 ++-- .../MemoryPoolIteratorTests.cs | 265 +++++++++++++++--- 3 files changed, 289 insertions(+), 73 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIterator.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIterator.cs index b0b6837f65..5f7e908c36 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIterator.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIterator.cs @@ -180,38 +180,57 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure } while (true); } - public unsafe long PeekLong() + // NOTE: Little-endian only! + public unsafe bool TryPeekLong(out ulong longValue) { + longValue = 0; + if (_block == null) { - return -1; + return false; } var wasLastBlock = _block.Next == null; + var blockBytes = _block.End - _index; - if (_block.End - _index >= sizeof(long)) + if (blockBytes >= sizeof(ulong)) { - return *(long*)(_block.DataFixedPtr + _index); + longValue = *(ulong*)(_block.DataFixedPtr + _index); + return true; } else if (wasLastBlock) { - return -1; + return false; } else { - var blockBytes = _block.End - _index; - var nextBytes = sizeof(long) - blockBytes; + // Each block will be filled with at least 2048 bytes before the Next pointer is set, so a long + // will cross at most one block boundary assuming there are at least 8 bytes following the iterator. + var nextBytes = sizeof(ulong) - blockBytes; if (_block.Next.End - _block.Next.Start < nextBytes) { - return -1; + return false; } - var blockLong = *(long*)(_block.DataFixedPtr + _block.End - sizeof(long)); + var nextLong = *(ulong*)(_block.Next.DataFixedPtr + _block.Next.Start); - var nextLong = *(long*)(_block.Next.DataFixedPtr + _block.Next.Start); + if (blockBytes == 0) + { + // This case can not fall through to the else block since that would cause a 64-bit right shift + // on blockLong which is equivalent to no shift at all instead of shifting in all zeros. + // https://msdn.microsoft.com/en-us/library/xt18et0d.aspx + longValue = nextLong; + } + else + { + var blockLong = *(ulong*)(_block.DataFixedPtr + _block.End - sizeof(ulong)); - return (blockLong >> (sizeof(long) - blockBytes) * 8) | (nextLong << (sizeof(long) - nextBytes) * 8); + // Ensure that the right shift has a ulong operand so a logical shift is performed. + longValue = (blockLong >> nextBytes * 8) | (nextLong << blockBytes * 8); + } + + return true; } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs index 146d9f127a..d59f689662 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs @@ -17,26 +17,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure public const string Http11Version = "HTTP/1.1"; // readonly primitive statics can be Jit'd to consts https://github.com/dotnet/coreclr/issues/1079 - private readonly static long _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT "); - private readonly static long _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE \0"); - private readonly static long _httpGetMethodLong = GetAsciiStringAsLong("GET \0\0\0\0"); - private readonly static long _httpHeadMethodLong = GetAsciiStringAsLong("HEAD \0\0\0"); - private readonly static long _httpPatchMethodLong = GetAsciiStringAsLong("PATCH \0\0"); - private readonly static long _httpPostMethodLong = GetAsciiStringAsLong("POST \0\0\0"); - private readonly static long _httpPutMethodLong = GetAsciiStringAsLong("PUT \0\0\0\0"); - private readonly static long _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS "); - private readonly static long _httpTraceMethodLong = GetAsciiStringAsLong("TRACE \0\0"); + private readonly static ulong _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT "); + private readonly static ulong _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE \0"); + private readonly static ulong _httpGetMethodLong = GetAsciiStringAsLong("GET \0\0\0\0"); + private readonly static ulong _httpHeadMethodLong = GetAsciiStringAsLong("HEAD \0\0\0"); + private readonly static ulong _httpPatchMethodLong = GetAsciiStringAsLong("PATCH \0\0"); + private readonly static ulong _httpPostMethodLong = GetAsciiStringAsLong("POST \0\0\0"); + private readonly static ulong _httpPutMethodLong = GetAsciiStringAsLong("PUT \0\0\0\0"); + private readonly static ulong _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS "); + private readonly static ulong _httpTraceMethodLong = GetAsciiStringAsLong("TRACE \0\0"); - private readonly static long _http10VersionLong = GetAsciiStringAsLong("HTTP/1.0"); - private readonly static long _http11VersionLong = GetAsciiStringAsLong("HTTP/1.1"); + private readonly static ulong _http10VersionLong = GetAsciiStringAsLong("HTTP/1.0"); + private readonly static ulong _http11VersionLong = GetAsciiStringAsLong("HTTP/1.1"); - private readonly static long _mask8Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); - private readonly static long _mask7Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00 }); - private readonly static long _mask6Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00 }); - private readonly static long _mask5Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00 }); - private readonly static long _mask4Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 }); + private readonly static ulong _mask8Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); + private readonly static ulong _mask7Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00 }); + private readonly static ulong _mask6Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00 }); + private readonly static ulong _mask5Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00 }); + private readonly static ulong _mask4Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 }); - private readonly static Tuple[] _knownMethods = new Tuple[8]; + private readonly static Tuple[] _knownMethods = new Tuple[8]; static MemoryPoolIteratorExtensions() { @@ -50,7 +50,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure _knownMethods[7] = Tuple.Create(_mask8Chars, _httpOptionsMethodLong, HttpMethods.Options); } - private unsafe static long GetAsciiStringAsLong(string str) + private unsafe static ulong GetAsciiStringAsLong(string str) { Debug.Assert(str.Length == 8, "String must be exactly 8 (ASCII) characters long."); @@ -58,16 +58,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure fixed (byte* ptr = &bytes[0]) { - return *(long*)ptr; + return *(ulong*)ptr; } } - private unsafe static long GetMaskAsLong(byte[] bytes) + private unsafe static ulong GetMaskAsLong(byte[] bytes) { Debug.Assert(bytes.Length == 8, "Mask must be exactly 8 bytes long."); fixed (byte* ptr = bytes) { - return *(long*)ptr; + return *(ulong*)ptr; } } @@ -286,7 +286,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure public static bool GetKnownMethod(this MemoryPoolIterator begin, out string knownMethod) { knownMethod = null; - var value = begin.PeekLong(); + + ulong value; + if (!begin.TryPeekLong(out value)) + { + return false; + } if ((value & _mask4Chars) == _httpGetMethodLong) { @@ -321,7 +326,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure public static bool GetKnownVersion(this MemoryPoolIterator begin, out string knownVersion) { knownVersion = null; - var value = begin.PeekLong(); + + ulong value; + if (!begin.TryPeekLong(out value)) + { + return false; + } if (value == _http11VersionLong) { diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs index ae539b10f5..454f66291f 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs @@ -299,23 +299,23 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { // Arrange var block = _pool.Lease(); - var bytes = BitConverter.GetBytes(0x0102030405060708); + var bytes = BitConverter.GetBytes(0x0102030405060708UL); Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); block.End += bytes.Length; var scan = block.GetIterator(); var originalIndex = scan.Index; - // Act - var result = scan.PeekLong(); - // Assert - Assert.Equal(0x0102030405060708, result); + ulong result; + Assert.True(scan.TryPeekLong(out result)); + Assert.Equal(0x0102030405060708UL, result); Assert.Equal(originalIndex, scan.Index); _pool.Return(block); } [Theory] + [InlineData(0)] [InlineData(1)] [InlineData(2)] [InlineData(3)] @@ -323,31 +323,57 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [InlineData(5)] [InlineData(6)] [InlineData(7)] - public void PeekLongAtBlockBoundary(int blockBytes) + public void PeekLongNotEnoughBytes(int totalBytes) { // Arrange - var nextBlockBytes = 8 - blockBytes; + var block = _pool.Lease(); + var bytes = BitConverter.GetBytes(0x0102030405060708UL); + var bytesLength = totalBytes; + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytesLength); + block.End += bytesLength; + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Assert + ulong result; + Assert.False(scan.TryPeekLong(out result)); + Assert.Equal(originalIndex, scan.Index); + _pool.Return(block); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + public void PeekLongNotEnoughBytesAtBlockBoundary(int firstBlockBytes) + { + // Arrange + var expectedResult = 0x0102030405060708UL; + var nextBlockBytes = 7 - firstBlockBytes; var block = _pool.Lease(); - block.End += blockBytes; + block.End += firstBlockBytes; var nextBlock = _pool.Lease(); nextBlock.End += nextBlockBytes; block.Next = nextBlock; - var bytes = BitConverter.GetBytes(0x0102030405060708); - Buffer.BlockCopy(bytes, 0, block.Array, block.Start, blockBytes); - Buffer.BlockCopy(bytes, blockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); + var bytes = BitConverter.GetBytes(expectedResult); + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, firstBlockBytes); + Buffer.BlockCopy(bytes, firstBlockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); var scan = block.GetIterator(); var originalIndex = scan.Index; - // Act - var result = scan.PeekLong(); - // Assert - Assert.Equal(0x0102030405060708, result); + ulong result; + Assert.False(scan.TryPeekLong(out result)); Assert.Equal(originalIndex, scan.Index); _pool.Return(block); @@ -355,6 +381,109 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + [InlineData(8)] + public void PeekLongAtBlockBoundary(int firstBlockBytes) + { + // Arrange + var expectedResult = 0x0102030405060708UL; + var nonZeroData = 0xFF00FFFF0000FFFFUL; + var nextBlockBytes = 8 - firstBlockBytes; + + var block = _pool.Lease(); + block.Start += 8; + block.End = block.Start + firstBlockBytes; + + var nextBlock = _pool.Lease(); + nextBlock.Start += 8; + nextBlock.End = nextBlock.Start + nextBlockBytes; + + block.Next = nextBlock; + + var bytes = BitConverter.GetBytes(expectedResult); + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, firstBlockBytes); + Buffer.BlockCopy(bytes, firstBlockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); + + // Fill in surrounding bytes with non-zero data + var nonZeroBytes = BitConverter.GetBytes(nonZeroData); + Buffer.BlockCopy(nonZeroBytes, 0, block.Array, block.Start - 8, 8); + Buffer.BlockCopy(nonZeroBytes, 0, block.Array, block.End, 8); + Buffer.BlockCopy(nonZeroBytes, 0, nextBlock.Array, nextBlock.Start - 8, 8); + Buffer.BlockCopy(nonZeroBytes, 0, nextBlock.Array, nextBlock.End, 8); + + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Assert + ulong result; + Assert.True(scan.TryPeekLong(out result)); + Assert.Equal(expectedResult, result); + Assert.Equal(originalIndex, scan.Index); + + _pool.Return(block); + _pool.Return(nextBlock); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + [InlineData(8)] + public void PeekLongAtBlockBoundarayWithMostSignificatBitsSet(int firstBlockBytes) + { + // Arrange + var expectedResult = 0xFF02030405060708UL; + var nonZeroData = 0xFF00FFFF0000FFFFUL; + var nextBlockBytes = 8 - firstBlockBytes; + + var block = _pool.Lease(); + block.Start += 8; + block.End = block.Start + firstBlockBytes; + + var nextBlock = _pool.Lease(); + nextBlock.Start += 8; + nextBlock.End = nextBlock.Start + nextBlockBytes; + + block.Next = nextBlock; + + var expectedBytes = BitConverter.GetBytes(expectedResult); + Buffer.BlockCopy(expectedBytes, 0, block.Array, block.Start, firstBlockBytes); + Buffer.BlockCopy(expectedBytes, firstBlockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); + + // Fill in surrounding bytes with non-zero data + var nonZeroBytes = BitConverter.GetBytes(nonZeroData); + Buffer.BlockCopy(nonZeroBytes, 0, block.Array, block.Start - 8, 8); + Buffer.BlockCopy(nonZeroBytes, 0, block.Array, block.End, 8); + Buffer.BlockCopy(nonZeroBytes, 0, nextBlock.Array, nextBlock.Start - 8, 8); + Buffer.BlockCopy(nonZeroBytes, 0, nextBlock.Array, nextBlock.End, 8); + + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Assert + ulong result; + Assert.True(scan.TryPeekLong(out result)); + Assert.Equal(expectedResult, result); + Assert.Equal(originalIndex, scan.Index); + + _pool.Return(block); + _pool.Return(nextBlock); + } + + [Theory] + [InlineData(0)] [InlineData(1)] [InlineData(2)] [InlineData(3)] @@ -432,23 +561,23 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Theory] - [InlineData("CONNECT / HTTP/1.1", ' ', true, "CONNECT")] - [InlineData("DELETE / HTTP/1.1", ' ', true, "DELETE")] - [InlineData("GET / HTTP/1.1", ' ', true, "GET")] - [InlineData("HEAD / HTTP/1.1", ' ', true, "HEAD")] - [InlineData("PATCH / HTTP/1.1", ' ', true, "PATCH")] - [InlineData("POST / HTTP/1.1", ' ', true, "POST")] - [InlineData("PUT / HTTP/1.1", ' ', true, "PUT")] - [InlineData("OPTIONS / HTTP/1.1", ' ', true, "OPTIONS")] - [InlineData("TRACE / HTTP/1.1", ' ', true, "TRACE")] - [InlineData("GET/ HTTP/1.1", ' ', false, null)] - [InlineData("get / HTTP/1.1", ' ', false, null)] - [InlineData("GOT / HTTP/1.1", ' ', false, null)] - [InlineData("ABC / HTTP/1.1", ' ', false, null)] - [InlineData("PO / HTTP/1.1", ' ', false, null)] - [InlineData("PO ST / HTTP/1.1", ' ', false, null)] - [InlineData("short ", ' ', false, null)] - public void GetsKnownMethod(string input, char endChar, bool expectedResult, string expectedKnownString) + [InlineData("CONNECT / HTTP/1.1", true, "CONNECT")] + [InlineData("DELETE / HTTP/1.1", true, "DELETE")] + [InlineData("GET / HTTP/1.1", true, "GET")] + [InlineData("HEAD / HTTP/1.1", true, "HEAD")] + [InlineData("PATCH / HTTP/1.1", true, "PATCH")] + [InlineData("POST / HTTP/1.1", true, "POST")] + [InlineData("PUT / HTTP/1.1", true, "PUT")] + [InlineData("OPTIONS / HTTP/1.1", true, "OPTIONS")] + [InlineData("TRACE / HTTP/1.1", true, "TRACE")] + [InlineData("GET/ HTTP/1.1", false, null)] + [InlineData("get / HTTP/1.1", false, null)] + [InlineData("GOT / HTTP/1.1", false, null)] + [InlineData("ABC / HTTP/1.1", false, null)] + [InlineData("PO / HTTP/1.1", false, null)] + [InlineData("PO ST / HTTP/1.1", false, null)] + [InlineData("short ", false, null)] + public void GetsKnownMethod(string input, bool expectedResult, string expectedKnownString) { // Arrange var block = _pool.Lease(); @@ -465,17 +594,46 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString); + // Test at boundary + var maxSplit = Math.Min(input.Length, 8); + var nextBlock = _pool.Lease(); + + for (var split = 0; split <= maxSplit; split++) + { + // Arrange + block.Reset(); + nextBlock.Reset(); + + Buffer.BlockCopy(chars, 0, block.Array, block.Start, split); + Buffer.BlockCopy(chars, split, nextBlock.Array, nextBlock.Start, chars.Length - split); + + block.End += split; + nextBlock.End += chars.Length - split; + block.Next = nextBlock; + + var boundaryBegin = block.GetIterator(); + string boundaryKnownString; + + // Act + var boundaryResult = boundaryBegin.GetKnownMethod(out boundaryKnownString); + + // Assert + Assert.Equal(expectedResult, boundaryResult); + Assert.Equal(expectedKnownString, boundaryKnownString); + } + _pool.Return(block); + _pool.Return(nextBlock); } [Theory] - [InlineData("HTTP/1.0\r", '\r', true, MemoryPoolIteratorExtensions.Http10Version)] - [InlineData("HTTP/1.1\r", '\r', true, MemoryPoolIteratorExtensions.Http11Version)] - [InlineData("HTTP/3.0\r", '\r', false, null)] - [InlineData("http/1.0\r", '\r', false, null)] - [InlineData("http/1.1\r", '\r', false, null)] - [InlineData("short ", ' ', false, null)] - public void GetsKnownVersion(string input, char endChar, bool expectedResult, string expectedKnownString) + [InlineData("HTTP/1.0\r", true, MemoryPoolIteratorExtensions.Http10Version)] + [InlineData("HTTP/1.1\r", true, MemoryPoolIteratorExtensions.Http11Version)] + [InlineData("HTTP/3.0\r", false, null)] + [InlineData("http/1.0\r", false, null)] + [InlineData("http/1.1\r", false, null)] + [InlineData("short ", false, null)] + public void GetsKnownVersion(string input, bool expectedResult, string expectedKnownString) { // Arrange var block = _pool.Lease(); @@ -491,7 +649,36 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString); + // Test at boundary + var maxSplit = Math.Min(input.Length, 9); + var nextBlock = _pool.Lease(); + + for (var split = 0; split <= maxSplit; split++) + { + // Arrange + block.Reset(); + nextBlock.Reset(); + + Buffer.BlockCopy(chars, 0, block.Array, block.Start, split); + Buffer.BlockCopy(chars, split, nextBlock.Array, nextBlock.Start, chars.Length - split); + + block.End += split; + nextBlock.End += chars.Length - split; + block.Next = nextBlock; + + var boundaryBegin = block.GetIterator(); + string boundaryKnownString; + + // Act + var boundaryResult = boundaryBegin.GetKnownVersion(out boundaryKnownString); + + // Assert + Assert.Equal(expectedResult, boundaryResult); + Assert.Equal(expectedKnownString, boundaryKnownString); + } + _pool.Return(block); + _pool.Return(nextBlock); } [Theory]