From b94912bcb12fbc9c27067fb9098a0abe744351d8 Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Mon, 13 Mar 2017 22:32:28 +0000 Subject: [PATCH] InitializeHeaders only at start of parsing/Fix remaining (#1488) * Don't reinitialize header collection each loop * Correct remaining tracking value * Add tests --- .../Internal/Http/FrameOfT.cs | 4 +- .../Internal/Http/KestrelHttpParser.cs | 2 +- .../RequestHeaderLimitsTests.cs | 154 +++++++++++++++++- test/shared/TestServer.cs | 6 +- 4 files changed, 160 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs index 0065f6b8cf..7af359cfc4 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/FrameOfT.cs @@ -36,14 +36,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http { ConnectionControl.SetTimeout(_keepAliveMilliseconds, TimeoutAction.CloseConnection); + InitializeHeaders(); + while (!_requestProcessingStopping) { var result = await Input.Reader.ReadAsync(); var examined = result.Buffer.End; var consumed = result.Buffer.End; - InitializeHeaders(); - try { ParseRequest(result.Buffer, out consumed, out examined); diff --git a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/KestrelHttpParser.cs b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/KestrelHttpParser.cs index 17c0ebbf30..71a49c2f45 100644 --- a/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/KestrelHttpParser.cs +++ b/src/Microsoft.AspNetCore.Server.Kestrel/Internal/Http/KestrelHttpParser.cs @@ -194,7 +194,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http while (!reader.End) { var span = reader.Span; - var remaining = span.Length; + var remaining = span.Length - reader.Index; fixed (byte* pBuffer = &span.DangerousGetPinnableReference()) { diff --git a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs index 7ce1faf4bc..f7c1d8bc1a 100644 --- a/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs +++ b/test/Microsoft.AspNetCore.Server.Kestrel.FunctionalTests/RequestHeaderLimitsTests.cs @@ -1,11 +1,16 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; using System.Linq; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Testing; +using Microsoft.AspNetCore.Server.Kestrel.Internal.Http; using Xunit; +using Microsoft.Extensions.Primitives; +using System.Collections; +using System.Collections.Generic; namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests { @@ -75,6 +80,108 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } + [Theory] + [InlineData(1, 1)] + [InlineData(5, 5)] + [InlineData(100, 100)] + [InlineData(600, 100)] + [InlineData(700, 1)] + [InlineData(1, 700)] + public async Task ServerAcceptsHeadersAcrossSends(int header0Count, int header1Count) + { + var headers0 = MakeHeaders(header0Count); + var headers1 = MakeHeaders(header1Count, header0Count); + + using (var server = CreateServer(maxRequestHeaderCount: header0Count + header1Count)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll("GET / HTTP/1.1\r\n"); + // Wait for parsing to start + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame?.RequestHeaders != null); + + Assert.Equal(0, server.Frame.RequestHeaders.Count); + + await connection.SendAll(headers0); + // Wait for headers to be parsed + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count); + + Assert.Equal(header0Count, server.Frame.RequestHeaders.Count); + + await connection.SendAll(headers1); + // Wait for headers to be parsed + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count + header1Count); + + Assert.Equal(header0Count + header1Count, server.Frame.RequestHeaders.Count); + + await connection.SendAll("\r\n"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + + [Theory] + [InlineData(1, 1)] + [InlineData(5, 5)] + public async Task ServerKeepsSameHeaderCollectionAcrossSends(int header0Count, int header1Count) + { + var headers0 = MakeHeaders(header0Count); + var headers1 = MakeHeaders(header0Count, header1Count); + + using (var server = CreateServer(maxRequestHeaderCount: header0Count + header1Count)) + { + using (var connection = new TestConnection(server.Port)) + { + await connection.SendAll("GET / HTTP/1.1\r\n"); + // Wait for parsing to start + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame?.RequestHeaders != null); + + Assert.Equal(0, server.Frame.RequestHeaders.Count); + + var newRequestHeaders = new RequestHeadersWrapper(server.Frame.RequestHeaders); + server.Frame.RequestHeaders = newRequestHeaders; + + Assert.Same(newRequestHeaders, server.Frame.RequestHeaders); + + await connection.SendAll(headers0); + // Wait for headers to be parsed + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count); + + Assert.Same(newRequestHeaders, server.Frame.RequestHeaders); + Assert.Equal(header0Count, server.Frame.RequestHeaders.Count); + + await connection.SendAll(headers1); + // Wait for headers to be parsed + await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count + header1Count); + + Assert.Equal(header0Count + header1Count, server.Frame.RequestHeaders.Count); + + Assert.Same(newRequestHeaders, server.Frame.RequestHeaders); + + await connection.SendAll("\r\n"); + await connection.ReceiveEnd( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "hello, world", + "0", + "", + ""); + } + } + } + [Theory] [InlineData(1)] [InlineData(5)] @@ -122,11 +229,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests } } - private static string MakeHeaders(int count) + private static async Task WaitForCondition(TimeSpan timeout, Func condition) + { + const int MaxWaitLoop = 150; + + var delay = (int)Math.Ceiling(timeout.TotalMilliseconds / MaxWaitLoop); + + var waitLoop = 0; + while (waitLoop < MaxWaitLoop && !condition()) + { + // Wait for parsing condition to trigger + await Task.Delay(delay); + waitLoop++; + } + } + + private static string MakeHeaders(int count, int startAt = 0) { return string.Join("", Enumerable .Range(0, count) - .Select(i => $"Header-{i}: value{i}\r\n")); + .Select(i => $"Header-{startAt + i}: value{startAt + i}\r\n")); } private TestServer CreateServer(int? maxRequestHeaderCount = null, int? maxRequestHeadersTotalSize = null) @@ -148,5 +270,33 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests ServerOptions = options }); } + + private class RequestHeadersWrapper : IHeaderDictionary + { + IHeaderDictionary _innerHeaders; + + public RequestHeadersWrapper(IHeaderDictionary headers) + { + _innerHeaders = headers; + } + + public StringValues this[string key] { get => _innerHeaders[key]; set => _innerHeaders[key] = value; } + public long? ContentLength { get => _innerHeaders.ContentLength; set => _innerHeaders.ContentLength = value; } + public ICollection Keys => _innerHeaders.Keys; + public ICollection Values => _innerHeaders.Values; + public int Count => _innerHeaders.Count; + public bool IsReadOnly => _innerHeaders.IsReadOnly; + public void Add(string key, StringValues value) => _innerHeaders.Add(key, value); + public void Add(KeyValuePair item) => _innerHeaders.Add(item); + public void Clear() => _innerHeaders.Clear(); + public bool Contains(KeyValuePair item) => _innerHeaders.Contains(item); + public bool ContainsKey(string key) => _innerHeaders.ContainsKey(key); + public void CopyTo(KeyValuePair[] array, int arrayIndex) => _innerHeaders.CopyTo(array, arrayIndex); + public IEnumerator> GetEnumerator() => _innerHeaders.GetEnumerator(); + public bool Remove(string key) => _innerHeaders.Remove(key); + public bool Remove(KeyValuePair item) => _innerHeaders.Remove(item); + public bool TryGetValue(string key, out StringValues value) => _innerHeaders.TryGetValue(key, out value); + IEnumerator IEnumerable.GetEnumerator() => _innerHeaders.GetEnumerator(); + } } } \ No newline at end of file diff --git a/test/shared/TestServer.cs b/test/shared/TestServer.cs index 480e456bd2..95d5d81699 100644 --- a/test/shared/TestServer.cs +++ b/test/shared/TestServer.cs @@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Testing private KestrelEngine _engine; private IDisposable _server; private ListenOptions _listenOptions; + private Frame _frame; public TestServer(RequestDelegate app) : this(app, new TestServiceContext()) @@ -46,7 +47,8 @@ namespace Microsoft.AspNetCore.Testing context.FrameFactory = connectionContext => { - return new Frame(new DummyApplication(app, httpContextFactory), connectionContext); + _frame = new Frame(new DummyApplication(app, httpContextFactory), connectionContext); + return _frame; }; try @@ -65,6 +67,8 @@ namespace Microsoft.AspNetCore.Testing public int Port => _listenOptions.IPEndPoint.Port; + public Frame Frame => _frame; + public TestServiceContext Context { get; } public TestConnection CreateConnection()