InitializeHeaders only at start of parsing/Fix remaining (#1488)

* Don't reinitialize header collection each loop
* Correct remaining tracking value
* Add tests
This commit is contained in:
Ben Adams 2017-03-13 22:32:28 +00:00 committed by David Fowler
parent 5644310811
commit b94912bcb1
4 changed files with 160 additions and 6 deletions

View File

@ -36,14 +36,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
{
ConnectionControl.SetTimeout(_keepAliveMilliseconds, TimeoutAction.CloseConnection);
InitializeHeaders();
while (!_requestProcessingStopping)
{
var result = await Input.Reader.ReadAsync();
var examined = result.Buffer.End;
var consumed = result.Buffer.End;
InitializeHeaders();
try
{
ParseRequest(result.Buffer, out consumed, out examined);

View File

@ -194,7 +194,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Internal.Http
while (!reader.End)
{
var span = reader.Span;
var remaining = span.Length;
var remaining = span.Length - reader.Index;
fixed (byte* pBuffer = &span.DangerousGetPinnableReference())
{

View File

@ -1,11 +1,16 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Testing;
using Microsoft.AspNetCore.Server.Kestrel.Internal.Http;
using Xunit;
using Microsoft.Extensions.Primitives;
using System.Collections;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
{
@ -75,6 +80,108 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
}
}
[Theory]
[InlineData(1, 1)]
[InlineData(5, 5)]
[InlineData(100, 100)]
[InlineData(600, 100)]
[InlineData(700, 1)]
[InlineData(1, 700)]
public async Task ServerAcceptsHeadersAcrossSends(int header0Count, int header1Count)
{
var headers0 = MakeHeaders(header0Count);
var headers1 = MakeHeaders(header1Count, header0Count);
using (var server = CreateServer(maxRequestHeaderCount: header0Count + header1Count))
{
using (var connection = new TestConnection(server.Port))
{
await connection.SendAll("GET / HTTP/1.1\r\n");
// Wait for parsing to start
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame?.RequestHeaders != null);
Assert.Equal(0, server.Frame.RequestHeaders.Count);
await connection.SendAll(headers0);
// Wait for headers to be parsed
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count);
Assert.Equal(header0Count, server.Frame.RequestHeaders.Count);
await connection.SendAll(headers1);
// Wait for headers to be parsed
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count + header1Count);
Assert.Equal(header0Count + header1Count, server.Frame.RequestHeaders.Count);
await connection.SendAll("\r\n");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}
[Theory]
[InlineData(1, 1)]
[InlineData(5, 5)]
public async Task ServerKeepsSameHeaderCollectionAcrossSends(int header0Count, int header1Count)
{
var headers0 = MakeHeaders(header0Count);
var headers1 = MakeHeaders(header0Count, header1Count);
using (var server = CreateServer(maxRequestHeaderCount: header0Count + header1Count))
{
using (var connection = new TestConnection(server.Port))
{
await connection.SendAll("GET / HTTP/1.1\r\n");
// Wait for parsing to start
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame?.RequestHeaders != null);
Assert.Equal(0, server.Frame.RequestHeaders.Count);
var newRequestHeaders = new RequestHeadersWrapper(server.Frame.RequestHeaders);
server.Frame.RequestHeaders = newRequestHeaders;
Assert.Same(newRequestHeaders, server.Frame.RequestHeaders);
await connection.SendAll(headers0);
// Wait for headers to be parsed
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count);
Assert.Same(newRequestHeaders, server.Frame.RequestHeaders);
Assert.Equal(header0Count, server.Frame.RequestHeaders.Count);
await connection.SendAll(headers1);
// Wait for headers to be parsed
await WaitForCondition(TimeSpan.FromSeconds(1), () => server.Frame.RequestHeaders.Count >= header0Count + header1Count);
Assert.Equal(header0Count + header1Count, server.Frame.RequestHeaders.Count);
Assert.Same(newRequestHeaders, server.Frame.RequestHeaders);
await connection.SendAll("\r\n");
await connection.ReceiveEnd(
"HTTP/1.1 200 OK",
$"Date: {server.Context.DateHeaderValue}",
"Transfer-Encoding: chunked",
"",
"c",
"hello, world",
"0",
"",
"");
}
}
}
[Theory]
[InlineData(1)]
[InlineData(5)]
@ -122,11 +229,26 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
}
}
private static string MakeHeaders(int count)
private static async Task WaitForCondition(TimeSpan timeout, Func<bool> condition)
{
const int MaxWaitLoop = 150;
var delay = (int)Math.Ceiling(timeout.TotalMilliseconds / MaxWaitLoop);
var waitLoop = 0;
while (waitLoop < MaxWaitLoop && !condition())
{
// Wait for parsing condition to trigger
await Task.Delay(delay);
waitLoop++;
}
}
private static string MakeHeaders(int count, int startAt = 0)
{
return string.Join("", Enumerable
.Range(0, count)
.Select(i => $"Header-{i}: value{i}\r\n"));
.Select(i => $"Header-{startAt + i}: value{startAt + i}\r\n"));
}
private TestServer CreateServer(int? maxRequestHeaderCount = null, int? maxRequestHeadersTotalSize = null)
@ -148,5 +270,33 @@ namespace Microsoft.AspNetCore.Server.Kestrel.FunctionalTests
ServerOptions = options
});
}
private class RequestHeadersWrapper : IHeaderDictionary
{
IHeaderDictionary _innerHeaders;
public RequestHeadersWrapper(IHeaderDictionary headers)
{
_innerHeaders = headers;
}
public StringValues this[string key] { get => _innerHeaders[key]; set => _innerHeaders[key] = value; }
public long? ContentLength { get => _innerHeaders.ContentLength; set => _innerHeaders.ContentLength = value; }
public ICollection<string> Keys => _innerHeaders.Keys;
public ICollection<StringValues> Values => _innerHeaders.Values;
public int Count => _innerHeaders.Count;
public bool IsReadOnly => _innerHeaders.IsReadOnly;
public void Add(string key, StringValues value) => _innerHeaders.Add(key, value);
public void Add(KeyValuePair<string, StringValues> item) => _innerHeaders.Add(item);
public void Clear() => _innerHeaders.Clear();
public bool Contains(KeyValuePair<string, StringValues> item) => _innerHeaders.Contains(item);
public bool ContainsKey(string key) => _innerHeaders.ContainsKey(key);
public void CopyTo(KeyValuePair<string, StringValues>[] array, int arrayIndex) => _innerHeaders.CopyTo(array, arrayIndex);
public IEnumerator<KeyValuePair<string, StringValues>> GetEnumerator() => _innerHeaders.GetEnumerator();
public bool Remove(string key) => _innerHeaders.Remove(key);
public bool Remove(KeyValuePair<string, StringValues> item) => _innerHeaders.Remove(item);
public bool TryGetValue(string key, out StringValues value) => _innerHeaders.TryGetValue(key, out value);
IEnumerator IEnumerable.GetEnumerator() => _innerHeaders.GetEnumerator();
}
}
}

View File

@ -18,6 +18,7 @@ namespace Microsoft.AspNetCore.Testing
private KestrelEngine _engine;
private IDisposable _server;
private ListenOptions _listenOptions;
private Frame<HttpContext> _frame;
public TestServer(RequestDelegate app)
: this(app, new TestServiceContext())
@ -46,7 +47,8 @@ namespace Microsoft.AspNetCore.Testing
context.FrameFactory = connectionContext =>
{
return new Frame<HttpContext>(new DummyApplication(app, httpContextFactory), connectionContext);
_frame = new Frame<HttpContext>(new DummyApplication(app, httpContextFactory), connectionContext);
return _frame;
};
try
@ -65,6 +67,8 @@ namespace Microsoft.AspNetCore.Testing
public int Port => _listenOptions.IPEndPoint.Port;
public Frame<HttpContext> Frame => _frame;
public TestServiceContext Context { get; }
public TestConnection CreateConnection()