Drain chunked extensions + refactor

This commit is contained in:
Ben Adams 2016-02-11 00:58:06 +00:00
parent 331d4a87ac
commit 4bfcd7ba1f
3 changed files with 437 additions and 215 deletions

View File

@ -3,6 +3,7 @@
using System;
using System.IO;
using System.Numerics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
@ -11,10 +12,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
public abstract class MessageBody
{
private readonly FrameContext _context;
private readonly Frame _context;
private int _send100Continue = 1;
protected MessageBody(FrameContext context)
protected MessageBody(Frame context)
{
_context = context;
}
@ -99,7 +100,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public static MessageBody For(
string httpVersion,
FrameRequestHeaders headers,
FrameContext context)
Frame context)
{
// see also http://tools.ietf.org/html/rfc2616#section-4.4
@ -114,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
var transferEncoding = headers.HeaderTransferEncoding.ToString();
if (transferEncoding.Length > 0)
{
return new ForChunkedEncoding(keepAlive, context);
return new ForChunkedEncoding(keepAlive, headers, context);
}
var contentLength = headers.HeaderContentLength.ToString();
@ -133,7 +134,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private class ForRemainingData : MessageBody
{
public ForRemainingData(FrameContext context)
public ForRemainingData(Frame context)
: base(context)
{
}
@ -149,7 +150,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
private readonly int _contentLength;
private int _inputLength;
public ForContentLength(bool keepAlive, int contentLength, FrameContext context)
public ForContentLength(bool keepAlive, int contentLength, Frame context)
: base(context)
{
RequestKeepAlive = keepAlive;
@ -204,96 +205,203 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
/// </summary>
private class ForChunkedEncoding : MessageBody
{
private int _inputLength;
private Mode _mode = Mode.ChunkPrefix;
private static Vector<byte> _vectorCRs = new Vector<byte>((byte)'\r');
public ForChunkedEncoding(bool keepAlive, FrameContext context)
private int _inputLength;
private Mode _mode = Mode.Prefix;
private FrameRequestHeaders _requestHeaders;
public ForChunkedEncoding(bool keepAlive, FrameRequestHeaders headers, Frame context)
: base(context)
{
RequestKeepAlive = keepAlive;
_requestHeaders = headers;
}
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
return ReadAsyncAwaited(buffer, cancellationToken);
return ReadStateMachineAsync(_context.SocketInput, buffer, cancellationToken);
}
private async Task<int> ReadAsyncAwaited(ArraySegment<byte> buffer, CancellationToken cancellationToken)
private async Task<int> ReadStateMachineAsync(SocketInput input, ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
var input = _context.SocketInput;
while (_mode != Mode.Trailer && _mode != Mode.Complete)
while (_mode < Mode.Trailer)
{
while (_mode == Mode.ChunkPrefix)
while (_mode == Mode.Prefix)
{
ReadChunkedPrefix(input);
await input;
ParseChunkedPrefix(input);
if (_mode != Mode.Prefix)
{
break;
}
await GetDataAsync(input);
}
while (_mode == Mode.ChunkData)
while (_mode == Mode.Extension)
{
ParseExtension(input);
if (_mode != Mode.Extension)
{
break;
}
await GetDataAsync(input);
}
while (_mode == Mode.Data)
{
int actual = ReadChunkedData(input, buffer.Array, buffer.Offset, buffer.Count);
if (actual != 0)
{
return actual;
}
else if (_mode != Mode.Data)
{
break;
}
await input;
await GetDataAsync(input);
}
while (_mode == Mode.ChunkSuffix)
while (_mode == Mode.Suffix)
{
ReadChunkedSuffix(input);
await input;
ParseChunkedSuffix(input);
if (_mode != Mode.Suffix)
{
break;
}
await GetDataAsync(input);
}
}
// Chunks finished, parse trailers
while (_mode == Mode.Trailer)
{
ReadChunkedTrailer(input);
if (_mode != Mode.Complete && _mode != Mode.TrailerHeaders)
ParseChunkedTrailer(input);
if (_mode != Mode.Trailer)
{
await input;
break;
}
await GetDataAsync(input);
}
if (_mode == Mode.TrailerHeaders)
{
// Take trailer headers
var frame = (Frame)_context;
while (!Frame.TakeMessageHeaders(input, frame._requestHeaders))
while (!Frame.TakeMessageHeaders(input, _requestHeaders))
{
if (input.RemoteIntakeFin)
{
ThrowChunkedRequestIncomplete();
}
await input;
await GetDataAsync(input);
}
_mode = Mode.Complete;
}
return 0;
}
private void ReadChunkedPrefix(SocketInput input)
private void ParseChunkedPrefix(SocketInput input)
{
int chunkSize;
if (TakeChunkedLine(input, out chunkSize))
var scan = input.ConsumingStart();
var consumed = scan;
try
{
if (chunkSize == 0)
var ch1 = scan.Take();
var ch2 = scan.Take();
if (ch1 == -1 || ch2 == -1)
{
_mode = Mode.Trailer;
}
else
{
_mode = Mode.ChunkData;
return;
}
_inputLength = chunkSize;
var chunkSize = CalculateChunkSize(ch1, 0);
ch1 = ch2;
do
{
if (ch1 == ';')
{
consumed = scan;
_inputLength = chunkSize;
_mode = Mode.Extension;
return;
}
ch2 = scan.Take();
if (ch2 == -1)
{
return;
}
if (ch1 == '\r' && ch2 == '\n')
{
consumed = scan;
_inputLength = chunkSize;
if (chunkSize > 0)
{
_mode = Mode.Data;
}
else
{
_mode = Mode.Trailer;
}
return;
}
chunkSize = CalculateChunkSize(ch1, chunkSize);
ch1 = ch2;
} while (ch1 != -1);
}
else if (input.RemoteIntakeFin)
finally
{
ThrowChunkedRequestIncomplete();
input.ConsumingComplete(consumed, scan);
}
}
private void ParseExtension(SocketInput input)
{
var scan = input.ConsumingStart();
var consumed = scan;
try
{
// Chunk-extensions not currently parsed
// Just drain the data
do
{
if (scan.Seek(ref _vectorCRs) == -1)
{
// End marker not found yet
consumed = scan;
return;
};
var ch1 = scan.Take();
var ch2 = scan.Take();
if (ch2 == '\n')
{
consumed = scan;
if (_inputLength > 0)
{
_mode = Mode.Data;
}
else
{
_mode = Mode.Trailer;
}
}
else if (ch2 == -1)
{
return;
}
} while (_mode == Mode.Extension);
}
finally
{
input.ConsumingComplete(consumed, scan);
}
}
@ -314,17 +422,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
if (_inputLength == 0)
{
_mode = Mode.ChunkSuffix;
_mode = Mode.Suffix;
}
else if (actual == 0 && input.RemoteIntakeFin)
else if (actual == 0)
{
ThrowChunkedRequestIncomplete();
ThrowIfRequestIncomplete(input);
}
return actual;
}
private void ReadChunkedSuffix(SocketInput input)
private void ParseChunkedSuffix(SocketInput input)
{
var scan = input.ConsumingStart();
var consumed = scan;
@ -332,18 +440,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{
var ch1 = scan.Take();
var ch2 = scan.Take();
if (ch1 == '\r' && ch2 == '\n')
if (ch1 == -1 || ch2 == -1)
{
return;
}
else if (ch1 == '\r' && ch2 == '\n')
{
consumed = scan;
_mode = Mode.ChunkPrefix;
}
else if (ch1 == -1 || ch2 == -1)
{
if (input.RemoteIntakeFin)
{
ThrowChunkedRequestIncomplete();
}
_mode = Mode.Prefix;
}
else
{
@ -356,7 +460,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
}
}
private void ReadChunkedTrailer(SocketInput input)
private void ParseChunkedTrailer(SocketInput input)
{
var scan = input.ConsumingStart();
var consumed = scan;
@ -365,29 +469,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
var ch1 = scan.Take();
var ch2 = scan.Take();
if (ch1 == '\r' && ch2 == '\n')
if (ch1 == -1 || ch2 == -1)
{
return;
}
else if (ch1 == '\r' && ch2 == '\n')
{
consumed = scan;
_mode = Mode.Complete;
}
else if (ch1 == -1 || ch2 == -1)
{
if (input.RemoteIntakeFin)
{
ThrowChunkedRequestIncomplete();
}
}
else
{
// Post request headers
if (_context is Frame)
{
_mode = Mode.TrailerHeaders;
}
else
{
ThrowTrailingHeadersNotSupported();
}
_mode = Mode.TrailerHeaders;
}
}
finally
@ -396,124 +489,61 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
}
}
private static bool TakeChunkedLine(SocketInput baton, out int chunkSizeOut)
private static int CalculateChunkSize(int extraHexDigit, int currentParsedSize)
{
var scan = baton.ConsumingStart();
var consumed = scan;
try
checked
{
var ch0 = scan.Take();
var chunkSize = 0;
var mode = Mode.ChunkPrefix;
while (ch0 != -1)
if (extraHexDigit >= '0' && extraHexDigit <= '9')
{
var ch1 = scan.Take();
if (ch1 == -1)
{
chunkSizeOut = 0;
return false;
}
if (mode == Mode.ChunkPrefix)
{
if (ch0 >= '0' && ch0 <= '9')
{
chunkSize = chunkSize * 0x10 + (ch0 - '0');
}
else if (ch0 >= 'A' && ch0 <= 'F')
{
chunkSize = chunkSize * 0x10 + (ch0 - ('A' - 10));
}
else if (ch0 >= 'a' && ch0 <= 'f')
{
chunkSize = chunkSize * 0x10 + (ch0 - ('a' - 10));
}
else
{
ThrowInvalidFormat();
}
mode = Mode.ChunkData;
}
else if (mode == Mode.ChunkData)
{
if (ch0 >= '0' && ch0 <= '9')
{
chunkSize = chunkSize * 0x10 + (ch0 - '0');
}
else if (ch0 >= 'A' && ch0 <= 'F')
{
chunkSize = chunkSize * 0x10 + (ch0 - ('A' - 10));
}
else if (ch0 >= 'a' && ch0 <= 'f')
{
chunkSize = chunkSize * 0x10 + (ch0 - ('a' - 10));
}
else if (ch0 == ';')
{
mode = Mode.ChunkSuffix;
}
else if (ch0 == '\r' && ch1 == '\n')
{
consumed = scan;
chunkSizeOut = chunkSize;
return true;
}
else
{
ThrowInvalidFormat();
}
}
else if (mode == Mode.ChunkSuffix)
{
if (ch0 == '\r' && ch1 == '\n')
{
consumed = scan;
chunkSizeOut = chunkSize;
return true;
}
else
{
// chunk-extensions not currently parsed
ThrowChunkedExtensionsNotSupported();
}
}
ch0 = ch1;
return currentParsedSize * 0x10 + (extraHexDigit - '0');
}
else if (extraHexDigit >= 'A' && extraHexDigit <= 'F')
{
return currentParsedSize * 0x10 + (extraHexDigit - ('A' - 10));
}
else if (extraHexDigit >= 'a' && extraHexDigit <= 'f')
{
return currentParsedSize * 0x10 + (extraHexDigit - ('a' - 10));
}
else
{
return ThrowInvalidFormat();
}
chunkSizeOut = 0;
return false;
}
finally
{
baton.ConsumingComplete(consumed, scan);
}
}
private static void ThrowInvalidFormat()
private static SocketInput GetDataAsync(SocketInput input)
{
throw new InvalidOperationException("Bad Request");
ThrowIfRequestIncomplete(input);
return input;
}
private static void ThrowChunkedRequestIncomplete()
private static void ThrowIfRequestIncomplete(SocketInput input)
{
if (input.RemoteIntakeFin)
{
ThrowRequestIncomplete();
}
}
private static int ThrowInvalidFormat()
{
// returns int so can be used as item non-void function
throw new InvalidOperationException("Bad request");
}
private static void ThrowRequestIncomplete()
{
throw new InvalidOperationException("Chunked request incomplete");
}
private static void ThrowChunkedExtensionsNotSupported()
{
throw new NotImplementedException("Chunked-extensions not supported");
}
private static void ThrowTrailingHeadersNotSupported()
{
throw new NotImplementedException("Trailing headers not supported");
}
private enum Mode
{
ChunkPrefix,
ChunkData,
ChunkSuffix,
Prefix,
Extension,
Data,
Suffix,
Trailer,
TrailerHeaders,
Complete

View File

@ -2,15 +2,16 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net.Sockets;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Server.Kestrel;
using Microsoft.AspNetCore.Server.Kestrel.Filter;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
using Microsoft.AspNetCore.Testing.xunit;
using Microsoft.Extensions.Logging;
using Xunit;
@ -64,7 +65,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
await response.Body.WriteAsync(bytes, 0, bytes.Length);
}
[ConditionalTheory]
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task Http10TransferEncoding(ServiceContext testContext)
{
@ -88,7 +89,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[ConditionalTheory]
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task Http10KeepAliveTransferEncoding(ServiceContext testContext)
{
@ -123,7 +124,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[ConditionalTheory]
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task RequestBodyIsConsumedAutomaticallyIfAppDoesntConsumeItFully(ServiceContext testContext)
{
@ -171,9 +172,8 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[ConditionalTheory]
[Theory]
[MemberData(nameof(ConnectionFilterData))]
[FrameworkSkipCondition(RuntimeFrameworks.Mono, SkipReason = "Test hangs after execution on Mono.")]
public async Task TrailingHeadersAreParsed(ServiceContext testContext)
{
var requestCount = 10;
@ -186,20 +186,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
var buffer = new byte[200];
Assert.Equal(string.Empty, request.Headers["X-Trailer-Header"]);
Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"]));
while(await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0)
while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0)
{
// read to end
}
if (requestsReceived < requestCount)
{
Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"]);
Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"].ToString());
}
else
{
Assert.Equal(string.Empty, request.Headers["X-Trailer-Header"]);
Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"]));
}
requestsReceived++;
@ -218,41 +218,231 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
var expectedFullResponse = string.Join("", Enumerable.Repeat(response, requestCount + 1));
using (var connection = new TestConnection(server.Port))
IEnumerable<string> sendSequence = new string[] {
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C",
"HelloChunked",
"0",
""};
for (var i = 1; i < requestCount; i++)
{
await connection.Send(
sendSequence = sendSequence.Concat(new string[] {
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C", "HelloChunked",
"C",
$"HelloChunk{i:00}",
"0",
"");
string.Concat("X-Trailer-Header: ", new string('a', i)),
"" });
}
for (var i = 1; i < requestCount; i++)
{
await connection.Send(
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C", "HelloChunked",
"0",
string.Concat("X-Trailer-Header", new string('a', i)),
"");
}
sendSequence = sendSequence.Concat(new string[] {
"POST / HTTP/1.1",
"Content-Length: 7",
"",
"Goodbye"
});
await connection.SendEnd(
"POST / HTTP/1.1",
"Content-Length: 7",
"",
"Goodbye");
var fullRequest = sendSequence.ToArray();
using (var connection = new TestConnection(server.Port))
{
await connection.SendEnd(fullRequest);
await connection.ReceiveEnd(expectedFullResponse);
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task ExtensionsAreIgnored(ServiceContext testContext)
{
var requestCount = 10;
var requestsReceived = 0;
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
var request = httpContext.Request;
var buffer = new byte[200];
Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"]));
while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0)
{
// read to end
}
if (requestsReceived < requestCount)
{
Assert.Equal(new string('a', requestsReceived), request.Headers["X-Trailer-Header"].ToString());
}
else
{
Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"]));
}
requestsReceived++;
response.Headers.Clear();
response.Headers["Content-Length"] = new[] { "11" };
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
}, testContext))
{
var response = string.Join("\r\n", new string[] {
"HTTP/1.1 200 OK",
"Content-Length: 11",
"",
"Hello World"});
var expectedFullResponse = string.Join("", Enumerable.Repeat(response, requestCount + 1));
IEnumerable<string> sendSequence = new string[] {
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C;hello there",
"HelloChunked",
"0;hello there",
""};
for (var i = 1; i < requestCount; i++)
{
sendSequence = sendSequence.Concat(new string[] {
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C;hello there",
$"HelloChunk{i:00}",
"0;hello there",
string.Concat("X-Trailer-Header: ", new string('a', i)),
"" });
}
sendSequence = sendSequence.Concat(new string[] {
"POST / HTTP/1.1",
"Content-Length: 7",
"",
"Goodbye"
});
var fullRequest = sendSequence.ToArray();
using (var connection = new TestConnection(server.Port))
{
await connection.SendEnd(fullRequest);
await connection.ReceiveEnd(expectedFullResponse);
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task InvalidLengthResultsIn500(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
var request = httpContext.Request;
var buffer = new byte[200];
while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0)
{
;// read to end
}
response.Headers.Clear();
response.Headers["Content-Length"] = new[] { "11" };
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
await connection.Send(
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"Cio",
"HelloChunked",
"0",
"");
// Should really be a 40x as is bad request
await connection.Receive(
"HTTP/1.1 500 Internal Server Error",
"");
await connection.ReceiveStartsWith("Date:");
await connection.ReceiveEnd(
"Content-Length: 0",
"Server: Kestrel",
"",
"");
}
}
}
[Theory]
[MemberData(nameof(ConnectionFilterData))]
public async Task InvalidSizedDataResultsIn500(ServiceContext testContext)
{
using (var server = new TestServer(async httpContext =>
{
var response = httpContext.Response;
var request = httpContext.Request;
var buffer = new byte[200];
while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0)
{
;// read to end
}
response.Headers.Clear();
response.Headers["Content-Length"] = new[] { "11" };
await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11);
}, testContext))
{
using (var connection = new TestConnection(server.Port))
{
await connection.Send(
"POST / HTTP/1.1",
"Transfer-Encoding: chunked",
"",
"C",
"HelloChunkedInvalid",
"0",
"");
// Should really be a 40x as is bad request
await connection.Receive(
"HTTP/1.1 500 Internal Server Error",
"");
await connection.ReceiveStartsWith("Date:");
await connection.ReceiveEnd(
"Content-Length: 0",
"Server: Kestrel",
"",
"");
}
}
}
private class TestApplicationErrorLogger : ILogger
{
// Application errors are logged using 13 as the eventId.
private const int ApplicationErrorEventId = 13;
public int ApplicationErrorsLogged { get; set; }
public IDisposable BeginScopeImpl(object state)
@ -267,8 +457,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception exception, Func<TState, Exception, string> formatter)
{
// Application errors are logged using 13 as the eventId.
if (eventId.Id == 13)
if (eventId.Id == ApplicationErrorEventId)
{
ApplicationErrorsLogged++;
}

View File

@ -18,17 +18,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
{
var trace = new KestrelTrace(new TestKestrelTrace());
var ltp = new LoggingThreadPool(trace);
FrameContext = new FrameContext
var context = new FrameContext()
{
DateHeaderValueManager = new DateHeaderValueManager(),
ServerAddress = ServerAddress.FromUrl("http://localhost:5000"),
ConnectionControl = this,
FrameControl = this
};
FrameContext = new Frame<object>(null, context);
_memoryPool = new MemoryPool2();
FrameContext.SocketInput = new SocketInput(_memoryPool, ltp);
}
public FrameContext FrameContext { get; set; }
public Frame FrameContext { get; set; }
public void Add(string text, bool fin = false)
{