From 43398482a51d48b0f693dc7403076f38af1181cb Mon Sep 17 00:00:00 2001 From: "Chris Ross (ASP.NET)" Date: Mon, 13 Aug 2018 11:58:33 -0700 Subject: [PATCH] Implement MaxRequestBodySize for HTTP/2 #2810 --- .../Internal/Http/Http1MessageBody.cs | 11 - src/Kestrel.Core/Internal/Http/MessageBody.cs | 11 + .../Internal/Http2/Http2MessageBody.cs | 10 + test/Kestrel.Core.Tests/Http2StreamTests.cs | 286 +++++++++++++++++- 4 files changed, 305 insertions(+), 13 deletions(-) diff --git a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs index d2edbff7f8..6e83307de2 100644 --- a/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/Http1MessageBody.cs @@ -399,7 +399,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private const int MaxChunkPrefixBytes = 10; private long _inputLength; - private long _consumedBytes; private Mode _mode = Mode.Prefix; @@ -490,16 +489,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return _mode == Mode.Complete; } - private void AddAndCheckConsumedBytes(long consumedBytes) - { - _consumedBytes += consumedBytes; - - if (_consumedBytes > _context.MaxRequestBodySize) - { - BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); - } - } - private void ParseChunkedPrefix(ReadOnlySequence buffer, out SequencePosition consumed, out SequencePosition examined) { consumed = buffer.Start; diff --git a/src/Kestrel.Core/Internal/Http/MessageBody.cs b/src/Kestrel.Core/Internal/Http/MessageBody.cs index 0cbf0e0ea0..ed0bab633b 100644 --- a/src/Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http/MessageBody.cs @@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private readonly HttpProtocol _context; private bool _send100Continue = true; + private long _consumedBytes; protected MessageBody(HttpProtocol context) { @@ -168,6 +169,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http { } + protected void AddAndCheckConsumedBytes(long consumedBytes) + { + _consumedBytes += consumedBytes; + + if (_consumedBytes > _context.MaxRequestBodySize) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); + } + } + private class ForZeroContentLength : MessageBody { public ForZeroContentLength(bool keepAlive) diff --git a/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs index 2ac43a2974..f835ffa570 100644 --- a/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs +++ b/src/Kestrel.Core/Internal/Http2/Http2MessageBody.cs @@ -16,6 +16,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 _context = context; } + protected override void OnReadStarting() + { + // Note ContentLength or MaxRequestBodySize may be null + if (_context.RequestHeaders.ContentLength > _context.MaxRequestBodySize) + { + BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge); + } + } + protected override void OnReadStarted() { // Produce 100-continue if no request body data for the stream has arrived yet. @@ -28,6 +37,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2 protected override void OnDataRead(int bytesRead) { _context.OnDataRead(bytesRead); + AddAndCheckConsumedBytes(bytesRead); } protected override Task OnConsumeAsync() => Task.CompletedTask; diff --git a/test/Kestrel.Core.Tests/Http2StreamTests.cs b/test/Kestrel.Core.Tests/Http2StreamTests.cs index f5e0665e74..3331767990 100644 --- a/test/Kestrel.Core.Tests/Http2StreamTests.cs +++ b/test/Kestrel.Core.Tests/Http2StreamTests.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Linq; +using System.Runtime.ExceptionServices; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; @@ -191,7 +192,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests ConnectionFeatures = new FeatureCollection(), ServiceContext = new TestServiceContext() { - Log = new TestKestrelTrace(_logger) + Log = new TestKestrelTrace(_logger), }, MemoryPool = _memoryPool, Application = _pair.Application, @@ -1231,6 +1232,287 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal("11", _decodedHeaders[HeaderNames.ContentLength]); } + [Fact] + public async Task MaxRequestBodySize_ContentLengthUnder_200() + { + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 15; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.ContentLength, "12"), + }; + await InitializeConnectionAsync(async context => + { + var buffer = new byte[100]; + var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(12, read); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[12].AsSpan(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task MaxRequestBodySize_ContentLengthOver_413() + { + BadHttpRequestException exception = null; + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + new KeyValuePair(HeaderNames.ContentLength, "12"), + }; + await InitializeConnectionAsync(async context => + { + exception = await Assert.ThrowsAsync(async () => + { + var buffer = new byte[100]; + while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } + }); + ExceptionDispatchInfo.Capture(exception).Throw(); + }); + + await StartStreamAsync(1, headers, endStream: false); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 59, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("413", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + + Assert.NotNull(exception); + } + + [Fact] + public async Task MaxRequestBodySize_NoContentLength_Under_200() + { + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 15; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + var buffer = new byte[100]; + var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(12, read); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[12].AsSpan(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + + [Fact] + public async Task MaxRequestBodySize_NoContentLength_Over_413() + { + BadHttpRequestException exception = null; + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async context => + { + exception = await Assert.ThrowsAsync(async () => + { + var buffer = new byte[100]; + while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } + }); + ExceptionDispatchInfo.Capture(exception).Throw(); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 59, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("413", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + + Assert.NotNull(exception); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MaxRequestBodySize_AppCanLowerLimit(bool includeContentLength) + { + BadHttpRequestException exception = null; + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 20; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + if (includeContentLength) + { + headers.Concat(new[] + { + new KeyValuePair(HeaderNames.ContentLength, "18"), + }); + } + await InitializeConnectionAsync(async context => + { + Assert.False(context.Features.Get().IsReadOnly); + context.Features.Get().MaxRequestBodySize = 17; + exception = await Assert.ThrowsAsync(async () => + { + var buffer = new byte[100]; + while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { } + }); + Assert.True(context.Features.Get().IsReadOnly); + ExceptionDispatchInfo.Capture(exception).Throw(); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: false); + await SendDataAsync(1, new byte[6].AsSpan(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 59, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("413", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + + Assert.NotNull(exception); + } + + [Theory] + [InlineData(true)] + [InlineData(false)] + public async Task MaxRequestBodySize_AppCanRaiseLimit(bool includeContentLength) + { + _connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10; + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "POST"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + if (includeContentLength) + { + headers.Concat(new[] + { + new KeyValuePair(HeaderNames.ContentLength, "12"), + }); + } + await InitializeConnectionAsync(async context => + { + Assert.False(context.Features.Get().IsReadOnly); + context.Features.Get().MaxRequestBodySize = 12; + var buffer = new byte[100]; + var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length); + Assert.Equal(12, read); + Assert.True(context.Features.Get().IsReadOnly); + }); + + await StartStreamAsync(1, headers, endStream: false); + await SendDataAsync(1, new byte[12].AsSpan(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + [Fact] public async Task ApplicationExeption_BeforeFirstWrite_Sends500() { @@ -1725,7 +2007,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests private async Task ExpectAsync(Http2FrameType type, int withLength, byte withFlags, int withStreamId) { - var frame = await ReceiveFrameAsync(); + var frame = await ReceiveFrameAsync().DefaultTimeout(); Assert.Equal(type, frame.Type); Assert.Equal(withLength, frame.Length);