diff --git a/KestrelHttpServer.sln b/KestrelHttpServer.sln index 210365479a..2b528d0f1b 100644 --- a/KestrelHttpServer.sln +++ b/KestrelHttpServer.sln @@ -24,6 +24,7 @@ Project("{2150E333-8FDC-42A3-9474-1A3956D46DE8}") = "shared", "shared", "{0EF2AC test\shared\HttpParsingData.cs = test\shared\HttpParsingData.cs test\shared\KestrelTestLoggerFactory.cs = test\shared\KestrelTestLoggerFactory.cs test\shared\LifetimeNotImplemented.cs = test\shared\LifetimeNotImplemented.cs + test\shared\MockConnectionInformation.cs = test\shared\MockConnectionInformation.cs test\shared\MockFrameControl.cs = test\shared\MockFrameControl.cs test\shared\MockLogger.cs = test\shared\MockLogger.cs test\shared\MockSystemClock.cs = test\shared\MockSystemClock.cs diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs index 5ee289501e..0387c0ce8e 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/Frame.cs @@ -98,8 +98,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _requestHeadersTimeoutTicks = ServerOptions.Limits.RequestHeadersTimeout.Ticks; Output = new OutputProducer(frameContext.Output, frameContext.ConnectionId, frameContext.ServiceContext.Log); + RequestBodyPipe = CreateRequestBodyPipe(); } + public IPipe RequestBodyPipe { get; } + public ServiceContext ServiceContext => _frameContext.ServiceContext; public IConnectionInformation ConnectionInformation => _frameContext.ConnectionInformation; @@ -1365,6 +1368,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } + private IPipe CreateRequestBodyPipe() + => ConnectionInformation.PipeFactory.Create(new PipeOptions + { + ReaderScheduler = ServiceContext.ThreadPool, + WriterScheduler = ConnectionInformation.InputWriterScheduler, + MaximumSizeHigh = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0, + MaximumSizeLow = ServiceContext.ServerOptions.Limits.MaxRequestBufferSize ?? 0 + }); + private enum HttpRequestTarget { Unknown = -1, diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs index c8ea98003f..4fdda6d9f1 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/FrameOfT.cs @@ -93,6 +93,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http InitializeStreams(messageBody); + var messageBodyTask = messageBody.StartAsync(); + var context = _application.CreateContext(this); try { @@ -156,12 +158,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http await ProduceEnd(); } - if (_keepAlive) + if (!_keepAlive) { - // Finish reading the request body in case the app did not. - await messageBody.Consume(); + messageBody.Cancel(); } + // An upgraded request has no defined request body length. + // Cancel any pending read so the read loop ends. + if (_upgrade) + { + Input.CancelPendingRead(); + } + + // Finish reading the request body in case the app did not. + await messageBody.ConsumeAsync(); + if (!HasResponseStarted) { await ProduceEnd(); @@ -189,6 +200,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http // to ensure InitializeStreams has been called. StopStreams(); } + + // At this point both the request body pipe reader and writer should be completed. + await messageBodyTask; + + // ForZeroContentLength does not complete the reader nor the writer + if (_keepAlive && !messageBody.IsEmpty) + { + RequestBodyPipe.Reset(); + } } if (!_keepAlive) @@ -225,6 +245,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http try { Input.Complete(); + // If _requestAborted is set, the connection has already been closed. if (Volatile.Read(ref _requestAborted) == 0) { diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs index ca26fe9095..6ab87c8096 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/MessageBody.cs @@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http private readonly Frame _context; private bool _send100Continue = true; + private volatile bool _canceled; protected MessageBody(Frame context) { @@ -30,179 +31,178 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http public bool RequestUpgrade { get; protected set; } - public Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) + public virtual bool IsEmpty => false; + + public virtual async Task StartAsync() { - var task = PeekAsync(cancellationToken); + Exception error = null; - if (!task.IsCompleted) - { - TryProduceContinue(); - - // Incomplete Task await result - return ReadAsyncAwaited(task, buffer); - } - else - { - var readSegment = task.Result; - var consumed = CopyReadSegment(readSegment, buffer); - - return consumed == 0 ? TaskCache.DefaultCompletedTask : Task.FromResult(consumed); - } - } - - private async Task ReadAsyncAwaited(ValueTask> currentTask, ArraySegment buffer) - { - return CopyReadSegment(await currentTask, buffer); - } - - private int CopyReadSegment(ArraySegment readSegment, ArraySegment buffer) - { - var consumed = Math.Min(readSegment.Count, buffer.Count); - - if (consumed != 0) - { - Buffer.BlockCopy(readSegment.Array, readSegment.Offset, buffer.Array, buffer.Offset, consumed); - ConsumedBytes(consumed); - } - - return consumed; - } - - public Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) - { - var peekTask = PeekAsync(cancellationToken); - - while (peekTask.IsCompleted) - { - // ValueTask uses .GetAwaiter().GetResult() if necessary - var segment = peekTask.Result; - - if (segment.Count == 0) - { - return TaskCache.CompletedTask; - } - - Task destinationTask; - try - { - destinationTask = destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - } - catch - { - ConsumedBytes(segment.Count); - throw; - } - - if (!destinationTask.IsCompleted) - { - return CopyToAsyncDestinationAwaited(destinationTask, segment.Count, destination, cancellationToken); - } - - ConsumedBytes(segment.Count); - - // Surface errors if necessary - destinationTask.GetAwaiter().GetResult(); - - peekTask = PeekAsync(cancellationToken); - } - - TryProduceContinue(); - - return CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken); - } - - private async Task CopyToAsyncPeekAwaited( - ValueTask> peekTask, - Stream destination, - CancellationToken cancellationToken = default(CancellationToken)) - { - while (true) - { - var segment = await peekTask; - - if (segment.Count == 0) - { - return; - } - - try - { - await destination.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - } - finally - { - ConsumedBytes(segment.Count); - } - - peekTask = PeekAsync(cancellationToken); - } - } - - private async Task CopyToAsyncDestinationAwaited( - Task destinationTask, - int bytesConsumed, - Stream destination, - CancellationToken cancellationToken = default(CancellationToken)) - { try { - await destinationTask; + while (true) + { + var awaitable = _context.Input.ReadAsync(); + + if (!awaitable.IsCompleted) + { + TryProduceContinue(); + } + + var result = await awaitable; + var readableBuffer = result.Buffer; + var consumed = readableBuffer.Start; + var examined = readableBuffer.End; + + try + { + if (_canceled) + { + break; + } + + if (!readableBuffer.IsEmpty) + { + var writableBuffer = _context.RequestBodyPipe.Writer.Alloc(1); + bool done; + + try + { + done = Read(readableBuffer, writableBuffer, out consumed, out examined); + } + finally + { + writableBuffer.Commit(); + } + + await writableBuffer.FlushAsync(); + + if (done) + { + break; + } + } + else if (result.IsCompleted) + { + _context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent); + } + } + finally + { + _context.Input.Advance(consumed, examined); + } + } + } + catch (Exception ex) + { + error = ex; } finally { - ConsumedBytes(bytesConsumed); + _context.RequestBodyPipe.Writer.Complete(error); } - - var peekTask = PeekAsync(cancellationToken); - - if (!peekTask.IsCompleted) - { - TryProduceContinue(); - } - - await CopyToAsyncPeekAwaited(peekTask, destination, cancellationToken); } - public Task Consume(CancellationToken cancellationToken = default(CancellationToken)) + public void Cancel() + { + _canceled = true; + } + + public virtual async Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) { while (true) { - var task = PeekAsync(cancellationToken); - if (!task.IsCompleted) - { - TryProduceContinue(); + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; - // Incomplete Task await result - return ConsumeAwaited(task, cancellationToken); - } - else + try { - // ValueTask uses .GetAwaiter().GetResult() if necessary - if (task.Result.Count == 0) + if (!readableBuffer.IsEmpty) { - // Completed Task, end of stream - return TaskCache.CompletedTask; + var actual = Math.Min(readableBuffer.Length, buffer.Count); + var slice = readableBuffer.Slice(0, actual); + consumed = readableBuffer.Move(readableBuffer.Start, actual); + slice.CopyTo(buffer); + return actual; } - - ConsumedBytes(task.Result.Count); + else if (result.IsCompleted) + { + return 0; + } + } + finally + { + _context.RequestBodyPipe.Reader.Advance(consumed); } } } - private async Task ConsumeAwaited(ValueTask> currentTask, CancellationToken cancellationToken) + public virtual async Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) { while (true) { - var count = (await currentTask).Count; + var result = await _context.RequestBodyPipe.Reader.ReadAsync(); + var readableBuffer = result.Buffer; + var consumed = readableBuffer.End; - if (count == 0) + try { - // Completed Task, end of stream - return; + if (!readableBuffer.IsEmpty) + { + foreach (var memory in readableBuffer) + { + var array = memory.GetArray(); + await destination.WriteAsync(array.Array, array.Offset, array.Count, cancellationToken); + } + } + else if (result.IsCompleted) + { + return; + } } + finally + { + _context.RequestBodyPipe.Reader.Advance(consumed); + } + } + } - ConsumedBytes(count); - currentTask = PeekAsync(cancellationToken); + public virtual async Task ConsumeAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + Exception error = null; + + try + { + ReadResult result; + do + { + result = await _context.RequestBodyPipe.Reader.ReadAsync(); + _context.RequestBodyPipe.Reader.Advance(result.Buffer.End); + } while (!result.IsCompleted); + } + catch (Exception ex) + { + error = ex; + throw; + } + finally + { + _context.RequestBodyPipe.Reader.Complete(error); + } + } + + protected void Copy(ReadableBuffer readableBuffer, WritableBuffer writableBuffer) + { + if (readableBuffer.IsSingleSpan) + { + writableBuffer.Write(readableBuffer.First.Span); + } + else + { + foreach (var memory in readableBuffer) + { + writableBuffer.Write(memory.Span); + } } } @@ -215,20 +215,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } } - private void ConsumedBytes(int count) - { - var scan = _context.Input.ReadAsync().GetResult().Buffer; - var consumed = scan.Move(scan.Start, count); - _context.Input.Advance(consumed, consumed); - - OnConsumedBytes(count); - } - - protected abstract ValueTask> PeekAsync(CancellationToken cancellationToken); - - protected virtual void OnConsumedBytes(int count) - { - } + protected abstract bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined); public static MessageBody For( HttpVersion httpVersion, @@ -316,9 +303,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http RequestUpgrade = true; } - protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + protected override bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) { - return _context.Input.PeekAsync(); + Copy(readableBuffer, writableBuffer); + consumed = readableBuffer.End; + examined = readableBuffer.End; + return false; } } @@ -330,17 +320,31 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http RequestKeepAlive = keepAlive; } - protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + public override bool IsEmpty => true; + + public override Task StartAsync() { - return new ValueTask>(); + return Task.CompletedTask; } - protected override void OnConsumedBytes(int count) + public override Task ReadAsync(ArraySegment buffer, CancellationToken cancellationToken = default(CancellationToken)) { - if (count > 0) - { - throw new InvalidDataException("Consuming non-existent data"); - } + return Task.FromResult(0); + } + + public override Task CopyToAsync(Stream destination, CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.CompletedTask; + } + + public override Task ConsumeAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + return Task.CompletedTask; + } + + protected override bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) + { + throw new NotImplementedException(); } } @@ -357,65 +361,22 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _inputLength = _contentLength; } - protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + protected override bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) { - var limit = (int)Math.Min(_inputLength, int.MaxValue); - if (limit == 0) + if (_inputLength == 0) { - return new ValueTask>(); + throw new InvalidOperationException("Attempted to read from completed Content-Length request body."); } - var task = _context.Input.PeekAsync(); + var actual = (int)Math.Min(readableBuffer.Length, _inputLength); + _inputLength -= actual; - if (task.IsCompleted) - { - // .GetAwaiter().GetResult() done by ValueTask if needed - var actual = Math.Min(task.Result.Count, limit); + consumed = readableBuffer.Move(readableBuffer.Start, actual); + examined = consumed; - if (task.Result.Count == 0) - { - _context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent); - } + Copy(readableBuffer.Slice(0, actual), writableBuffer); - if (task.Result.Count < _inputLength) - { - return task; - } - else - { - var result = task.Result; - var part = new ArraySegment(result.Array, result.Offset, (int)_inputLength); - return new ValueTask>(part); - } - } - else - { - return new ValueTask>(PeekAsyncAwaited(task)); - } - } - - private async Task> PeekAsyncAwaited(ValueTask> task) - { - var segment = await task; - - if (segment.Count == 0) - { - _context.RejectRequest(RequestRejectionReason.UnexpectedEndOfRequestContent); - } - - if (segment.Count <= _inputLength) - { - return segment; - } - else - { - return new ArraySegment(segment.Array, segment.Offset, (int)_inputLength); - } - } - - protected override void OnConsumedBytes(int count) - { - _inputLength -= count; + return _inputLength == 0; } } @@ -441,188 +402,84 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http _requestHeaders = headers; } - protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + protected override bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) { - return new ValueTask>(PeekStateMachineAsync()); - } + consumed = default(ReadCursor); + examined = default(ReadCursor); - protected override void OnConsumedBytes(int count) - { - _inputLength -= count; - } - - private async Task> PeekStateMachineAsync() - { while (_mode < Mode.Trailer) { - while (_mode == Mode.Prefix) + if (_mode == Mode.Prefix) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - var consumed = default(ReadCursor); - var examined = default(ReadCursor); + ParseChunkedPrefix(readableBuffer, out consumed, out examined); - try + if (_mode == Mode.Prefix) { - ParseChunkedPrefix(buffer, out consumed, out examined); - } - finally - { - _input.Advance(consumed, examined); - } - - if (_mode != Mode.Prefix) - { - break; - } - else if (result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); + return false; } + readableBuffer = readableBuffer.Slice(consumed); } - while (_mode == Mode.Extension) + if (_mode == Mode.Extension) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - var consumed = default(ReadCursor); - var examined = default(ReadCursor); + ParseExtension(readableBuffer, out consumed, out examined); - try + if (_mode == Mode.Extension) { - ParseExtension(buffer, out consumed, out examined); - } - finally - { - _input.Advance(consumed, examined); - } - - if (_mode != Mode.Extension) - { - break; - } - else if (result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); + return false; } + readableBuffer = readableBuffer.Slice(consumed); } - while (_mode == Mode.Data) + if (_mode == Mode.Data) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - ArraySegment segment; - try + ReadChunkedData(readableBuffer, writableBuffer, out consumed, out examined); + + if (_mode == Mode.Data) { - segment = PeekChunkedData(buffer); - } - finally - { - _input.Advance(buffer.Start, buffer.Start); + return false; } - if (segment.Count != 0) - { - return segment; - } - else if (_mode != Mode.Data) - { - break; - } - else if (result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); - } + readableBuffer = readableBuffer.Slice(consumed); } - while (_mode == Mode.Suffix) + if (_mode == Mode.Suffix) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - var consumed = default(ReadCursor); - var examined = default(ReadCursor); + ParseChunkedSuffix(readableBuffer, out consumed, out examined); - try + if (_mode == Mode.Suffix) { - ParseChunkedSuffix(buffer, out consumed, out examined); - } - finally - { - _input.Advance(consumed, examined); + return false; } - if (_mode != Mode.Suffix) - { - break; - } - else if (result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); - } + readableBuffer = readableBuffer.Slice(consumed); } } // Chunks finished, parse trailers - while (_mode == Mode.Trailer) + if (_mode == Mode.Trailer) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - var consumed = default(ReadCursor); - var examined = default(ReadCursor); + ParseChunkedTrailer(readableBuffer, out consumed, out examined); - try + if (_mode == Mode.Trailer) { - ParseChunkedTrailer(buffer, out consumed, out examined); - } - finally - { - _input.Advance(consumed, examined); - } - - if (_mode != Mode.Trailer) - { - break; - } - else if (result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); + return false; } + readableBuffer = readableBuffer.Slice(consumed); } if (_mode == Mode.TrailerHeaders) { - while (true) + if (_context.TakeMessageHeaders(readableBuffer, out consumed, out examined)) { - var result = await _input.ReadAsync(); - var buffer = result.Buffer; - - if (buffer.IsEmpty && result.IsCompleted) - { - _context.RejectRequest(RequestRejectionReason.ChunkedRequestIncomplete); - } - - var consumed = default(ReadCursor); - var examined = default(ReadCursor); - - try - { - if (_context.TakeMessageHeaders(buffer, out consumed, out examined)) - { - break; - } - } - finally - { - _input.Advance(consumed, examined); - } + _mode = Mode.Complete; } - _mode = Mode.Complete; } - return default(ArraySegment); + return _mode == Mode.Complete; } private void ParseChunkedPrefix(ReadableBuffer buffer, out ReadCursor consumed, out ReadCursor examined) @@ -728,24 +585,19 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http } while (_mode == Mode.Extension); } - private ArraySegment PeekChunkedData(ReadableBuffer buffer) + private void ReadChunkedData(ReadableBuffer buffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) { + var actual = Math.Min(buffer.Length, _inputLength); + consumed = buffer.Move(buffer.Start, actual); + examined = consumed; + + Copy(buffer.Slice(0, actual), writableBuffer); + + _inputLength -= actual; + if (_inputLength == 0) { _mode = Mode.Suffix; - return default(ArraySegment); - } - var segment = buffer.First.GetArray(); - - int actual = Math.Min(segment.Count, _inputLength); - // Nothing is consumed yet. ConsumedBytes(int) will move the iterator. - if (actual == segment.Count) - { - return segment; - } - else - { - return new ArraySegment(segment.Array, segment.Offset, actual); } } @@ -760,12 +612,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http return; } - var sufixBuffer = buffer.Slice(0, 2); - var sufixSpan = sufixBuffer.ToSpan(); - if (sufixSpan[0] == '\r' && sufixSpan[1] == '\n') + var suffixBuffer = buffer.Slice(0, 2); + var suffixSpan = suffixBuffer.ToSpan(); + if (suffixSpan[0] == '\r' && suffixSpan[1] == '\n') { - consumed = sufixBuffer.End; - examined = sufixBuffer.End; + consumed = suffixBuffer.End; + examined = suffixBuffer.End; _mode = Mode.Prefix; } else diff --git a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/PipelineExtensions.cs b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/PipelineExtensions.cs index 1fe8107209..cdc6579650 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/PipelineExtensions.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel.Core/Internal/Http/PipelineExtensions.cs @@ -18,67 +18,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http [ThreadStatic] private static byte[] _numericBytesScratch; - public static ValueTask> PeekAsync(this IPipeReader pipelineReader) - { - var input = pipelineReader.ReadAsync(); - while (input.IsCompleted) - { - var result = input.GetResult(); - try - { - if (!result.Buffer.IsEmpty) - { - var segment = result.Buffer.First; - var data = segment.GetArray(); - - return new ValueTask>(data); - } - else if (result.IsCompleted) - { - return default(ValueTask>); - } - } - finally - { - pipelineReader.Advance(result.Buffer.Start, result.Buffer.IsEmpty - ? result.Buffer.End - : result.Buffer.Start); - } - input = pipelineReader.ReadAsync(); - } - - return new ValueTask>(pipelineReader.PeekAsyncAwaited(input)); - } - - private static async Task> PeekAsyncAwaited(this IPipeReader pipelineReader, ReadableBufferAwaitable readingTask) - { - while (true) - { - var result = await readingTask; - - try - { - if (!result.Buffer.IsEmpty) - { - var segment = result.Buffer.First; - return segment.GetArray(); - } - else if (result.IsCompleted) - { - return default(ArraySegment); - } - } - finally - { - pipelineReader.Advance(result.Buffer.Start, result.Buffer.IsEmpty - ? result.Buffer.End - : result.Buffer.Start); - } - - readingTask = pipelineReader.ReadAsync(); - } - } - [MethodImpl(MethodImplOptions.AggressiveInlining)] public static Span ToSpan(this ReadableBuffer buffer) { diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameResponseHeadersTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameResponseHeadersTests.cs index b7cf438e92..fabae69f65 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameResponseHeadersTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameResponseHeadersTests.cs @@ -7,6 +7,7 @@ using System.Globalization; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Internal.System.IO.Pipelines; using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Primitives; @@ -23,7 +24,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var frameContext = new FrameContext { ServiceContext = new TestServiceContext(), - ConnectionInformation = Mock.Of(), + ConnectionInformation = new MockConnectionInformation + { + PipeFactory = new PipeFactory() + }, TimeoutControl = null }; diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs index 1923a56d9d..ca5eb53008 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/FrameTests.cs @@ -62,7 +62,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests _frameContext = new FrameContext { ServiceContext = _serviceContext, - ConnectionInformation = Mock.Of(), + ConnectionInformation = new MockConnectionInformation + { + PipeFactory = _pipelineFactory + }, TimeoutControl = _timeoutControl.Object, Input = _input.Reader, Output = output diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs index 6b69166ab6..bf61a5df51 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/MessageBodyTests.cs @@ -9,10 +9,14 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; +using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; using Moq; using Xunit; +using Xunit.Abstractions; using Xunit.Sdk; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests @@ -22,7 +26,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void CanReadFromContentLength(HttpVersion httpVersion) + public async Task CanReadFromContentLength(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -30,6 +34,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; @@ -40,6 +46,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await bodyTask; } } @@ -54,6 +62,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; @@ -64,11 +74,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await bodyTask; } } [Fact] - public void CanReadFromChunkedEncoding() + public async Task CanReadFromChunkedEncoding() { using (var input = new TestInput()) { @@ -76,6 +88,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("5\r\nHello\r\n"); var buffer = new byte[1024]; @@ -88,6 +102,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await bodyTask; } } @@ -100,6 +116,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("5\r\nHello\r\n"); var buffer = new byte[1024]; @@ -112,13 +130,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(0, count); + + await bodyTask; } } [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void CanReadFromRemainingData(HttpVersion httpVersion) + public async Task CanReadFromRemainingData(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -126,6 +146,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; @@ -133,6 +155,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var count = stream.Read(buffer, 0, buffer.Length); Assert.Equal(5, count); AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Fin(); + + await bodyTask; } } @@ -147,6 +173,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; @@ -154,13 +182,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var count = await stream.ReadAsync(buffer, 0, buffer.Length); Assert.Equal(5, count); AssertASCII("Hello", new ArraySegment(buffer, 0, count)); + + input.Fin(); + + await bodyTask; } } [Theory] [InlineData(HttpVersion.Http10)] [InlineData(HttpVersion.Http11)] - public void ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) + public async Task ReadFromNoContentLengthReturnsZero(HttpVersion httpVersion) { using (var input = new TestInput()) { @@ -168,10 +200,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; Assert.Equal(0, stream.Read(buffer, 0, buffer.Length)); + + await bodyTask; } } @@ -186,10 +222,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer, 0, buffer.Length)); + + await bodyTask; } } @@ -202,6 +242,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + // Input needs to be greater than 4032 bytes to allocate a block not backed by a slab. var largeInput = new string('a', 8192); @@ -217,8 +259,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(8197, requestArray.Length); AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); - var count = await stream.ReadAsync(new byte[1], 0, 1); - Assert.Equal(0, count); + await bodyTask; } } @@ -267,6 +308,45 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } + [Fact] + public async Task CopyToAsyncDoesNotCompletePipeReader() + { + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders { HeaderContentLength = "5" }, input.FrameContext); + var bodyTask = body.StartAsync(); + + input.Add("Hello"); + + using (var ms = new MemoryStream()) + { + await body.CopyToAsync(ms); + } + + Assert.Equal(0, await body.ReadAsync(new ArraySegment(new byte[1]))); + + await bodyTask; + } + } + + [Fact] + public async Task ConsumeAsyncCompletesPipeReader() + { + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders { HeaderContentLength = "5" }, input.FrameContext); + var bodyTask = body.StartAsync(); + + input.Add("Hello"); + + await body.ConsumeAsync(); + + await Assert.ThrowsAsync(async () => await body.ReadAsync(new ArraySegment(new byte[1]))); + + await bodyTask; + } + } + public static IEnumerable StreamData => new[] { new object[] { new ThrowOnWriteSynchronousStream() }, @@ -306,14 +386,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); + var bodyTask = body.StartAsync(); var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); // The block returned by IncomingStart always has at least 2048 available bytes, // so no need to bounds check in this test. - var socketInput = input.Pipe; var bytes = Encoding.ASCII.GetBytes(data[0]); - var buffer = socketInput.Writer.Alloc(2048); + var buffer = input.Pipe.Writer.Alloc(2048); ArraySegment block; Assert.True(buffer.Buffer.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); @@ -325,7 +405,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests writeTcs = new TaskCompletionSource(); bytes = Encoding.ASCII.GetBytes(data[1]); - buffer = socketInput.Writer.Alloc(2048); + buffer = input.Pipe.Writer.Alloc(2048); Assert.True(buffer.Buffer.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); buffer.Advance(bytes.Length); @@ -335,37 +415,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests if (headers.HeaderConnection == "close") { - socketInput.Writer.Complete(); + input.Pipe.Writer.Complete(); } await copyToAsyncTask; Assert.Equal(2, writeCount); - } - } - [Theory] - [MemberData(nameof(CombinedData))] - public async Task CopyToAsyncAdvancesRequestStreamWhenDestinationWriteAsyncThrows(Stream writeStream, FrameRequestHeaders headers, string[] data) - { - using (var input = new TestInput()) - { - var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); - - input.Add(data[0]); - - await Assert.ThrowsAsync(() => body.CopyToAsync(writeStream)); - - input.Add(data[1]); - - // "Hello " should have been consumed - var readBuffer = new byte[6]; - var count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); - Assert.Equal(6, count); - AssertASCII("World!", new ArraySegment(readBuffer, 0, 6)); - - count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); - Assert.Equal(0, count); + await bodyTask; } } @@ -374,7 +431,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests [InlineData("Keep-Alive, Upgrade")] [InlineData("upgrade, keep-alive")] [InlineData("Upgrade, Keep-Alive")] - public void ConnectionUpgradeKeepAlive(string headerConnection) + public async Task ConnectionUpgradeKeepAlive(string headerConnection) { using (var input = new TestInput()) { @@ -382,11 +439,74 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); + var bodyTask = body.StartAsync(); + input.Add("Hello"); var buffer = new byte[1024]; Assert.Equal(5, stream.Read(buffer, 0, buffer.Length)); AssertASCII("Hello", new ArraySegment(buffer, 0, 5)); + + input.Fin(); + + await bodyTask; + } + } + + [Fact] + public async Task StartAsyncDoesNotReturnAfterCancelingInput() + { + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderContentLength = "2" }, input.FrameContext); + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(body); + + var bodyTask = body.StartAsync(); + + // Add some input and consume it to ensure StartAsync is in the loop + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + input.Pipe.Reader.CancelPendingRead(); + + // Add more input and verify is read + input.Add("b"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + // All input was read, body task should complete + await bodyTask; + } + } + + [Fact] + public async Task StartAsyncReturnsAfterCanceling() + { + using (var input = new TestInput()) + { + var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderContentLength = "2" }, input.FrameContext); + var stream = new FrameRequestStream(); + stream.StartAcceptingReads(body); + + var bodyTask = body.StartAsync(); + + // Add some input and consume it to ensure StartAsync is in the loop + input.Add("a"); + Assert.Equal(1, await stream.ReadAsync(new byte[1], 0, 1)); + + body.Cancel(); + + // Add some more data. Checking for cancelation and exiting the loop + // should take priority over reading this data. + input.Add("b"); + + // Unblock the loop + input.Pipe.Reader.CancelPendingRead(); + + await bodyTask.TimeoutAfter(TimeSpan.FromSeconds(10)); + + // There shouldn't be any additional data available + Assert.Equal(0, await stream.ReadAsync(new byte[1], 0, 1)); } } @@ -480,4 +600,4 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public override long Position { get; set; } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/OutputProducerTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/OutputProducerTests.cs index 93c90b10df..0298bbff95 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/OutputProducerTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/OutputProducerTests.cs @@ -54,7 +54,15 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests { var pipe = _pipeFactory.Create(pipeOptions); var serviceContext = new TestServiceContext(); - var frame = new Frame(null, new FrameContext { ServiceContext = serviceContext }); + var frameContext = new FrameContext + { + ServiceContext = serviceContext, + ConnectionInformation = new MockConnectionInformation + { + PipeFactory = _pipeFactory + } + }; + var frame = new Frame(null, frameContext); var socketOutput = new OutputProducer(pipe, "0", serviceContext.Log); return socketOutput; diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs index d9f8854a14..95eae94658 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/StreamsTests.cs @@ -6,6 +6,7 @@ using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Internal.System.IO.Pipelines; using Moq; using Xunit; @@ -41,7 +42,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); Assert.Same(ex, - await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); Assert.Same(ex, await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); @@ -64,7 +65,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(CoreStrings.ResponseStreamWasUpgraded, writeEx.Message); Assert.Same(ex, - await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); + await Assert.ThrowsAsync(() => request.ReadAsync(new byte[1], 0, 1))); Assert.Same(ex, await Assert.ThrowsAsync(() => upgrade.ReadAsync(new byte[1], 0, 1))); @@ -80,9 +81,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests RequestUpgrade = upgradeable; } - protected override ValueTask> PeekAsync(CancellationToken cancellationToken) + protected override bool Read(ReadableBuffer readableBuffer, WritableBuffer writableBuffer, out ReadCursor consumed, out ReadCursor examined) { - return new ValueTask>(new ArraySegment(new byte[1])); + consumed = default(ReadCursor); + examined = default(ReadCursor); + return true; } } } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/TestInput.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/TestInput.cs index 1a6ce35dd3..df90641045 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/TestInput.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Core.Tests/TestInput.cs @@ -25,19 +25,24 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests { _memoryPool = new MemoryPool(); _pipelineFactory = new PipeFactory(); - Pipe = _pipelineFactory.Create(); FrameContext = new Frame(null, new FrameContext { ServiceContext = new TestServiceContext(), - Input = Pipe.Reader + Input = Pipe.Reader, + ConnectionInformation = new MockConnectionInformation + { + PipeFactory = _pipelineFactory + } }); FrameContext.FrameControl = this; } public IPipe Pipe { get; } + public PipeFactory PipeFactory => _pipelineFactory; + public Frame FrameContext { get; set; } public void Add(string text) @@ -46,6 +51,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Pipe.Writer.WriteAsync(data).Wait(); } + public void Fin() + { + Pipe.Writer.Complete(); + } + public void ProduceContinue() { } diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ChunkedRequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ChunkedRequestTests.cs index c6e8adbdb2..ce0f9bf2df 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ChunkedRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/ChunkedRequestTests.cs @@ -72,9 +72,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "5", "Hello", "6", " World", "0", - "", - ""); - await connection.ReceiveEnd( + "", + ""); + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: close", $"Date: {testContext.DateHeaderValue}", @@ -103,7 +103,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "5", "Hello", "6", " World", "0", - "", + "", "POST / HTTP/1.0", "Content-Length: 7", "", @@ -115,7 +115,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "Content-Length: 11", "", "Hello World"); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", "Connection: close", $"Date: {testContext.DateHeaderValue}", @@ -185,7 +185,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests [MemberData(nameof(ConnectionAdapterData))] public async Task TrailingHeadersAreParsed(ListenOptions listenOptions) { - var testContext = new TestServiceContext(); var requestCount = 10; var requestsReceived = 0; @@ -196,8 +195,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var buffer = new byte[200]; - Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"])); - while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) { ;// read to end @@ -217,11 +214,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests response.Headers["Content-Length"] = new[] { "11" }; await response.Body.WriteAsync(Encoding.ASCII.GetBytes("Hello World"), 0, 11); - }, testContext, listenOptions)) + }, new TestServiceContext(), listenOptions)) { var response = string.Join("\r\n", new string[] { "HTTP/1.1 200 OK", - $"Date: {testContext.DateHeaderValue}", + $"Date: {server.Context.DateHeaderValue}", "Content-Length: 11", "", "Hello World"}); @@ -372,8 +369,6 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests var buffer = new byte[200]; - Assert.True(string.IsNullOrEmpty(request.Headers["X-Trailer-Header"])); - while (await request.Body.ReadAsync(buffer, 0, buffer.Length) != 0) { ;// read to end diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/DefaultHeaderTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/DefaultHeaderTests.cs index c9c1a94ae6..167019d218 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/DefaultHeaderTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/DefaultHeaderTests.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests "", ""); - await connection.ReceiveEnd( + await connection.ReceiveForcedEnd( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", "Server: Kestrel", diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs index 0717a67a2a..0fb5db7212 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestTests.cs @@ -1270,6 +1270,153 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Fact] + public async Task ServerConsumesKeepAliveContentLengthRequest() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "hello"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // If the server consumed the previous request properly, the + // next request should be successful + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "world"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task ServerConsumesKeepAliveChunkedRequest() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Transfer-Encoding: chunked", + "", + "5", + "hello", + "5", + "world", + "0", + "Trailer: value", + "", + ""); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + // If the server consumed the previous request properly, the + // next request should be successful + await connection.Send( + "POST / HTTP/1.1", + "Host:", + "Content-Length: 5", + "", + "world"); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task NonKeepAliveRequestNotConsumedByAppCompletes() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(context => Task.CompletedTask)) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "POST / HTTP/1.0", + "Host:", + "Content-Length: 5", + "", + "hello"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 200 OK", + "Connection: close", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + } + } + + [Fact] + public async Task UpgradedRequestNotConsumedByAppCompletes() + { + // The app doesn't read the request body, so it should be consumed by the server + using (var server = new TestServer(async context => + { + var upgradeFeature = context.Features.Get(); + var duplexStream = await upgradeFeature.UpgradeAsync(); + + var response = Encoding.ASCII.GetBytes("goodbye"); + await duplexStream.WriteAsync(response, 0, response.Length); + })) + { + using (var connection = server.CreateConnection()) + { + await connection.SendAll( + "GET / HTTP/1.1", + "Host:", + "Connection: upgrade", + "", + "hello"); + + await connection.ReceiveForcedEnd( + "HTTP/1.1 101 Switching Protocols", + "Connection: Upgrade", + $"Date: {server.Context.DateHeaderValue}", + "", + "goodbye"); + } + } + } + private async Task TestRemoteIPAddress(string registerAddress, string requestAddress, string expectAddress) { var builder = new WebHostBuilder() diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs index e24e667b7d..2d565d7168 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests/LibuvOutputConsumerTests.cs @@ -699,6 +699,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Libuv.Tests var frame = new Frame(null, new FrameContext { ServiceContext = serviceContext, + ConnectionInformation = new MockConnectionInformation + { + PipeFactory = _pipeFactory + }, TimeoutControl = Mock.Of(), Output = pipe }); diff --git a/test/shared/MockConnectionInformation.cs b/test/shared/MockConnectionInformation.cs new file mode 100644 index 0000000000..1a68548100 --- /dev/null +++ b/test/shared/MockConnectionInformation.cs @@ -0,0 +1,22 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Net; +using Microsoft.AspNetCore.Server.Kestrel.Internal.System.IO.Pipelines; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions; + +namespace Microsoft.AspNetCore.Testing +{ + public class MockConnectionInformation : IConnectionInformation + { + public IPEndPoint RemoteEndPoint { get; } + + public IPEndPoint LocalEndPoint { get; } + + public PipeFactory PipeFactory { get; set; } + + public IScheduler InputWriterScheduler { get; } + + public IScheduler OutputReaderScheduler { get; } + } +}