diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/BadHttpRequestTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/BadHttpRequestTests.cs index 0394b7fcd3..80f127c9d1 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/BadHttpRequestTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/BadHttpRequestTests.cs @@ -1,11 +1,15 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Collections.Generic; using System.Linq; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Internal; +using Microsoft.Extensions.Logging; +using Moq; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests @@ -14,106 +18,102 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { [Theory] [MemberData(nameof(InvalidRequestLineData))] - public async Task TestInvalidRequestLines(string request) + public Task TestInvalidRequestLines(string request, string expectedExceptionMessage) { - using (var server = new TestServer(context => TaskCache.CompletedTask)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendAll(request); - await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + request, + "400 Bad Request", + expectedExceptionMessage); } [Theory] [MemberData(nameof(UnrecognizedHttpVersionData))] - public async Task TestInvalidRequestLinesWithUnrecognizedVersion(string httpVersion) + public Task TestInvalidRequestLinesWithUnrecognizedVersion(string httpVersion) { - using (var server = new TestServer(context => TaskCache.CompletedTask)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendAll($"GET / {httpVersion}\r\n"); - await ReceiveBadRequestResponse(connection, "505 HTTP Version Not Supported", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + $"GET / {httpVersion}\r\n", + "505 HTTP Version Not Supported", + $"Unrecognized HTTP version: {httpVersion}"); } [Theory] [MemberData(nameof(InvalidRequestHeaderData))] - public async Task TestInvalidHeaders(string rawHeaders) + public Task TestInvalidHeaders(string rawHeaders, string expectedExceptionMessage) { - using (var server = new TestServer(context => TaskCache.CompletedTask)) - { - using (var connection = server.CreateConnection()) - { - await connection.SendAll($"GET / HTTP/1.1\r\n{rawHeaders}"); - await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + $"GET / HTTP/1.1\r\n{rawHeaders}", + "400 Bad Request", + expectedExceptionMessage); } - [Fact] - public async Task BadRequestWhenHeaderNameContainsNonASCIICharacters() + [Theory] + [InlineData("Hea\0der: value", "Invalid characters in header name.")] + [InlineData("Header: va\0lue", "Malformed request: invalid headers.")] + [InlineData("Head\x80r: value", "Invalid characters in header name.")] + [InlineData("Header: valu\x80", "Malformed request: invalid headers.")] + public Task BadRequestWhenHeaderNameContainsNonASCIIOrNullCharacters(string header, string expectedExceptionMessage) { - using (var server = new TestServer(context => { return Task.FromResult(0); })) - { - using (var connection = server.CreateConnection()) - { - await connection.SendAll( - "GET / HTTP/1.1", - "H\u00eb\u00e4d\u00ebr: value", - "", - ""); - await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + $"GET / HTTP/1.1\r\n{header}\r\n\r\n", + "400 Bad Request", + expectedExceptionMessage); } [Theory] [InlineData("POST")] [InlineData("PUT")] - public async Task BadRequestIfMethodRequiresLengthButNoContentLengthOrTransferEncodingInRequest(string method) + public Task BadRequestIfMethodRequiresLengthButNoContentLengthOrTransferEncodingInRequest(string method) { - using (var server = new TestServer(context => { return Task.FromResult(0); })) - { - using (var connection = server.CreateConnection()) - { - await connection.Send($"{method} / HTTP/1.1\r\n\r\n"); - await ReceiveBadRequestResponse(connection, "411 Length Required", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + $"{method} / HTTP/1.1\r\n\r\n", + "411 Length Required", + $"{method} request contains no Content-Length or Transfer-Encoding header"); } [Theory] [InlineData("POST")] [InlineData("PUT")] - public async Task BadRequestIfMethodRequiresLengthButNoContentLengthInHttp10Request(string method) + public Task BadRequestIfMethodRequiresLengthButNoContentLengthInHttp10Request(string method) { - using (var server = new TestServer(context => { return Task.FromResult(0); })) - { - using (var connection = server.CreateConnection()) - { - await connection.Send($"{method} / HTTP/1.0\r\n\r\n"); - await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); - } - } + return TestBadRequest( + $"{method} / HTTP/1.0\r\n\r\n", + "400 Bad Request", + $"{method} request contains no Content-Length header"); } [Theory] [InlineData("NaN")] [InlineData("-1")] - public async Task BadRequestIfContentLengthInvalid(string contentLength) + public Task BadRequestIfContentLengthInvalid(string contentLength) { - using (var server = new TestServer(context => { return Task.FromResult(0); })) + return TestBadRequest( + $"POST / HTTP/1.1\r\nContent-Length: {contentLength}\r\n\r\n", + "400 Bad Request", + $"Invalid content length: {contentLength}"); + } + + private async Task TestBadRequest(string request, string expectedResponseStatusCode, string expectedExceptionMessage) + { + BadHttpRequestException loggedException = null; + var mockKestrelTrace = new Mock(); + mockKestrelTrace + .Setup(trace => trace.IsEnabled(LogLevel.Information)) + .Returns(true); + mockKestrelTrace + .Setup(trace => trace.ConnectionBadRequest(It.IsAny(), It.IsAny())) + .Callback((connectionId, exception) => loggedException = exception); + + using (var server = new TestServer(context => TaskCache.CompletedTask, new TestServiceContext { Log = mockKestrelTrace.Object })) { using (var connection = server.CreateConnection()) { - await connection.SendAll($"GET / HTTP/1.1\r\nContent-Length: {contentLength}\r\n\r\n"); - await ReceiveBadRequestResponse(connection, "400 Bad Request", server.Context.DateHeaderValue); + await connection.SendAll(request); + await ReceiveBadRequestResponse(connection, expectedResponseStatusCode, server.Context.DateHeaderValue); } } + + mockKestrelTrace.Verify(trace => trace.ConnectionBadRequest(It.IsAny(), It.IsAny())); + Assert.Equal(expectedExceptionMessage, loggedException.Message); } private async Task ReceiveBadRequestResponse(TestConnection connection, string expectedResponseStatusCode, string expectedDateHeaderValue) @@ -127,10 +127,25 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests ""); } - public static IEnumerable InvalidRequestLineData => HttpParsingData.InvalidRequestLineData.Select(data => new[] { data[0] }); + public static IEnumerable InvalidRequestLineData => HttpParsingData.InvalidRequestLineData + .Select(requestLine => new object[] + { + requestLine, + $"Invalid request line: {requestLine.Replace("\r", "<0x0D>").Replace("\n", "<0x0A>")}", + }) + .Concat(HttpParsingData.EncodedNullCharInTargetRequestLines.Select(requestLine => new object[] + { + requestLine, + "Invalid request line." + })) + .Concat(HttpParsingData.NullCharInTargetRequestLines.Select(requestLine => new object[] + { + requestLine, + "Invalid request line." + })); public static TheoryData UnrecognizedHttpVersionData => HttpParsingData.UnrecognizedHttpVersionData; - public static IEnumerable InvalidRequestHeaderData => HttpParsingData.InvalidRequestHeaderData.Select(data => new[] { data[0] }); + public static IEnumerable InvalidRequestHeaderData => HttpParsingData.InvalidRequestHeaderData; } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestHeadersTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestHeadersTests.cs index dbf6d857a5..006ac65cd1 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestHeadersTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameRequestHeadersTests.cs @@ -304,14 +304,14 @@ namespace Microsoft.AspNetCore.Server.KestrelTests } [Fact] - public void AppendThrowsWhenHeaderValueContainsNonASCIICharacters() + public void AppendThrowsWhenHeaderNameContainsNonASCIICharacters() { var headers = new FrameRequestHeaders(); const string key = "\u00141ód\017c"; var encoding = Encoding.GetEncoding("iso-8859-1"); var exception = Assert.Throws( - () => headers.Append(encoding.GetBytes(key), key)); + () => headers.Append(encoding.GetBytes(key), "value")); Assert.Equal(StatusCodes.Status400BadRequest, exception.StatusCode); } } diff --git a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs index f48a614579..ebfd64fea4 100644 --- a/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs +++ b/test/Microsoft.AspNetCore.Server.KestrelTests/FrameTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.IO.Pipelines; +using System.Linq; using System.Net; using System.Text; using System.Threading; @@ -759,7 +760,25 @@ namespace Microsoft.AspNetCore.Server.KestrelTests public static IEnumerable ValidRequestLineData => HttpParsingData.ValidRequestLineData; - public static IEnumerable InvalidRequestLineData => HttpParsingData.InvalidRequestLineData; + public static IEnumerable InvalidRequestLineData => HttpParsingData.InvalidRequestLineData + .Select(requestLine => new object[] + { + requestLine, + typeof(BadHttpRequestException), + $"Invalid request line: {requestLine.Replace("\r", "<0x0D>").Replace("\n", "<0x0A>")}", + }) + .Concat(HttpParsingData.EncodedNullCharInTargetRequestLines.Select(requestLine => new object[] + { + requestLine, + typeof(InvalidOperationException), + "The path contains null characters." + })) + .Concat(HttpParsingData.NullCharInTargetRequestLines.Select(requestLine => new object[] + { + requestLine, + typeof(InvalidOperationException), + new InvalidOperationException().Message + })); public static TheoryData UnrecognizedHttpVersionData => HttpParsingData.UnrecognizedHttpVersionData; diff --git a/test/shared/HttpParsingData.cs b/test/shared/HttpParsingData.cs index a5a83c58c6..50a5633c9c 100644 --- a/test/shared/HttpParsingData.cs +++ b/test/shared/HttpParsingData.cs @@ -4,7 +4,6 @@ using System; using System.Collections.Generic; using System.Linq; -using Microsoft.AspNetCore.Server.Kestrel; using Xunit; namespace Microsoft.AspNetCore.Testing @@ -71,150 +70,118 @@ namespace Microsoft.AspNetCore.Testing } } - // All these test cases must end in '\n', otherwise the server will spin forever - public static IEnumerable InvalidRequestLineData + public static IEnumerable InvalidRequestLineData => new[] { - get - { - var invalidRequestLines = new[] - { - "G\r\n", - "GE\r\n", - "GET\r\n", - "GET \r\n", - "GET /\r\n", - "GET / \r\n", - "GET/HTTP/1.1\r\n", - "GET /HTTP/1.1\r\n", - " \r\n", - " \r\n", - "/ HTTP/1.1\r\n", - " / HTTP/1.1\r\n", - "/ \r\n", - "GET \r\n", - "GET HTTP/1.0\r\n", - "GET HTTP/1.1\r\n", - "GET / \n", - "GET / HTTP/1.0\n", - "GET / HTTP/1.1\n", - "GET / HTTP/1.0\rA\n", - "GET / HTTP/1.1\ra\n", - "GET? / HTTP/1.1\r\n", - "GET ? HTTP/1.1\r\n", - "GET /a?b=cHTTP/1.1\r\n", - "GET /a%20bHTTP/1.1\r\n", - "GET /a%20b?c=dHTTP/1.1\r\n", - "GET %2F HTTP/1.1\r\n", - "GET %00 HTTP/1.1\r\n", - "CUSTOM \r\n", - "CUSTOM /\r\n", - "CUSTOM / \r\n", - "CUSTOM /HTTP/1.1\r\n", - "CUSTOM \r\n", - "CUSTOM HTTP/1.0\r\n", - "CUSTOM HTTP/1.1\r\n", - "CUSTOM / \n", - "CUSTOM / HTTP/1.0\n", - "CUSTOM / HTTP/1.1\n", - "CUSTOM / HTTP/1.0\rA\n", - "CUSTOM / HTTP/1.1\ra\n", - "CUSTOM ? HTTP/1.1\r\n", - "CUSTOM /a?b=cHTTP/1.1\r\n", - "CUSTOM /a%20bHTTP/1.1\r\n", - "CUSTOM /a%20b?c=dHTTP/1.1\r\n", - "CUSTOM %2F HTTP/1.1\r\n", - "CUSTOM %00 HTTP/1.1\r\n", - // Bad HTTP Methods (invalid according to RFC) - "( / HTTP/1.0\r\n", - ") / HTTP/1.0\r\n", - "< / HTTP/1.0\r\n", - "> / HTTP/1.0\r\n", - "@ / HTTP/1.0\r\n", - ", / HTTP/1.0\r\n", - "; / HTTP/1.0\r\n", - ": / HTTP/1.0\r\n", - "\\ / HTTP/1.0\r\n", - "\" / HTTP/1.0\r\n", - "/ / HTTP/1.0\r\n", - "[ / HTTP/1.0\r\n", - "] / HTTP/1.0\r\n", - "? / HTTP/1.0\r\n", - "= / HTTP/1.0\r\n", - "{ / HTTP/1.0\r\n", - "} / HTTP/1.0\r\n", - "get@ / HTTP/1.0\r\n", - "post= / HTTP/1.0\r\n", - }; + "G\r\n", + "GE\r\n", + "GET\r\n", + "GET \r\n", + "GET /\r\n", + "GET / \r\n", + "GET/HTTP/1.1\r\n", + "GET /HTTP/1.1\r\n", + " \r\n", + " \r\n", + "/ HTTP/1.1\r\n", + " / HTTP/1.1\r\n", + "/ \r\n", + "GET \r\n", + "GET HTTP/1.0\r\n", + "GET HTTP/1.1\r\n", + "GET / \n", + "GET / HTTP/1.0\n", + "GET / HTTP/1.1\n", + "GET / HTTP/1.0\rA\n", + "GET / HTTP/1.1\ra\n", + "GET? / HTTP/1.1\r\n", + "GET ? HTTP/1.1\r\n", + "GET /a?b=cHTTP/1.1\r\n", + "GET /a%20bHTTP/1.1\r\n", + "GET /a%20b?c=dHTTP/1.1\r\n", + "GET %2F HTTP/1.1\r\n", + "GET %00 HTTP/1.1\r\n", + "CUSTOM \r\n", + "CUSTOM /\r\n", + "CUSTOM / \r\n", + "CUSTOM /HTTP/1.1\r\n", + "CUSTOM \r\n", + "CUSTOM HTTP/1.0\r\n", + "CUSTOM HTTP/1.1\r\n", + "CUSTOM / \n", + "CUSTOM / HTTP/1.0\n", + "CUSTOM / HTTP/1.1\n", + "CUSTOM / HTTP/1.0\rA\n", + "CUSTOM / HTTP/1.1\ra\n", + "CUSTOM ? HTTP/1.1\r\n", + "CUSTOM /a?b=cHTTP/1.1\r\n", + "CUSTOM /a%20bHTTP/1.1\r\n", + "CUSTOM /a%20b?c=dHTTP/1.1\r\n", + "CUSTOM %2F HTTP/1.1\r\n", + "CUSTOM %00 HTTP/1.1\r\n", + // Bad HTTP Methods (invalid according to RFC) + "( / HTTP/1.0\r\n", + ") / HTTP/1.0\r\n", + "< / HTTP/1.0\r\n", + "> / HTTP/1.0\r\n", + "@ / HTTP/1.0\r\n", + ", / HTTP/1.0\r\n", + "; / HTTP/1.0\r\n", + ": / HTTP/1.0\r\n", + "\\ / HTTP/1.0\r\n", + "\" / HTTP/1.0\r\n", + "/ / HTTP/1.0\r\n", + "[ / HTTP/1.0\r\n", + "] / HTTP/1.0\r\n", + "? / HTTP/1.0\r\n", + "= / HTTP/1.0\r\n", + "{ / HTTP/1.0\r\n", + "} / HTTP/1.0\r\n", + "get@ / HTTP/1.0\r\n", + "post= / HTTP/1.0\r\n", + }; - var encodedNullCharInTargetRequestLines = new[] - { - "GET /%00 HTTP/1.1\r\n", - "GET /%00%00 HTTP/1.1\r\n", - "GET /%E8%00%84 HTTP/1.1\r\n", - "GET /%E8%85%00 HTTP/1.1\r\n", - "GET /%F3%00%82%86 HTTP/1.1\r\n", - "GET /%F3%85%00%82 HTTP/1.1\r\n", - "GET /%F3%85%82%00 HTTP/1.1\r\n", - "GET /%E8%85%00 HTTP/1.1\r\n", - "GET /%E8%01%00 HTTP/1.1\r\n", - }; - - var nullCharInTargetRequestLines = new[] - { - "GET \0 HTTP/1.1\r\n", - "GET /\0 HTTP/1.1\r\n", - "GET /\0\0 HTTP/1.1\r\n", - "GET /%C8\0 HTTP/1.1\r\n", - }; - - return invalidRequestLines.Select(requestLine => new object[] - { - requestLine, - typeof(BadHttpRequestException), - $"Invalid request line: {requestLine.Replace("\r", "<0x0D>").Replace("\n", "<0x0A>")}" - }) - .Concat(encodedNullCharInTargetRequestLines.Select(requestLine => new object[] - { - requestLine, - typeof(InvalidOperationException), - $"The path contains null characters." - })) - .Concat(nullCharInTargetRequestLines.Select(requestLine => new object[] - { - requestLine, - typeof(InvalidOperationException), - new InvalidOperationException().Message - })); - } - } - - public static TheoryData UnrecognizedHttpVersionData + public static IEnumerable EncodedNullCharInTargetRequestLines => new[] { - get + "GET /%00 HTTP/1.1\r\n", + "GET /%00%00 HTTP/1.1\r\n", + "GET /%E8%00%84 HTTP/1.1\r\n", + "GET /%E8%85%00 HTTP/1.1\r\n", + "GET /%F3%00%82%86 HTTP/1.1\r\n", + "GET /%F3%85%00%82 HTTP/1.1\r\n", + "GET /%F3%85%82%00 HTTP/1.1\r\n", + "GET /%E8%85%00 HTTP/1.1\r\n", + "GET /%E8%01%00 HTTP/1.1\r\n", + }; + + public static IEnumerable NullCharInTargetRequestLines => new[] { - return new TheoryData - { - "H", - "HT", - "HTT", - "HTTP", - "HTTP/", - "HTTP/1", - "HTTP/1.", - "http/1.0", - "http/1.1", - "HTTP/1.1 ", - "HTTP/1.0a", - "HTTP/1.0ab", - "HTTP/1.1a", - "HTTP/1.1ab", - "HTTP/1.2", - "HTTP/3.0", - "hello", - "8charact", - }; - } - } + "GET \0 HTTP/1.1\r\n", + "GET /\0 HTTP/1.1\r\n", + "GET /\0\0 HTTP/1.1\r\n", + "GET /%C8\0 HTTP/1.1\r\n", + }; + + public static TheoryData UnrecognizedHttpVersionData => new TheoryData + { + "H", + "HT", + "HTT", + "HTTP", + "HTTP/", + "HTTP/1", + "HTTP/1.", + "http/1.0", + "http/1.1", + "HTTP/1.1 ", + "HTTP/1.0a", + "HTTP/1.0ab", + "HTTP/1.1a", + "HTTP/1.1ab", + "HTTP/1.2", + "HTTP/3.0", + "hello", + "8charact", + }; public static IEnumerable InvalidRequestHeaderData {