// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Linq; using System.Text; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Server.Kestrel; using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; using Moq; using Xunit; using Xunit.Sdk; namespace Microsoft.AspNetCore.Server.KestrelTests { /// /// Summary description for MessageBodyTests /// public class MessageBodyTests { [Fact] public void Http10ConnectionClose() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders { HeaderContentLength = "5" }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; var count1 = stream.Read(buffer1, 0, 1024); AssertASCII("Hello", new ArraySegment(buffer1, 0, 5)); var buffer2 = new byte[1024]; var count2 = stream.Read(buffer2, 0, 1024); Assert.Equal(0, count2); } } [Fact] public async Task Http10ConnectionCloseAsync() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders { HeaderContentLength = "5" }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; var count1 = await stream.ReadAsync(buffer1, 0, 1024); AssertASCII("Hello", new ArraySegment(buffer1, 0, 5)); var buffer2 = new byte[1024]; var count2 = await stream.ReadAsync(buffer2, 0, 1024); Assert.Equal(0, count2); } } [Fact] public void Http10NoContentLength() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders(), input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, stream.Read(buffer1, 0, 1024)); } } [Fact] public async Task Http10NoContentLengthAsync() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders(), input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer1, 0, 1024)); } } [Fact] public void Http11NoContentLength() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders(), input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, stream.Read(buffer1, 0, 1024)); } } [Fact] public async Task Http11NoContentLengthAsync() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders(), input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer1, 0, 1024)); } } [Fact] public void Http11NoContentLengthConnectionClose() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderConnection = "close" }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, stream.Read(buffer1, 0, 1024)); } } [Fact] public async Task Http11NoContentLengthConnectionCloseAsync() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderConnection = "close" }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer1 = new byte[1024]; Assert.Equal(0, await stream.ReadAsync(buffer1, 0, 1024)); } } [Fact] public async Task CanHandleLargeBlocks() { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders { HeaderContentLength = "8197" }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); // Input needs to be greater than 4032 bytes to allocate a block not backed by a slab. var largeInput = new string('a', 8192); input.Add(largeInput); // Add a smaller block to the end so that SocketInput attempts to return the large // block to the memory pool. input.Add("Hello", fin: true); var ms = new MemoryStream(); await stream.CopyToAsync(ms); var requestArray = ms.ToArray(); Assert.Equal(8197, requestArray.Length); AssertASCII(largeInput + "Hello", new ArraySegment(requestArray, 0, requestArray.Length)); var count = await stream.ReadAsync(new byte[1], 0, 1); Assert.Equal(0, count); } } [Fact] public void ForThrowsWhenFinalTransferCodingIsNotChunked() { using (var input = new TestInput()) { var ex = Assert.Throws(() => MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderTransferEncoding = "chunked, not-chunked" }, input.FrameContext)); Assert.Equal(StatusCodes.Status400BadRequest, ex.StatusCode); Assert.Equal("Final transfer coding is not \"chunked\": \"chunked, not-chunked\"", ex.Message); } } [Theory] [InlineData("POST")] [InlineData("PUT")] public void ForThrowsWhenMethodRequiresLengthButNoContentLengthOrTransferEncodingIsSet(string method) { using (var input = new TestInput()) { input.FrameContext.Method = method; var ex = Assert.Throws(() => MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders(), input.FrameContext)); Assert.Equal(StatusCodes.Status411LengthRequired, ex.StatusCode); Assert.Equal($"{method} request contains no Content-Length or Transfer-Encoding header", ex.Message); } } [Theory] [InlineData("POST")] [InlineData("PUT")] public void ForThrowsWhenMethodRequiresLengthButNoContentLengthSetHttp10(string method) { using (var input = new TestInput()) { input.FrameContext.Method = method; var ex = Assert.Throws(() => MessageBody.For(HttpVersion.Http10, new FrameRequestHeaders(), input.FrameContext)); Assert.Equal(StatusCodes.Status400BadRequest, ex.StatusCode); Assert.Equal($"{method} request contains no Content-Length header", ex.Message); } } public static IEnumerable StreamData => new[] { new object[] { new ThrowOnWriteSynchronousStream() }, new object[] { new ThrowOnWriteAsynchronousStream() }, }; public static IEnumerable RequestData => new[] { // Content-Length new object[] { new FrameRequestHeaders { HeaderContentLength = "12" }, new[] { "Hello ", "World!" } }, // Chunked new object[] { new FrameRequestHeaders { HeaderTransferEncoding = "chunked" }, new[] { "6\r\nHello \r\n", "6\r\nWorld!\r\n0\r\n\r\n" } }, }; public static IEnumerable CombinedData => from stream in StreamData from request in RequestData select new[] { stream[0], request[0], request[1] }; [Theory] [MemberData(nameof(RequestData))] public async Task CopyToAsyncDoesNotCopyBlocks(FrameRequestHeaders headers, string[] data) { var writeCount = 0; var writeTcs = new TaskCompletionSource(); var mockDestination = new Mock(); mockDestination .Setup(m => m.WriteAsync(It.IsAny(), It.IsAny(), It.IsAny(), CancellationToken.None)) .Callback((byte[] buffer, int offset, int count, CancellationToken cancellationToken) => { writeTcs.SetResult(buffer); writeCount++; }) .Returns(TaskCache.CompletedTask); using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); var copyToAsyncTask = body.CopyToAsync(mockDestination.Object); // The block returned by IncomingStart always has at least 2048 available bytes, // so no need to bounds check in this test. var socketInput = input.FrameContext.Input; var bytes = Encoding.ASCII.GetBytes(data[0]); var buffer = socketInput.Writer.Alloc(2048); ArraySegment block; Assert.True(buffer.Memory.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); buffer.Advance(bytes.Length); await buffer.FlushAsync(); // Verify the block passed to WriteAsync is the same one incoming data was written into. Assert.Same(block.Array, await writeTcs.Task); writeTcs = new TaskCompletionSource(); bytes = Encoding.ASCII.GetBytes(data[1]); buffer = socketInput.Writer.Alloc(2048); Assert.True(buffer.Memory.TryGetArray(out block)); Buffer.BlockCopy(bytes, 0, block.Array, block.Offset, bytes.Length); buffer.Advance(bytes.Length); await buffer.FlushAsync(); Assert.Same(block.Array, await writeTcs.Task); if (headers.HeaderConnection == "close") { socketInput.Writer.Complete(); } await copyToAsyncTask; Assert.Equal(2, writeCount); } } [Theory] [MemberData(nameof(CombinedData))] public async Task CopyToAsyncAdvancesRequestStreamWhenDestinationWriteAsyncThrows(Stream writeStream, FrameRequestHeaders headers, string[] data) { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, headers, input.FrameContext); input.Add(data[0]); await Assert.ThrowsAsync(() => body.CopyToAsync(writeStream)); input.Add(data[1], fin: headers.HeaderConnection == "close"); // "Hello " should have been consumed var readBuffer = new byte[6]; var count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); Assert.Equal(6, count); AssertASCII("World!", new ArraySegment(readBuffer, 0, 6)); count = await body.ReadAsync(new ArraySegment(readBuffer, 0, readBuffer.Length)); Assert.Equal(0, count); } } [Theory] [InlineData("keep-alive, upgrade")] [InlineData("Keep-Alive, Upgrade")] [InlineData("upgrade, keep-alive")] [InlineData("Upgrade, Keep-Alive")] public void ConnectionUpgradeKeepAlive(string headerConnection) { using (var input = new TestInput()) { var body = MessageBody.For(HttpVersion.Http11, new FrameRequestHeaders { HeaderConnection = headerConnection }, input.FrameContext); var stream = new FrameRequestStream(); stream.StartAcceptingReads(body); input.Add("Hello", true); var buffer = new byte[1024]; Assert.Equal(5, stream.Read(buffer, 0, 1024)); AssertASCII("Hello", new ArraySegment(buffer, 0, 5)); } } private void AssertASCII(string expected, ArraySegment actual) { var encoding = Encoding.ASCII; var bytes = encoding.GetBytes(expected); Assert.Equal(bytes.Length, actual.Count); for (var index = 0; index < bytes.Length; index++) { Assert.Equal(bytes[index], actual.Array[actual.Offset + index]); } } private class ThrowOnWriteSynchronousStream : Stream { public override void Flush() { throw new NotImplementedException(); } public override int Read(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } public override long Seek(long offset, SeekOrigin origin) { throw new NotImplementedException(); } public override void SetLength(long value) { throw new NotImplementedException(); } public override void Write(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { throw new XunitException(); } public override bool CanRead { get; } public override bool CanSeek { get; } public override bool CanWrite => true; public override long Length { get; } public override long Position { get; set; } } private class ThrowOnWriteAsynchronousStream : Stream { public override void Flush() { throw new NotImplementedException(); } public override int Read(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } public override long Seek(long offset, SeekOrigin origin) { throw new NotImplementedException(); } public override void SetLength(long value) { throw new NotImplementedException(); } public override void Write(byte[] buffer, int offset, int count) { throw new NotImplementedException(); } public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) { await Task.Delay(1); throw new XunitException(); } public override bool CanRead { get; } public override bool CanSeek { get; } public override bool CanWrite => true; public override long Length { get; } public override long Position { get; set; } } } }