From 3186e1bd72df577274f0e8bd63aa50c511bff4b4 Mon Sep 17 00:00:00 2001 From: Cesar Blum Silveira Date: Thu, 28 Apr 2016 17:39:02 -0700 Subject: [PATCH] Make TakeStartLine more robust (#683). --- .../Http/Frame.cs | 175 +++++++++++++----- .../Http/FrameOfT.cs | 113 ++++++----- .../Http/MessageBody.cs | 29 +-- .../Infrastructure/MemoryPoolIterator.cs | 14 +- .../MemoryPoolIteratorExtensions.cs | 35 ++-- .../BadHttpRequestTests.cs | 72 +++++++ .../ChunkedRequestTests.cs | 10 +- .../EngineTests.cs | 78 ++++++-- .../MemoryPoolIteratorTests.cs | 44 ++++- 9 files changed, 397 insertions(+), 173 deletions(-) create mode 100644 test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs index 4ed675767c..feeca29db9 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs @@ -35,6 +35,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private static readonly byte[] _bytesContentLengthZero = Encoding.ASCII.GetBytes("\r\nContent-Length: 0"); private static readonly byte[] _bytesSpace = Encoding.ASCII.GetBytes(" "); private static readonly byte[] _bytesEndHeaders = Encoding.ASCII.GetBytes("\r\n\r\n"); + private static readonly int _httpVersionLength = "HTTP/1.*".Length; private static Vector _vectorCRs = new Vector((byte)'\r'); private static Vector _vectorColons = new Vector((byte)':'); @@ -45,7 +46,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private readonly object _onStartingSync = new Object(); private readonly object _onCompletedSync = new Object(); - protected bool _corruptedRequest = false; + private bool _requestRejected; private Headers _frameHeaders; private Streams _frameStreams; @@ -60,7 +61,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http protected CancellationTokenSource _abortedCts; protected CancellationToken? _manuallySetRequestAbortToken; - protected bool _responseStarted; + protected RequestProcessingStatus _requestProcessingStatus; protected bool _keepAlive; private bool _autoChunk; protected Exception _applicationException; @@ -96,7 +97,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { return "HTTP/1.0"; } - return ""; + return string.Empty; } set { @@ -167,9 +168,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return cts; } } + public bool HasResponseStarted { - get { return _responseStarted; } + get { return _requestProcessingStatus == RequestProcessingStatus.ResponseStarted; } } protected FrameRequestHeaders FrameRequestHeaders => _frameHeaders.RequestHeaders; @@ -216,7 +218,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _onStarting = null; _onCompleted = null; - _responseStarted = false; + _requestProcessingStatus = RequestProcessingStatus.RequestPending; _keepAlive = false; _autoChunk = false; _applicationException = null; @@ -446,7 +448,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public Task WriteAsync(ArraySegment data, CancellationToken cancellationToken) { - if (!_responseStarted) + if (!HasResponseStarted) { return WriteAsyncAwaited(data, cancellationToken); } @@ -506,7 +508,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void ProduceContinue() { - if (_responseStarted) return; + if (HasResponseStarted) + { + return; + } StringValues expect; if (_httpVersion == HttpVersionType.Http1_1 && @@ -519,7 +524,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public Task ProduceStartAndFireOnStarting() { - if (_responseStarted) return TaskUtilities.CompletedTask; + if (HasResponseStarted) + { + return TaskUtilities.CompletedTask; + } if (_onStarting != null) { @@ -554,30 +562,49 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private void ProduceStart(bool appCompleted) { - if (_responseStarted) return; - _responseStarted = true; + if (HasResponseStarted) + { + return; + } + + _requestProcessingStatus = RequestProcessingStatus.ResponseStarted; var statusBytes = ReasonPhrases.ToStatusBytes(StatusCode, ReasonPhrase); CreateResponseHeader(statusBytes, appCompleted); } + protected Task TryProduceInvalidRequestResponse() + { + if (_requestProcessingStatus == RequestProcessingStatus.RequestStarted && _requestRejected) + { + if (_frameHeaders == null) + { + InitializeHeaders(); + } + + return ProduceEnd(); + } + + return TaskUtilities.CompletedTask; + } + protected Task ProduceEnd() { - if (_corruptedRequest || _applicationException != null) + if (_requestRejected || _applicationException != null) { - if (_corruptedRequest) + if (_requestRejected) { // 400 Bad Request StatusCode = 400; - } + } else { // 500 Internal Server Error StatusCode = 500; } - if (_responseStarted) + if (HasResponseStarted) { // We can no longer respond with a 500, so we simply close the connection. _requestProcessingStopping = true; @@ -601,7 +628,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } } - if (!_responseStarted) + if (!HasResponseStarted) { return ProduceEndAwaited(); } @@ -709,31 +736,50 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http SocketOutput.ProducingComplete(end); } - protected bool TakeStartLine(SocketInput input) + protected RequestLineStatus TakeStartLine(SocketInput input) { var scan = input.ConsumingStart(); var consumed = scan; + try { + // We may hit this when the client has stopped sending data but + // the connection hasn't closed yet, and therefore Frame.Stop() + // hasn't been called yet. + if (scan.Peek() == -1) + { + return RequestLineStatus.Empty; + } + + _requestProcessingStatus = RequestProcessingStatus.RequestStarted; + string method; var begin = scan; - if (!begin.GetKnownMethod(ref scan, out method)) + if (!begin.GetKnownMethod(out method)) { if (scan.Seek(ref _vectorSpaces) == -1) { - return false; + return RequestLineStatus.MethodIncomplete; } + method = begin.GetAsciiString(scan); - scan.Take(); + if (method == null) + { + RejectRequest("Missing method."); + } + } + else + { + scan.Skip(method.Length); } + scan.Take(); begin = scan; - var needDecode = false; var chFound = scan.Seek(ref _vectorSpaces, ref _vectorQuestionMarks, ref _vectorPercentages); if (chFound == -1) { - return false; + return RequestLineStatus.TargetIncomplete; } else if (chFound == '%') { @@ -741,7 +787,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http chFound = scan.Seek(ref _vectorSpaces, ref _vectorQuestionMarks); if (chFound == -1) { - return false; + return RequestLineStatus.TargetIncomplete; } } @@ -752,35 +798,61 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http if (chFound == '?') { begin = scan; - if (scan.Seek(ref _vectorSpaces) != ' ') + if (scan.Seek(ref _vectorSpaces) == -1) { - return false; + return RequestLineStatus.TargetIncomplete; } queryString = begin.GetAsciiString(scan); } + if (pathBegin.Peek() == ' ') + { + RejectRequest("Missing request target."); + } + scan.Take(); begin = scan; + if (scan.Seek(ref _vectorCRs) == -1) + { + return RequestLineStatus.VersionIncomplete; + } string httpVersion; - if (!begin.GetKnownVersion(ref scan, out httpVersion)) + if (!begin.GetKnownVersion(out httpVersion)) { - scan = begin; - if (scan.Seek(ref _vectorCRs) == -1) - { - return false; - } + // A slower fallback is necessary since the iterator's PeekLong() method + // used in GetKnownVersion() only examines two memory blocks at most. + // Although unlikely, it is possible that the 8 bytes forming the version + // could be spread out on more than two blocks, if the connection + // happens to be unusually slow. httpVersion = begin.GetAsciiString(scan); - scan.Take(); - } - if (scan.Take() != '\n') - { - return false; + if (httpVersion == null) + { + RejectRequest("Missing HTTP version."); + } + else if (httpVersion != "HTTP/1.0" && httpVersion != "HTTP/1.1") + { + RejectRequest("Malformed request."); + } } - // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 - // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; + // HttpVersion must be set here to send correct response when request is rejected + HttpVersion = httpVersion; + + scan.Take(); + var next = scan.Take(); + if (next == -1) + { + return RequestLineStatus.Incomplete; + } + else if (next != '\n') + { + RejectRequest("Missing LF in request line."); + } + + // URIs are always encoded/escaped to ASCII https://tools.ietf.org/html/rfc3986#page-11 + // Multibyte Internationalized Resource Identifiers (IRIs) are first converted to utf8; // then encoded/escaped to ASCII https://www.ietf.org/rfc/rfc3987.txt "Mapping of IRIs to URIs" string requestUrlPath; if (needDecode) @@ -802,7 +874,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http Method = method; RequestUri = requestUrlPath; QueryString = queryString; - HttpVersion = httpVersion; bool caseMatches; @@ -818,7 +889,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http Path = requestUrlPath; } - return true; + return RequestLineStatus.Done; } finally { @@ -881,7 +952,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return true; } - ReportCorruptedHttpRequest(new BadHttpRequestException("Headers corrupted, invalid header sequence.")); + RejectRequest("Headers corrupted, invalid header sequence."); // Headers corrupted, parsing headers is complete return true; } @@ -994,10 +1065,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http statusCode != 304; } - public void ReportCorruptedHttpRequest(BadHttpRequestException ex) + public void RejectRequest(string message) { - _corruptedRequest = true; + _requestProcessingStopping = true; + _requestRejected = true; + var ex = new BadHttpRequestException(message); Log.ConnectionBadRequest(ConnectionId, ex); + throw ex; } protected void ReportApplicationError(Exception ex) @@ -1024,5 +1098,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http Http1_0 = 0, Http1_1 = 1 } + + protected enum RequestLineStatus + { + Empty, + MethodIncomplete, + TargetIncomplete, + VersionIncomplete, + Incomplete, + Done + } + + protected enum RequestProcessingStatus + { + RequestPending, + RequestStarted, + ResponseStarted + } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs index 21eeaa1c6a..6c03a812d2 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs @@ -5,7 +5,6 @@ using System; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; -using Microsoft.AspNetCore.Server.Kestrel.Exceptions; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Http @@ -33,7 +32,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { while (!_requestProcessingStopping) { - while (!_requestProcessingStopping && !TakeStartLine(SocketInput)) + while (!_requestProcessingStopping && TakeStartLine(SocketInput) != RequestLineStatus.Done) { if (SocketInput.RemoteIntakeFin) { @@ -41,12 +40,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http // SocketInput.RemoteIntakeFin is set to true to ensure we don't close a // connection without giving the application a chance to respond to a request // sent immediately before the a FIN from the client. - if (TakeStartLine(SocketInput)) + var requestLineStatus = TakeStartLine(SocketInput); + + if (requestLineStatus == RequestLineStatus.Empty) { - break; + return; } - return; + if (requestLineStatus != RequestLineStatus.Done) + { + RejectRequest($"Malformed request: {requestLineStatus}"); + } + + break; } await SocketInput; @@ -62,12 +68,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http // SocketInput.RemoteIntakeFin is set to true to ensure we don't close a // connection without giving the application a chance to respond to a request // sent immediately before the a FIN from the client. - if (TakeMessageHeaders(SocketInput, FrameRequestHeaders)) + if (!TakeMessageHeaders(SocketInput, FrameRequestHeaders)) { - break; + RejectRequest($"Malformed request: invalid headers."); } - return; + break; } await SocketInput; @@ -83,66 +89,55 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _abortedCts = null; _manuallySetRequestAbortToken = null; - if (!_corruptedRequest) + var context = _application.CreateContext(this); + try { - var context = _application.CreateContext(this); - try + await _application.ProcessRequestAsync(context).ConfigureAwait(false); + } + catch (Exception ex) + { + ReportApplicationError(ex); + } + finally + { + // Trigger OnStarting if it hasn't been called yet and the app hasn't + // already failed. If an OnStarting callback throws we can go through + // our normal error handling in ProduceEnd. + // https://github.com/aspnet/KestrelHttpServer/issues/43 + if (!HasResponseStarted && _applicationException == null && _onStarting != null) { - await _application.ProcessRequestAsync(context).ConfigureAwait(false); - } - catch (Exception ex) - { - ReportApplicationError(ex); - } - finally - { - // Trigger OnStarting if it hasn't been called yet and the app hasn't - // already failed. If an OnStarting callback throws we can go through - // our normal error handling in ProduceEnd. - // https://github.com/aspnet/KestrelHttpServer/issues/43 - if (!_responseStarted && _applicationException == null && _onStarting != null) - { - await FireOnStarting(); - } - - PauseStreams(); - - if (_onCompleted != null) - { - await FireOnCompleted(); - } - - _application.DisposeContext(context, _applicationException); + await FireOnStarting(); } - // If _requestAbort is set, the connection has already been closed. - if (Volatile.Read(ref _requestAborted) == 0) + PauseStreams(); + + if (_onCompleted != null) { - ResumeStreams(); - - if (_keepAlive && !_corruptedRequest) - { - try - { - // Finish reading the request body in case the app did not. - await messageBody.Consume(); - } - catch (BadHttpRequestException ex) - { - ReportCorruptedHttpRequest(ex); - } - } - - await ProduceEnd(); + await FireOnCompleted(); } - StopStreams(); + _application.DisposeContext(context, _applicationException); } - if (!_keepAlive || _corruptedRequest) + // If _requestAbort is set, the connection has already been closed. + if (Volatile.Read(ref _requestAborted) == 0) { - // End the connection for non keep alive and Bad Requests - // as data incoming may have been thrown off + ResumeStreams(); + + if (_keepAlive) + { + // Finish reading the request body in case the app did not. + await messageBody.Consume(); + } + + await ProduceEnd(); + } + + StopStreams(); + + if (!_keepAlive) + { + // End the connection for non keep alive as data incoming may have been thrown off return; } } @@ -158,6 +153,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { try { + await TryProduceInvalidRequestResponse(); + ResetComponents(); _abortedCts = null; diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs index 3a151fc560..7bf6c1d7e6 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs @@ -59,7 +59,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return ConsumeAwaited(result.AsTask(), cancellationToken); } // ValueTask uses .GetAwaiter().GetResult() if necessary - else if (result.Result == 0) + else if (result.Result == 0) { // Completed Task, end of stream return TaskUtilities.CompletedTask; @@ -125,8 +125,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http long contentLength; if (!long.TryParse(unparsedContentLength, out contentLength) || contentLength < 0) { - context.ReportCorruptedHttpRequest(new BadHttpRequestException("Invalid content length.")); - return new ForContentLength(keepAlive, 0, context); + context.RejectRequest($"Invalid content length: {unparsedContentLength}"); } else { @@ -142,15 +141,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return new ForRemainingData(context); } - private int ThrowBadRequestException(string message) - { - // returns int so can be used as item non-void function - var ex = new BadHttpRequestException(message); - _context.ReportCorruptedHttpRequest(ex); - - throw ex; - } - private class ForRemainingData : MessageBody { public ForRemainingData(Frame context) @@ -197,7 +187,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _inputLength -= actual; if (actual == 0) { - ThrowBadRequestException("Unexpected end of request content"); + _context.RejectRequest("Unexpected end of request content"); } return actual; } @@ -213,7 +203,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _inputLength -= actual; if (actual == 0) { - ThrowBadRequestException("Unexpected end of request content"); + _context.RejectRequest("Unexpected end of request content"); } return actual; @@ -514,7 +504,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } else { - ThrowBadRequestException("Bad chunk suffix"); + _context.RejectRequest("Bad chunk suffix"); } } finally @@ -568,16 +558,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { return currentParsedSize * 0x10 + (extraHexDigit - ('a' - 10)); } - else - { - return ThrowBadRequestException("Bad chunk size data"); - } } + + _context.RejectRequest("Bad chunk size data"); + return -1; // can't happen, but compiler complains } private void ThrowChunkedRequestIncomplete() { - ThrowBadRequestException("Chunked request incomplete"); + _context.RejectRequest("Chunked request incomplete"); } private enum Mode diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator.cs index 3d42fba3b9..3e17b6c4b8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIterator.cs @@ -122,7 +122,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure { if (wasLastBlock) { - return; + throw new InvalidOperationException("Attempted to skip more bytes than available."); } else { @@ -248,7 +248,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure while (following > 0) { // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG // Check will be Jitted away https://github.com/dotnet/coreclr/issues/1079 if (Vector.IsHardwareAccelerated) { @@ -269,7 +269,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return byte0; } // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG } #endif @@ -330,7 +330,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure { // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG // Check will be Jitted away https://github.com/dotnet/coreclr/issues/1079 if (Vector.IsHardwareAccelerated) { @@ -369,7 +369,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return byte1; } // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG } #endif var pCurrent = (block.DataFixedPtr + index); @@ -436,7 +436,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure while (following > 0) { // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG // Check will be Jitted away https://github.com/dotnet/coreclr/issues/1079 if (Vector.IsHardwareAccelerated) { @@ -502,7 +502,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return toReturn; } // Need unit tests to test Vector path -#if !DEBUG +#if !DEBUG } #endif var pCurrent = (block.DataFixedPtr + index); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIteratorExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIteratorExtensions.cs index f17dab1394..47410277fd 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIteratorExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/MemoryPoolIteratorExtensions.cs @@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure private readonly static long _http10VersionLong = GetAsciiStringAsLong("HTTP/1.0"); private readonly static long _http11VersionLong = GetAsciiStringAsLong("HTTP/1.1"); - + private readonly static long _mask8Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff }); private readonly static long _mask7Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00 }); private readonly static long _mask6Chars = GetMaskAsLong(new byte[] { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x00, 0x00 }); @@ -93,7 +93,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return null; } - // Bytes out of the range of ascii are treated as "opaque data" + // Bytes out of the range of ascii are treated as "opaque data" // and kept in string as a char value that casts to same input byte value // https://tools.ietf.org/html/rfc7230#section-3.2.4 @@ -283,16 +283,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure /// /// A "known HTTP method" can be an HTTP method name defined in the HTTP/1.1 RFC. /// Since all of those fit in at most 8 bytes, they can be optimally looked up by reading those bytes as a long. Once - /// in that format, it can be checked against the known method. - /// The Known Methods (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE) are all less than 8 bytes + /// in that format, it can be checked against the known method. + /// The Known Methods (CONNECT, DELETE, GET, HEAD, PATCH, POST, PUT, OPTIONS, TRACE) are all less than 8 bytes /// and will be compared with the required space. A mask is used if the Known method is less than 8 bytes. /// To optimize performance the GET method will be checked first. /// /// The iterator from which to start the known string lookup. - /// If we found a valid method, then scan will be updated to new position /// A reference to a pre-allocated known string, if the input matches any. /// true if the input matches a known string, false otherwise. - public static bool GetKnownMethod(this MemoryPoolIterator begin, ref MemoryPoolIterator scan, out string knownMethod) + public static bool GetKnownMethod(this MemoryPoolIterator begin, out string knownMethod) { knownMethod = null; var value = begin.PeekLong(); @@ -300,7 +299,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure if ((value & _mask4Chars) == _httpGetMethodLong) { knownMethod = HttpGetMethod; - scan.Skip(4); return true; } foreach (var x in _knownMethods) @@ -308,7 +306,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure if ((value & x.Item1) == x.Item2) { knownMethod = x.Item3; - scan.Skip(knownMethod.Length + 1); return true; } } @@ -327,10 +324,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure /// To optimize performance the HTTP/1.1 will be checked first. /// /// The iterator from which to start the known string lookup. - /// If we found a valid method, then scan will be updated to new position /// A reference to a pre-allocated known string, if the input matches any. /// true if the input matches a known string, false otherwise. - public static bool GetKnownVersion(this MemoryPoolIterator begin, ref MemoryPoolIterator scan, out string knownVersion) + public static bool GetKnownVersion(this MemoryPoolIterator begin, out string knownVersion) { knownVersion = null; var value = begin.PeekLong(); @@ -338,24 +334,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure if (value == _http11VersionLong) { knownVersion = Http11Version; - scan.Skip(8); - if (scan.Take() == '\r') - { - return true; - } } else if (value == _http10VersionLong) { knownVersion = Http10Version; - scan.Skip(8); - if (scan.Take() == '\r') + } + + if (knownVersion != null) + { + begin.Skip(knownVersion.Length); + + if (begin.Peek() != '\r') { - return true; + knownVersion = null; } } - knownVersion = null; - return false; + return knownVersion != null; } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs new file mode 100644 index 0000000000..8ff6648b0a --- /dev/null +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/BadHttpRequestTests.cs @@ -0,0 +1,72 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.AspNetCore.Server.KestrelTests +{ + public class BadHttpRequestTests + { + [Theory] + [InlineData("/ HTTP/1.1\r\n\r\n")] + [InlineData(" / HTTP/1.1\r\n\r\n")] + [InlineData(" / HTTP/1.1\r\n\r\n")] + [InlineData("GET / HTTP/1.1\r\n\r\n")] + [InlineData("GET / HTTP/1.1\r\n\r\n")] + [InlineData("GET HTTP/1.1\r\n\r\n")] + [InlineData("GET /")] + [InlineData("GET / ")] + [InlineData("GET / H")] + [InlineData("GET / HTTP/1.")] + [InlineData("GET /\r\n")] + [InlineData("GET / \r\n")] + [InlineData("GET / \n")] + [InlineData("GET / http/1.0\r\n\r\n")] + [InlineData("GET / http/1.1\r\n\r\n")] + [InlineData("GET / HTTP/1.1 \r\n\r\n")] + [InlineData("GET / HTTP/1.1a\r\n\r\n")] + [InlineData("GET / HTTP/1.0\n\r\n")] + [InlineData("GET / HTTP/3.0\r\n\r\n")] + [InlineData("GET / H\r\n\r\n")] + [InlineData("GET / HTTP/1.\r\n\r\n")] + [InlineData("GET / hello\r\n\r\n")] + [InlineData("GET / 8charact\r\n\r\n")] + public async Task TestBadRequests(string request) + { + using (var server = new TestServer(context => { return Task.FromResult(0); })) + { + using (var connection = new TestConnection(server.Port)) + { + var receiveTask = Task.Run(async () => + { + await connection.Receive( + "HTTP/1.0 400 Bad Request", + ""); + await connection.ReceiveStartsWith("Date: "); + await connection.ReceiveForcedEnd( + "Content-Length: 0", + "Server: Kestrel", + "", + ""); + }); + + try + { + await connection.SendEnd(request).ConfigureAwait(false); + } + catch (Exception) + { + // TestConnection.SendEnd will start throwing while sending characters + // in cases where the server rejects the request as soon as it + // determines the request line is malformed, even though there + // are more characters following. + } + + await receiveTask; + } + } + } + } +} \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs index 8b34273913..6cbe46b947 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs @@ -217,7 +217,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "POST / HTTP/1.1", "Transfer-Encoding: chunked", "", - "C", + "C", "HelloChunked", "0", ""}; @@ -364,10 +364,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var connection = new TestConnection(server.Port)) { await connection.Send( - "POST / HTTP/1.1", - "Transfer-Encoding: chunked", - "", - "Cii"); + "POST / HTTP/1.1", + "Transfer-Encoding: chunked", + "", + "Cii"); await connection.Receive( "HTTP/1.1 400 Bad Request", diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs index b8bdffeebf..a4508b8893 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/EngineTests.cs @@ -728,18 +728,13 @@ namespace Microsoft.AspNetCore.Server.KestrelTests { await connection.SendEnd( "GET /"); - await connection.ReceiveEnd(); - } - - using (var connection = new TestConnection(server.Port)) - { - await connection.SendEnd( - "GET / HTTP/1.1", - "", - "Post / HTTP/1.1"); - await connection.ReceiveEnd( - "HTTP/1.1 200 OK", + await connection.Receive( + "HTTP/1.0 400 Bad Request", + ""); + await connection.ReceiveStartsWith("Date:"); + await connection.ReceiveForcedEnd( "Content-Length: 0", + "Server: Kestrel", "", ""); } @@ -749,12 +744,40 @@ namespace Microsoft.AspNetCore.Server.KestrelTests await connection.SendEnd( "GET / HTTP/1.1", "", - "Post / HTTP/1.1", - "Content-Length: 7"); - await connection.ReceiveEnd( + "POST / HTTP/1.1"); + await connection.Receive( "HTTP/1.1 200 OK", "Content-Length: 0", "", + "HTTP/1.0 400 Bad Request", + ""); + await connection.ReceiveStartsWith("Date:"); + await connection.ReceiveForcedEnd( + "Content-Length: 0", + "Server: Kestrel", + "", + ""); + } + + using (var connection = new TestConnection(server.Port)) + { + await connection.SendEnd( + "GET / HTTP/1.1", + "", + "POST / HTTP/1.1", + "Content-Length: 7"); + await connection.Receive( + "HTTP/1.1 200 OK", + "Content-Length: 0", + "", + "HTTP/1.1 400 Bad Request", + "Connection: close", + ""); + await connection.ReceiveStartsWith("Date:"); + await connection.ReceiveForcedEnd( + "Content-Length: 0", + "Server: Kestrel", + "", ""); } } @@ -1017,7 +1040,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.True(registrationWh.Wait(1000)); } } - + [Theory] [MemberData(nameof(ConnectionFilterData))] public async Task NoErrorsLoggedWhenServerEndsConnectionBeforeClient(ServiceContext testContext) @@ -1049,5 +1072,30 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(0, testLogger.TotalErrorsLogged); } + + [Theory] + [MemberData(nameof(ConnectionFilterData))] + public async Task NoResponseSentWhenConnectionIsClosedByServerBeforeClientFinishesSendingRequest(ServiceContext testContext) + { + var testLogger = new TestApplicationErrorLogger(); + testContext.Log = new KestrelTrace(testLogger); + + using (var server = new TestServer(httpContext => + { + httpContext.Abort(); + return Task.FromResult(0); + }, testContext)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.Send( + "POST / HTTP/1.0", + "Content-Length: 1", + "", + ""); + await connection.ReceiveEnd(); + } + } + } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs index 834c99baad..4e8a02c365 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs @@ -285,6 +285,40 @@ namespace Microsoft.AspNetCore.Server.KestrelTests _pool.Return(nextBlock); } + [Fact] + public void SkipThrowsWhenSkippingMoreBytesThanAvailableInSingleBlock() + { + // Arrange + var block = _pool.Lease(); + block.End += 5; + + var scan = block.GetIterator(); + + // Act/Assert + Assert.ThrowsAny(() => scan.Skip(8)); + + _pool.Return(block); + } + + [Fact] + public void SkipThrowsWhenSkippingMoreBytesThanAvailableInMultipleBlocks() + { + // Arrange + var block = _pool.Lease(); + block.End += 3; + + var nextBlock = _pool.Lease(); + nextBlock.End += 2; + block.Next = nextBlock; + + var scan = block.GetIterator(); + + // Act/Assert + Assert.ThrowsAny(() => scan.Skip(8)); + + _pool.Return(block); + } + [Theory] [InlineData("CONNECT / HTTP/1.1", ' ', true, MemoryPoolIteratorExtensions.HttpConnectMethod)] [InlineData("DELETE / HTTP/1.1", ' ', true, MemoryPoolIteratorExtensions.HttpDeleteMethod)] @@ -309,12 +343,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; - var scan = block.GetIterator(); - var begin = scan; + var begin = block.GetIterator(); string knownString; // Act - var result = begin.GetKnownMethod(ref scan, out knownString); + var result = begin.GetKnownMethod(out knownString); // Assert Assert.Equal(expectedResult, result); @@ -337,12 +370,11 @@ namespace Microsoft.AspNetCore.Server.KestrelTests var chars = input.ToCharArray().Select(c => (byte)c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; - var scan = block.GetIterator(); - var begin = scan; + var begin = block.GetIterator(); string knownString; // Act - var result = begin.GetKnownVersion(ref scan, out knownString); + var result = begin.GetKnownVersion(out knownString); // Assert Assert.Equal(expectedResult, result); Assert.Equal(expectedKnownString, knownString);