diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameRequestStream.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameRequestStream.cs index 3889a02db5..b75effe5e8 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameRequestStream.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameRequestStream.cs @@ -119,8 +119,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http var task = ValidateState(cancellationToken); if (task == null) { - // Needs .AsTask to match Stream's Async method return types - return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken).AsTask(); + return _body.ReadAsync(new ArraySegment(buffer, offset, count), cancellationToken); + } + return task; + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + if (destination == null) + { + throw new ArgumentNullException(nameof(destination)); + } + if (bufferSize <= 0) + { + throw new ArgumentException($"{nameof(bufferSize)} must be positive.", nameof(bufferSize)); + } + + var task = ValidateState(cancellationToken); + if (task == null) + { + return _body.CopyToAsync(destination, cancellationToken); } return task; } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs index 65bbe9a808..784f8d97de 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/MessageBody.cs @@ -2,9 +2,11 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.Numerics; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http @@ -12,7 +14,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public abstract class MessageBody { private readonly Frame _context; - private int _send100Continue = 1; + private bool _send100Continue = true; protected MessageBody(Frame context) { @@ -21,80 +23,205 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http public bool RequestKeepAlive { get; protected set; } - public ValueTask ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) + public Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) { - var send100Continue = 0; - var result = ReadAsyncImplementation(buffer, cancellationToken); - if (!result.IsCompleted) + var task = PeekAsync(cancellationToken); + + if (!task.IsCompleted) { - send100Continue = Interlocked.Exchange(ref _send100Continue, 0); + TryProduceContinue(); + + // Incomplete Task await result + return ReadAsyncAwaited(task, buffer); } - if (send100Continue == 1) + else { - _context.FrameControl.ProduceContinue(); + var readSegment = task.Result; + var consumed = CopyReadSegment(readSegment, buffer); + + return consumed == 0 ? TaskCache.DefaultCompletedTask : Task.FromResult(consumed); } - return result; + } + + private async Task ReadAsyncAwaited(ValueTask> currentTask, ArraySegment buffer) + { + return CopyReadSegment(await currentTask, buffer); + } + + private int CopyReadSegment(ArraySegment readSegment, ArraySegment buffer) + { + var consumed = Math.Min(readSegment.Count, buffer.Count); + + if (consumed != 0) + { + Buffer.BlockCopy(readSegment.Array, readSegment.Offset, buffer.Array, buffer.Offset, consumed); + ConsumedBytes(consumed); + } + + return consumed; + } + + public Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) + { + var peekTask = PeekAsync(cancellationToken); + + while (peekTask.IsCompleted) + { + // ValueTask uses .GetAwaiter().GetResult() if necessary + var segment = peekTask.Result; + + if (segment.Count == 0) + { + return TaskCache.CompletedTask; + } + + Task destinationTask; + try + { + destinationTask = destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + } + catch + { + ConsumedBytes(segment.Count); + throw; + } + + if (!destinationTask.IsCompleted) + { + return CopyToAsyncDestinationAwaited(destinationTask, segment.Count, destination, cancellationToken); + } + + ConsumedBytes(segment.Count); + + // Surface errors if necessary + destinationTask.GetAwaiter().GetResult(); + + peekTask = PeekAsync(cancellationToken); + } + + TryProduceContinue(); + + return CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken); + } + + private async Task CopyToAsyncPeekAwaited( + ValueTask> peekTask, + Stream destination, + CancellationToken cancellationToken = default(CancellationToken)) + { + while (true) + { + var segment = await peekTask; + + if (segment.Count == 0) + { + return; + } + + try + { + await destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + } + finally + { + ConsumedBytes(segment.Count); + } + + peekTask = PeekAsync(cancellationToken); + } + } + + private async Task CopyToAsyncDestinationAwaited( + Task destinationTask, + int bytesConsumed, + Stream destination, + CancellationToken cancellationToken = default(CancellationToken)) + { + try + { + await destinationTask; + } + finally + { + ConsumedBytes(bytesConsumed); + } + + var peekTask = PeekAsync(cancellationToken); + + if (!peekTask.IsCompleted) + { + TryProduceContinue(); + } + + await CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken); } public Task Consume(CancellationToken cancellationToken = default(CancellationToken)) { - ValueTask result; - var send100checked = false; - do + while (true) { - result = ReadAsyncImplementation(default(ArraySegment), cancellationToken); - if (!result.IsCompleted) + var task = PeekAsync(cancellationToken); + if (!task.IsCompleted) { - if (!send100checked) - { - if (Interlocked.Exchange(ref _send100Continue, 0) == 1) - { - _context.FrameControl.ProduceContinue(); - } - send100checked = true; - } + TryProduceContinue(); + // Incomplete Task await result - return ConsumeAwaited(result.AsTask(), cancellationToken); + return ConsumeAwaited(task, cancellationToken); } - // ValueTask uses .GetAwaiter().GetResult() if necessary - else if (result.Result == 0) - { - // Completed Task, end of stream - return TaskCache.CompletedTask; - } - - } while (true); - } - - private async Task ConsumeAwaited(Task currentTask, CancellationToken cancellationToken) - { - if (await currentTask == 0) - { - return; - } - - ValueTask result; - do - { - result = ReadAsyncImplementation(default(ArraySegment), cancellationToken); - if (result.IsCompleted) + else { // ValueTask uses .GetAwaiter().GetResult() if necessary - if (result.Result == 0) + if (task.Result.Count == 0) { // Completed Task, end of stream - return; - } - else - { - // Completed Task, get next Task rather than await - continue; + return TaskCache.CompletedTask; } + + ConsumedBytes(task.Result.Count); } - } while (await result != 0); + } } - public abstract ValueTask ReadAsyncImplementation(ArraySegment buffer, CancellationToken cancellationToken); + private async Task ConsumeAwaited(ValueTask> currentTask, CancellationToken cancellationToken) + { + while (true) + { + var count = (await currentTask).Count; + + if (count == 0) + { + // Completed Task, end of stream + return; + } + + ConsumedBytes(count); + currentTask = PeekAsync(cancellationToken); + } + } + + private void TryProduceContinue() + { + if (_send100Continue) + { + _context.FrameControl.ProduceContinue(); + _send100Continue = false; + } + } + + private void ConsumedBytes(int count) + { + var scan = _context.SocketInput.ConsumingStart(); + scan.Skip(count); + _context.SocketInput.ConsumingComplete(scan, scan); + + OnConsumedBytes(count); + } + + protected abstract ValueTask> PeekAsync(CancellationToken cancellationToken); + + protected virtual void OnConsumedBytes(int count) + { + } public static MessageBody For( HttpVersion httpVersion, @@ -145,9 +272,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { } - public override ValueTask ReadAsyncImplementation(ArraySegment buffer, CancellationToken cancellationToken) + protected override ValueTask> PeekAsync(CancellationToken cancellationToken) { - return _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, buffer.Array == null ? 8192 : buffer.Count); + return _context.SocketInput.PeekAsync(); } } @@ -164,49 +291,65 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _inputLength = _contentLength; } - public override ValueTask ReadAsyncImplementation(ArraySegment buffer, CancellationToken cancellationToken) + protected override ValueTask> PeekAsync(CancellationToken cancellationToken) { - var input = _context.SocketInput; - - var inputLengthLimit = (int)Math.Min(_inputLength, int.MaxValue); - var limit = buffer.Array == null ? inputLengthLimit : Math.Min(buffer.Count, inputLengthLimit); + var limit = (int)Math.Min(_inputLength, int.MaxValue); if (limit == 0) { - return new ValueTask(0); + return new ValueTask>(); } - var task = _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, limit); + var task = _context.SocketInput.PeekAsync(); if (task.IsCompleted) { // .GetAwaiter().GetResult() done by ValueTask if needed - var actual = task.Result; - _inputLength -= actual; + var actual = Math.Min(task.Result.Count, limit); - if (actual == 0) + if (task.Result.Count == 0) { _context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent); } - return new ValueTask(actual); + if (task.Result.Count < _inputLength) + { + return task; + } + else + { + var result = task.Result; + var part = new ArraySegment(result.Array, result.Offset, (int)_inputLength); + return new ValueTask>(part); + } } else { - return new ValueTask(ReadAsyncAwaited(task.AsTask())); + return new ValueTask>(PeekAsyncAwaited(task)); } } - private async Task ReadAsyncAwaited(Task task) + private async Task> PeekAsyncAwaited(ValueTask> task) { - var actual = await task; - _inputLength -= actual; + var segment = await task; - if (actual == 0) + if (segment.Count == 0) { _context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent); } - return actual; + if (segment.Count <= _inputLength) + { + return segment; + } + else + { + return new ArraySegment(segment.Array, segment.Offset, (int)_inputLength); + } + } + + protected override void OnConsumedBytes(int count) + { + _inputLength -= count; } } @@ -219,31 +362,39 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http // https://github.com/dotnet/corefx/issues/8825 private Vector _vectorCRs = new Vector((byte)'\r'); + private readonly SocketInput _input; + private readonly FrameRequestHeaders _requestHeaders; private int _inputLength; + private Mode _mode = Mode.Prefix; - private FrameRequestHeaders _requestHeaders; public ForChunkedEncoding(bool keepAlive, FrameRequestHeaders headers, Frame context) : base(context) { RequestKeepAlive = keepAlive; + _input = _context.SocketInput; _requestHeaders = headers; } - public override ValueTask ReadAsyncImplementation(ArraySegment buffer, CancellationToken cancellationToken) + protected override ValueTask> PeekAsync(CancellationToken cancellationToken) { - return new ValueTask(ReadStateMachineAsync(_context.SocketInput, buffer, cancellationToken)); + return new ValueTask>(PeekStateMachineAsync()); } - private async Task ReadStateMachineAsync(SocketInput input, ArraySegment buffer, CancellationToken cancellationToken) + protected override void OnConsumedBytes(int count) + { + _inputLength -= count; + } + + private async Task> PeekStateMachineAsync() { while (_mode < Mode.Trailer) { while (_mode == Mode.Prefix) { - var fin = input.CheckFinOrThrow(); + var fin = _input.CheckFinOrThrow(); - ParseChunkedPrefix(input); + ParseChunkedPrefix(); if (_mode != Mode.Prefix) { @@ -254,14 +405,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await input; + await _input; } while (_mode == Mode.Extension) { - var fin = input.CheckFinOrThrow(); + var fin = _input.CheckFinOrThrow(); - ParseExtension(input); + ParseExtension(); if (_mode != Mode.Extension) { @@ -272,18 +423,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await input; + await _input; } while (_mode == Mode.Data) { - var fin = input.CheckFinOrThrow(); + var fin = _input.CheckFinOrThrow(); - int actual = ReadChunkedData(input, buffer.Array, buffer.Offset, buffer.Count); + var segment = PeekChunkedData(); - if (actual != 0) + if (segment.Count != 0) { - return actual; + return segment; } else if (_mode != Mode.Data) { @@ -294,14 +445,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await input; + await _input; } while (_mode == Mode.Suffix) { - var fin = input.CheckFinOrThrow(); + var fin = _input.CheckFinOrThrow(); - ParseChunkedSuffix(input); + ParseChunkedSuffix(); if (_mode != Mode.Suffix) { @@ -312,16 +463,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await input; + await _input; } } // Chunks finished, parse trailers while (_mode == Mode.Trailer) { - var fin = input.CheckFinOrThrow(); + var fin = _input.CheckFinOrThrow(); - ParseChunkedTrailer(input); + ParseChunkedTrailer(); if (_mode != Mode.Trailer) { @@ -332,16 +483,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); } - await input; + await _input; } if (_mode == Mode.TrailerHeaders) { - while (!_context.TakeMessageHeaders(input, _requestHeaders)) + while (!_context.TakeMessageHeaders(_input, _requestHeaders)) { - if (input.CheckFinOrThrow()) + if (_input.CheckFinOrThrow()) { - if (_context.TakeMessageHeaders(input, _requestHeaders)) + if (_context.TakeMessageHeaders(_input, _requestHeaders)) { break; } @@ -351,18 +502,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } } - await input; + await _input; } _mode = Mode.Complete; } - return 0; + return default(ArraySegment); } - private void ParseChunkedPrefix(SocketInput input) + private void ParseChunkedPrefix() { - var scan = input.ConsumingStart(); + var scan = _input.ConsumingStart(); var consumed = scan; try { @@ -416,13 +567,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } finally { - input.ConsumingComplete(consumed, scan); + _input.ConsumingComplete(consumed, scan); } } - private void ParseExtension(SocketInput input) + private void ParseExtension() { - var scan = input.ConsumingStart(); + var scan = _input.ConsumingStart(); var consumed = scan; try { @@ -460,36 +611,37 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } finally { - input.ConsumingComplete(consumed, scan); + _input.ConsumingComplete(consumed, scan); } } - private int ReadChunkedData(SocketInput input, byte[] buffer, int offset, int count) + private ArraySegment PeekChunkedData() { - var scan = input.ConsumingStart(); - int actual; - try - { - var limit = buffer == null ? _inputLength : Math.Min(count, _inputLength); - scan = scan.CopyTo(buffer, offset, limit, out actual); - _inputLength -= actual; - } - finally - { - input.ConsumingComplete(scan, scan); - } - if (_inputLength == 0) { _mode = Mode.Suffix; + return default(ArraySegment); } - return actual; + var scan = _input.ConsumingStart(); + var segment = scan.PeekArraySegment(); + int actual = Math.Min(segment.Count, _inputLength); + // Nothing is consumed yet. ConsumedBytes(int) will move the iterator. + _input.ConsumingComplete(scan, scan); + + if (actual == segment.Count) + { + return segment; + } + else + { + return new ArraySegment(segment.Array, segment.Offset, actual); + } } - private void ParseChunkedSuffix(SocketInput input) + private void ParseChunkedSuffix() { - var scan = input.ConsumingStart(); + var scan = _input.ConsumingStart(); var consumed = scan; try { @@ -511,13 +663,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } finally { - input.ConsumingComplete(consumed, scan); + _input.ConsumingComplete(consumed, scan); } } - private void ParseChunkedTrailer(SocketInput input) + private void ParseChunkedTrailer() { - var scan = input.ConsumingStart(); + var scan = _input.ConsumingStart(); var consumed = scan; try { @@ -540,7 +692,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http } finally { - input.ConsumingComplete(consumed, scan); + _input.ConsumingComplete(consumed, scan); } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs index 2ff1f4c037..8dd26803ab 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/SocketInputExtensions.cs @@ -1,7 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { @@ -18,14 +20,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http var end = begin.CopyTo(buffer, offset, count, out actual); input.ConsumingComplete(end, end); - if (actual != 0) + if (actual != 0 || fin) { return new ValueTask(actual); } - else if (fin) - { - return new ValueTask(0); - } } return new ValueTask(input.ReadAsyncAwaited(buffer, offset, count)); @@ -44,13 +42,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http var end = begin.CopyTo(buffer, offset, count, out actual); input.ConsumingComplete(end, end); - if (actual != 0) + if (actual != 0 || fin) { return actual; } - else if (fin) + } + } + + public static ValueTask> PeekAsync(this SocketInput input) + { + while (input.IsCompleted) + { + var fin = input.CheckFinOrThrow(); + + var begin = input.ConsumingStart(); + var segment = begin.PeekArraySegment(); + input.ConsumingComplete(begin, begin); + + if (segment.Count != 0 || fin) { - return 0; + return new ValueTask>(segment); + } + } + + return new ValueTask>(input.PeekAsyncAwaited()); + } + + private static async Task> PeekAsyncAwaited(this SocketInput input) + { + while (true) + { + await input; + + var fin = input.CheckFinOrThrow(); + + var begin = input.ConsumingStart(); + var segment = begin.PeekArraySegment(); + input.ConsumingComplete(begin, begin); + + if (segment.Count != 0 || fin) + { + return segment; } } } diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs index 6bb47c0578..6cc8f38bce 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Infrastructure/MemoryPoolIteratorExtensions.cs @@ -232,6 +232,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure return new ArraySegment(array, 0, length); } + public static ArraySegment PeekArraySegment(this MemoryPoolIterator iter) + { + if (iter.IsDefault || iter.IsEnd) + { + return default(ArraySegment); + } + + if (iter.Index < iter.Block.End) + { + return new ArraySegment(iter.Block.Array, iter.Index, iter.Block.End - iter.Index); + } + + var block = iter.Block.Next; + while (block != null) + { + if (block.Start < block.End) + { + return new ArraySegment(block.Array, block.Start, block.End - block.Start); + } + block = block.Next; + } + + // The following should be unreachable due to the IsEnd check above. + throw new InvalidOperationException("This should be unreachable!"); + } + /// /// Checks that up to 8 bytes from correspond to a known HTTP method. /// diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs index cbe4c4ccb6..5d1c96e0f8 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestStreamTests.cs @@ -5,6 +5,7 @@ using System; using System.IO; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; +using Moq; using Xunit; namespace Microsoft.AspNetCore.Server.KestrelTests @@ -126,5 +127,61 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.True(task.IsFaulted); Assert.Same(error, task.Exception.InnerException); } + + [Fact] + public void StopAcceptingReadsCausesReadToThrowObjectDisposedException() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + stream.StopAcceptingReads(); + Assert.Throws(() => { stream.ReadAsync(new byte[1], 0, 1); }); + } + + [Fact] + public void AbortCausesCopyToAsyncToCancel() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + stream.Abort(); + var task = stream.CopyToAsync(Mock.Of()); + Assert.True(task.IsCanceled); + } + + [Fact] + public void AbortWithErrorCausesCopyToAsyncToCancel() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + var error = new Exception(); + stream.Abort(error); + var task = stream.CopyToAsync(Mock.Of()); + Assert.True(task.IsFaulted); + Assert.Same(error, task.Exception.InnerException); + } + + [Fact] + public void StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + stream.StopAcceptingReads(); + Assert.Throws(() => { stream.CopyToAsync(Mock.Of()); }); + } + + [Fact] + public void NullDestinationCausesCopyToAsyncToThrowArgumentNullException() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + Assert.Throws(() => { stream.CopyToAsync(null); }); + } + + [Fact] + public void ZeroBufferSizeCausesCopyToAsyncToThrowArgumentException() + { + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(null); + Assert.Throws(() => { stream.CopyToAsync(Mock.Of(), 0); }); + } } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs index fe450431e6..5424514870 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MemoryPoolIteratorTests.cs @@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public void MemorySeek(string raw, string search, char expectResult, int expectIndex) { var block = _pool.Lease(); - var chars = raw.ToCharArray().Select(c => (byte)c).ToArray(); + var chars = raw.ToCharArray().Select(c => (byte) c).ToArray(); Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length); block.End += chars.Length; @@ -98,20 +98,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests int found = -1; if (searchFor.Length == 1) { - var search0 = new Vector((byte)searchFor[0]); + var search0 = new Vector((byte) searchFor[0]); found = begin.Seek(ref search0); } else if (searchFor.Length == 2) { - var search0 = new Vector((byte)searchFor[0]); - var search1 = new Vector((byte)searchFor[1]); + var search0 = new Vector((byte) searchFor[0]); + var search1 = new Vector((byte) searchFor[1]); found = begin.Seek(ref search0, ref search1); } else if (searchFor.Length == 3) { - var search0 = new Vector((byte)searchFor[0]); - var search1 = new Vector((byte)searchFor[1]); - var search2 = new Vector((byte)searchFor[2]); + var search0 = new Vector((byte) searchFor[0]); + var search1 = new Vector((byte) searchFor[1]); + var search2 = new Vector((byte) searchFor[2]); found = begin.Seek(ref search0, ref search1, ref search2); } else @@ -176,7 +176,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests head = blocks[0].GetIterator(); for (var i = 0; i < 64; ++i) { - Assert.True(head.Put((byte)i), $"Fail to put data at {i}."); + Assert.True(head.Put((byte) i), $"Fail to put data at {i}."); } // Can't put anything by the end @@ -188,6 +188,112 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } + [Fact] + public void PeekArraySegment() + { + // Arrange + var block = _pool.Lease(); + var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); + block.End += bytes.Length; + var scan = block.GetIterator(); + var originalIndex = scan.Index; + + // Act + var result = scan.PeekArraySegment(); + + // Assert + Assert.Equal(new byte[] {0, 1, 2, 3, 4, 5, 6, 7}, result); + Assert.Equal(originalIndex, scan.Index); + + _pool.Return(block); + } + + [Fact] + public void PeekArraySegmentOnDefaultIteratorReturnsDefaultArraySegment() + { + // Assert.Equals doesn't work since xunit tries to access the underlying array. + Assert.True(default(ArraySegment).Equals(default(MemoryPoolIterator).PeekArraySegment())); + } + + [Fact] + public void PeekArraySegmentAtEndOfDataReturnsDefaultArraySegment() + { + // Arrange + var block = _pool.Lease(); + var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7}; + Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length); + block.End += bytes.Length; + block.Start = block.End; + + var scan = block.GetIterator(); + + // Act + var result = scan.PeekArraySegment(); + + // Assert + // Assert.Equals doesn't work since xunit tries to access the underlying array. + Assert.True(default(ArraySegment).Equals(result)); + + _pool.Return(block); + } + + [Fact] + public void PeekArraySegmentAtBlockBoundary() + { + // Arrange + var firstBlock = _pool.Lease(); + var lastBlock = _pool.Lease(); + + var firstBytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }; + var lastBytes = new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }; + + Buffer.BlockCopy(firstBytes, 0, firstBlock.Array, firstBlock.Start, firstBytes.Length); + firstBlock.End += lastBytes.Length; + + firstBlock.Next = lastBlock; + Buffer.BlockCopy(lastBytes, 0, lastBlock.Array, lastBlock.Start, lastBytes.Length); + lastBlock.End += lastBytes.Length; + + var scan = firstBlock.GetIterator(); + var originalIndex = scan.Index; + var originalBlock = scan.Block; + + // Act + var result = scan.PeekArraySegment(); + + // Assert + Assert.Equal(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }, result); + Assert.Equal(originalBlock, scan.Block); + Assert.Equal(originalIndex, scan.Index); + + // Act + // Advance past the data in the first block + scan.Skip(8); + result = scan.PeekArraySegment(); + + // Assert + Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result); + Assert.Equal(originalBlock, scan.Block); + Assert.Equal(originalIndex + 8, scan.Index); + + // Act + // Add anther empty block between the first and last block + var middleBlock = _pool.Lease(); + firstBlock.Next = middleBlock; + middleBlock.Next = lastBlock; + result = scan.PeekArraySegment(); + + // Assert + Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result); + Assert.Equal(originalBlock, scan.Block); + Assert.Equal(originalIndex + 8, scan.Index); + + _pool.Return(firstBlock); + _pool.Return(middleBlock); + _pool.Return(lastBlock); + } + [Fact] public void PeekLong() { diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs index 2dcb19ffef..9dd54a75d0 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/MessageBodyTests.cs @@ -2,10 +2,18 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; using System.Text; +using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; +using Microsoft.AspNetCore.Server.KestrelTests.TestHelpers; +using Microsoft.Extensions.Internal; +using Moq; using Xunit; +using Xunit.Sdk; namespace Microsoft.AspNetCore.Server.KestrelTests { @@ -68,23 +76,119 @@ namespace Microsoft.AspNetCore.Server.KestrelTests // Input needs to be greater than 4032 bytes to allocate a block not backed by a slab. var largeInput = new string('a', 8192); - input.Add(largeInput, true); + input.Add(largeInput); // Add a smaller block to the end so that SocketInput attempts to return the large // block to the memory pool. - input.Add("Hello", true); + input.Add("Hello", fin: true); - var readBuffer = new byte[8192]; + var ms = new MemoryStream(); - var count1 = await stream.ReadAsync(readBuffer, 0, 8192); - Assert.Equal(8192, count1); - AssertASCII(largeInput, new ArraySegment(readBuffer, 0, 8192)); + await stream.CopyToAsync(ms); + var requestArray = ms.ToArray(); + Assert.Equal(8197, requestArray.Length); + AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); - var count2 = await stream.ReadAsync(readBuffer, 0, 8192); - Assert.Equal(5, count2); - AssertASCII("Hello", new ArraySegment(readBuffer, 0, 5)); + var count = await stream.ReadAsync(new byte[1], 0, 1); + Assert.Equal(0, count); + } + } - var count3 = await stream.ReadAsync(readBuffer, 0, 8192); - Assert.Equal(0, count3); + public static IEnumerable StreamData => new[] + { + new object[] { new ThrowOnWriteSynchronousStream() }, + new object[] { new ThrowOnWriteAsynchronousStream() }, + }; + + public static IEnumerable RequestData => new[] + { + // Remaining Data + new object[] { new FrameRequestHeaders { HeaderConnection = "close" }, new[] { "Hello ", "World!" } }, + // Content-Length + new object[] { new FrameRequestHeaders { HeaderContentLength = "12" }, new[] { "Hello ", "World!" } }, + // Chunked + new object[] { new FrameRequestHeaders { HeaderTransferEncoding = "chunked" }, new[] { "6\r\nHello \r\n", "6\r\nWorld!\r\n0\r\n\r\n" } }, + }; + + public static IEnumerable CombinedData => + from stream in StreamData + from request in RequestData + select new[] { stream[0], request[0], request[1] }; + + [Theory] + [MemberData(nameof(RequestData))] + public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, string[] data) + { + var writeCount = 0; + var writeTcs = new TaskCompletionSource(); + var mockDestination = new Mock(); + + mockDestination + .Setup(m => m.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), CancellationToken.None)) + .Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => + { + writeTcs.SetResult(buffer); + writeCount++; + }) + .Returns(TaskCache.CompletedTask); + + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); + + var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); + + // The block returned by IncomingStart always has at least 2048 available bytes, + // so no need to bounds check in this test. + var socketInput = input.FrameContext.SocketInput; + var bytes = Encoding.ASCII.GetBytes(data[0]); + var block = socketInput.IncomingStart(); + Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length); + socketInput.IncomingComplete(bytes.Length, null); + + // Verify the block passed to WriteAsync is the same one incoming data was written into. + Assert.Same(block.Array, await writeTcs.Task); + + writeTcs = new TaskCompletionSource(); + bytes = Encoding.ASCII.GetBytes(data[1]); + block = socketInput.IncomingStart(); + Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length); + socketInput.IncomingComplete(bytes.Length, null); + + Assert.Same(block.Array, await writeTcs.Task); + + if (headers.HeaderConnection == "close") + { + socketInput.IncomingFin(); + } + + await copyToAsyncTask; + + Assert.Equal(2, writeCount); + } + } + + [Theory] + [MemberData(nameof(CombinedData))] + public async Task CopyToAsyncAdvancesRequestStreamWhenDestinationWriteAsyncThrows(Stream writeStream, FrameRequestHeaders headers, string[] data) + { + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); + + input.Add(data[0]); + + await Assert.ThrowsAsync(() => body.CopyToAsync(writeStream)); + + input.Add(data[1], fin: headers.HeaderConnection == "close"); + + // "Hello " should have been consumed + var readBuffer = new byte[6]; + var count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); + Assert.Equal(6, count); + AssertASCII("World!", new ArraySegment(readBuffer, 0, 6)); + + count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); + Assert.Equal(0, count); } } @@ -98,5 +202,84 @@ namespace Microsoft.AspNetCore.Server.KestrelTests Assert.Equal(bytes[index], actual.Array[actual.Offset + index]); } } + + private class ThrowOnWriteSynchronousStream : Stream + { + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + throw new XunitException(); + } + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite => true; + public override long Length { get; } + public override long Position { get; set; } + } + + private class ThrowOnWriteAsynchronousStream : Stream + { + public override void Flush() + { + throw new NotImplementedException(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotImplementedException(); + } + + public override void SetLength(long value) + { + throw new NotImplementedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotImplementedException(); + } + + public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await Task.Delay(1); + throw new XunitException(); + } + + public override bool CanRead { get; } + public override bool CanSeek { get; } + public override bool CanWrite => true; + public override long Length { get; } + public override long Position { get; set; } + } } } \ No newline at end of file diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs index b2bb528649..a751bbe342 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/SocketInputTests.cs @@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests new TheoryData>() { new Mock(), null }; [Theory] - [MemberData("MockBufferSizeControlData")] + [MemberData(nameof(MockBufferSizeControlData))] public void IncomingDataCallsBufferSizeControlAdd(Mock mockBufferSizeControl) { using (var memory = new MemoryPool()) @@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Theory] - [MemberData("MockBufferSizeControlData")] + [MemberData(nameof(MockBufferSizeControlData))] public void IncomingCompleteCallsBufferSizeControlAdd(Mock mockBufferSizeControl) { using (var memory = new MemoryPool()) @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Theory] - [MemberData("MockBufferSizeControlData")] + [MemberData(nameof(MockBufferSizeControlData))] public void ConsumingCompleteCallsBufferSizeControlSubtract(Mock mockBufferSizeControl) { using (var kestrelEngine = new KestrelEngine(new MockLibuv(), new TestServiceContext())) @@ -154,6 +154,80 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } } + [Fact] + public async Task PeekAsyncRereturnsTheSameData() + { + using (var memory = new MemoryPool()) + using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) + { + socketInput.IncomingData(new byte[5], 0, 5); + + Assert.True(socketInput.IsCompleted); + Assert.Equal(5, (await socketInput.PeekAsync()).Count); + + // The same 5 bytes will be returned again since it hasn't been consumed. + Assert.True(socketInput.IsCompleted); + Assert.Equal(5, (await socketInput.PeekAsync()).Count); + + var scan = socketInput.ConsumingStart(); + scan.Skip(3); + socketInput.ConsumingComplete(scan, scan); + + // The remaining 2 unconsumed bytes will be returned. + Assert.True(socketInput.IsCompleted); + Assert.Equal(2, (await socketInput.PeekAsync()).Count); + + scan = socketInput.ConsumingStart(); + scan.Skip(2); + socketInput.ConsumingComplete(scan, scan); + + // Everything has been consume so socketInput is no longer in the completed state + Assert.False(socketInput.IsCompleted); + } + } + + [Fact] + public async Task CompleteAwaitingDoesNotCauseZeroLengthRead() + { + using (var memory = new MemoryPool()) + using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) + { + var readBuffer = new byte[20]; + + socketInput.IncomingData(new byte[5], 0, 5); + Assert.Equal(5, await socketInput.ReadAsync(readBuffer, 0, 20)); + + var readTask = socketInput.ReadAsync(readBuffer, 0, 20); + socketInput.CompleteAwaiting(); + Assert.False(readTask.IsCompleted); + + socketInput.IncomingData(new byte[5], 0, 5); + Assert.Equal(5, await readTask); + } + } + + [Fact] + public async Task CompleteAwaitingDoesNotCauseZeroLengthPeek() + { + using (var memory = new MemoryPool()) + using (var socketInput = new SocketInput(memory, new SynchronousThreadPool())) + { + socketInput.IncomingData(new byte[5], 0, 5); + Assert.Equal(5, (await socketInput.PeekAsync()).Count); + + var scan = socketInput.ConsumingStart(); + scan.Skip(5); + socketInput.ConsumingComplete(scan, scan); + + var peekTask = socketInput.PeekAsync(); + socketInput.CompleteAwaiting(); + Assert.False(peekTask.IsCompleted); + + socketInput.IncomingData(new byte[5], 0, 5); + Assert.Equal(5, (await socketInput.PeekAsync()).Count); + } + } + private static void TestConcurrentFaultedTask(Task t) { Assert.True(t.IsFaulted);