Make FrameRequestStream.CopyToAsync(...) copyless
- Add tests for when the CopyToAsync destinationStream throws. - Add test to verify the destination stream sees the same array written to by the producer.
This commit is contained in:
parent
afa89b3993
commit
63509b9e10
|
|
@ -119,8 +119,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
var task = ValidateState(cancellationToken);
|
||||
if (task == null)
|
||||
{
|
||||
// Needs .AsTask to match Stream's Async method return types
|
||||
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken).AsTask();
|
||||
return _body.ReadAsync(new ArraySegment<byte>(buffer, offset, count), cancellationToken);
|
||||
}
|
||||
return task;
|
||||
}
|
||||
|
||||
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
|
||||
{
|
||||
if (destination == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(destination));
|
||||
}
|
||||
if (bufferSize <= 0)
|
||||
{
|
||||
throw new ArgumentException($"{nameof(bufferSize)} must be positive.", nameof(bufferSize));
|
||||
}
|
||||
|
||||
var task = ValidateState(cancellationToken);
|
||||
if (task == null)
|
||||
{
|
||||
return _body.CopyToAsync(destination, cancellationToken);
|
||||
}
|
||||
return task;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,11 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.IO;
|
||||
using System.Numerics;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure;
|
||||
using Microsoft.Extensions.Internal;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
||||
|
|
@ -12,7 +14,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
public abstract class MessageBody
|
||||
{
|
||||
private readonly Frame _context;
|
||||
private int _send100Continue = 1;
|
||||
private bool _send100Continue = true;
|
||||
|
||||
protected MessageBody(Frame context)
|
||||
{
|
||||
|
|
@ -21,80 +23,205 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
|
||||
public bool RequestKeepAlive { get; protected set; }
|
||||
|
||||
public ValueTask<int> ReadAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
|
||||
public Task<int> ReadAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
var send100Continue = 0;
|
||||
var result = ReadAsyncImplementation(buffer, cancellationToken);
|
||||
if (!result.IsCompleted)
|
||||
var task = PeekAsync(cancellationToken);
|
||||
|
||||
if (!task.IsCompleted)
|
||||
{
|
||||
send100Continue = Interlocked.Exchange(ref _send100Continue, 0);
|
||||
TryProduceContinue();
|
||||
|
||||
// Incomplete Task await result
|
||||
return ReadAsyncAwaited(task, buffer);
|
||||
}
|
||||
if (send100Continue == 1)
|
||||
else
|
||||
{
|
||||
_context.FrameControl.ProduceContinue();
|
||||
var readSegment = task.Result;
|
||||
var consumed = CopyReadSegment(readSegment, buffer);
|
||||
|
||||
return consumed == 0 ? TaskCache<int>.DefaultCompletedTask : Task.FromResult(consumed);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
private async Task<int> ReadAsyncAwaited(ValueTask<ArraySegment<byte>> currentTask, ArraySegment<byte> buffer)
|
||||
{
|
||||
return CopyReadSegment(await currentTask, buffer);
|
||||
}
|
||||
|
||||
private int CopyReadSegment(ArraySegment<byte> readSegment, ArraySegment<byte> buffer)
|
||||
{
|
||||
var consumed = Math.Min(readSegment.Count, buffer.Count);
|
||||
|
||||
if (consumed != 0)
|
||||
{
|
||||
Buffer.BlockCopy(readSegment.Array, readSegment.Offset, buffer.Array, buffer.Offset, consumed);
|
||||
ConsumedBytes(consumed);
|
||||
}
|
||||
|
||||
return consumed;
|
||||
}
|
||||
|
||||
public Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
var peekTask = PeekAsync(cancellationToken);
|
||||
|
||||
while (peekTask.IsCompleted)
|
||||
{
|
||||
// ValueTask uses .GetAwaiter().GetResult() if necessary
|
||||
var segment = peekTask.Result;
|
||||
|
||||
if (segment.Count == 0)
|
||||
{
|
||||
return TaskCache.CompletedTask;
|
||||
}
|
||||
|
||||
Task destinationTask;
|
||||
try
|
||||
{
|
||||
destinationTask = destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
|
||||
}
|
||||
catch
|
||||
{
|
||||
ConsumedBytes(segment.Count);
|
||||
throw;
|
||||
}
|
||||
|
||||
if (!destinationTask.IsCompleted)
|
||||
{
|
||||
return CopyToAsyncDestinationAwaited(destinationTask, segment.Count, destination, cancellationToken);
|
||||
}
|
||||
|
||||
ConsumedBytes(segment.Count);
|
||||
|
||||
// Surface errors if necessary
|
||||
destinationTask.GetAwaiter().GetResult();
|
||||
|
||||
peekTask = PeekAsync(cancellationToken);
|
||||
}
|
||||
|
||||
TryProduceContinue();
|
||||
|
||||
return CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken);
|
||||
}
|
||||
|
||||
private async Task CopyToAsyncPeekAwaited(
|
||||
ValueTask<ArraySegment<byte>> peekTask,
|
||||
Stream destination,
|
||||
CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var segment = await peekTask;
|
||||
|
||||
if (segment.Count == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
await destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
|
||||
}
|
||||
finally
|
||||
{
|
||||
ConsumedBytes(segment.Count);
|
||||
}
|
||||
|
||||
peekTask = PeekAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
private async Task CopyToAsyncDestinationAwaited(
|
||||
Task destinationTask,
|
||||
int bytesConsumed,
|
||||
Stream destination,
|
||||
CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
try
|
||||
{
|
||||
await destinationTask;
|
||||
}
|
||||
finally
|
||||
{
|
||||
ConsumedBytes(bytesConsumed);
|
||||
}
|
||||
|
||||
var peekTask = PeekAsync(cancellationToken);
|
||||
|
||||
if (!peekTask.IsCompleted)
|
||||
{
|
||||
TryProduceContinue();
|
||||
}
|
||||
|
||||
await CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken);
|
||||
}
|
||||
|
||||
public Task Consume(CancellationToken cancellationToken = default(CancellationToken))
|
||||
{
|
||||
ValueTask<int> result;
|
||||
var send100checked = false;
|
||||
do
|
||||
while (true)
|
||||
{
|
||||
result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken);
|
||||
if (!result.IsCompleted)
|
||||
var task = PeekAsync(cancellationToken);
|
||||
if (!task.IsCompleted)
|
||||
{
|
||||
if (!send100checked)
|
||||
{
|
||||
if (Interlocked.Exchange(ref _send100Continue, 0) == 1)
|
||||
{
|
||||
_context.FrameControl.ProduceContinue();
|
||||
}
|
||||
send100checked = true;
|
||||
}
|
||||
TryProduceContinue();
|
||||
|
||||
// Incomplete Task await result
|
||||
return ConsumeAwaited(result.AsTask(), cancellationToken);
|
||||
return ConsumeAwaited(task, cancellationToken);
|
||||
}
|
||||
// ValueTask uses .GetAwaiter().GetResult() if necessary
|
||||
else if (result.Result == 0)
|
||||
{
|
||||
// Completed Task, end of stream
|
||||
return TaskCache.CompletedTask;
|
||||
}
|
||||
|
||||
} while (true);
|
||||
}
|
||||
|
||||
private async Task ConsumeAwaited(Task<int> currentTask, CancellationToken cancellationToken)
|
||||
{
|
||||
if (await currentTask == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
ValueTask<int> result;
|
||||
do
|
||||
{
|
||||
result = ReadAsyncImplementation(default(ArraySegment<byte>), cancellationToken);
|
||||
if (result.IsCompleted)
|
||||
else
|
||||
{
|
||||
// ValueTask uses .GetAwaiter().GetResult() if necessary
|
||||
if (result.Result == 0)
|
||||
if (task.Result.Count == 0)
|
||||
{
|
||||
// Completed Task, end of stream
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Completed Task, get next Task rather than await
|
||||
continue;
|
||||
return TaskCache.CompletedTask;
|
||||
}
|
||||
|
||||
ConsumedBytes(task.Result.Count);
|
||||
}
|
||||
} while (await result != 0);
|
||||
}
|
||||
}
|
||||
|
||||
public abstract ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken);
|
||||
private async Task ConsumeAwaited(ValueTask<ArraySegment<byte>> currentTask, CancellationToken cancellationToken)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
var count = (await currentTask).Count;
|
||||
|
||||
if (count == 0)
|
||||
{
|
||||
// Completed Task, end of stream
|
||||
return;
|
||||
}
|
||||
|
||||
ConsumedBytes(count);
|
||||
currentTask = PeekAsync(cancellationToken);
|
||||
}
|
||||
}
|
||||
|
||||
private void TryProduceContinue()
|
||||
{
|
||||
if (_send100Continue)
|
||||
{
|
||||
_context.FrameControl.ProduceContinue();
|
||||
_send100Continue = false;
|
||||
}
|
||||
}
|
||||
|
||||
private void ConsumedBytes(int count)
|
||||
{
|
||||
var scan = _context.SocketInput.ConsumingStart();
|
||||
scan.Skip(count);
|
||||
_context.SocketInput.ConsumingComplete(scan, scan);
|
||||
|
||||
OnConsumedBytes(count);
|
||||
}
|
||||
|
||||
protected abstract ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken);
|
||||
|
||||
protected virtual void OnConsumedBytes(int count)
|
||||
{
|
||||
}
|
||||
|
||||
public static MessageBody For(
|
||||
HttpVersion httpVersion,
|
||||
|
|
@ -145,9 +272,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
{
|
||||
}
|
||||
|
||||
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
|
||||
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, buffer.Array == null ? 8192 : buffer.Count);
|
||||
return _context.SocketInput.PeekAsync();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -164,49 +291,65 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_inputLength = _contentLength;
|
||||
}
|
||||
|
||||
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
|
||||
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
var input = _context.SocketInput;
|
||||
|
||||
var inputLengthLimit = (int)Math.Min(_inputLength, int.MaxValue);
|
||||
var limit = buffer.Array == null ? inputLengthLimit : Math.Min(buffer.Count, inputLengthLimit);
|
||||
var limit = (int)Math.Min(_inputLength, int.MaxValue);
|
||||
if (limit == 0)
|
||||
{
|
||||
return new ValueTask<int>(0);
|
||||
return new ValueTask<ArraySegment<byte>>();
|
||||
}
|
||||
|
||||
var task = _context.SocketInput.ReadAsync(buffer.Array, buffer.Offset, limit);
|
||||
var task = _context.SocketInput.PeekAsync();
|
||||
|
||||
if (task.IsCompleted)
|
||||
{
|
||||
// .GetAwaiter().GetResult() done by ValueTask if needed
|
||||
var actual = task.Result;
|
||||
_inputLength -= actual;
|
||||
var actual = Math.Min(task.Result.Count, limit);
|
||||
|
||||
if (actual == 0)
|
||||
if (task.Result.Count == 0)
|
||||
{
|
||||
_context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent);
|
||||
}
|
||||
|
||||
return new ValueTask<int>(actual);
|
||||
if (task.Result.Count < _inputLength)
|
||||
{
|
||||
return task;
|
||||
}
|
||||
else
|
||||
{
|
||||
var result = task.Result;
|
||||
var part = new ArraySegment<byte>(result.Array, result.Offset, (int)_inputLength);
|
||||
return new ValueTask<ArraySegment<byte>>(part);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return new ValueTask<int>(ReadAsyncAwaited(task.AsTask()));
|
||||
return new ValueTask<ArraySegment<byte>>(PeekAsyncAwaited(task));
|
||||
}
|
||||
}
|
||||
|
||||
private async Task<int> ReadAsyncAwaited(Task<int> task)
|
||||
private async Task<ArraySegment<byte>> PeekAsyncAwaited(ValueTask<ArraySegment<byte>> task)
|
||||
{
|
||||
var actual = await task;
|
||||
_inputLength -= actual;
|
||||
var segment = await task;
|
||||
|
||||
if (actual == 0)
|
||||
if (segment.Count == 0)
|
||||
{
|
||||
_context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent);
|
||||
}
|
||||
|
||||
return actual;
|
||||
if (segment.Count <= _inputLength)
|
||||
{
|
||||
return segment;
|
||||
}
|
||||
else
|
||||
{
|
||||
return new ArraySegment<byte>(segment.Array, segment.Offset, (int)_inputLength);
|
||||
}
|
||||
}
|
||||
|
||||
protected override void OnConsumedBytes(int count)
|
||||
{
|
||||
_inputLength -= count;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -219,31 +362,39 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
// https://github.com/dotnet/corefx/issues/8825
|
||||
private Vector<byte> _vectorCRs = new Vector<byte>((byte)'\r');
|
||||
|
||||
private readonly SocketInput _input;
|
||||
private readonly FrameRequestHeaders _requestHeaders;
|
||||
private int _inputLength;
|
||||
|
||||
private Mode _mode = Mode.Prefix;
|
||||
private FrameRequestHeaders _requestHeaders;
|
||||
|
||||
public ForChunkedEncoding(bool keepAlive, FrameRequestHeaders headers, Frame context)
|
||||
: base(context)
|
||||
{
|
||||
RequestKeepAlive = keepAlive;
|
||||
_input = _context.SocketInput;
|
||||
_requestHeaders = headers;
|
||||
}
|
||||
|
||||
public override ValueTask<int> ReadAsyncImplementation(ArraySegment<byte> buffer, CancellationToken cancellationToken)
|
||||
protected override ValueTask<ArraySegment<byte>> PeekAsync(CancellationToken cancellationToken)
|
||||
{
|
||||
return new ValueTask<int>(ReadStateMachineAsync(_context.SocketInput, buffer, cancellationToken));
|
||||
return new ValueTask<ArraySegment<byte>>(PeekStateMachineAsync());
|
||||
}
|
||||
|
||||
private async Task<int> ReadStateMachineAsync(SocketInput input, ArraySegment<byte> buffer, CancellationToken cancellationToken)
|
||||
protected override void OnConsumedBytes(int count)
|
||||
{
|
||||
_inputLength -= count;
|
||||
}
|
||||
|
||||
private async Task<ArraySegment<byte>> PeekStateMachineAsync()
|
||||
{
|
||||
while (_mode < Mode.Trailer)
|
||||
{
|
||||
while (_mode == Mode.Prefix)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
var fin = _input.CheckFinOrThrow();
|
||||
|
||||
ParseChunkedPrefix(input);
|
||||
ParseChunkedPrefix();
|
||||
|
||||
if (_mode != Mode.Prefix)
|
||||
{
|
||||
|
|
@ -254,14 +405,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
|
||||
while (_mode == Mode.Extension)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
var fin = _input.CheckFinOrThrow();
|
||||
|
||||
ParseExtension(input);
|
||||
ParseExtension();
|
||||
|
||||
if (_mode != Mode.Extension)
|
||||
{
|
||||
|
|
@ -272,18 +423,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
|
||||
while (_mode == Mode.Data)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
var fin = _input.CheckFinOrThrow();
|
||||
|
||||
int actual = ReadChunkedData(input, buffer.Array, buffer.Offset, buffer.Count);
|
||||
var segment = PeekChunkedData();
|
||||
|
||||
if (actual != 0)
|
||||
if (segment.Count != 0)
|
||||
{
|
||||
return actual;
|
||||
return segment;
|
||||
}
|
||||
else if (_mode != Mode.Data)
|
||||
{
|
||||
|
|
@ -294,14 +445,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
|
||||
while (_mode == Mode.Suffix)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
var fin = _input.CheckFinOrThrow();
|
||||
|
||||
ParseChunkedSuffix(input);
|
||||
ParseChunkedSuffix();
|
||||
|
||||
if (_mode != Mode.Suffix)
|
||||
{
|
||||
|
|
@ -312,16 +463,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
}
|
||||
|
||||
// Chunks finished, parse trailers
|
||||
while (_mode == Mode.Trailer)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
var fin = _input.CheckFinOrThrow();
|
||||
|
||||
ParseChunkedTrailer(input);
|
||||
ParseChunkedTrailer();
|
||||
|
||||
if (_mode != Mode.Trailer)
|
||||
{
|
||||
|
|
@ -332,16 +483,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
_context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete);
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
|
||||
if (_mode == Mode.TrailerHeaders)
|
||||
{
|
||||
while (!_context.TakeMessageHeaders(input, _requestHeaders))
|
||||
while (!_context.TakeMessageHeaders(_input, _requestHeaders))
|
||||
{
|
||||
if (input.CheckFinOrThrow())
|
||||
if (_input.CheckFinOrThrow())
|
||||
{
|
||||
if (_context.TakeMessageHeaders(input, _requestHeaders))
|
||||
if (_context.TakeMessageHeaders(_input, _requestHeaders))
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
|
@ -351,18 +502,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
}
|
||||
}
|
||||
|
||||
await input;
|
||||
await _input;
|
||||
}
|
||||
|
||||
_mode = Mode.Complete;
|
||||
}
|
||||
|
||||
return 0;
|
||||
return default(ArraySegment<byte>);
|
||||
}
|
||||
|
||||
private void ParseChunkedPrefix(SocketInput input)
|
||||
private void ParseChunkedPrefix()
|
||||
{
|
||||
var scan = input.ConsumingStart();
|
||||
var scan = _input.ConsumingStart();
|
||||
var consumed = scan;
|
||||
try
|
||||
{
|
||||
|
|
@ -416,13 +567,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
}
|
||||
finally
|
||||
{
|
||||
input.ConsumingComplete(consumed, scan);
|
||||
_input.ConsumingComplete(consumed, scan);
|
||||
}
|
||||
}
|
||||
|
||||
private void ParseExtension(SocketInput input)
|
||||
private void ParseExtension()
|
||||
{
|
||||
var scan = input.ConsumingStart();
|
||||
var scan = _input.ConsumingStart();
|
||||
var consumed = scan;
|
||||
try
|
||||
{
|
||||
|
|
@ -460,36 +611,37 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
}
|
||||
finally
|
||||
{
|
||||
input.ConsumingComplete(consumed, scan);
|
||||
_input.ConsumingComplete(consumed, scan);
|
||||
}
|
||||
}
|
||||
|
||||
private int ReadChunkedData(SocketInput input, byte[] buffer, int offset, int count)
|
||||
private ArraySegment<byte> PeekChunkedData()
|
||||
{
|
||||
var scan = input.ConsumingStart();
|
||||
int actual;
|
||||
try
|
||||
{
|
||||
var limit = buffer == null ? _inputLength : Math.Min(count, _inputLength);
|
||||
scan = scan.CopyTo(buffer, offset, limit, out actual);
|
||||
_inputLength -= actual;
|
||||
}
|
||||
finally
|
||||
{
|
||||
input.ConsumingComplete(scan, scan);
|
||||
}
|
||||
|
||||
if (_inputLength == 0)
|
||||
{
|
||||
_mode = Mode.Suffix;
|
||||
return default(ArraySegment<byte>);
|
||||
}
|
||||
|
||||
return actual;
|
||||
var scan = _input.ConsumingStart();
|
||||
var segment = scan.PeekArraySegment();
|
||||
int actual = Math.Min(segment.Count, _inputLength);
|
||||
// Nothing is consumed yet. ConsumedBytes(int) will move the iterator.
|
||||
_input.ConsumingComplete(scan, scan);
|
||||
|
||||
if (actual == segment.Count)
|
||||
{
|
||||
return segment;
|
||||
}
|
||||
else
|
||||
{
|
||||
return new ArraySegment<byte>(segment.Array, segment.Offset, actual);
|
||||
}
|
||||
}
|
||||
|
||||
private void ParseChunkedSuffix(SocketInput input)
|
||||
private void ParseChunkedSuffix()
|
||||
{
|
||||
var scan = input.ConsumingStart();
|
||||
var scan = _input.ConsumingStart();
|
||||
var consumed = scan;
|
||||
try
|
||||
{
|
||||
|
|
@ -511,13 +663,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
}
|
||||
finally
|
||||
{
|
||||
input.ConsumingComplete(consumed, scan);
|
||||
_input.ConsumingComplete(consumed, scan);
|
||||
}
|
||||
}
|
||||
|
||||
private void ParseChunkedTrailer(SocketInput input)
|
||||
private void ParseChunkedTrailer()
|
||||
{
|
||||
var scan = input.ConsumingStart();
|
||||
var scan = _input.ConsumingStart();
|
||||
var consumed = scan;
|
||||
try
|
||||
{
|
||||
|
|
@ -540,7 +692,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
}
|
||||
finally
|
||||
{
|
||||
input.ConsumingComplete(consumed, scan);
|
||||
_input.ConsumingComplete(consumed, scan);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
||||
{
|
||||
|
|
@ -18,14 +20,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
var end = begin.CopyTo(buffer, offset, count, out actual);
|
||||
input.ConsumingComplete(end, end);
|
||||
|
||||
if (actual != 0)
|
||||
if (actual != 0 || fin)
|
||||
{
|
||||
return new ValueTask<int>(actual);
|
||||
}
|
||||
else if (fin)
|
||||
{
|
||||
return new ValueTask<int>(0);
|
||||
}
|
||||
}
|
||||
|
||||
return new ValueTask<int>(input.ReadAsyncAwaited(buffer, offset, count));
|
||||
|
|
@ -44,13 +42,47 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
|
|||
var end = begin.CopyTo(buffer, offset, count, out actual);
|
||||
input.ConsumingComplete(end, end);
|
||||
|
||||
if (actual != 0)
|
||||
if (actual != 0 || fin)
|
||||
{
|
||||
return actual;
|
||||
}
|
||||
else if (fin)
|
||||
}
|
||||
}
|
||||
|
||||
public static ValueTask<ArraySegment<byte>> PeekAsync(this SocketInput input)
|
||||
{
|
||||
while (input.IsCompleted)
|
||||
{
|
||||
var fin = input.CheckFinOrThrow();
|
||||
|
||||
var begin = input.ConsumingStart();
|
||||
var segment = begin.PeekArraySegment();
|
||||
input.ConsumingComplete(begin, begin);
|
||||
|
||||
if (segment.Count != 0 || fin)
|
||||
{
|
||||
return 0;
|
||||
return new ValueTask<ArraySegment<byte>>(segment);
|
||||
}
|
||||
}
|
||||
|
||||
return new ValueTask<ArraySegment<byte>>(input.PeekAsyncAwaited());
|
||||
}
|
||||
|
||||
private static async Task<ArraySegment<byte>> PeekAsyncAwaited(this SocketInput input)
|
||||
{
|
||||
while (true)
|
||||
{
|
||||
await input;
|
||||
|
||||
var fin = input.CheckFinOrThrow();
|
||||
|
||||
var begin = input.ConsumingStart();
|
||||
var segment = begin.PeekArraySegment();
|
||||
input.ConsumingComplete(begin, begin);
|
||||
|
||||
if (segment.Count != 0 || fin)
|
||||
{
|
||||
return segment;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -232,6 +232,32 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure
|
|||
return new ArraySegment<byte>(array, 0, length);
|
||||
}
|
||||
|
||||
public static ArraySegment<byte> PeekArraySegment(this MemoryPoolIterator iter)
|
||||
{
|
||||
if (iter.IsDefault || iter.IsEnd)
|
||||
{
|
||||
return default(ArraySegment<byte>);
|
||||
}
|
||||
|
||||
if (iter.Index < iter.Block.End)
|
||||
{
|
||||
return new ArraySegment<byte>(iter.Block.Array, iter.Index, iter.Block.End - iter.Index);
|
||||
}
|
||||
|
||||
var block = iter.Block.Next;
|
||||
while (block != null)
|
||||
{
|
||||
if (block.Start < block.End)
|
||||
{
|
||||
return new ArraySegment<byte>(block.Array, block.Start, block.End - block.Start);
|
||||
}
|
||||
block = block.Next;
|
||||
}
|
||||
|
||||
// The following should be unreachable due to the IsEnd check above.
|
||||
throw new InvalidOperationException("This should be unreachable!");
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Checks that up to 8 bytes from <paramref name="begin"/> correspond to a known HTTP method.
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ using System;
|
|||
using System.IO;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Internal.Http;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.KestrelTests
|
||||
|
|
@ -126,5 +127,61 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
Assert.True(task.IsFaulted);
|
||||
Assert.Same(error, task.Exception.InnerException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StopAcceptingReadsCausesReadToThrowObjectDisposedException()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
stream.StopAcceptingReads();
|
||||
Assert.Throws<ObjectDisposedException>(() => { stream.ReadAsync(new byte[1], 0, 1); });
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AbortCausesCopyToAsyncToCancel()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
stream.Abort();
|
||||
var task = stream.CopyToAsync(Mock.Of<Stream>());
|
||||
Assert.True(task.IsCanceled);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void AbortWithErrorCausesCopyToAsyncToCancel()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
var error = new Exception();
|
||||
stream.Abort(error);
|
||||
var task = stream.CopyToAsync(Mock.Of<Stream>());
|
||||
Assert.True(task.IsFaulted);
|
||||
Assert.Same(error, task.Exception.InnerException);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void StopAcceptingReadsCausesCopyToAsyncToThrowObjectDisposedException()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
stream.StopAcceptingReads();
|
||||
Assert.Throws<ObjectDisposedException>(() => { stream.CopyToAsync(Mock.Of<Stream>()); });
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void NullDestinationCausesCopyToAsyncToThrowArgumentNullException()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
Assert.Throws<ArgumentNullException>(() => { stream.CopyToAsync(null); });
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void ZeroBufferSizeCausesCopyToAsyncToThrowArgumentException()
|
||||
{
|
||||
var stream = new FrameRequestStream();
|
||||
stream.StartAcceptingReads(null);
|
||||
Assert.Throws<ArgumentException>(() => { stream.CopyToAsync(Mock.Of<Stream>(), 0); });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
public void MemorySeek(string raw, string search, char expectResult, int expectIndex)
|
||||
{
|
||||
var block = _pool.Lease();
|
||||
var chars = raw.ToCharArray().Select(c => (byte)c).ToArray();
|
||||
var chars = raw.ToCharArray().Select(c => (byte) c).ToArray();
|
||||
Buffer.BlockCopy(chars, 0, block.Array, block.Start, chars.Length);
|
||||
block.End += chars.Length;
|
||||
|
||||
|
|
@ -98,20 +98,20 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
int found = -1;
|
||||
if (searchFor.Length == 1)
|
||||
{
|
||||
var search0 = new Vector<byte>((byte)searchFor[0]);
|
||||
var search0 = new Vector<byte>((byte) searchFor[0]);
|
||||
found = begin.Seek(ref search0);
|
||||
}
|
||||
else if (searchFor.Length == 2)
|
||||
{
|
||||
var search0 = new Vector<byte>((byte)searchFor[0]);
|
||||
var search1 = new Vector<byte>((byte)searchFor[1]);
|
||||
var search0 = new Vector<byte>((byte) searchFor[0]);
|
||||
var search1 = new Vector<byte>((byte) searchFor[1]);
|
||||
found = begin.Seek(ref search0, ref search1);
|
||||
}
|
||||
else if (searchFor.Length == 3)
|
||||
{
|
||||
var search0 = new Vector<byte>((byte)searchFor[0]);
|
||||
var search1 = new Vector<byte>((byte)searchFor[1]);
|
||||
var search2 = new Vector<byte>((byte)searchFor[2]);
|
||||
var search0 = new Vector<byte>((byte) searchFor[0]);
|
||||
var search1 = new Vector<byte>((byte) searchFor[1]);
|
||||
var search2 = new Vector<byte>((byte) searchFor[2]);
|
||||
found = begin.Seek(ref search0, ref search1, ref search2);
|
||||
}
|
||||
else
|
||||
|
|
@ -176,7 +176,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
head = blocks[0].GetIterator();
|
||||
for (var i = 0; i < 64; ++i)
|
||||
{
|
||||
Assert.True(head.Put((byte)i), $"Fail to put data at {i}.");
|
||||
Assert.True(head.Put((byte) i), $"Fail to put data at {i}.");
|
||||
}
|
||||
|
||||
// Can't put anything by the end
|
||||
|
|
@ -188,6 +188,112 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PeekArraySegment()
|
||||
{
|
||||
// Arrange
|
||||
var block = _pool.Lease();
|
||||
var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length);
|
||||
block.End += bytes.Length;
|
||||
var scan = block.GetIterator();
|
||||
var originalIndex = scan.Index;
|
||||
|
||||
// Act
|
||||
var result = scan.PeekArraySegment();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(new byte[] {0, 1, 2, 3, 4, 5, 6, 7}, result);
|
||||
Assert.Equal(originalIndex, scan.Index);
|
||||
|
||||
_pool.Return(block);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PeekArraySegmentOnDefaultIteratorReturnsDefaultArraySegment()
|
||||
{
|
||||
// Assert.Equals doesn't work since xunit tries to access the underlying array.
|
||||
Assert.True(default(ArraySegment<byte>).Equals(default(MemoryPoolIterator).PeekArraySegment()));
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PeekArraySegmentAtEndOfDataReturnsDefaultArraySegment()
|
||||
{
|
||||
// Arrange
|
||||
var block = _pool.Lease();
|
||||
var bytes = new byte[] {0, 1, 2, 3, 4, 5, 6, 7};
|
||||
Buffer.BlockCopy(bytes, 0, block.Array, block.Start, bytes.Length);
|
||||
block.End += bytes.Length;
|
||||
block.Start = block.End;
|
||||
|
||||
var scan = block.GetIterator();
|
||||
|
||||
// Act
|
||||
var result = scan.PeekArraySegment();
|
||||
|
||||
// Assert
|
||||
// Assert.Equals doesn't work since xunit tries to access the underlying array.
|
||||
Assert.True(default(ArraySegment<byte>).Equals(result));
|
||||
|
||||
_pool.Return(block);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PeekArraySegmentAtBlockBoundary()
|
||||
{
|
||||
// Arrange
|
||||
var firstBlock = _pool.Lease();
|
||||
var lastBlock = _pool.Lease();
|
||||
|
||||
var firstBytes = new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 };
|
||||
var lastBytes = new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 };
|
||||
|
||||
Buffer.BlockCopy(firstBytes, 0, firstBlock.Array, firstBlock.Start, firstBytes.Length);
|
||||
firstBlock.End += lastBytes.Length;
|
||||
|
||||
firstBlock.Next = lastBlock;
|
||||
Buffer.BlockCopy(lastBytes, 0, lastBlock.Array, lastBlock.Start, lastBytes.Length);
|
||||
lastBlock.End += lastBytes.Length;
|
||||
|
||||
var scan = firstBlock.GetIterator();
|
||||
var originalIndex = scan.Index;
|
||||
var originalBlock = scan.Block;
|
||||
|
||||
// Act
|
||||
var result = scan.PeekArraySegment();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7 }, result);
|
||||
Assert.Equal(originalBlock, scan.Block);
|
||||
Assert.Equal(originalIndex, scan.Index);
|
||||
|
||||
// Act
|
||||
// Advance past the data in the first block
|
||||
scan.Skip(8);
|
||||
result = scan.PeekArraySegment();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result);
|
||||
Assert.Equal(originalBlock, scan.Block);
|
||||
Assert.Equal(originalIndex + 8, scan.Index);
|
||||
|
||||
// Act
|
||||
// Add anther empty block between the first and last block
|
||||
var middleBlock = _pool.Lease();
|
||||
firstBlock.Next = middleBlock;
|
||||
middleBlock.Next = lastBlock;
|
||||
result = scan.PeekArraySegment();
|
||||
|
||||
// Assert
|
||||
Assert.Equal(new byte[] { 8, 9, 10, 11, 12, 13, 14, 15 }, result);
|
||||
Assert.Equal(originalBlock, scan.Block);
|
||||
Assert.Equal(originalIndex + 8, scan.Index);
|
||||
|
||||
_pool.Return(firstBlock);
|
||||
_pool.Return(middleBlock);
|
||||
_pool.Return(lastBlock);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public void PeekLong()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -2,10 +2,18 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Server.Kestrel.Internal.Http;
|
||||
using Microsoft.AspNetCore.Server.KestrelTests.TestHelpers;
|
||||
using Microsoft.Extensions.Internal;
|
||||
using Moq;
|
||||
using Xunit;
|
||||
using Xunit.Sdk;
|
||||
|
||||
namespace Microsoft.AspNetCore.Server.KestrelTests
|
||||
{
|
||||
|
|
@ -68,23 +76,119 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
// Input needs to be greater than 4032 bytes to allocate a block not backed by a slab.
|
||||
var largeInput = new string('a', 8192);
|
||||
|
||||
input.Add(largeInput, true);
|
||||
input.Add(largeInput);
|
||||
// Add a smaller block to the end so that SocketInput attempts to return the large
|
||||
// block to the memory pool.
|
||||
input.Add("Hello", true);
|
||||
input.Add("Hello", fin: true);
|
||||
|
||||
var readBuffer = new byte[8192];
|
||||
var ms = new MemoryStream();
|
||||
|
||||
var count1 = await stream.ReadAsync(readBuffer, 0, 8192);
|
||||
Assert.Equal(8192, count1);
|
||||
AssertASCII(largeInput, new ArraySegment<byte>(readBuffer, 0, 8192));
|
||||
await stream.CopyToAsync(ms);
|
||||
var requestArray = ms.ToArray();
|
||||
Assert.Equal(8197, requestArray.Length);
|
||||
AssertASCII(largeInput + "Hello", new ArraySegment<byte>(requestArray, 0, requestArray.Length));
|
||||
|
||||
var count2 = await stream.ReadAsync(readBuffer, 0, 8192);
|
||||
Assert.Equal(5, count2);
|
||||
AssertASCII("Hello", new ArraySegment<byte>(readBuffer, 0, 5));
|
||||
var count = await stream.ReadAsync(new byte[1], 0, 1);
|
||||
Assert.Equal(0, count);
|
||||
}
|
||||
}
|
||||
|
||||
var count3 = await stream.ReadAsync(readBuffer, 0, 8192);
|
||||
Assert.Equal(0, count3);
|
||||
public static IEnumerable<object[]> StreamData => new[]
|
||||
{
|
||||
new object[] { new ThrowOnWriteSynchronousStream() },
|
||||
new object[] { new ThrowOnWriteAsynchronousStream() },
|
||||
};
|
||||
|
||||
public static IEnumerable<object[]> RequestData => new[]
|
||||
{
|
||||
// Remaining Data
|
||||
new object[] { new FrameRequestHeaders { HeaderConnection = "close" }, new[] { "Hello ", "World!" } },
|
||||
// Content-Length
|
||||
new object[] { new FrameRequestHeaders { HeaderContentLength = "12" }, new[] { "Hello ", "World!" } },
|
||||
// Chunked
|
||||
new object[] { new FrameRequestHeaders { HeaderTransferEncoding = "chunked" }, new[] { "6\r\nHello \r\n", "6\r\nWorld!\r\n0\r\n\r\n" } },
|
||||
};
|
||||
|
||||
public static IEnumerable<object[]> CombinedData =>
|
||||
from stream in StreamData
|
||||
from request in RequestData
|
||||
select new[] { stream[0], request[0], request[1] };
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(RequestData))]
|
||||
public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, string[] data)
|
||||
{
|
||||
var writeCount = 0;
|
||||
var writeTcs = new TaskCompletionSource<byte[]>();
|
||||
var mockDestination = new Mock<Stream>();
|
||||
|
||||
mockDestination
|
||||
.Setup(m => m.WriteAsync(It.IsAny<byte[]>(), It.IsAny<int>(), It.IsAny<int>(), CancellationToken.None))
|
||||
.Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) =>
|
||||
{
|
||||
writeTcs.SetResult(buffer);
|
||||
writeCount++;
|
||||
})
|
||||
.Returns(TaskCache.CompletedTask);
|
||||
|
||||
using (var input = new TestInput())
|
||||
{
|
||||
var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext);
|
||||
|
||||
var copyToAsyncTask = body.CopyToAsync(mockDestination.Object);
|
||||
|
||||
// The block returned by IncomingStart always has at least 2048 available bytes,
|
||||
// so no need to bounds check in this test.
|
||||
var socketInput = input.FrameContext.SocketInput;
|
||||
var bytes = Encoding.ASCII.GetBytes(data[0]);
|
||||
var block = socketInput.IncomingStart();
|
||||
Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length);
|
||||
socketInput.IncomingComplete(bytes.Length, null);
|
||||
|
||||
// Verify the block passed to WriteAsync is the same one incoming data was written into.
|
||||
Assert.Same(block.Array, await writeTcs.Task);
|
||||
|
||||
writeTcs = new TaskCompletionSource<byte[]>();
|
||||
bytes = Encoding.ASCII.GetBytes(data[1]);
|
||||
block = socketInput.IncomingStart();
|
||||
Buffer.BlockCopy(bytes, 0, block.Array, block.End, bytes.Length);
|
||||
socketInput.IncomingComplete(bytes.Length, null);
|
||||
|
||||
Assert.Same(block.Array, await writeTcs.Task);
|
||||
|
||||
if (headers.HeaderConnection == "close")
|
||||
{
|
||||
socketInput.IncomingFin();
|
||||
}
|
||||
|
||||
await copyToAsyncTask;
|
||||
|
||||
Assert.Equal(2, writeCount);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(CombinedData))]
|
||||
public async Task CopyToAsyncAdvancesRequestStreamWhenDestinationWriteAsyncThrows(Stream writeStream, FrameRequestHeaders headers, string[] data)
|
||||
{
|
||||
using (var input = new TestInput())
|
||||
{
|
||||
var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext);
|
||||
|
||||
input.Add(data[0]);
|
||||
|
||||
await Assert.ThrowsAsync<XunitException>(() => body.CopyToAsync(writeStream));
|
||||
|
||||
input.Add(data[1], fin: headers.HeaderConnection == "close");
|
||||
|
||||
// "Hello " should have been consumed
|
||||
var readBuffer = new byte[6];
|
||||
var count = await body.ReadAsync(new ArraySegment<byte>(readBuffer, 0, readBuffer.Length));
|
||||
Assert.Equal(6, count);
|
||||
AssertASCII("World!", new ArraySegment<byte>(readBuffer, 0, 6));
|
||||
|
||||
count = await body.ReadAsync(new ArraySegment<byte>(readBuffer, 0, readBuffer.Length));
|
||||
Assert.Equal(0, count);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -98,5 +202,84 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
Assert.Equal(bytes[index], actual.Array[actual.Offset + index]);
|
||||
}
|
||||
}
|
||||
|
||||
private class ThrowOnWriteSynchronousStream : Stream
|
||||
{
|
||||
public override void Flush()
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
throw new XunitException();
|
||||
}
|
||||
|
||||
public override bool CanRead { get; }
|
||||
public override bool CanSeek { get; }
|
||||
public override bool CanWrite => true;
|
||||
public override long Length { get; }
|
||||
public override long Position { get; set; }
|
||||
}
|
||||
|
||||
private class ThrowOnWriteAsynchronousStream : Stream
|
||||
{
|
||||
public override void Flush()
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override int Read(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override long Seek(long offset, SeekOrigin origin)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override void SetLength(long value)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override void Write(byte[] buffer, int offset, int count)
|
||||
{
|
||||
throw new NotImplementedException();
|
||||
}
|
||||
|
||||
public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
|
||||
{
|
||||
await Task.Delay(1);
|
||||
throw new XunitException();
|
||||
}
|
||||
|
||||
public override bool CanRead { get; }
|
||||
public override bool CanSeek { get; }
|
||||
public override bool CanWrite => true;
|
||||
public override long Length { get; }
|
||||
public override long Position { get; set; }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -19,7 +19,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
new TheoryData<Mock<IBufferSizeControl>>() { new Mock<IBufferSizeControl>(), null };
|
||||
|
||||
[Theory]
|
||||
[MemberData("MockBufferSizeControlData")]
|
||||
[MemberData(nameof(MockBufferSizeControlData))]
|
||||
public void IncomingDataCallsBufferSizeControlAdd(Mock<IBufferSizeControl> mockBufferSizeControl)
|
||||
{
|
||||
using (var memory = new MemoryPool())
|
||||
|
|
@ -31,7 +31,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData("MockBufferSizeControlData")]
|
||||
[MemberData(nameof(MockBufferSizeControlData))]
|
||||
public void IncomingCompleteCallsBufferSizeControlAdd(Mock<IBufferSizeControl> mockBufferSizeControl)
|
||||
{
|
||||
using (var memory = new MemoryPool())
|
||||
|
|
@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData("MockBufferSizeControlData")]
|
||||
[MemberData(nameof(MockBufferSizeControlData))]
|
||||
public void ConsumingCompleteCallsBufferSizeControlSubtract(Mock<IBufferSizeControl> mockBufferSizeControl)
|
||||
{
|
||||
using (var kestrelEngine = new KestrelEngine(new MockLibuv(), new TestServiceContext()))
|
||||
|
|
@ -154,6 +154,80 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task PeekAsyncRereturnsTheSameData()
|
||||
{
|
||||
using (var memory = new MemoryPool())
|
||||
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
|
||||
{
|
||||
socketInput.IncomingData(new byte[5], 0, 5);
|
||||
|
||||
Assert.True(socketInput.IsCompleted);
|
||||
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
|
||||
|
||||
// The same 5 bytes will be returned again since it hasn't been consumed.
|
||||
Assert.True(socketInput.IsCompleted);
|
||||
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
|
||||
|
||||
var scan = socketInput.ConsumingStart();
|
||||
scan.Skip(3);
|
||||
socketInput.ConsumingComplete(scan, scan);
|
||||
|
||||
// The remaining 2 unconsumed bytes will be returned.
|
||||
Assert.True(socketInput.IsCompleted);
|
||||
Assert.Equal(2, (await socketInput.PeekAsync()).Count);
|
||||
|
||||
scan = socketInput.ConsumingStart();
|
||||
scan.Skip(2);
|
||||
socketInput.ConsumingComplete(scan, scan);
|
||||
|
||||
// Everything has been consume so socketInput is no longer in the completed state
|
||||
Assert.False(socketInput.IsCompleted);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteAwaitingDoesNotCauseZeroLengthRead()
|
||||
{
|
||||
using (var memory = new MemoryPool())
|
||||
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
|
||||
{
|
||||
var readBuffer = new byte[20];
|
||||
|
||||
socketInput.IncomingData(new byte[5], 0, 5);
|
||||
Assert.Equal(5, await socketInput.ReadAsync(readBuffer, 0, 20));
|
||||
|
||||
var readTask = socketInput.ReadAsync(readBuffer, 0, 20);
|
||||
socketInput.CompleteAwaiting();
|
||||
Assert.False(readTask.IsCompleted);
|
||||
|
||||
socketInput.IncomingData(new byte[5], 0, 5);
|
||||
Assert.Equal(5, await readTask);
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task CompleteAwaitingDoesNotCauseZeroLengthPeek()
|
||||
{
|
||||
using (var memory = new MemoryPool())
|
||||
using (var socketInput = new SocketInput(memory, new SynchronousThreadPool()))
|
||||
{
|
||||
socketInput.IncomingData(new byte[5], 0, 5);
|
||||
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
|
||||
|
||||
var scan = socketInput.ConsumingStart();
|
||||
scan.Skip(5);
|
||||
socketInput.ConsumingComplete(scan, scan);
|
||||
|
||||
var peekTask = socketInput.PeekAsync();
|
||||
socketInput.CompleteAwaiting();
|
||||
Assert.False(peekTask.IsCompleted);
|
||||
|
||||
socketInput.IncomingData(new byte[5], 0, 5);
|
||||
Assert.Equal(5, (await socketInput.PeekAsync()).Count);
|
||||
}
|
||||
}
|
||||
|
||||
private static void TestConcurrentFaultedTask(Task t)
|
||||
{
|
||||
Assert.True(t.IsFaulted);
|
||||
|
|
|
|||
Loading…
Reference in New Issue