Make FrameRequestStream.CopyToAsync(...) copyless

- Add tests for when the CopyToAsync destinationStream throws.
- Add test to verify the destination stream sees the same array written to by
  the producer.
This commit is contained in:
Stephen Halter 2016-09-13 11:03:12 -07:00
parent afa89b3993
commit 63509b9e10
8 changed files with 806 additions and 158 deletions

View File

@ -119,8 +119,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
var task = ValidateState(cancellationToken);
if (task == null)
{
// Needs .AsTask to match Stream's Async method return types
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken).AsTask();
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
}
return task;
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
if (destination == null)
{
throw new ArgumentNullException(nameof(destination));
}
if (bufferSize <= 0)
{
throw new ArgumentException($"{nameof(bufferSize)} must be positive.", nameof(bufferSize));
}
var task = ValidateState(cancellationToken);
if (task == null)
{
return _body.CopyToAsync(destination, cancellationToken);
}
return task;
}

View File

@ -2,9 +2,11 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Numerics;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
@ -12,7 +14,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
public abstract class MessageBody
{
private readonly Frame _context;
private int _send100Continue = 1;
private bool _send100Continue = true;
protected MessageBody(Frame context)
{
@ -21,80 +23,205 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
public bool RequestKeepAlive { get; protected set; }
public ValueTask<int> ReadAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
public Task<int> ReadAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
{
var send100Continue = 0;
var result = ReadAsyncImplementation(buffer, cancellationToken);
if (!result.IsCompleted)
var task = PeekAsync(cancellationToken);
if (!task.IsCompleted)
{
send100Continue = Interlocked.Exchange(ref _send100Continue, 0);
TryProduceContinue();
// Incomplete Task await result
return ReadAsyncAwaited(task, buffer);
}
if (send100Continue == 1)
else
{
_context.FrameControl.ProduceContinue();
var readSegment = task.Result;
var consumed = CopyReadSegment(readSegment, buffer);
return consumed == 0 ? TaskCache<int>.DefaultCompletedTask : Task.FromResult(consumed);
}
return result;
}
private async Task<int> ReadAsyncAwaited(ValueTask<ArraySegment<byte>> currentTask, ArraySegment<byte> buffer)
{
return CopyReadSegment(await currentTask, buffer);
}
private int CopyReadSegment(ArraySegment<byte> readSegment, ArraySegment<byte> buffer)
{
var consumed = Math.Min(readSegment.Count, buffer.Count);
if (consumed != 0)
{
Buffer.BlockCopy(readSegment.Array, readSegment.Offset, buffer.Array, buffer.Offset, consumed);
ConsumedBytes(consumed);
}
return consumed;
}
public Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken))
{
var peekTask = PeekAsync(cancellationToken);
while (peekTask.IsCompleted)
{
// ValueTask uses .GetAwaiter().GetResult() if necessary
var segment = peekTask.Result;
if (segment.Count == 0)
{
return TaskCache.CompletedTask;
}
Task destinationTask;
try
{
destinationTask = destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
}
catch
{
ConsumedBytes(segment.Count);
throw;
}
if (!destinationTask.IsCompleted)
{
return CopyToAsyncDestinationAwaited(destinationTask, segment.Count, destination, cancellationToken);
}
ConsumedBytes(segment.Count);
// Surface errors if necessary
destinationTask.GetAwaiter().GetResult();
peekTask = PeekAsync(cancellationToken);
}
TryProduceContinue();
return CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken);
}
private async Task CopyToAsyncPeekAwaited(
ValueTask<ArraySegment<byte>> peekTask,
Stream destination,
CancellationToken cancellationToken = default(CancellationToken))
{
while (true)
{
var segment = await peekTask;
if (segment.Count == 0)
{
return;
}
try
{
await destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
}
finally
{
ConsumedBytes(segment.Count);
}
peekTask = PeekAsync(cancellationToken);
}
}
private async Task CopyToAsyncDestinationAwaited(
Task destinationTask,
int bytesConsumed,
Stream destination,
CancellationToken cancellationToken = default(CancellationToken))
{
try
{
await destinationTask;
}
finally
{
ConsumedBytes(bytesConsumed);
}
var peekTask = PeekAsync(cancellationToken);
if (!peekTask.IsCompleted)
{
TryProduceContinue();
}
await CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken);
}
public Task Consume(CancellationToken cancellationToken = default(CancellationToken))
{
ValueTask<int> result;
var send100checked = false;
do
while (true)
{
result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken);
if (!result.IsCompleted)
var task = PeekAsync(cancellationToken);
if (!task.IsCompleted)
{
if (!send100checked)
{
if (Interlocked.Exchange(ref _send100Continue, 0) == 1)
{
_context.FrameControl.ProduceContinue();
}
send100checked = true;
}
TryProduceContinue();
// Incomplete Task await result
return ConsumeAwaited(result.AsTask(), cancellationToken);
return ConsumeAwaited(task, cancellationToken);
}
// ValueTask uses .GetAwaiter().GetResult() if necessary
else if (result.Result == 0)
{
// Completed Task, end of stream
return TaskCache.CompletedTask;
}
} while (true);
}
private async Task ConsumeAwaited(Task<int> currentTask, CancellationToken cancellationToken)
{
if (await currentTask == 0)
{
return;
}
ValueTask<int> result;
do
{
result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken);
if (result.IsCompleted)
else
{
// ValueTask uses .GetAwaiter().GetResult() if necessary
if (result.Result == 0)
if (task.Result.Count == 0)
{
// Completed Task, end of stream
return;
}
else
{
// Completed Task, get next Task rather than await
continue;
return TaskCache.CompletedTask;
}
ConsumedBytes(task.Result.Count);
}
} while (await result != 0);
}
}
public abstract ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken);
private async Task ConsumeAwaited(ValueTask<ArraySegment<byte>> currentTask, CancellationToken cancellationToken)
{
while (true)
{
var count = (await currentTask).Count;
if (count == 0)
{
// Completed Task, end of stream
return;
}
ConsumedBytes(count);
currentTask = PeekAsync(cancellationToken);
}
}
private void TryProduceContinue()
{
if (_send100Continue)
{
_context.FrameControl.ProduceContinue();
_send100Continue = false;
}
}
private void ConsumedBytes(int count)
{
var scan = _context.SocketInput.ConsumingStart();
scan.Skip(count);
_context.SocketInput.ConsumingComplete(scan, scan);
OnConsumedBytes(count);
}
protected abstract ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken);
protected virtual void OnConsumedBytes(int count)
{
}
public static MessageBody For(
HttpVersion httpVersion,
@ -145,9 +272,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
{
}
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
{
return _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, buffer.Array == null ? 8192 : buffer.Count);
return _context.SocketInput.PeekAsync();
}
}
@ -164,49 +291,65 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_inputLength = _contentLength;
}
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
{
var input = _context.SocketInput;
var inputLengthLimit = (int)Math.Min(_inputLength, int.MaxValue);
var limit = buffer.Array == null ? inputLengthLimit : Math.Min(buffer.Count, inputLengthLimit);
var limit = (int)Math.Min(_inputLength, int.MaxValue);
if (limit == 0)
{
return new ValueTask<int>(0);
return new ValueTask<ArraySegment<byte>>();
}
var task = _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, limit);
var task = _context.SocketInput.PeekAsync();
if (task.IsCompleted)
{
// .GetAwaiter().GetResult() done by ValueTask if needed
var actual = task.Result;
_inputLength -= actual;
var actual = Math.Min(task.Result.Count, limit);
if (actual == 0)
if (task.Result.Count == 0)
{
_context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent);
}
return new ValueTask<int>(actual);
if (task.Result.Count < _inputLength)
{
return task;
}
else
{
var result = task.Result;
var part = new ArraySegment<byte>(result.Array, result.Offset, (int)_inputLength);
return new ValueTask<ArraySegment<byte>>(part);
}
}
else
{
return new ValueTask<int>(ReadAsyncAwaited(task.AsTask()));
return new ValueTask<ArraySegment<byte>>(PeekAsyncAwaited(task));
}
}
private async Task<int> ReadAsyncAwaited(Task<int> task)
private async Task<ArraySegment<byte>> PeekAsyncAwaited(ValueTask<ArraySegment<byte>> task)
{
var actual = await task;
_inputLength -= actual;
var segment = await task;
if (actual == 0)
if (segment.Count == 0)
{
_context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent);
}
return actual;
if (segment.Count <= _inputLength)
{
return segment;
}
else
{
return new ArraySegment<byte>(segment.Array, segment.Offset, (int)_inputLength);
}
}
protected override void OnConsumedBytes(int count)
{
_inputLength -= count;
}
}
@ -219,31 +362,39 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
// https://github.com/dotnet/corefx/issues/8825
private Vector<byte> _vectorCRs = new Vector<byte>((byte)'\r');
private readonly SocketInput _input;
private readonly FrameRequestHeaders _requestHeaders;
private int _inputLength;
private Mode _mode = Mode.Prefix;
private FrameRequestHeaders _requestHeaders;
public ForChunkedEncoding(bool keepAlive, FrameRequestHeaders headers, Frame context)
: base(context)
{
RequestKeepAlive = keepAlive;
_input = _context.SocketInput;
_requestHeaders = headers;
}
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
{
return new ValueTask<int>(ReadStateMachineAsync(_context.SocketInput, buffer, cancellationToken));
return new ValueTask<ArraySegment<byte>>(PeekStateMachineAsync());
}
private async Task<int> ReadStateMachineAsync(SocketInput input, ArraySegment<byte> buffer, CancellationToken cancellationToken)
protected override void OnConsumedBytes(int count)
{
_inputLength -= count;
}
private async Task<ArraySegment<byte>> PeekStateMachineAsync()
{
while (_mode < Mode.Trailer)
{
while (_mode == Mode.Prefix)
{
var fin = input.CheckFinOrThrow();
var fin = _input.CheckFinOrThrow();
ParseChunkedPrefix(input);
ParseChunkedPrefix();
if (_mode != Mode.Prefix)
{
@ -254,14 +405,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
}
await input;
await _input;
}
while (_mode == Mode.Extension)
{
var fin = input.CheckFinOrThrow();
var fin = _input.CheckFinOrThrow();
ParseExtension(input);
ParseExtension();
if (_mode != Mode.Extension)
{
@ -272,18 +423,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
}
await input;
await _input;
}
while (_mode == Mode.Data)
{
var fin = input.CheckFinOrThrow();
var fin = _input.CheckFinOrThrow();
int actual = ReadChunkedData(input, buffer.Array, buffer.Offset, buffer.Count);
var segment = PeekChunkedData();
if (actual != 0)
if (segment.Count != 0)
{
return actual;
return segment;
}
else if (_mode != Mode.Data)
{
@ -294,14 +445,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
}
await input;
await _input;
}
while (_mode == Mode.Suffix)
{
var fin = input.CheckFinOrThrow();
var fin = _input.CheckFinOrThrow();
ParseChunkedSuffix(input);
ParseChunkedSuffix();
if (_mode != Mode.Suffix)
{
@ -312,16 +463,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
}
await input;
await _input;
}
}
// Chunks finished, parse trailers
while (_mode == Mode.Trailer)
{
var fin = input.CheckFinOrThrow();
var fin = _input.CheckFinOrThrow();
ParseChunkedTrailer(input);
ParseChunkedTrailer();
if (_mode != Mode.Trailer)
{
@ -332,16 +483,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
}
await input;
await _input;
}
if (_mode == Mode.TrailerHeaders)
{
while (!_context.TakeMessageHeaders(input, _requestHeaders))
while (!_context.TakeMessageHeaders(_input, _requestHeaders))
{
if (input.CheckFinOrThrow())
if (_input.CheckFinOrThrow())
{
if (_context.TakeMessageHeaders(input, _requestHeaders))
if (_context.TakeMessageHeaders(_input, _requestHeaders))
{
break;
}
@ -351,18 +502,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
}
}
await input;
await _input;
}
_mode = Mode.Complete;
}
return 0;
return default(ArraySegment<byte>);
}
private void ParseChunkedPrefix(SocketInput input)
private void ParseChunkedPrefix()
{
var scan = input.ConsumingStart();
var scan = _input.ConsumingStart();
var consumed = scan;
try
{
@ -416,13 +567,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
}
finally
{
input.ConsumingComplete(consumed, scan);
_input.ConsumingComplete(consumed, scan);
}
}
private void ParseExtension(SocketInput input)
private void ParseExtension()
{
var scan = input.ConsumingStart();
var scan = _input.ConsumingStart();
var consumed = scan;
try
{
@ -460,36 +611,37 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
}
finally
{
input.ConsumingComplete(consumed, scan);
_input.ConsumingComplete(consumed, scan);
}
}
private int ReadChunkedData(SocketInput input, byte[] buffer, int offset, int count)
private ArraySegment<byte> PeekChunkedData()
{
var scan = input.ConsumingStart();
int actual;
try
{
var limit = buffer == null ? _inputLength : Math.Min(count, _inputLength);
scan = scan.CopyTo(buffer, offset, limit, out actual);
_inputLength -= actual;
}
finally
{
input.ConsumingComplete(scan, scan);
}
if (_inputLength == 0)
{
_mode = Mode.Suffix;
return default(ArraySegment<byte>);
}
return actual;
var scan = _input.ConsumingStart();
var segment = scan.PeekArraySegment();
int actual = Math.Min(segment.Count, _inputLength);
// Nothing is consumed yet. ConsumedBytes(int) will move the iterator.
_input.ConsumingComplete(scan, scan);
if (actual == segment.Count)
{
return segment;
}
else
{
return new ArraySegment<byte>(segment.Array, segment.Offset, actual);
}
}
private void ParseChunkedSuffix(SocketInput input)
private void ParseChunkedSuffix()
{
var scan = input.ConsumingStart();
var scan = _input.ConsumingStart();
var consumed = scan;
try
{
@ -511,13 +663,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
}
finally
{
input.ConsumingComplete(consumed, scan);
_input.ConsumingComplete(consumed, scan);
}
}
private void ParseChunkedTrailer(SocketInput input)
private void ParseChunkedTrailer()
{
var scan = input.ConsumingStart();
var scan = _input.ConsumingStart();
var consumed = scan;
try
{
@ -540,7 +692,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
}
finally
{
input.ConsumingComplete(consumed, scan);
_input.ConsumingComplete(consumed, scan);
}
}

View File

@ -1,7 +1,9 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure;
namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
{
@ -18,14 +20,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
var end = begin.CopyTo(buffer, offset, count, out actual);
input.ConsumingComplete(end, end);
if (actual != 0)
if (actual != 0 || fin)
{
return new ValueTask<int>(actual);
}
else if (fin)
{
return new ValueTask<int>(0);
}
}
return new ValueTask<int>(input.ReadAsyncAwaited(buffer, offset, count));
@ -44,13 +42,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
var end = begin.CopyTo(buffer, offset, count, out actual);
input.ConsumingComplete(end, end);
if (actual != 0)
if (actual != 0 || fin)
{
return actual;
}
else if (fin)
}
}
public static ValueTask<ArraySegment<byte>> PeekAsync(this SocketInput input)
{
while (input.IsCompleted)
{
var fin = input.CheckFinOrThrow();
var begin = input.ConsumingStart();
var segment = begin.PeekArraySegment();
input.ConsumingComplete(begin, begin);
if (segment.Count != 0 || fin)
{
return 0;
return new ValueTask<ArraySegment<byte>>(segment);
}
}
return new ValueTask<ArraySegment<byte>>(input.PeekAsyncAwaited());
}
private static async Task<ArraySegment<byte>> PeekAsyncAwaited(this SocketInput input)
{
while (true)
{
await input;
var fin = input.CheckFinOrThrow();
var begin = input.ConsumingStart();
var segment = begin.PeekArraySegment();
input.ConsumingComplete(begin, begin);
if (segment.Count != 0 || fin)
{
return segment;
}
}
}

View File

@ -232,6 +232,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure
return new ArraySegment<byte>(array, 0, length);
}
public static ArraySegment<byte> PeekArraySegment(this MemoryPoolIterator iter)
{
if (iter.IsDefault || iter.IsEnd)
{
return default(ArraySegment<byte>);
}
if (iter.Index < iter.Block.End)
{
return new ArraySegment<byte>(iter.Block.Array, iter.Index, iter.Block.End - iter.Index);
}
var block = iter.Block.Next;
while (block != null)
{
if (block.Start < block.End)
{
return new ArraySegment<byte>(block.Array, block.Start, block.End - block.Start);
}
block = block.Next;
}
// The following should be unreachable due to the IsEnd check above.
throw new InvalidOperationException("This should be unreachable!");
}
/// <summary>
/// Checks that up to 8 bytes from <paramref name="begin"/> correspond to a known HTTP method.
/// </summary>

View File

@ -5,6 +5,7 @@ using System;
using System.IO;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Http;
using Moq;
using Xunit;
namespace Microsoft.AspNetCore.Server.KestrelTests
@ -126,5 +127,61 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.True(task.IsFaulted);
Assert.Same(error, task.Exception.InnerException);
}
[Fact]
public void StopAcceptingReadsCausesReadToThrowObjectDisposedException()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
stream.StopAcceptingReads();
Assert.Throws<ObjectDisposedException>(() => { stream.ReadAsync(new byte[1], 0, 1); });
}
[Fact]
public void AbortCausesCopyToAsyncToCancel()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
stream.Abort();
var task = stream.CopyToAsync(Mock.Of<Stream>());
Assert.True(task.IsCanceled);
}
[Fact]
public void AbortWithErrorCausesCopyToAsyncToCancel()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
var error = new Exception();
stream.Abort(error);
var task = stream.CopyToAsync(Mock.Of<Stream>());
Assert.True(task.IsFaulted);
Assert.Same(error, task.Exception.InnerException);
}
[Fact]
public void StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
stream.StopAcceptingReads();
Assert.Throws<ObjectDisposedException>(() => { stream.CopyToAsync(Mock.Of<Stream>()); });
}
[Fact]
public void NullDestinationCausesCopyToAsyncToThrowArgumentNullException()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
Assert.Throws<ArgumentNullException>(() => { stream.CopyToAsync(null); });
}
[Fact]
public void ZeroBufferSizeCausesCopyToAsyncToThrowArgumentException()
{
var stream = new FrameRequestStream();
stream.StartAcceptingReads(null);
Assert.Throws<ArgumentException>(() => { stream.CopyToAsync(Mock.Of<Stream>(), 0); });
}
}
}

View File

@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
public void MemorySeek(string raw, string search, char expectResult, int expectIndex)
{
var block = _pool.Lease();
var chars = raw.ToCharArray().Select(c => (byte)c).ToArray();
var chars = raw.ToCharArray().Select(c => (byte) c).ToArray();
Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length);
block.End += chars.Length;
@ -98,20 +98,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
int found = -1;
if (searchFor.Length == 1)
{
var search0 = new Vector<byte>((byte)searchFor[0]);
var search0 = new Vector<byte>((byte) searchFor[0]);
found = begin.Seek(ref search0);
}
else if (searchFor.Length == 2)
{
var search0 = new Vector<byte>((byte)searchFor[0]);
var search1 = new Vector<byte>((byte)searchFor[1]);
var search0 = new Vector<byte>((byte) searchFor[0]);
var search1 = new Vector<byte>((byte) searchFor[1]);
found = begin.Seek(ref search0, ref search1);
}
else if (searchFor.Length == 3)
{
var search0 = new Vector<byte>((byte)searchFor[0]);
var search1 = new Vector<byte>((byte)searchFor[1]);
var search2 = new Vector<byte>((byte)searchFor[2]);
var search0 = new Vector<byte>((byte) searchFor[0]);
var search1 = new Vector<byte>((byte) searchFor[1]);
var search2 = new Vector<byte>((byte) searchFor[2]);
found = begin.Seek(ref search0, ref search1, ref search2);
}
else
@ -176,7 +176,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
head = blocks[0].GetIterator();
for (var i = 0; i < 64; ++i)
{
Assert.True(head.Put((byte)i), $"Fail to put data at {i}.");
Assert.True(head.Put((byte) i), $"Fail to put data at {i}.");
}
// Can't put anything by the end
@ -188,6 +188,112 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public void PeekArraySegment()
{
// Arrange
var block = _pool.Lease();
var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7};
Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length);
block.End += bytes.Length;
var scan = block.GetIterator();
var originalIndex = scan.Index;
// Act
var result = scan.PeekArraySegment();
// Assert
Assert.Equal(new byte[] {0, 1, 2, 3, 4, 5, 6, 7}, result);
Assert.Equal(originalIndex, scan.Index);
_pool.Return(block);
}
[Fact]
public void PeekArraySegmentOnDefaultIteratorReturnsDefaultArraySegment()
{
// Assert.Equals doesn't work since xunit tries to access the underlying array.
Assert.True(default(ArraySegment<byte>).Equals(default(MemoryPoolIterator).PeekArraySegment()));
}
[Fact]
public void PeekArraySegmentAtEndOfDataReturnsDefaultArraySegment()
{
// Arrange
var block = _pool.Lease();
var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7};
Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length);
block.End += bytes.Length;
block.Start = block.End;
var scan = block.GetIterator();
// Act
var result = scan.PeekArraySegment();
// Assert
// Assert.Equals doesn't work since xunit tries to access the underlying array.
Assert.True(default(ArraySegment<byte>).Equals(result));
_pool.Return(block);
}
[Fact]
public void PeekArraySegmentAtBlockBoundary()
{
// Arrange
var firstBlock = _pool.Lease();
var lastBlock = _pool.Lease();
var firstBytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 };
var lastBytes = new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 };
Buffer.BlockCopy(firstBytes, 0, firstBlock.Array, firstBlock.Start, firstBytes.Length);
firstBlock.End += lastBytes.Length;
firstBlock.Next = lastBlock;
Buffer.BlockCopy(lastBytes, 0, lastBlock.Array, lastBlock.Start, lastBytes.Length);
lastBlock.End += lastBytes.Length;
var scan = firstBlock.GetIterator();
var originalIndex = scan.Index;
var originalBlock = scan.Block;
// Act
var result = scan.PeekArraySegment();
// Assert
Assert.Equal(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }, result);
Assert.Equal(originalBlock, scan.Block);
Assert.Equal(originalIndex, scan.Index);
// Act
// Advance past the data in the first block
scan.Skip(8);
result = scan.PeekArraySegment();
// Assert
Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result);
Assert.Equal(originalBlock, scan.Block);
Assert.Equal(originalIndex + 8, scan.Index);
// Act
// Add anther empty block between the first and last block
var middleBlock = _pool.Lease();
firstBlock.Next = middleBlock;
middleBlock.Next = lastBlock;
result = scan.PeekArraySegment();
// Assert
Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result);
Assert.Equal(originalBlock, scan.Block);
Assert.Equal(originalIndex + 8, scan.Index);
_pool.Return(firstBlock);
_pool.Return(middleBlock);
_pool.Return(lastBlock);
}
[Fact]
public void PeekLong()
{

View File

@ -2,10 +2,18 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Http;
using Microsoft.AspNetCore.Server.KestrelTests.TestHelpers;
using Microsoft.Extensions.Internal;
using Moq;
using Xunit;
using Xunit.Sdk;
namespace Microsoft.AspNetCore.Server.KestrelTests
{
@ -68,23 +76,119 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
// Input needs to be greater than 4032 bytes to allocate a block not backed by a slab.
var largeInput = new string('a', 8192);
input.Add(largeInput, true);
input.Add(largeInput);
// Add a smaller block to the end so that SocketInput attempts to return the large
// block to the memory pool.
input.Add("Hello", true);
input.Add("Hello", fin: true);
var readBuffer = new byte[8192];
var ms = new MemoryStream();
var count1 = await stream.ReadAsync(readBuffer, 0, 8192);
Assert.Equal(8192, count1);
AssertASCII(largeInput, new ArraySegment<byte>(readBuffer, 0, 8192));
await stream.CopyToAsync(ms);
var requestArray = ms.ToArray();
Assert.Equal(8197, requestArray.Length);
AssertASCII(largeInput + "Hello", new ArraySegment<byte>(requestArray, 0, requestArray.Length));
var count2 = await stream.ReadAsync(readBuffer, 0, 8192);
Assert.Equal(5, count2);
AssertASCII("Hello", new ArraySegment<byte>(readBuffer, 0, 5));
var count = await stream.ReadAsync(new byte[1], 0, 1);
Assert.Equal(0, count);
}
}
var count3 = await stream.ReadAsync(readBuffer, 0, 8192);
Assert.Equal(0, count3);
public static IEnumerable<object[]> StreamData => new[]
{
new object[] { new ThrowOnWriteSynchronousStream() },
new object[] { new ThrowOnWriteAsynchronousStream() },
};
public static IEnumerable<object[]> RequestData => new[]
{
// Remaining Data
new object[] { new FrameRequestHeaders { HeaderConnection = "close" }, new[] { "Hello ", "World!" } },
// Content-Length
new object[] { new FrameRequestHeaders { HeaderContentLength = "12" }, new[] { "Hello ", "World!" } },
// Chunked
new object[] { new FrameRequestHeaders { HeaderTransferEncoding = "chunked" }, new[] { "6\r\nHello \r\n", "6\r\nWorld!\r\n0\r\n\r\n" } },
};
public static IEnumerable<object[]> CombinedData =>
from stream in StreamData
from request in RequestData
select new[] { stream[0], request[0], request[1] };
[Theory]
[MemberData(nameof(RequestData))]
public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, string[] data)
{
var writeCount = 0;
var writeTcs = new TaskCompletionSource<byte[]>();
var mockDestination = new Mock<Stream>();
mockDestination
.Setup(m => m.WriteAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), CancellationToken.None))
.Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
{
writeTcs.SetResult(buffer);
writeCount++;
})
.Returns(TaskCache.CompletedTask);
using (var input = new TestInput())
{
var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext);
var copyToAsyncTask = body.CopyToAsync(mockDestination.Object);
// The block returned by IncomingStart always has at least 2048 available bytes,
// so no need to bounds check in this test.
var socketInput = input.FrameContext.SocketInput;
var bytes = Encoding.ASCII.GetBytes(data[0]);
var block = socketInput.IncomingStart();
Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length);
socketInput.IncomingComplete(bytes.Length, null);
// Verify the block passed to WriteAsync is the same one incoming data was written into.
Assert.Same(block.Array, await writeTcs.Task);
writeTcs = new TaskCompletionSource<byte[]>();
bytes = Encoding.ASCII.GetBytes(data[1]);
block = socketInput.IncomingStart();
Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length);
socketInput.IncomingComplete(bytes.Length, null);
Assert.Same(block.Array, await writeTcs.Task);
if (headers.HeaderConnection == "close")
{
socketInput.IncomingFin();
}
await copyToAsyncTask;
Assert.Equal(2, writeCount);
}
}
[Theory]
[MemberData(nameof(CombinedData))]
public async Task CopyToAsyncAdvancesRequestStreamWhenDestinationWriteAsyncThrows(Stream writeStream, FrameRequestHeaders headers, string[] data)
{
using (var input = new TestInput())
{
var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext);
input.Add(data[0]);
await Assert.ThrowsAsync<XunitException>(() => body.CopyToAsync(writeStream));
input.Add(data[1], fin: headers.HeaderConnection == "close");
// "Hello " should have been consumed
var readBuffer = new byte[6];
var count = await body.ReadAsync(new ArraySegment<byte>(readBuffer, 0, readBuffer.Length));
Assert.Equal(6, count);
AssertASCII("World!", new ArraySegment<byte>(readBuffer, 0, 6));
count = await body.ReadAsync(new ArraySegment<byte>(readBuffer, 0, readBuffer.Length));
Assert.Equal(0, count);
}
}
@ -98,5 +202,84 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.Equal(bytes[index], actual.Array[actual.Offset + index]);
}
}
private class ThrowOnWriteSynchronousStream : Stream
{
public override void Flush()
{
throw new NotImplementedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}
public override void SetLength(long value)
{
throw new NotImplementedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
throw new XunitException();
}
public override bool CanRead { get; }
public override bool CanSeek { get; }
public override bool CanWrite => true;
public override long Length { get; }
public override long Position { get; set; }
}
private class ThrowOnWriteAsynchronousStream : Stream
{
public override void Flush()
{
throw new NotImplementedException();
}
public override int Read(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotImplementedException();
}
public override void SetLength(long value)
{
throw new NotImplementedException();
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotImplementedException();
}
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await Task.Delay(1);
throw new XunitException();
}
public override bool CanRead { get; }
public override bool CanSeek { get; }
public override bool CanWrite => true;
public override long Length { get; }
public override long Position { get; set; }
}
}
}

View File

@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
new TheoryData<Mock<IBufferSizeControl>>() { new Mock<IBufferSizeControl>(), null };
[Theory]
[MemberData("MockBufferSizeControlData")]
[MemberData(nameof(MockBufferSizeControlData))]
public void IncomingDataCallsBufferSizeControlAdd(Mock<IBufferSizeControl> mockBufferSizeControl)
{
using (var memory = new MemoryPool())
@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
[Theory]
[MemberData("MockBufferSizeControlData")]
[MemberData(nameof(MockBufferSizeControlData))]
public void IncomingCompleteCallsBufferSizeControlAdd(Mock<IBufferSizeControl> mockBufferSizeControl)
{
using (var memory = new MemoryPool())
@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
[Theory]
[MemberData("MockBufferSizeControlData")]
[MemberData(nameof(MockBufferSizeControlData))]
public void ConsumingCompleteCallsBufferSizeControlSubtract(Mock<IBufferSizeControl> mockBufferSizeControl)
{
using (var kestrelEngine = new KestrelEngine(new MockLibuv(), new TestServiceContext()))
@ -154,6 +154,80 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task PeekAsyncRereturnsTheSameData()
{
using (var memory = new MemoryPool())
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
{
socketInput.IncomingData(new byte[5], 0, 5);
Assert.True(socketInput.IsCompleted);
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
// The same 5 bytes will be returned again since it hasn't been consumed.
Assert.True(socketInput.IsCompleted);
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
var scan = socketInput.ConsumingStart();
scan.Skip(3);
socketInput.ConsumingComplete(scan, scan);
// The remaining 2 unconsumed bytes will be returned.
Assert.True(socketInput.IsCompleted);
Assert.Equal(2, (await socketInput.PeekAsync()).Count);
scan = socketInput.ConsumingStart();
scan.Skip(2);
socketInput.ConsumingComplete(scan, scan);
// Everything has been consume so socketInput is no longer in the completed state
Assert.False(socketInput.IsCompleted);
}
}
[Fact]
public async Task CompleteAwaitingDoesNotCauseZeroLengthRead()
{
using (var memory = new MemoryPool())
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
{
var readBuffer = new byte[20];
socketInput.IncomingData(new byte[5], 0, 5);
Assert.Equal(5, await socketInput.ReadAsync(readBuffer, 0, 20));
var readTask = socketInput.ReadAsync(readBuffer, 0, 20);
socketInput.CompleteAwaiting();
Assert.False(readTask.IsCompleted);
socketInput.IncomingData(new byte[5], 0, 5);
Assert.Equal(5, await readTask);
}
}
[Fact]
public async Task CompleteAwaitingDoesNotCauseZeroLengthPeek()
{
using (var memory = new MemoryPool())
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
{
socketInput.IncomingData(new byte[5], 0, 5);
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
var scan = socketInput.ConsumingStart();
scan.Skip(5);
socketInput.ConsumingComplete(scan, scan);
var peekTask = socketInput.PeekAsync();
socketInput.CompleteAwaiting();
Assert.False(peekTask.IsCompleted);
socketInput.IncomingData(new byte[5], 0, 5);
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
}
}
private static void TestConcurrentFaultedTask(Task t)
{
Assert.True(t.IsFaulted);