From 0f013999555bab185dfd49005b71c0426e0fabfb Mon Sep 17 00:00:00 2001 From: Chris R Date: Fri, 8 Jan 2016 12:05:36 -0800 Subject: [PATCH] #123 Default headers to UTF8 --- .../RequestProcessing/HeaderEncoding.cs | 92 ++++--------------- .../RequestProcessing/Response.cs | 15 +-- src/Microsoft.Net.Http.Server/WebListener.cs | 3 +- .../RequestHeaderTests.cs | 34 ++++++- 4 files changed, 52 insertions(+), 92 deletions(-) diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/HeaderEncoding.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/HeaderEncoding.cs index 9b3726be86..39f658eb7d 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/HeaderEncoding.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/HeaderEncoding.cs @@ -21,94 +21,34 @@ // // ----------------------------------------------------------------------- -using System; -using System.Collections.Generic; -using System.Linq; using System.Text; namespace Microsoft.Net.Http.Server { - // we use this static class as a helper class to encode/decode HTTP headers. - // what we need is a 1-1 correspondence between a char in the range U+0000-U+00FF - // and a byte in the range 0x00-0xFF (which is the range that can hit the network). - // The Latin-1 encoding (ISO-88591-1) (GetEncoding(28591)) works for byte[] to string, but is a little slow. - // It doesn't work for string -> byte[] because of best-fit-mapping problems. internal static class HeaderEncoding { - internal static unsafe string GetString(byte[] bytes, int byteIndex, int byteCount) - { - fixed (byte* pBytes = bytes) - return GetString(pBytes + byteIndex, byteCount); - } + // It should just be ASCII or ANSI, but they break badly with un-expected values. We use UTF-8 because it's the same for + // ASCII, and because some old client would send UTF8 Host headers and expect UTF8 Location responses + // (e.g. IE and HttpWebRequest on intranets). + private static Encoding Encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: false); internal static unsafe string GetString(sbyte* pBytes, int byteCount) { - return GetString((byte*)pBytes, byteCount); + // net451: return new string(pBytes, 0, byteCount, Encoding); + + var charCount = Encoding.GetCharCount((byte*)pBytes, byteCount); + var chars = new char[charCount]; + fixed (char* pChars = chars) + { + var count = Encoding.GetChars((byte*)pBytes, byteCount, pChars, charCount); + System.Diagnostics.Debug.Assert(count == charCount); + } + return new string(chars); } - internal static unsafe string GetString(byte* pBytes, int byteCount) + internal static byte[] GetBytes(string myString) { - if (byteCount < 1) - { - return string.Empty; - } - - string s = new String('\0', byteCount); - - fixed (char* pStr = s) - { - char* pString = pStr; - while (byteCount >= 8) - { - pString[0] = (char)pBytes[0]; - pString[1] = (char)pBytes[1]; - pString[2] = (char)pBytes[2]; - pString[3] = (char)pBytes[3]; - pString[4] = (char)pBytes[4]; - pString[5] = (char)pBytes[5]; - pString[6] = (char)pBytes[6]; - pString[7] = (char)pBytes[7]; - pString += 8; - pBytes += 8; - byteCount -= 8; - } - for (int i = 0; i < byteCount; i++) - { - pString[i] = (char)pBytes[i]; - } - } - - return s; - } - - internal static int GetByteCount(string myString) - { - return myString.Length; - } - internal static unsafe void GetBytes(string myString, int charIndex, int charCount, byte[] bytes, int byteIndex) - { - if (myString.Length == 0) - { - return; - } - fixed (byte* bufferPointer = bytes) - { - byte* newBufferPointer = bufferPointer + byteIndex; - int finalIndex = charIndex + charCount; - while (charIndex < finalIndex) - { - *newBufferPointer++ = (byte)myString[charIndex++]; - } - } - } - internal static unsafe byte[] GetBytes(string myString) - { - byte[] bytes = new byte[myString.Length]; - if (myString.Length != 0) - { - GetBytes(myString, 0, myString.Length, bytes, 0); - } - return bytes; + return Encoding.GetBytes(myString); } } } diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs index a7d00602ce..668ba5ae42 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs @@ -401,11 +401,10 @@ namespace Microsoft.Net.Http.Server cachePolicy.SecondsToLive = (uint)Math.Min(cacheTtl.Value.Ticks / TimeSpan.TicksPerSecond, Int32.MaxValue); } - byte[] reasonPhraseBytes = new byte[HeaderEncoding.GetByteCount(reasonPhrase)]; + byte[] reasonPhraseBytes = HeaderEncoding.GetBytes(reasonPhrase); fixed (byte* pReasonPhrase = reasonPhraseBytes) { _nativeResponse.Response_V1.ReasonLength = (ushort)reasonPhraseBytes.Length; - HeaderEncoding.GetBytes(reasonPhrase, 0, reasonPhraseBytes.Length, reasonPhraseBytes, 0); _nativeResponse.Response_V1.pReason = (sbyte*)pReasonPhrase; fixed (HttpApi.HTTP_RESPONSE_V2* pResponse = &_nativeResponse) { @@ -622,18 +621,16 @@ namespace Microsoft.Net.Http.Server for (int headerValueIndex = 0; headerValueIndex < headerValues.Count; headerValueIndex++) { // Add Name - bytes = new byte[HeaderEncoding.GetByteCount(headerName)]; + bytes = HeaderEncoding.GetBytes(headerName); unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].NameLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerName, 0, bytes.Length, bytes, 0); gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); pinnedHeaders.Add(gcHandle); unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].pName = (sbyte*)gcHandle.AddrOfPinnedObject(); // Add Value headerValue = headerValues[headerValueIndex] ?? string.Empty; - bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + bytes = HeaderEncoding.GetBytes(headerValue); unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].RawValueLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); pinnedHeaders.Add(gcHandle); unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); @@ -643,9 +640,8 @@ namespace Microsoft.Net.Http.Server else if (headerPair.Value.Count == 1) { headerValue = headerValues[0] ?? string.Empty; - bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + bytes = HeaderEncoding.GetBytes(headerValue); pKnownHeaders[lookup].RawValueLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); pinnedHeaders.Add(gcHandle); pKnownHeaders[lookup].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); @@ -677,9 +673,8 @@ namespace Microsoft.Net.Http.Server { // Add Value headerValue = headerValues[headerValueIndex] ?? string.Empty; - bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + bytes = HeaderEncoding.GetBytes(headerValue); nativeHeaderValues[header.KnownHeaderCount].RawValueLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); pinnedHeaders.Add(gcHandle); nativeHeaderValues[header.KnownHeaderCount].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); diff --git a/src/Microsoft.Net.Http.Server/WebListener.cs b/src/Microsoft.Net.Http.Server/WebListener.cs index 29ecb64eba..942cc18c8f 100644 --- a/src/Microsoft.Net.Http.Server/WebListener.cs +++ b/src/Microsoft.Net.Http.Server/WebListener.cs @@ -774,9 +774,8 @@ namespace Microsoft.Net.Http.Server { // Add Value string headerValue = authChallenges[headerValueIndex]; - byte[] bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + byte[] bytes = HeaderEncoding.GetBytes(headerValue); nativeHeaderValues[header.KnownHeaderCount].RawValueLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); pinnedHeaders.Add(gcHandle); nativeHeaderValues[header.KnownHeaderCount].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); diff --git a/test/Microsoft.Net.Http.Server.FunctionalTests/RequestHeaderTests.cs b/test/Microsoft.Net.Http.Server.FunctionalTests/RequestHeaderTests.cs index f28e6850a5..40980cfbd3 100644 --- a/test/Microsoft.Net.Http.Server.FunctionalTests/RequestHeaderTests.cs +++ b/test/Microsoft.Net.Http.Server.FunctionalTests/RequestHeaderTests.cs @@ -1,7 +1,6 @@ -// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. using System; -using System.Linq; using System.Net.Http; using System.Net.Sockets; using System.Text; @@ -64,7 +63,34 @@ namespace Microsoft.Net.Http.Server await responseTask; } } - + + [Fact] + public async Task RequestHeaders_ClientSendsUtf8Headers_Success() + { + string address; + using (var server = Utilities.CreateHttpServer(out address)) + { + string[] customValues = new string[] { "custom1, and custom测试2", "custom3" }; + Task responseTask = SendRequestAsync(address, "Custom-Header", customValues); + + var context = await server.GetContextAsync(); + var requestHeaders = context.Request.Headers; + Assert.Equal(4, requestHeaders.Count); + Assert.Equal(new Uri(address).Authority, requestHeaders["Host"]); + Assert.Equal(new[] { new Uri(address).Authority }, requestHeaders.GetValues("Host")); + Assert.Equal("close", requestHeaders["Connection"]); + Assert.Equal(new[] { "close" }, requestHeaders.GetValues("Connection")); + // Apparently Http.Sys squashes request headers together. + Assert.Equal("custom1, and custom测试2, custom3", requestHeaders["Custom-Header"]); + Assert.Equal(new[] { "custom1", "and custom测试2", "custom3" }, requestHeaders.GetValues("Custom-Header")); + Assert.Equal("spacervalue, spacervalue", requestHeaders["Spacer-Header"]); + Assert.Equal(new[] { "spacervalue", "spacervalue" }, requestHeaders.GetValues("Spacer-Header")); + context.Dispose(); + + await responseTask; + } + } + private async Task SendRequestAsync(string uri) { using (HttpClient client = new HttpClient()) @@ -90,7 +116,7 @@ namespace Microsoft.Net.Http.Server } builder.AppendLine(); - byte[] request = Encoding.ASCII.GetBytes(builder.ToString()); + byte[] request = Encoding.UTF8.GetBytes(builder.ToString()); Socket socket = new Socket(SocketType.Stream, ProtocolType.Tcp); socket.Connect(uri.Host, uri.Port);