Using the optimized method for converting header name to ASCII

This commit is contained in:
moozzyk 2016-05-27 11:03:05 -07:00
parent de022b6051
commit 0342754c57
5 changed files with 499 additions and 451 deletions

View File

@ -172,6 +172,23 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
}
}
[Fact]
public async Task BadRequestWhenNameHeaderNamesContainsNonASCIICharacters()
{
using (var server = new TestServer(context => { return Task.FromResult(0); }))
{
using (var connection = server.CreateConnection())
{
await connection.SendEnd(
"GET / HTTP/1.1",
"Hëädër: value",
"",
"");
await ReceiveBadRequestResponse(connection);
}
}
}
private async Task ReceiveBadRequestResponse(TestConnection connection)
{
await connection.Receive(

View File

@ -1,5 +1,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.Server.Kestrel.Exceptions;
using Microsoft.AspNetCore.Server.Kestrel.Http;
using Microsoft.Extensions.Primitives;
using Xunit;
@ -233,5 +235,16 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
Assert.Null(entries[3].Key);
Assert.Equal(new StringValues(), entries[0].Value);
}
[Fact]
public void AppendThrowsWhenHeaderValueContainsNonASCIICharacters()
{
var headers = new FrameRequestHeaders();
const string key = "\u00141ód\017c";
var encoding = Encoding.GetEncoding("iso-8859-1");
Assert.Throws<BadHttpRequestException>(
() => headers.Append(encoding.GetBytes(key), 0, encoding.GetByteCount(key), key));
}
}
}

View File

@ -61,7 +61,7 @@ namespace Microsoft.AspNetCore.Server.KestrelTests
public async Task Send(params string[] lines)
{
var text = String.Join("\r\n", lines);
var writer = new StreamWriter(_stream, Encoding.ASCII);
var writer = new StreamWriter(_stream, Encoding.GetEncoding("iso-8859-1"));
for (var index = 0; index < text.Length; index++)
{
var ch = text[index];

View File

@ -72,7 +72,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.GeneratedCode
return $"(({array}[{offset / count}] & {mask}{suffix}) == {comp}{suffix})";
}
}
public static string GeneratedFile()
{
var commonHeaders = new[]
@ -199,11 +199,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.GeneratedCode
return $@"
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.Server.Kestrel.Exceptions;
using Microsoft.AspNetCore.Server.Kestrel.Infrastructure;
using Microsoft.Extensions.Primitives;
namespace Microsoft.AspNetCore.Server.Kestrel.Http
namespace Microsoft.AspNetCore.Server.Kestrel.Http
{{
{Each(loops, loop => $@"
public partial class {loop.ClassName}
@ -214,7 +214,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{Each(loop.Bytes, b => $"{b},")}
}};"
: "")}
private long _bits = 0;
private HeaderReferences _headers;
{Each(loop.Headers, header => $@"
@ -239,7 +239,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
public void SetRaw{header.Identifier}(StringValues value, byte[] raw)
{{
{header.SetBit()};
_headers._{header.Identifier} = value;
_headers._{header.Identifier} = value;
_headers._raw{header.Identifier} = raw;
}}")}
protected override int GetCountFast()
@ -266,7 +266,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
")}}}
break;
")}}}
if (MaybeUnknown == null)
if (MaybeUnknown == null)
{{
ThrowKeyNotFoundException();
}}
@ -278,7 +278,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{{{Each(loop.HeadersByLength, byLength => $@"
case {byLength.Key}:
{{{Each(byLength, header => $@"
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
{{
if ({header.TestBit()})
{{
@ -304,7 +304,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{{{Each(loop.HeadersByLength, byLength => $@"
case {byLength.Key}:
{{{Each(byLength, header => $@"
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
{{
{header.SetBit()};
_headers._{header.Identifier} = value;{(header.EnhancedSetter == false ? "" : $@"
@ -347,7 +347,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{{{Each(loop.HeadersByLength, byLength => $@"
case {byLength.Key}:
{{{Each(byLength, header => $@"
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
if (""{header.Name}"".Equals(key, StringComparison.OrdinalIgnoreCase))
{{
if ({header.TestBit()})
{{
@ -372,7 +372,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
_headers = default(HeaderReferences);
MaybeUnknown?.Clear();
}}
protected override void CopyToFast(KeyValuePair<string, StringValues>[] array, int arrayIndex)
{{
if (arrayIndex < 0)
@ -380,7 +380,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
ThrowArgumentException();
}}
{Each(loop.Headers, header => $@"
if ({header.TestBit()})
if ({header.TestBit()})
{{
if (arrayIndex == array.Length)
{{
@ -397,12 +397,12 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
protected void CopyToFast(ref MemoryPoolIterator output)
{{
{Each(loop.Headers, header => $@"
if ({header.TestBit()})
if ({header.TestBit()})
{{ {(header.EnhancedSetter == false ? "" : $@"
if (_headers._raw{header.Identifier} != null)
if (_headers._raw{header.Identifier} != null)
{{
output.CopyFrom(_headers._raw{header.Identifier}, 0, _headers._raw{header.Identifier}.Length);
}}
}}
else ")}
foreach (var value in _headers._{header.Identifier})
{{
@ -418,17 +418,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
{(loop.ClassName == "FrameRequestHeaders" ? $@"
public unsafe void Append(byte[] keyBytes, int keyOffset, int keyLength, string value)
{{
fixed (byte* ptr = &keyBytes[keyOffset])
{{
var pUB = ptr;
var pUL = (ulong*)pUB;
var pUI = (uint*)pUB;
var key = new string('\0', keyLength);
fixed (byte* ptr = &keyBytes[keyOffset])
{{
var pUB = ptr;
var pUL = (ulong*)pUB;
var pUI = (uint*)pUB;
var pUS = (ushort*)pUB;
switch (keyLength)
{{{Each(loop.HeadersByLength, byLength => $@"
case {byLength.Key}:
{{{Each(byLength, header => $@"
if ({header.EqualIgnoreCaseBytes()})
if ({header.EqualIgnoreCaseBytes()})
{{
if ({header.TestBit()})
{{
@ -445,8 +446,16 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Http
")}}}
break;
")}}}
fixed(char *keyBuffer = key)
{{
if (!AsciiUtilities.TryGetAsciiString(ptr, keyBuffer, keyLength))
{{
throw new BadHttpRequestException(""Invalid characters in header name"");
}}
}}
}}
var key = System.Text.Encoding.ASCII.GetString(keyBytes, keyOffset, keyLength);
StringValues existing;
Unknown.TryGetValue(key, out existing);
Unknown[key] = AppendValue(existing, value);