Convert TakeStartLine and TakeMessageHeaders to be state machines (#1401)

- Less passes over the buffer
- Single pass to find all start line delimiters instead
of calling IndexOf multiple times.
- Made TakeStartLine and TakeMessageHeaders a state machine
- Only check length against remaining bytes once
- Change variable names to match TakeStartLine
- Use ReadableBuffer.First.Span instead of ToSpan()
- Added test for missing path with a querystring
This commit is contained in:
David Fowler 2017-02-27 11:55:30 -08:00 committed by GitHub
parent 5692f51bf7
commit c6705d8693
2 changed files with 327 additions and 200 deletions

View File

@ -982,10 +982,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
Output.ProducingComplete(end);
}
public bool TakeStartLine(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined)
public unsafe bool TakeStartLine(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined)
{
var start = buffer.Start;
var end = buffer.Start;
var bufferEnd = buffer.End;
examined = buffer.End;
consumed = buffer.Start;
@ -997,19 +998,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_requestProcessingStatus = RequestProcessingStatus.RequestStarted;
ReadableBuffer limitedBuffer;
var overLength = false;
if (buffer.Length >= ServerOptions.Limits.MaxRequestLineSize)
{
limitedBuffer = buffer.Slice(0, ServerOptions.Limits.MaxRequestLineSize);
}
else
{
limitedBuffer = buffer;
bufferEnd = buffer.Move(start, ServerOptions.Limits.MaxRequestLineSize);
overLength = true;
}
if (ReadCursorOperations.Seek(limitedBuffer.Start, limitedBuffer.End, out end, ByteLF) == -1)
if (ReadCursorOperations.Seek(start, bufferEnd, out end, ByteLF) == -1)
{
if (limitedBuffer.Length == ServerOptions.Limits.MaxRequestLineSize)
if (overLength)
{
RejectRequest(RequestRejectionReason.RequestLineTooLong);
}
@ -1021,26 +1020,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
const int stackAllocLimit = 512;
// Move 1 byte past the \r
end = limitedBuffer.Move(end, 1);
var startLineBuffer = limitedBuffer.Slice(0, end);
// Move 1 byte past the \n
end = buffer.Move(end, 1);
var startLineBuffer = buffer.Slice(start, end);
Span<byte> span;
if (startLineBuffer.IsSingleSpan)
{
// No copies, directly use the one and only span
span = startLineBuffer.ToSpan();
span = startLineBuffer.First.Span;
}
else if (startLineBuffer.Length < stackAllocLimit)
{
unsafe
{
// Multiple buffers and < stackAllocLimit, copy into a stack buffer
byte* stackBuffer = stackalloc byte[startLineBuffer.Length];
span = new Span<byte>(stackBuffer, startLineBuffer.Length);
startLineBuffer.CopyTo(span);
}
// Multiple buffers and < stackAllocLimit, copy into a stack buffer
byte* stackBuffer = stackalloc byte[startLineBuffer.Length];
span = new Span<byte>(stackBuffer, startLineBuffer.Length);
startLineBuffer.CopyTo(span);
}
else
{
@ -1049,113 +1045,163 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
startLineBuffer.CopyTo(span);
}
var methodEnd = 0;
if (!span.GetKnownMethod(out string method))
var needDecode = false;
var pathStart = -1;
var queryStart = -1;
var queryEnd = -1;
var pathEnd = -1;
var versionStart = -1;
var queryString = "";
var httpVersion = "";
var method = "";
var state = StartLineState.KnownMethod;
fixed (byte* data = &span.DangerousGetPinnableReference())
{
methodEnd = span.IndexOf(ByteSpace);
if (methodEnd == -1)
var length = span.Length;
for (var i = 0; i < length; i++)
{
RejectRequestLine(start, end);
}
var ch = data[i];
method = span.Slice(0, methodEnd).GetAsciiString();
if (method == null)
{
RejectRequestLine(start, end);
}
// Note: We're not in the fast path any more (GetKnownMethod should have handled any HTTP Method we're aware of)
// So we can be a tiny bit slower and more careful here.
for (int i = 0; i < method.Length; i++)
{
if (!IsValidTokenChar(method[i]))
switch (state)
{
RejectRequestLine(start, end);
case StartLineState.KnownMethod:
if (span.GetKnownMethod(out method))
{
// Update the index, current char, state and jump directly
// to the next state
i += method.Length + 1;
ch = data[i];
state = StartLineState.Path;
goto case StartLineState.Path;
}
state = StartLineState.UnknownMethod;
goto case StartLineState.UnknownMethod;
case StartLineState.UnknownMethod:
if (ch == ByteSpace)
{
method = span.Slice(0, i).GetAsciiString();
if (method == null)
{
RejectRequestLine(start, end);
}
state = StartLineState.Path;
}
else if (!IsValidTokenChar((char)ch))
{
RejectRequestLine(start, end);
}
break;
case StartLineState.Path:
if (ch == ByteSpace)
{
pathEnd = i;
if (pathStart == -1)
{
// Empty path is illegal
RejectRequestLine(start, end);
}
// No query string found
queryStart = queryEnd = i;
state = StartLineState.KnownVersion;
}
else if (ch == ByteQuestionMark)
{
pathEnd = i;
if (pathStart == -1)
{
// Empty path is illegal
RejectRequestLine(start, end);
}
queryStart = i;
state = StartLineState.QueryString;
}
else if (ch == BytePercentage)
{
needDecode = true;
}
if (pathStart == -1)
{
pathStart = i;
}
break;
case StartLineState.QueryString:
if (ch == ByteSpace)
{
queryEnd = i;
state = StartLineState.KnownVersion;
queryString = span.Slice(queryStart, queryEnd - queryStart).GetAsciiString() ?? string.Empty;
}
break;
case StartLineState.KnownVersion:
// REVIEW: We don't *need* to slice here but it makes the API
// nicer, slicing should be free :)
if (span.Slice(i).GetKnownVersion(out httpVersion))
{
// Update the index, current char, state and jump directly
// to the next state
i += httpVersion.Length + 1;
ch = data[i];
state = StartLineState.NewLine;
goto case StartLineState.NewLine;
}
versionStart = i;
state = StartLineState.UnknownVersion;
goto case StartLineState.UnknownVersion;
case StartLineState.UnknownVersion:
if (ch == ByteCR)
{
var versionSpan = span.Slice(versionStart, i - versionStart);
if (versionSpan.Length == 0)
{
RejectRequestLine(start, end);
}
else
{
RejectRequest(RequestRejectionReason.UnrecognizedHTTPVersion, versionSpan.GetAsciiStringEscaped());
}
}
break;
case StartLineState.NewLine:
if (ch != ByteLF)
{
RejectRequestLine(start, end);
}
state = StartLineState.Complete;
break;
case StartLineState.Complete:
break;
default:
break;
}
}
}
else
{
methodEnd += method.Length;
}
var needDecode = false;
var pathBegin = methodEnd + 1;
var pathToEndSpan = span.Slice(pathBegin, span.Length - pathBegin);
pathBegin = 0;
// TODO: IndexOfAny
var spaceIndex = pathToEndSpan.IndexOf(ByteSpace);
var questionMarkIndex = pathToEndSpan.IndexOf(ByteQuestionMark);
var percentageIndex = pathToEndSpan.IndexOf(BytePercentage);
var pathEnd = MinNonZero(spaceIndex, questionMarkIndex, percentageIndex);
if (spaceIndex == -1 && questionMarkIndex == -1 && percentageIndex == -1)
{
RejectRequestLine(start, end);
}
else if (percentageIndex != -1)
{
needDecode = true;
pathEnd = MinNonZero(spaceIndex, questionMarkIndex);
if (questionMarkIndex == -1 && spaceIndex == -1)
{
RejectRequestLine(start, end);
}
}
var queryString = "";
var queryEnd = pathEnd;
if (questionMarkIndex != -1)
{
queryEnd = spaceIndex;
if (spaceIndex == -1)
{
RejectRequestLine(start, end);
}
queryString = pathToEndSpan.Slice(pathEnd, queryEnd - pathEnd).GetAsciiString();
}
if (pathBegin == pathEnd)
if (state != StartLineState.Complete)
{
RejectRequestLine(start, end);
}
var versionBegin = queryEnd + 1;
var versionToEndSpan = pathToEndSpan.Slice(versionBegin, pathToEndSpan.Length - versionBegin);
versionBegin = 0;
var versionEnd = versionToEndSpan.IndexOf(ByteCR);
if (versionEnd == -1)
{
RejectRequestLine(start, end);
}
if (!versionToEndSpan.GetKnownVersion(out string httpVersion))
{
httpVersion = versionToEndSpan.Slice(0, versionEnd).GetAsciiStringEscaped();
if (httpVersion == string.Empty)
{
RejectRequestLine(start, end);
}
else
{
RejectRequest(RequestRejectionReason.UnrecognizedHTTPVersion, httpVersion);
}
}
if (versionToEndSpan[versionEnd + 1] != ByteLF)
{
RejectRequestLine(start, end);
}
var pathBuffer = pathToEndSpan.Slice(pathBegin, pathEnd);
var targetBuffer = pathToEndSpan.Slice(pathBegin, queryEnd);
var pathBuffer = span.Slice(pathStart, pathEnd - pathStart);
var targetBuffer = span.Slice(pathStart, queryEnd - pathStart);
// URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11
// Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8;
@ -1198,8 +1244,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
RawTarget = rawTarget;
HttpVersion = httpVersion;
bool caseMatches;
if (RequestUrlStartsWithPathBase(normalizedTarget, out caseMatches))
if (RequestUrlStartsWithPathBase(normalizedTarget, out bool caseMatches))
{
PathBase = caseMatches ? _pathBase : normalizedTarget.Substring(0, _pathBase.Length);
Path = normalizedTarget.Substring(_pathBase.Length);
@ -1215,24 +1260,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
QueryString = string.Empty;
}
return true;
}
private int MinNonZero(int v1, int v2)
{
v1 = v1 == -1 ? int.MaxValue : v1;
v2 = v2 == -1 ? int.MaxValue : v2;
return Math.Min(v1, v2);
}
private int MinNonZero(int v1, int v2, int v3)
{
v1 = v1 == -1 ? int.MaxValue : v1;
v2 = v2 == -1 ? int.MaxValue : v2;
v3 = v3 == -1 ? int.MaxValue : v3;
return Math.Min(Math.Min(v1, v2), v3);
}
private void RejectRequestLine(ReadCursor start, ReadCursor end)
{
const int MaxRequestLineError = 32;
@ -1297,13 +1328,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
return true;
}
public bool TakeMessageHeaders(ReadableBuffer buffer, FrameRequestHeaders requestHeaders, out ReadCursor consumed, out ReadCursor examined)
public unsafe bool TakeMessageHeaders(ReadableBuffer buffer, FrameRequestHeaders requestHeaders, out ReadCursor consumed, out ReadCursor examined)
{
consumed = buffer.Start;
examined = buffer.End;
var bufferLength = buffer.Length;
var bufferEnd = buffer.End;
var reader = new ReadableBufferReader(buffer);
// Make sure the buffer is limited
var overLength = false;
if (buffer.Length >= _remainingRequestHeadersBytesAllowed)
{
bufferEnd = buffer.Move(consumed, _remainingRequestHeadersBytesAllowed);
// If we sliced it means the current buffer bigger than what we're
// allowed to look at
overLength = true;
}
while (true)
{
@ -1349,24 +1391,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
// Reset the reader since we're not at the end of headers
reader = start;
// Now parse a single header
ReadableBuffer limitedBuffer;
var overLength = false;
if (bufferLength >= _remainingRequestHeadersBytesAllowed)
{
limitedBuffer = buffer.Slice(consumed, _remainingRequestHeadersBytesAllowed);
// If we sliced it means the current buffer bigger than what we're
// allowed to look at
overLength = true;
}
else
{
limitedBuffer = buffer;
}
if (ReadCursorOperations.Seek(consumed, limitedBuffer.End, out var lineEnd, ByteLF) == -1)
if (ReadCursorOperations.Seek(consumed, bufferEnd, out var lineEnd, ByteLF) == -1)
{
// We didn't find a \n in the current buffer and we had to slice it so it's an issue
if (overLength)
@ -1381,28 +1406,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
const int stackAllocLimit = 512;
if (lineEnd != limitedBuffer.End)
if (lineEnd != bufferEnd)
{
lineEnd = limitedBuffer.Move(lineEnd, 1);
lineEnd = buffer.Move(lineEnd, 1);
}
var headerBuffer = limitedBuffer.Slice(consumed, lineEnd);
var headerBuffer = buffer.Slice(consumed, lineEnd);
Span<byte> span;
if (headerBuffer.IsSingleSpan)
{
// No copies, directly use the one and only span
span = headerBuffer.ToSpan();
span = headerBuffer.First.Span;
}
else if (headerBuffer.Length < stackAllocLimit)
{
unsafe
{
// Multiple buffers and < stackAllocLimit, copy into a stack buffer
byte* stackBuffer = stackalloc byte[headerBuffer.Length];
span = new Span<byte>(stackBuffer, headerBuffer.Length);
headerBuffer.CopyTo(span);
}
// Multiple buffers and < stackAllocLimit, copy into a stack buffer
byte* stackBuffer = stackalloc byte[headerBuffer.Length];
span = new Span<byte>(stackBuffer, headerBuffer.Length);
headerBuffer.CopyTo(span);
}
else
{
@ -1411,59 +1433,139 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
headerBuffer.CopyTo(span);
}
int endNameIndex = span.IndexOf(ByteColon);
if (endNameIndex == -1)
var state = HeaderState.Name;
var nameStart = 0;
var nameEnd = -1;
var valueStart = -1;
var valueEnd = -1;
var nameHasWhitespace = false;
var previouslyWhitespace = false;
var headerLineLength = span.Length;
fixed (byte* data = &span.DangerousGetPinnableReference())
{
for (var i = 0; i < headerLineLength; i++)
{
var ch = data[i];
switch (state)
{
case HeaderState.Name:
if (ch == ByteColon)
{
if (nameHasWhitespace)
{
RejectRequest(RequestRejectionReason.WhitespaceIsNotAllowedInHeaderName);
}
state = HeaderState.Whitespace;
nameEnd = i;
}
if (ch == ByteSpace || ch == ByteTab)
{
nameHasWhitespace = true;
}
break;
case HeaderState.Whitespace:
{
var whitespace = ch == ByteTab || ch == ByteSpace || ch == ByteCR;
if (!whitespace)
{
// Mark the first non whitespace char as the start of the
// header value and change the state to expect to the header value
valueStart = i;
state = HeaderState.ExpectValue;
}
// If we see a CR then jump to the next state directly
else if (ch == ByteCR)
{
state = HeaderState.ExpectValue;
goto case HeaderState.ExpectValue;
}
}
break;
case HeaderState.ExpectValue:
{
var whitespace = ch == ByteTab || ch == ByteSpace;
if (whitespace)
{
if (!previouslyWhitespace)
{
// If we see a whitespace char then maybe it's end of the
// header value
valueEnd = i;
}
}
else if (ch == ByteCR)
{
// If we see a CR and we haven't ever seen whitespace then
// this is the end of the header value
if (valueEnd == -1)
{
valueEnd = i;
}
// We never saw a non whitespace character before the CR
if (valueStart == -1)
{
valueStart = valueEnd;
}
state = HeaderState.ExpectNewLine;
}
else
{
// If we find a non whitespace char that isn't CR then reset the end index
valueEnd = -1;
}
previouslyWhitespace = whitespace;
}
break;
case HeaderState.ExpectNewLine:
if (ch != ByteLF)
{
RejectRequest(RequestRejectionReason.HeaderValueMustNotContainCR);
}
state = HeaderState.Complete;
break;
default:
break;
}
}
}
if (state == HeaderState.Name)
{
RejectRequest(RequestRejectionReason.NoColonCharacterFoundInHeaderLine);
}
var nameBuffer = span.Slice(0, endNameIndex);
if (nameBuffer.IndexOf(ByteSpace) != -1 || nameBuffer.IndexOf(ByteTab) != -1)
{
RejectRequest(RequestRejectionReason.WhitespaceIsNotAllowedInHeaderName);
}
int endValueIndex = span.IndexOf(ByteCR);
if (endValueIndex == -1)
if (state == HeaderState.ExpectValue || state == HeaderState.Whitespace)
{
RejectRequest(RequestRejectionReason.MissingCRInHeaderLine);
}
var lineSuffix = span.Slice(endValueIndex);
if (lineSuffix.Length < 2)
if (state != HeaderState.Complete)
{
return false;
}
// This check and MissingCRInHeaderLine is a bit backwards, we should do it at once instead of having another seek
if (lineSuffix[1] != ByteLF)
{
RejectRequest(RequestRejectionReason.HeaderValueMustNotContainCR);
}
// Trim trailing whitespace from header value by repeatedly advancing to next
// whitespace or CR.
//
// - If CR is found, this is the end of the header value.
// - If whitespace is found, this is the _tentative_ end of the header value.
// If non-whitespace is found after it and it's not CR, seek again to the next
// whitespace or CR for a new (possibly tentative) end of value.
var valueBuffer = span.Slice(endNameIndex + 1, endValueIndex - (endNameIndex + 1));
// TODO: Trim else where
var value = valueBuffer.GetAsciiString()?.Trim() ?? string.Empty;
var headerLineLength = span.Length;
// -1 so that we can re-check the extra \r
// Skip the reader forward past the header line
reader.Skip(headerLineLength);
// Before accepting the header line, we need to see at least one character
// > so we can make sure there's no space or tab
var next = reader.Peek();
// We cant check for line continuations to reject everything we've done so far
// TODO: We don't need to reject the line here, we can use the state machine
// to store the fact that we're reading a header value
if (next == -1)
{
// If we can't see the next char then reject the entire line
return false;
}
@ -1489,10 +1591,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
RejectRequest(RequestRejectionReason.HeaderValueLineFoldingNotSupported);
}
var nameBuffer = span.Slice(nameStart, nameEnd - nameStart);
var valueBuffer = span.Slice(valueStart, valueEnd - valueStart);
var value = valueBuffer.GetAsciiString() ?? string.Empty;
// Update the frame state only after we know there's no header line continuation
_remainingRequestHeadersBytesAllowed -= headerLineLength;
bufferLength -= headerLineLength;
_requestHeadersParsed++;
requestHeaders.Append(nameBuffer, value);
@ -1638,5 +1743,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
RequestStarted,
ResponseStarted
}
private enum StartLineState
{
KnownMethod,
UnknownMethod,
Path,
QueryString,
KnownVersion,
UnknownVersion,
NewLine,
Complete
}
private enum HeaderState
{
Name,
Whitespace,
ExpectValue,
ExpectNewLine,
Complete
}
}
}

View File

@ -572,6 +572,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
[InlineData("GET HTTP/1.1\r\n", "Invalid request line: GET HTTP/1.1<0x0D><0x0A>")]
[InlineData("GET / HTTP/1.1\n", "Invalid request line: GET / HTTP/1.1<0x0A>")]
[InlineData("GET / \r\n", "Invalid request line: GET / <0x0D><0x0A>")]
[InlineData("GET ? HTTP/1.1\r\n", "Invalid request line: GET ? HTTP/1.1<0x0D><0x0A>")]
[InlineData("GET / HTTP/1.1\ra\n", "Invalid request line: GET / HTTP/1.1<0x0D>a<0x0A>")]
public async Task TakeStartLineThrowsWhenInvalid(string requestLine, string expectedExceptionMessage)
{