diff --git a/test/Kestrel.Core.Tests/MessageBodyTests.cs b/test/Kestrel.Core.Tests/MessageBodyTests.cs index 8a563474fe..e9647cbee4 100644 --- a/test/Kestrel.Core.Tests/MessageBodyTests.cs +++ b/test/Kestrel.Core.Tests/MessageBodyTests.cs @@ -2,9 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; using System.IO; -using System.Linq; +using System.IO.Pipelines; using System.Runtime.InteropServices; using System.Text; using System.Threading; @@ -13,6 +12,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; +using Microsoft.AspNetCore.Server.Kestrel.Transport.Abstractions.Internal; using Microsoft.AspNetCore.Testing; using Moq; using Xunit; @@ -401,81 +401,77 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests } } - public static IEnumerable StreamData => new[] - { - new object[] { new ThrowOnWriteSynchronousStream() }, - new object[] { new ThrowOnWriteAsynchronousStream() }, - }; - - public static IEnumerable RequestData => new[] - { - // Content-Length - new object[] { new HttpRequestHeaders { HeaderContentLength = "12" }, new[] { "Hello ", "World!" } }, - // Chunked - new object[] { new HttpRequestHeaders { HeaderTransferEncoding = "chunked" }, new[] { "6\r\nHello \r\n", "6\r\nWorld!\r\n0\r\n\r\n" } }, - }; - - public static IEnumerable CombinedData => - from stream in StreamData - from request in RequestData - select new[] { stream[0], request[0], request[1] }; - - [Theory] - [MemberData(nameof(RequestData))] - public async Task CopyToAsyncDoesNotCopyBlocks(HttpRequestHeaders headers, string[] data) + [Fact] + public async Task CopyToAsyncDoesNotCopyBlocks() { var writeCount = 0; - var writeTcs = new TaskCompletionSource(); + var writeTcs = new TaskCompletionSource<(byte[], int, int)>(); var mockDestination = new Mock() { CallBase = true }; mockDestination .Setup(m => m.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), CancellationToken.None)) .Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => { - writeTcs.SetResult(buffer); + writeTcs.SetResult((buffer, offset, count)); writeCount++; }) .Returns(Task.CompletedTask); - using (var input = new TestInput()) + using (var memoryPool = KestrelMemoryPool.Create()) { - var body = Http1MessageBody.For(HttpVersion.Http11, headers, input.Http1Connection); + var options = new PipeOptions(pool: memoryPool, readerScheduler: PipeScheduler.Inline, writerScheduler: PipeScheduler.Inline, useSynchronizationContext: false); + var pair = DuplexPipe.CreateConnectionPair(options, options); + var transport = pair.Transport; + var application = pair.Application; + var http1ConnectionContext = new Http1ConnectionContext + { + ServiceContext = new TestServiceContext(), + ConnectionFeatures = new FeatureCollection(), + Application = application, + Transport = transport, + MemoryPool = memoryPool, + TimeoutControl = Mock.Of() + }; + var http1Connection = new Http1Connection(http1ConnectionContext) + { + HasStartedConsumingRequestBody = true + }; + + var headers = new HttpRequestHeaders { HeaderContentLength = "12" }; + var body = Http1MessageBody.For(HttpVersion.Http11, headers, http1Connection); var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); - // The block returned by IncomingStart always has at least 2048 available bytes, - // so no need to bounds check in this test. - var bytes = Encoding.ASCII.GetBytes(data[0]); - var buffer = input.Application.Output.GetMemory(2028); - ArraySegment block; - Assert.True(MemoryMarshal.TryGetArray(buffer, out block)); - Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); - input.Application.Output.Advance(bytes.Length); - await input.Application.Output.FlushAsync(); + var bytes = Encoding.ASCII.GetBytes("Hello "); + var buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); + ArraySegment segment; + Assert.True(MemoryMarshal.TryGetArray(buffer, out segment)); + Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); + http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); + await http1Connection.RequestBodyPipe.Writer.FlushAsync(); - // Verify the block passed to WriteAsync is the same one incoming data was written into. - Assert.Same(block.Array, await writeTcs.Task); + // Verify the block passed to Stream.WriteAsync() is the same one incoming data was written into. + Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); - writeTcs = new TaskCompletionSource(); - bytes = Encoding.ASCII.GetBytes(data[1]); - buffer = input.Application.Output.GetMemory(2048); - Assert.True(MemoryMarshal.TryGetArray(buffer, out block)); - Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); - input.Application.Output.Advance(bytes.Length); - await input.Application.Output.FlushAsync(); + // Verify the again when GetMemory returns the tail space of the same block. + writeTcs = new TaskCompletionSource<(byte[], int, int)>(); + bytes = Encoding.ASCII.GetBytes("World!"); + buffer = http1Connection.RequestBodyPipe.Writer.GetMemory(2048); + Assert.True(MemoryMarshal.TryGetArray(buffer, out segment)); + Buffer.BlockCopy(bytes, 0, segment.Array, segment.Offset, bytes.Length); + http1Connection.RequestBodyPipe.Writer.Advance(bytes.Length); + await http1Connection.RequestBodyPipe.Writer.FlushAsync(); - Assert.Same(block.Array, await writeTcs.Task); + Assert.Equal((segment.Array, segment.Offset, bytes.Length), await writeTcs.Task); - if (headers.HeaderConnection == "close") - { - input.Application.Output.Complete(); - } + http1Connection.RequestBodyPipe.Writer.Complete(); await copyToAsyncTask; Assert.Equal(2, writeCount); - await body.StopAsync(); + // Don't call body.StopAsync() because PumpAsync() was never called. + http1Connection.RequestBodyPipe.Reader.Complete(); } }