diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index ba4a7ba84d..426f17b283 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -701,7 +701,12 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { return false; } - var method = begin.GetAsciiString(scan); + + string method; + if (!begin.GetKnownString(scan, out method)) + { + method = begin.GetAsciiString(scan); + } scan.Take(); begin = scan; @@ -734,7 +739,12 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { return false; } - var httpVersion = begin.GetAsciiString(scan); + + string httpVersion; + if (!begin.GetKnownString(scan, out httpVersion)) + { + httpVersion = begin.GetAsciiString(scan); + } scan.Take(); if (scan.Take() != '\n') diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs index 41551d056c..524c7a0278 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs @@ -138,6 +138,49 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure } } + public unsafe long PeekLong() + { + if (_block == null) + { + return -1; + } + else if (_block.End - _index >= sizeof(long)) + { + fixed (byte* ptr = _block.Array) + { + return *(long*)(ptr + _index); + } + } + else if (_block.Next == null) + { + return -1; + } + else + { + var blockBytes = _block.End - _index; + var nextBytes = sizeof(long) - blockBytes; + + if (_block.Next.End - _block.Next.Start < nextBytes) + { + return -1; + } + + long blockLong; + fixed (byte* ptr = _block.Array) + { + blockLong = *(long*)(ptr + _block.End - sizeof(long)); + } + + long nextLong; + fixed (byte* ptr = _block.Next.Array) + { + nextLong = *(long*)(ptr + _block.Next.Start); + } + + return (blockLong >> (sizeof(long) - blockBytes) * 8) | (nextLong << (sizeof(long) - nextBytes) * 8); + } + } + public int Seek(int char0) { if (IsDefault) diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs index d1f83c3d6e..1fbd76c5a8 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.Text; namespace Microsoft.AspNet.Server.Kestrel.Infrastructure @@ -12,6 +13,62 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure private static Encoding _utf8 = Encoding.UTF8; + public const string HttpConnectMethod = "CONNECT"; + public const string HttpDeleteMethod = "DELETE"; + public const string HttpGetMethod = "GET"; + public const string HttpHeadMethod = "HEAD"; + public const string HttpPatchMethod = "PATCH"; + public const string HttpPostMethod = "POST"; + public const string HttpPutMethod = "PUT"; + public const string HttpOptionsMethod = "OPTIONS"; + public const string HttpTraceMethod = "TRACE"; + + public const string Http10Version = "HTTP/1.0"; + public const string Http11Version = "HTTP/1.1"; + + private static long _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT\0"); + private static long _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE\0\0"); + private static long _httpGetMethodLong = GetAsciiStringAsLong("GET\0\0\0\0\0"); + private static long _httpHeadMethodLong = GetAsciiStringAsLong("HEAD\0\0\0\0"); + private static long _httpPatchMethodLong = GetAsciiStringAsLong("PATCH\0\0\0"); + private static long _httpPostMethodLong = GetAsciiStringAsLong("POST\0\0\0\0"); + private static long _httpPutMethodLong = GetAsciiStringAsLong("PUT\0\0\0\0\0"); + private static long _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS\0"); + private static long _httpTraceMethodLong = GetAsciiStringAsLong("TRACE\0\0\0"); + + private static long _http10VersionLong = GetAsciiStringAsLong("HTTP/1.0"); + private static long _http11VersionLong = GetAsciiStringAsLong("HTTP/1.1"); + + private const int PerfectHashDivisor = 37; + private static Tuple[] _knownStrings = new Tuple[PerfectHashDivisor]; + + static MemoryPoolIterator2Extensions() + { + _knownStrings[_httpConnectMethodLong % PerfectHashDivisor] = Tuple.Create(_httpConnectMethodLong, HttpConnectMethod); + _knownStrings[_httpDeleteMethodLong % PerfectHashDivisor] = Tuple.Create(_httpDeleteMethodLong, HttpDeleteMethod); + _knownStrings[_httpGetMethodLong % PerfectHashDivisor] = Tuple.Create(_httpGetMethodLong, HttpGetMethod); + _knownStrings[_httpHeadMethodLong % PerfectHashDivisor] = Tuple.Create(_httpHeadMethodLong, HttpHeadMethod); + _knownStrings[_httpPatchMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPatchMethodLong, HttpPatchMethod); + _knownStrings[_httpPostMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPostMethodLong, HttpPostMethod); + _knownStrings[_httpPutMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPutMethodLong, HttpPutMethod); + _knownStrings[_httpOptionsMethodLong % PerfectHashDivisor] = Tuple.Create(_httpOptionsMethodLong, HttpOptionsMethod); + _knownStrings[_httpTraceMethodLong % PerfectHashDivisor] = Tuple.Create(_httpTraceMethodLong, HttpTraceMethod); + _knownStrings[_http10VersionLong % PerfectHashDivisor] = Tuple.Create(_http10VersionLong, Http10Version); + _knownStrings[_http11VersionLong % PerfectHashDivisor] = Tuple.Create(_http11VersionLong, Http11Version); + } + + private unsafe static long GetAsciiStringAsLong(string str) + { + Debug.Assert(str.Length == 8, "String must be exactly 8 (ASCII) characters long."); + + var bytes = Encoding.ASCII.GetBytes(str); + + fixed (byte* ptr = bytes) + { + return *(long*)ptr; + } + } + private static unsafe string GetAsciiStringStack(byte[] input, int inputOffset, int length) { // avoid declaring other local vars, or doing work with stackalloc @@ -20,6 +77,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure return GetAsciiStringImplementation(output, input, inputOffset, length); } + private static unsafe string GetAsciiStringImplementation(char* output, byte[] input, int inputOffset, int length) { for (var i = 0; i < length; i++) @@ -203,5 +261,55 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure start.CopyTo(array, 0, length, out length); return new ArraySegment(array, 0, length); } + + /// + /// Checks that up to 8 bytes between and correspond to a known HTTP string. + /// + /// + /// A "known HTTP string" can be an HTTP method name defined in the HTTP/1.1 RFC or an HTTP version (HTTP/1.0 or HTTP/1.1). + /// Since all of those fit in at most 8 bytes, they can be optimally looked up by reading those bytes as a long. Once + /// in that format, uninteresting bits are cleared and the remaining long modulo 37 is looked up in a table. + /// The number 37 was chosen because that number allows for a perfect hash of the set of + /// "known strings" (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE, HTTP/1.0 and HTTP/1.1, where strings + /// with less than 8 characters have 0s appended to their ends to fill for the missing bytes). + /// + /// The iterator from which to start the known string lookup. + /// The iterator pointing to the end of the input string. + /// A reference to a pre-allocated known string, if the input matches any. + /// true if the input matches a known string, false otherwise. + public static bool GetKnownString(this MemoryPoolIterator2 begin, MemoryPoolIterator2 end, out string knownString) + { + knownString = null; + + // This optimization only works on little endian environments (for now). + if (!BitConverter.IsLittleEndian) + { + return false; + } + + var inputLength = begin.GetLength(end); + + if (inputLength > sizeof(long)) + { + return false; + } + + var inputLong = begin.PeekLong(); + + if (inputLong == -1) + { + return false; + } + + inputLong &= (long)(unchecked((ulong)~0) >> ((sizeof(long) - inputLength) * 8)); + + var value = _knownStrings[inputLong % PerfectHashDivisor]; + if (value != null && value.Item1 == inputLong) + { + knownString = value.Item2; + } + + return knownString != null; + } } } diff --git a/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs b/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs index b7d101c28c..7368bf08d9 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs @@ -128,5 +128,104 @@ namespace Microsoft.AspNet.Server.KestrelTests // Can't put anything by the end Assert.False(head.Put(0xFF)); } + + [Fact] + public void PeekLong() + { + // Arrange + var block = _pool.Lease(); + var bytes = BitConverter.GetBytes(0x0102030405060708); + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); + block.End += bytes.Length; + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Act + var result = scan.PeekLong(); + + // Assert + Assert.Equal(0x0102030405060708, result); + Assert.Equal(originalIndex, scan.Index); + } + + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + public void PeekLongAtBlockBoundary(int blockBytes) + { + // Arrange + var nextBlockBytes = 8 - blockBytes; + + var block = _pool.Lease(); + block.End += blockBytes; + + var nextBlock = _pool.Lease(); + nextBlock.End += nextBlockBytes; + + block.Next = nextBlock; + + var bytes = BitConverter.GetBytes(0x0102030405060708); + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, blockBytes); + Buffer.BlockCopy(bytes, blockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); + + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Act + var result = scan.PeekLong(); + + // Assert + Assert.Equal(0x0102030405060708, result); + Assert.Equal(originalIndex, scan.Index); + } + + [Theory] + [InlineData("CONNECT / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpConnectMethod)] + [InlineData("DELETE / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpDeleteMethod)] + [InlineData("GET / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpGetMethod)] + [InlineData("HEAD / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpHeadMethod)] + [InlineData("PATCH / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpPatchMethod)] + [InlineData("POST / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpPostMethod)] + [InlineData("PUT / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpPutMethod)] + [InlineData("OPTIONS / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpOptionsMethod)] + [InlineData("TRACE / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpTraceMethod)] + [InlineData("HTTP/1.0\r", '\r', true, MemoryPoolIterator2Extensions.Http10Version)] + [InlineData("HTTP/1.1\r", '\r', true, MemoryPoolIterator2Extensions.Http11Version)] + [InlineData("GET/ HTTP/1.1", ' ', false, null)] + [InlineData("get / HTTP/1.1", ' ', false, null)] + [InlineData("GOT / HTTP/1.1", ' ', false, null)] + [InlineData("ABC / HTTP/1.1", ' ', false, null)] + [InlineData("PO / HTTP/1.1", ' ', false, null)] + [InlineData("PO ST / HTTP/1.1", ' ', false, null)] + [InlineData("HTTP/1.0_\r", '\r', false, null)] + [InlineData("HTTP/1.1_\r", '\r', false, null)] + [InlineData("HTTP/3.0\r", '\r', false, null)] + [InlineData("http/1.0\r", '\r', false, null)] + [InlineData("http/1.1\r", '\r', false, null)] + [InlineData("short ", ' ', false, null)] + public void GetsKnownString(string input, char endChar, bool expectedResult, string expectedKnownString) + { + // Arrange + var block = _pool.Lease(); + var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); + Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); + block.End += chars.Length; + var begin = block.GetIterator(); + var end = begin; + end.Seek(endChar); + string knownString; + + // Act + var result = begin.GetKnownString(end, out knownString); + + // Assert + Assert.Equal(expectedResult, result); + Assert.Equal(expectedKnownString, knownString); + } } }