diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index a69c5fc0d2..36c93192b0 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -271,10 +271,10 @@ namespace Microsoft.AspNet.Server.Kestrel.Http _requestProcessingStarted = true; _requestProcessingTask = Task.Factory.StartNew( - (o) => ((Frame)o).RequestProcessingAsync(), - this, - CancellationToken.None, - TaskCreationOptions.DenyChildAttach, + (o) => ((Frame)o).RequestProcessingAsync(), + this, + CancellationToken.None, + TaskCreationOptions.DenyChildAttach, TaskScheduler.Default); } } @@ -710,19 +710,18 @@ namespace Microsoft.AspNet.Server.Kestrel.Http var consumed = scan; try { - var begin = scan; - if (scan.Seek(_vectorSpaces) == -1) - { - return false; - } - string method; - if (!begin.GetKnownString(scan, out method)) + var begin = scan; + if (!begin.GetKnownMethod(ref scan,out method)) { + if (scan.Seek(_vectorSpaces) == -1) + { + return false; + } method = begin.GetAsciiString(scan); + scan.Take(); } - scan.Take(); begin = scan; var needDecode = false; @@ -749,18 +748,19 @@ namespace Microsoft.AspNet.Server.Kestrel.Http scan.Take(); begin = scan; - if (scan.Seek(_vectorCRs) == -1) - { - return false; - } string httpVersion; - if (!begin.GetKnownString(scan, out httpVersion)) + if (!begin.GetKnownVersion(ref scan, out httpVersion)) { + scan = begin; + if (scan.Seek(_vectorCRs) == -1) + { + return false; + } httpVersion = begin.GetAsciiString(scan); - } - scan.Take(); + scan.Take(); + } if (scan.Take() != '\n') { return false; diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs index a63fafc149..645e0279d7 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs @@ -94,6 +94,43 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure } while (true); } + public void Skip(int bytesToSkip) + { + if (_block == null) + { + return; + } + var following = _block.End - _index; + if (following >= bytesToSkip) + { + _index += bytesToSkip; + return; + } + + var block = _block; + var index = _index; + while (true) + { + if (block.Next == null) + { + return; + } + else + { + bytesToSkip -= following; + block = block.Next; + index = block.Start; + } + following = block.End - index; + if (following >= bytesToSkip) + { + _block = block; + _index = index + bytesToSkip; + return; + } + } + } + public int Peek() { var block = _block; diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs index b91f455cb4..4b8c66dbbd 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2Extensions.cs @@ -27,35 +27,37 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure public const string Http11Version = "HTTP/1.1"; // readonly primitive statics can be Jit'd to consts https://github.com/dotnet/coreclr/issues/1079 - private readonly static long _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT\0"); - private readonly static long _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE\0\0"); - private readonly static long _httpGetMethodLong = GetAsciiStringAsLong("GET\0\0\0\0\0"); - private readonly static long _httpHeadMethodLong = GetAsciiStringAsLong("HEAD\0\0\0\0"); - private readonly static long _httpPatchMethodLong = GetAsciiStringAsLong("PATCH\0\0\0"); - private readonly static long _httpPostMethodLong = GetAsciiStringAsLong("POST\0\0\0\0"); - private readonly static long _httpPutMethodLong = GetAsciiStringAsLong("PUT\0\0\0\0\0"); - private readonly static long _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS\0"); - private readonly static long _httpTraceMethodLong = GetAsciiStringAsLong("TRACE\0\0\0"); + private readonly static long _httpConnectMethodLong = GetAsciiStringAsLong("CONNECT "); + private readonly static long _httpDeleteMethodLong = GetAsciiStringAsLong("DELETE \0"); + private readonly static long _httpGetMethodLong = GetAsciiStringAsLong("GET \0\0\0\0"); + private readonly static long _httpHeadMethodLong = GetAsciiStringAsLong("HEAD \0\0\0"); + private readonly static long _httpPatchMethodLong = GetAsciiStringAsLong("PATCH \0\0"); + private readonly static long _httpPostMethodLong = GetAsciiStringAsLong("POST \0\0\0"); + private readonly static long _httpPutMethodLong = GetAsciiStringAsLong("PUT \0\0\0\0"); + private readonly static long _httpOptionsMethodLong = GetAsciiStringAsLong("OPTIONS "); + private readonly static long _httpTraceMethodLong = GetAsciiStringAsLong("TRACE \0\0"); private readonly static long _http10VersionLong = GetAsciiStringAsLong("HTTP/1.0"); private readonly static long _http11VersionLong = GetAsciiStringAsLong("HTTP/1.1"); + + private readonly static long _mask8Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); + private readonly static long _mask7Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00 }); + private readonly static long _mask6Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00 }); + private readonly static long _mask5Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00 }); + private readonly static long _mask4Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00 }); - private const int PerfectHashDivisor = 37; - private static readonly Tuple[] _knownStrings = new Tuple[PerfectHashDivisor]; + private readonly static Tuple[] _knownMethods = new Tuple[8]; static MemoryPoolIterator2Extensions() { - _knownStrings[_httpConnectMethodLong % PerfectHashDivisor] = Tuple.Create(_httpConnectMethodLong, HttpConnectMethod); - _knownStrings[_httpDeleteMethodLong % PerfectHashDivisor] = Tuple.Create(_httpDeleteMethodLong, HttpDeleteMethod); - _knownStrings[_httpGetMethodLong % PerfectHashDivisor] = Tuple.Create(_httpGetMethodLong, HttpGetMethod); - _knownStrings[_httpHeadMethodLong % PerfectHashDivisor] = Tuple.Create(_httpHeadMethodLong, HttpHeadMethod); - _knownStrings[_httpPatchMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPatchMethodLong, HttpPatchMethod); - _knownStrings[_httpPostMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPostMethodLong, HttpPostMethod); - _knownStrings[_httpPutMethodLong % PerfectHashDivisor] = Tuple.Create(_httpPutMethodLong, HttpPutMethod); - _knownStrings[_httpOptionsMethodLong % PerfectHashDivisor] = Tuple.Create(_httpOptionsMethodLong, HttpOptionsMethod); - _knownStrings[_httpTraceMethodLong % PerfectHashDivisor] = Tuple.Create(_httpTraceMethodLong, HttpTraceMethod); - _knownStrings[_http10VersionLong % PerfectHashDivisor] = Tuple.Create(_http10VersionLong, Http10Version); - _knownStrings[_http11VersionLong % PerfectHashDivisor] = Tuple.Create(_http11VersionLong, Http11Version); + _knownMethods[0] = Tuple.Create(_mask4Chars, _httpPutMethodLong, HttpPutMethod); + _knownMethods[1] = Tuple.Create(_mask5Chars, _httpPostMethodLong, HttpPostMethod); + _knownMethods[2] = Tuple.Create(_mask5Chars, _httpHeadMethodLong, HttpHeadMethod); + _knownMethods[3] = Tuple.Create(_mask6Chars, _httpTraceMethodLong, HttpTraceMethod); + _knownMethods[4] = Tuple.Create(_mask6Chars, _httpPatchMethodLong, HttpPatchMethod); + _knownMethods[5] = Tuple.Create(_mask7Chars, _httpDeleteMethodLong, HttpDeleteMethod); + _knownMethods[6] = Tuple.Create(_mask8Chars, _httpConnectMethodLong, HttpConnectMethod); + _knownMethods[7] = Tuple.Create(_mask8Chars, _httpOptionsMethodLong, HttpOptionsMethod); } private unsafe static long GetAsciiStringAsLong(string str) @@ -69,6 +71,15 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure return *(long*)ptr; } } + private unsafe static long GetMaskAsLong(byte[] bytes) + { + Debug.Assert(bytes.Length == 8, "Mask must be exactly 8 bytes long."); + + fixed (byte* ptr = bytes) + { + return *(long*)ptr; + } + } private static unsafe string GetAsciiStringStack(byte[] input, int inputOffset, int length) { @@ -264,53 +275,84 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure } /// - /// Checks that up to 8 bytes between and correspond to a known HTTP string. + /// Checks that up to 8 bytes from correspond to a known HTTP method. /// /// - /// A "known HTTP string" can be an HTTP method name defined in the HTTP/1.1 RFC or an HTTP version (HTTP/1.0 or HTTP/1.1). + /// A "known HTTP method" can be an HTTP method name defined in the HTTP/1.1 RFC. /// Since all of those fit in at most 8 bytes, they can be optimally looked up by reading those bytes as a long. Once - /// in that format, uninteresting bits are cleared and the remaining long modulo 37 is looked up in a table. - /// The number 37 was chosen because that number allows for a perfect hash of the set of - /// "known strings" (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE, HTTP/1.0 and HTTP/1.1, where strings - /// with less than 8 characters have 0s appended to their ends to fill for the missing bytes). + /// in that format, it can be checked against the known method. + /// The Known Methods (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE) are all less than 8 bytes + /// and will be compared with the required space. A mask is used if the Known method is less than 8 bytes. + /// To optimize performance the GET method will be checked first. /// /// The iterator from which to start the known string lookup. - /// The iterator pointing to the end of the input string. - /// A reference to a pre-allocated known string, if the input matches any. + /// If we found a valid method, then scan will be updated to new position + /// A reference to a pre-allocated known string, if the input matches any. /// true if the input matches a known string, false otherwise. - public static bool GetKnownString(this MemoryPoolIterator2 begin, MemoryPoolIterator2 end, out string knownString) + public static bool GetKnownMethod(this MemoryPoolIterator2 begin, ref MemoryPoolIterator2 scan, out string knownMethod) { - knownString = null; + knownMethod = null; + var value = begin.PeekLong(); - // This optimization only works on little endian environments (for now). - if (!BitConverter.IsLittleEndian) + if ((value & _mask4Chars) == _httpGetMethodLong) { - return false; + knownMethod = HttpGetMethod; + scan.Skip(4); + return true; + } + foreach (var x in _knownMethods) + { + if ((value & x.Item1) == x.Item2) + { + knownMethod = x.Item3; + scan.Skip(knownMethod.Length + 1); + return true; + } } - var inputLength = begin.GetLength(end); + return false; + } - if (inputLength > sizeof(long)) + /// + /// Checks 9 bytes from correspond to a known HTTP version. + /// + /// + /// A "known HTTP version" Is is either HTTP/1.0 or HTTP/1.1. + /// Since those fit in 8 bytes, they can be optimally looked up by reading those bytes as a long. Once + /// in that format, it can be checked against the known versions. + /// The Known versions will be checked with the required '\r'. + /// To optimize performance the HTTP/1.1 will be checked first. + /// + /// The iterator from which to start the known string lookup. + /// If we found a valid method, then scan will be updated to new position + /// A reference to a pre-allocated known string, if the input matches any. + /// true if the input matches a known string, false otherwise. + public static bool GetKnownVersion(this MemoryPoolIterator2 begin, ref MemoryPoolIterator2 scan, out string knownVersion) + { + knownVersion = null; + var value = begin.PeekLong(); + + if (value == _http11VersionLong) { - return false; + knownVersion = Http11Version; + scan.Skip(8); + if (scan.Take() == '\r') + { + return true; + } + } + else if (value == _http10VersionLong) + { + knownVersion = Http10Version; + scan.Skip(8); + if (scan.Take() == '\r') + { + return true; + } } - var inputLong = begin.PeekLong(); - - if (inputLong == -1) - { - return false; - } - - inputLong &= (long)(unchecked((ulong)~0) >> ((sizeof(long) - inputLength) * 8)); - - var value = _knownStrings[inputLong % PerfectHashDivisor]; - if (value != null && value.Item1 == inputLong) - { - knownString = value.Item2; - } - - return knownString != null; + knownVersion = null; + return false; } } -} +} \ No newline at end of file diff --git a/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs b/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs index 67c4a7e6ae..d7b9f494eb 100644 --- a/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs +++ b/test/Microsoft.AspNet.Server.KestrelTests/MemoryPoolIterator2Tests.cs @@ -164,7 +164,7 @@ namespace Microsoft.AspNet.Server.KestrelTests var block = _pool.Lease(); block.End += blockBytes; - + var nextBlock = _pool.Lease(); nextBlock.End += nextBlockBytes; @@ -185,6 +185,45 @@ namespace Microsoft.AspNet.Server.KestrelTests Assert.Equal(originalIndex, scan.Index); } + [Theory] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + [InlineData(5)] + [InlineData(6)] + [InlineData(7)] + [InlineData(8)] + [InlineData(9)] + public void SkipAtBlockBoundary(int blockBytes) + { + // Arrange + var nextBlockBytes = 10 - blockBytes; + + var block = _pool.Lease(); + block.End += blockBytes; + + var nextBlock = _pool.Lease(); + nextBlock.End += nextBlockBytes; + + block.Next = nextBlock; + + var bytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 }; + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, blockBytes); + Buffer.BlockCopy(bytes, blockBytes, nextBlock.Array, nextBlock.Start, nextBlockBytes); + + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Act + scan.Skip(8); + var result = scan.Take(); + + // Assert + Assert.Equal(0x08, result); + Assert.NotEqual(originalIndex, scan.Index); + } + [Theory] [InlineData("CONNECT / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpConnectMethod)] [InlineData("DELETE / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpDeleteMethod)] @@ -195,38 +234,55 @@ namespace Microsoft.AspNet.Server.KestrelTests [InlineData("PUT / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpPutMethod)] [InlineData("OPTIONS / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpOptionsMethod)] [InlineData("TRACE / HTTP/1.1", ' ', true, MemoryPoolIterator2Extensions.HttpTraceMethod)] - [InlineData("HTTP/1.0\r", '\r', true, MemoryPoolIterator2Extensions.Http10Version)] - [InlineData("HTTP/1.1\r", '\r', true, MemoryPoolIterator2Extensions.Http11Version)] [InlineData("GET/ HTTP/1.1", ' ', false, null)] [InlineData("get / HTTP/1.1", ' ', false, null)] [InlineData("GOT / HTTP/1.1", ' ', false, null)] [InlineData("ABC / HTTP/1.1", ' ', false, null)] [InlineData("PO / HTTP/1.1", ' ', false, null)] [InlineData("PO ST / HTTP/1.1", ' ', false, null)] - [InlineData("HTTP/1.0_\r", '\r', false, null)] - [InlineData("HTTP/1.1_\r", '\r', false, null)] - [InlineData("HTTP/3.0\r", '\r', false, null)] - [InlineData("http/1.0\r", '\r', false, null)] - [InlineData("http/1.1\r", '\r', false, null)] [InlineData("short ", ' ', false, null)] - public void GetsKnownString(string input, char endChar, bool expectedResult, string expectedKnownString) + public void GetsKnownMethod(string input, char endChar, bool expectedResult, string expectedKnownString) { // Arrange var block = _pool.Lease(); var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; - var begin = block.GetIterator(); - var end = begin; - end.Seek(new Vector((byte)endChar)); + var scan = block.GetIterator(); + var begin = scan; string knownString; // Act - var result = begin.GetKnownString(end, out knownString); + var result = begin.GetKnownMethod(ref scan, out knownString); // Assert Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString); } + + [Theory] + [InlineData("HTTP/1.0\r", '\r', true, MemoryPoolIterator2Extensions.Http10Version)] + [InlineData("HTTP/1.1\r", '\r', true, MemoryPoolIterator2Extensions.Http11Version)] + [InlineData("HTTP/3.0\r", '\r', false, null)] + [InlineData("http/1.0\r", '\r', false, null)] + [InlineData("http/1.1\r", '\r', false, null)] + [InlineData("short ", ' ', false, null)] + public void GetsKnownVersion(string input, char endChar, bool expectedResult, string expectedKnownString) + { + // Arrange + var block = _pool.Lease(); + var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); + Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); + block.End += chars.Length; + var scan = block.GetIterator(); + var begin = scan; + string knownString; + + // Act + var result = begin.GetKnownVersion(ref scan, out knownString); + // Assert + Assert.Equal(expectedResult, result); + Assert.Equal(expectedKnownString, knownString); + } } -} +} \ No newline at end of file