diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs index c08d0eb684..dac00934db 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/Frame.cs @@ -232,7 +232,7 @@ namespace Microsoft.AspNet.Server.Kestrel.Http await ProduceEnd(); - while (await RequestBody.ReadAsync(_nullBuffer, 0, _nullBuffer.Length) != 0) + while (await MessageBody.SkipAsync() != 0) { // Finish reading the request body in case the app did not. } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/MessageBody.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/MessageBody.cs index e8e438512b..393803e752 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/MessageBody.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/MessageBody.cs @@ -38,8 +38,26 @@ namespace Microsoft.AspNet.Server.Kestrel.Http return result; } + public Task SkipAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + Task result = null; + var send100Continue = 0; + result = SkipImplementation(cancellationToken); + if (!result.IsCompleted) + { + send100Continue = Interlocked.Exchange(ref _send100Continue, 0); + } + if (send100Continue == 1) + { + _context.FrameControl.ProduceContinue(); + } + return result; + } + public abstract Task ReadAsyncImplementation(ArraySegment buffer, CancellationToken cancellationToken); + public abstract Task SkipImplementation(CancellationToken cancellationToken); + public static MessageBody For( string httpVersion, IDictionary headers, @@ -110,6 +128,10 @@ namespace Microsoft.AspNet.Server.Kestrel.Http { return _context.SocketInput.ReadAsync(buffer); } + public override Task SkipImplementation(CancellationToken cancellationToken) + { + return _context.SocketInput.SkipAsync(4096); + } } class ForContentLength : MessageBody @@ -146,6 +168,27 @@ namespace Microsoft.AspNet.Server.Kestrel.Http return actual; } + + public override async Task SkipImplementation(CancellationToken cancellationToken) + { + var input = _context.SocketInput; + + var limit = Math.Min(4096, _inputLength); + if (limit == 0) + { + return 0; + } + + var actual = await _context.SocketInput.SkipAsync(limit); + _inputLength -= actual; + + if (actual == 0) + { + throw new InvalidDataException("Unexpected end of request content"); + } + + return actual; + } } @@ -236,6 +279,78 @@ namespace Microsoft.AspNet.Server.Kestrel.Http return 0; } + public override async Task SkipImplementation(CancellationToken cancellationToken) + { + var input = _context.SocketInput; + + while (_mode != Mode.Complete) + { + while (_mode == Mode.ChunkPrefix) + { + var chunkSize = 0; + if (!TakeChunkedLine(input, ref chunkSize)) + { + await input; + } + else if (chunkSize == 0) + { + _mode = Mode.Complete; + } + else + { + _mode = Mode.ChunkData; + } + _inputLength = chunkSize; + } + while (_mode == Mode.ChunkData) + { + var limit = Math.Min(4096, _inputLength); + if (limit != 0) + { + await input; + } + + var begin = input.ConsumingStart(); + int actual; + var end = begin.Skip(limit, out actual); + _inputLength -= actual; + input.ConsumingComplete(end, end); + + if (_inputLength == 0) + { + _mode = Mode.ChunkSuffix; + } + if (actual != 0) + { + return actual; + } + } + while (_mode == Mode.ChunkSuffix) + { + var scan = input.ConsumingStart(); + var consumed = scan; + var ch1 = scan.Take(); + var ch2 = scan.Take(); + if (ch1 == -1 || ch2 == -1) + { + input.ConsumingComplete(consumed, scan); + await input; + } + else if (ch1 == '\r' && ch2 == '\n') + { + input.ConsumingComplete(scan, scan); + _mode = Mode.ChunkPrefix; + } + else + { + throw new NotImplementedException("INVALID REQUEST FORMAT"); + } + } + } + + return 0; + } + private static bool TakeChunkedLine(SocketInput baton, ref int chunkSizeOut) { var scan = baton.ConsumingStart(); diff --git a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs index 9c5d690707..70a491565d 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Http/SocketInputExtensions.cs @@ -29,5 +29,27 @@ namespace Microsoft.AspNet.Server.Kestrel.Http } } } + + public static async Task SkipAsync(this SocketInput input, int limit) + { + while (true) + { + await input; + + var begin = input.ConsumingStart(); + int actual; + var end = begin.Skip(limit, out actual); + input.ConsumingComplete(end, end); + + if (actual != 0) + { + return actual; + } + if (input.RemoteIntakeFin) + { + return 0; + } + } + } } } diff --git a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs index 7de5c7fb10..c14534a176 100644 --- a/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs +++ b/src/Microsoft.AspNet.Server.Kestrel/Infrastructure/MemoryPoolIterator2.cs @@ -619,5 +619,37 @@ namespace Microsoft.AspNet.Server.Kestrel.Infrastructure } } } + public MemoryPoolIterator2 Skip(int limit, out int actual) + { + if (IsDefault) + { + actual = 0; + return this; + } + + var block = _block; + var index = _index; + var remaining = limit; + while (true) + { + var following = block.End - index; + if (remaining <= following) + { + actual = limit; + return new MemoryPoolIterator2(block, index + remaining); + } + else if (block.Next == null) + { + actual = limit - remaining + following; + return new MemoryPoolIterator2(block, index + following); + } + else + { + remaining -= following; + block = block.Next; + index = block.Start; + } + } + } } }