diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Exceptions/BadHttpRequestException.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Exceptions/BadHttpRequestException.cs new file mode 100644 index 0000000000..19c004857e --- /dev/null +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Exceptions/BadHttpRequestException.cs @@ -0,0 +1,16 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.IO; + +namespace Microsoft.AspNetCore.Server.Kestrel.Exceptions +{ + public sealed class BadHttpRequestException : IOException + { + internal BadHttpRequestException(string message) + : base(message) + { + + } + } +} diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs index 2c83fc7c84..9d14928e16 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/Frame.cs @@ -5,7 +5,6 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; -using System.Net; using System.Numerics; using System.Text; using System.Threading; @@ -13,6 +12,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Exceptions; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; @@ -45,7 +45,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http private readonly object _onStartingSync = new Object(); private readonly object _onCompletedSync = new Object(); - protected bool _poolingPermitted = true; + protected bool _corruptedRequest = false; private Headers _frameHeaders; private Streams _frameStreams; @@ -211,7 +211,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http public void Reset() { - ResetComponents(poolingPermitted: true); + ResetComponents(); _onStarting = null; _onCompleted = null; @@ -248,27 +248,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _abortedCts = null; } - protected void ResetComponents(bool poolingPermitted) + protected void ResetComponents() { - if (_frameHeaders != null) + var frameHeaders = Interlocked.Exchange(ref _frameHeaders, null); + if (frameHeaders != null) { - var frameHeaders = _frameHeaders; - _frameHeaders = null; - RequestHeaders = null; ResponseHeaders = null; - HttpComponentFactory.DisposeHeaders(frameHeaders, poolingPermitted); + HttpComponentFactory.DisposeHeaders(frameHeaders); } - if (_frameStreams != null) + var frameStreams = Interlocked.Exchange(ref _frameStreams, null); + if (frameStreams != null) { - var frameStreams = _frameStreams; - _frameStreams = null; - RequestBody = null; ResponseBody = null; DuplexStream = null; - HttpComponentFactory.DisposeStreams(frameStreams, poolingPermitted); + HttpComponentFactory.DisposeStreams(frameStreams); } } @@ -568,8 +564,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http protected Task ProduceEnd() { - if (_applicationException != null) + if (_corruptedRequest || _applicationException != null) { + if (_corruptedRequest) + { + // 400 Bad Request + StatusCode = 400; + } + else + { + // 500 Internal Server Error + StatusCode = 500; + } + if (_responseStarted) { // We can no longer respond with a 500, so we simply close the connection. @@ -578,7 +585,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } else { - StatusCode = 500; ReasonPhrase = null; var responseHeaders = _frameHeaders.ResponseHeaders; @@ -711,7 +717,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http { string method; var begin = scan; - if (!begin.GetKnownMethod(ref scan,out method)) + if (!begin.GetKnownMethod(ref scan, out method)) { if (scan.Seek(ref _vectorSpaces) == -1) { @@ -834,7 +840,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return true; } - public static bool TakeMessageHeaders(SocketInput input, FrameRequestHeaders requestHeaders) + public bool TakeMessageHeaders(SocketInput input, FrameRequestHeaders requestHeaders) { var scan = input.ConsumingStart(); var consumed = scan; @@ -863,7 +869,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http consumed = scan; return true; } - throw new InvalidDataException("Malformed request"); + + ReportCorruptedHttpRequest(new BadHttpRequestException("Headers corrupted, invalid header sequence.")); + // Headers corrupted, parsing headers is complete + return true; } while ( @@ -953,16 +962,27 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http statusCode != 304; } + public void ReportCorruptedHttpRequest(BadHttpRequestException ex) + { + _corruptedRequest = true; + Log.ConnectionBadRequest(ConnectionId, ex); + } + protected void ReportApplicationError(Exception ex) { if (_applicationException == null) { _applicationException = ex; } + else if (_applicationException is AggregateException) + { + _applicationException = new AggregateException(_applicationException, ex).Flatten(); + } else { _applicationException = new AggregateException(_applicationException, ex); } + Log.ApplicationError(ex); } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs index 77e7c4047b..053688b587 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/FrameOfT.cs @@ -2,11 +2,10 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Net; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Hosting.Server; -using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Server.Kestrel.Exceptions; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel.Http @@ -64,55 +63,66 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _abortedCts = null; _manuallySetRequestAbortToken = null; - var context = _application.CreateContext(this); - try + if (!_corruptedRequest) { - await _application.ProcessRequestAsync(context).ConfigureAwait(false); - } - catch (Exception ex) - { - ReportApplicationError(ex); - } - finally - { - // Trigger OnStarting if it hasn't been called yet and the app hasn't - // already failed. If an OnStarting callback throws we can go through - // our normal error handling in ProduceEnd. - // https://github.com/aspnet/KestrelHttpServer/issues/43 - if (!_responseStarted && _applicationException == null && _onStarting != null) + var context = _application.CreateContext(this); + try { - await FireOnStarting(); + await _application.ProcessRequestAsync(context).ConfigureAwait(false); } - - PauseStreams(); - - if (_onCompleted != null) + catch (Exception ex) { - await FireOnCompleted(); + ReportApplicationError(ex); } + finally + { + // Trigger OnStarting if it hasn't been called yet and the app hasn't + // already failed. If an OnStarting callback throws we can go through + // our normal error handling in ProduceEnd. + // https://github.com/aspnet/KestrelHttpServer/issues/43 + if (!_responseStarted && _applicationException == null && _onStarting != null) + { + await FireOnStarting(); + } - _application.DisposeContext(context, _applicationException); + PauseStreams(); + + if (_onCompleted != null) + { + await FireOnCompleted(); + } + + _application.DisposeContext(context, _applicationException); + } // If _requestAbort is set, the connection has already been closed. if (Volatile.Read(ref _requestAborted) == 0) { ResumeStreams(); - await ProduceEnd(); - - if (_keepAlive) + if (_keepAlive && !_corruptedRequest) { - // Finish reading the request body in case the app did not. - await messageBody.Consume(); + try + { + // Finish reading the request body in case the app did not. + await messageBody.Consume(); + } + catch (BadHttpRequestException ex) + { + ReportCorruptedHttpRequest(ex); + } } + + await ProduceEnd(); } StopStreams(); } - if (!_keepAlive) + if (!_keepAlive || _corruptedRequest) { - ResetComponents(poolingPermitted: true); + // End the connection for non keep alive and Bad Requests + // as data incoming may have been thrown off return; } } @@ -122,15 +132,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } catch (Exception ex) { - // Error occurred, do not return components to pool - _poolingPermitted = false; Log.LogWarning(0, ex, "Connection processing ended abnormally"); } finally { try { - ResetComponents(poolingPermitted: _poolingPermitted); + ResetComponents(); _abortedCts = null; // If _requestAborted is set, the connection has already been closed. diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs index f17013c513..f3f604f267 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Http/MessageBody.cs @@ -7,6 +7,7 @@ using System.Numerics; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Exceptions; namespace Microsoft.AspNetCore.Server.Kestrel.Http { @@ -118,10 +119,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return new ForChunkedEncoding(keepAlive, headers, context); } - var contentLength = headers.HeaderContentLength.ToString(); - if (contentLength.Length > 0) + var unparsedContentLength = headers.HeaderContentLength.ToString(); + if (unparsedContentLength.Length > 0) { - return new ForContentLength(keepAlive, int.Parse(contentLength), context); + int contentLength; + if (!int.TryParse(unparsedContentLength, out contentLength) || contentLength < 0) + { + context.ReportCorruptedHttpRequest(new BadHttpRequestException("Invalid content length.")); + return new ForContentLength(keepAlive, 0, context); + } + else + { + return new ForContentLength(keepAlive, contentLength, context); + } } if (keepAlive) @@ -132,6 +142,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http return new ForRemainingData(context); } + private int ThrowBadRequestException(string message) + { + // returns int so can be used as item non-void function + var ex = new BadHttpRequestException(message); + _context.ReportCorruptedHttpRequest(ex); + + throw ex; + } + private class ForRemainingData : MessageBody { public ForRemainingData(Frame context) @@ -177,7 +196,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _inputLength -= actual; if (actual == 0) { - throw new InvalidDataException("Unexpected end of request content"); + ThrowBadRequestException("Unexpected end of request content"); } return actual; } @@ -193,7 +212,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http _inputLength -= actual; if (actual == 0) { - throw new InvalidDataException("Unexpected end of request content"); + ThrowBadRequestException("Unexpected end of request content"); } return actual; @@ -290,7 +309,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http if (_mode == Mode.TrailerHeaders) { - while (!Frame.TakeMessageHeaders(input, _requestHeaders)) + while (!_context.TakeMessageHeaders(input, _requestHeaders)) { await GetDataAsync(input); } @@ -451,7 +470,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } else { - ThrowInvalidFormat(); + ThrowBadRequestException("Bad chunk suffix"); } } finally @@ -489,7 +508,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } } - private static int CalculateChunkSize(int extraHexDigit, int currentParsedSize) + private int CalculateChunkSize(int extraHexDigit, int currentParsedSize) { checked { @@ -507,37 +526,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http } else { - return ThrowInvalidFormat(); + return ThrowBadRequestException("Bad chunk size data"); } } } - private static SocketInput GetDataAsync(SocketInput input) + private SocketInput GetDataAsync(SocketInput input) { ThrowIfRequestIncomplete(input); return input; } - private static void ThrowIfRequestIncomplete(SocketInput input) + private void ThrowIfRequestIncomplete(SocketInput input) { if (input.RemoteIntakeFin) { - ThrowRequestIncomplete(); + ThrowBadRequestException("Chunked request incomplete"); } } - private static int ThrowInvalidFormat() - { - // returns int so can be used as item non-void function - throw new InvalidOperationException("Bad request"); - } - - private static void ThrowRequestIncomplete() - { - throw new InvalidOperationException("Chunked request incomplete"); - } - private enum Mode { Prefix, diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/HttpComponentFactory.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/HttpComponentFactory.cs index f58102c1bb..5b346d73fd 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/HttpComponentFactory.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/HttpComponentFactory.cs @@ -34,9 +34,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return streams; } - public void DisposeStreams(Streams streams, bool poolingPermitted) + public void DisposeStreams(Streams streams) { - if (poolingPermitted && _streamPool.Count < ServerInformation.PoolingParameters.MaxPooledStreams) + if (_streamPool.Count < ServerInformation.PoolingParameters.MaxPooledStreams) { streams.Uninitialize(); @@ -58,9 +58,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure return headers; } - public void DisposeHeaders(Headers headers, bool poolingPermitted) + public void DisposeHeaders(Headers headers) { - if (poolingPermitted && _headerPool.Count < ServerInformation.PoolingParameters.MaxPooledHeaders) + if (_headerPool.Count < ServerInformation.PoolingParameters.MaxPooledHeaders) { headers.Uninitialize(); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IHttpComponentFactory.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IHttpComponentFactory.cs index 620c663fbb..c64f039680 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IHttpComponentFactory.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IHttpComponentFactory.cs @@ -11,10 +11,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure Streams CreateStreams(FrameContext owner); - void DisposeStreams(Streams streams, bool poolingPermitted); + void DisposeStreams(Streams streams); Headers CreateHeaders(DateHeaderValueManager dateValueManager); - void DisposeHeaders(Headers headers, bool poolingPermitted); + void DisposeHeaders(Headers headers); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs index b004cb9eee..b5460ecb4c 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/IKestrelTrace.cs @@ -1,5 +1,6 @@ using System; using Microsoft.Extensions.Logging; +using Microsoft.AspNetCore.Server.Kestrel.Exceptions; namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure { @@ -33,6 +34,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Infrastructure void ConnectionDisconnectedWrite(string connectionId, int count, Exception ex); + void ConnectionBadRequest(string connectionId, BadHttpRequestException ex); + void NotAllConnectionsClosedGracefully(); void ApplicationError(Exception ex); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs index 3c92c771ef..f3f12cbd1f 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Infrastructure/KestrelTrace.cs @@ -3,6 +3,7 @@ using System; using Microsoft.AspNetCore.Server.Kestrel.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Exceptions; using Microsoft.Extensions.Logging; namespace Microsoft.AspNetCore.Server.Kestrel @@ -24,6 +25,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel private static readonly Action _connectionError; private static readonly Action _connectionDisconnectedWrite; private static readonly Action _notAllConnectionsClosedGracefully; + private static readonly Action _connectionBadRequest; protected readonly ILogger _logger; @@ -45,6 +47,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel _connectionError = LoggerMessage.Define(LogLevel.Information, 14, @"Connection id ""{ConnectionId}"" communication error"); _connectionDisconnectedWrite = LoggerMessage.Define(LogLevel.Debug, 15, @"Connection id ""{ConnectionId}"" write of ""{count}"" bytes to disconnected client."); _notAllConnectionsClosedGracefully = LoggerMessage.Define(LogLevel.Debug, 16, "Some connections failed to close gracefully during server shutdown."); + _connectionBadRequest = LoggerMessage.Define(LogLevel.Information, 17, @"Connection id ""{ConnectionId}"" bad request data: ""{message}"""); } public KestrelTrace(ILogger logger) @@ -135,6 +138,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel _notAllConnectionsClosedGracefully(_logger, null); } + public void ConnectionBadRequest(string connectionId, BadHttpRequestException ex) + { + _connectionBadRequest(_logger, connectionId, ex.Message, ex); + } + public virtual void Log(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func formatter) { _logger.Log(logLevel, eventId, state, exception, formatter); diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs index 9408c44b14..f741539480 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/ChunkedRequestTests.cs @@ -185,7 +185,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) { - // read to end + ;// read to end } if (requestsReceived < requestCount) @@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) { - // read to end + ;// read to end } if (requestsReceived < requestCount) @@ -341,7 +341,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task InvalidLengthResultsIn500(ServiceContext testContext) + public async Task InvalidLengthResultsIn400(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -367,14 +367,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "POST / HTTP/1.1", "Transfer-Encoding: chunked", "", - "Cio", - "HelloChunked", - "0", - ""); + "Cii"); - // Should really be a 40x as is bad request await connection.Receive( - "HTTP/1.1 500 Internal Server Error", + "HTTP/1.1 400 Bad Request", ""); await connection.ReceiveStartsWith("Date:"); await connection.ReceiveEnd( @@ -388,7 +384,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests [Theory] [MemberData(nameof(ConnectionFilterData))] - public async Task InvalidSizedDataResultsIn500(ServiceContext testContext) + public async Task InvalidSizedDataResultsIn400(ServiceContext testContext) { using (var server = new TestServer(async httpContext => { @@ -415,13 +411,10 @@ namespace Microsoft.AspNetCore.Server.KestrelTests "Transfer-Encoding: chunked", "", "C", - "HelloChunkedInvalid", - "0", - ""); + "HelloChunkedIn"); - // Should really be a 40x as is bad request await connection.Receive( - "HTTP/1.1 500 Internal Server Error", + "HTTP/1.1 400 Bad Request", ""); await connection.ReceiveStartsWith("Date:"); await connection.ReceiveEnd( diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs index 70c4454b53..78d46f5d6c 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs @@ -29,12 +29,18 @@ namespace Microsoft.AspNetCore.Server.KestrelTests using (var pool = new MemoryPool2()) using (var socketInput = new SocketInput(pool, ltp)) { + var connectionContext = new ConnectionContext() + { + DateHeaderValueManager = new DateHeaderValueManager(), + ServerAddress = ServerAddress.FromUrl("http://localhost:5000") + }; + var frame = new Frame(application: null, context: connectionContext); var headerCollection = new FrameRequestHeaders(); var headerArray = Encoding.ASCII.GetBytes(rawHeaders); socketInput.IncomingData(headerArray, 0, headerArray.Length); - var success = Frame.TakeMessageHeaders(socketInput, headerCollection); + var success = frame.TakeMessageHeaders(socketInput, headerCollection); Assert.True(success); Assert.Equal(numHeaders, headerCollection.Count());