Implement MaxRequestBodySize for HTTP/2 #2810

This commit is contained in:
Chris Ross (ASP.NET) 2018-08-13 11:58:33 -07:00
parent cd6de2fa18
commit 43398482a5
4 changed files with 305 additions and 13 deletions

View File

@ -399,7 +399,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private const int MaxChunkPrefixBytes = 10;
private long _inputLength;
private long _consumedBytes;
private Mode _mode = Mode.Prefix;
@ -490,16 +489,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
return _mode == Mode.Complete;
}
private void AddAndCheckConsumedBytes(long consumedBytes)
{
_consumedBytes += consumedBytes;
if (_consumedBytes > _context.MaxRequestBodySize)
{
BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge);
}
}
private void ParseChunkedPrefix(ReadOnlySequence<byte> buffer, out SequencePosition consumed, out SequencePosition examined)
{
consumed = buffer.Start;

View File

@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
private readonly HttpProtocol _context;
private bool _send100Continue = true;
private long _consumedBytes;
protected MessageBody(HttpProtocol context)
{
@ -168,6 +169,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
{
}
protected void AddAndCheckConsumedBytes(long consumedBytes)
{
_consumedBytes += consumedBytes;
if (_consumedBytes > _context.MaxRequestBodySize)
{
BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge);
}
}
private class ForZeroContentLength : MessageBody
{
public ForZeroContentLength(bool keepAlive)

View File

@ -16,6 +16,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
_context = context;
}
protected override void OnReadStarting()
{
// Note ContentLength or MaxRequestBodySize may be null
if (_context.RequestHeaders.ContentLength > _context.MaxRequestBodySize)
{
BadHttpRequestException.Throw(RequestRejectionReason.RequestBodyTooLarge);
}
}
protected override void OnReadStarted()
{
// Produce 100-continue if no request body data for the stream has arrived yet.
@ -28,6 +37,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2
protected override void OnDataRead(int bytesRead)
{
_context.OnDataRead(bytesRead);
AddAndCheckConsumedBytes(bytesRead);
}
protected override Task OnConsumeAsync() => Task.CompletedTask;

View File

@ -8,6 +8,7 @@ using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Runtime.ExceptionServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
@ -191,7 +192,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
ConnectionFeatures = new FeatureCollection(),
ServiceContext = new TestServiceContext()
{
Log = new TestKestrelTrace(_logger)
Log = new TestKestrelTrace(_logger),
},
MemoryPool = _memoryPool,
Application = _pair.Application,
@ -1231,6 +1232,287 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
Assert.Equal("11", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task MaxRequestBodySize_ContentLengthUnder_200()
{
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 15;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.ContentLength, "12"),
};
await InitializeConnectionAsync(async context =>
{
var buffer = new byte[100];
var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length);
Assert.Equal(12, read);
});
await StartStreamAsync(1, headers, endStream: false);
await SendDataAsync(1, new byte[12].AsSpan(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task MaxRequestBodySize_ContentLengthOver_413()
{
BadHttpRequestException exception = null;
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.ContentLength, "12"),
};
await InitializeConnectionAsync(async context =>
{
exception = await Assert.ThrowsAsync<BadHttpRequestException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
});
ExceptionDispatchInfo.Capture(exception).Throw();
});
await StartStreamAsync(1, headers, endStream: false);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 59,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("413", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
Assert.NotNull(exception);
}
[Fact]
public async Task MaxRequestBodySize_NoContentLength_Under_200()
{
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 15;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
var buffer = new byte[100];
var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length);
Assert.Equal(12, read);
});
await StartStreamAsync(1, headers, endStream: false);
await SendDataAsync(1, new byte[12].AsSpan(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task MaxRequestBodySize_NoContentLength_Over_413()
{
BadHttpRequestException exception = null;
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
await InitializeConnectionAsync(async context =>
{
exception = await Assert.ThrowsAsync<BadHttpRequestException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
});
ExceptionDispatchInfo.Capture(exception).Throw();
});
await StartStreamAsync(1, headers, endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 59,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("413", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
Assert.NotNull(exception);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task MaxRequestBodySize_AppCanLowerLimit(bool includeContentLength)
{
BadHttpRequestException exception = null;
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 20;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
if (includeContentLength)
{
headers.Concat(new[]
{
new KeyValuePair<string, string>(HeaderNames.ContentLength, "18"),
});
}
await InitializeConnectionAsync(async context =>
{
Assert.False(context.Features.Get<IHttpMaxRequestBodySizeFeature>().IsReadOnly);
context.Features.Get<IHttpMaxRequestBodySizeFeature>().MaxRequestBodySize = 17;
exception = await Assert.ThrowsAsync<BadHttpRequestException>(async () =>
{
var buffer = new byte[100];
while (await context.Request.Body.ReadAsync(buffer, 0, buffer.Length) > 0) { }
});
Assert.True(context.Features.Get<IHttpMaxRequestBodySizeFeature>().IsReadOnly);
ExceptionDispatchInfo.Capture(exception).Throw();
});
await StartStreamAsync(1, headers, endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: false);
await SendDataAsync(1, new byte[6].AsSpan(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 59,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("413", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
Assert.NotNull(exception);
}
[Theory]
[InlineData(true)]
[InlineData(false)]
public async Task MaxRequestBodySize_AppCanRaiseLimit(bool includeContentLength)
{
_connectionContext.ServiceContext.ServerOptions.Limits.MaxRequestBodySize = 10;
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
new KeyValuePair<string, string>(HeaderNames.Path, "/"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
};
if (includeContentLength)
{
headers.Concat(new[]
{
new KeyValuePair<string, string>(HeaderNames.ContentLength, "12"),
});
}
await InitializeConnectionAsync(async context =>
{
Assert.False(context.Features.Get<IHttpMaxRequestBodySizeFeature>().IsReadOnly);
context.Features.Get<IHttpMaxRequestBodySizeFeature>().MaxRequestBodySize = 12;
var buffer = new byte[100];
var read = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length);
Assert.Equal(12, read);
Assert.True(context.Features.Get<IHttpMaxRequestBodySizeFeature>().IsReadOnly);
});
await StartStreamAsync(1, headers, endStream: false);
await SendDataAsync(1, new byte[12].AsSpan(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 55,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
_hpackDecoder.Decode(headersFrame.HeadersPayload, endHeaders: false, handler: this);
Assert.Equal(3, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]);
}
[Fact]
public async Task ApplicationExeption_BeforeFirstWrite_Sends500()
{
@ -1725,7 +2007,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
private async Task<Http2Frame> ExpectAsync(Http2FrameType type, int withLength, byte withFlags, int withStreamId)
{
var frame = await ReceiveFrameAsync();
var frame = await ReceiveFrameAsync().DefaultTimeout();
Assert.Equal(type, frame.Type);
Assert.Equal(withLength, frame.Length);