diff --git a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs index 5c174ad74c..a7d00602ce 100644 --- a/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs +++ b/src/Microsoft.Net.Http.Server/RequestProcessing/Response.cs @@ -575,7 +575,6 @@ namespace Microsoft.Net.Http.Server { if (headerPair.Value.Count == 0) { - // TODO: Have the collection exclude empty headers. continue; } // See if this is an unknown header @@ -602,7 +601,6 @@ namespace Microsoft.Net.Http.Server { if (headerPair.Value.Count == 0) { - // TODO: Have the collection exclude empty headers. continue; } headerName = headerPair.Key; @@ -632,7 +630,7 @@ namespace Microsoft.Net.Http.Server unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].pName = (sbyte*)gcHandle.AddrOfPinnedObject(); // Add Value - headerValue = headerValues[headerValueIndex]; + headerValue = headerValues[headerValueIndex] ?? string.Empty; bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; unknownHeaders[_nativeResponse.Response_V1.Headers.UnknownHeaderCount].RawValueLength = (ushort)bytes.Length; HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); @@ -644,16 +642,13 @@ namespace Microsoft.Net.Http.Server } else if (headerPair.Value.Count == 1) { - headerValue = headerValues[0]; - if (headerValue != null) - { - bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; - pKnownHeaders[lookup].RawValueLength = (ushort)bytes.Length; - HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); - gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); - pinnedHeaders.Add(gcHandle); - pKnownHeaders[lookup].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); - } + headerValue = headerValues[0] ?? string.Empty; + bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; + pKnownHeaders[lookup].RawValueLength = (ushort)bytes.Length; + HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); + gcHandle = GCHandle.Alloc(bytes, GCHandleType.Pinned); + pinnedHeaders.Add(gcHandle); + pKnownHeaders[lookup].pRawValue = (sbyte*)gcHandle.AddrOfPinnedObject(); } else { @@ -681,7 +676,7 @@ namespace Microsoft.Net.Http.Server for (int headerValueIndex = 0; headerValueIndex < headerValues.Count; headerValueIndex++) { // Add Value - headerValue = headerValues[headerValueIndex]; + headerValue = headerValues[headerValueIndex] ?? string.Empty; bytes = new byte[HeaderEncoding.GetByteCount(headerValue)]; nativeHeaderValues[header.KnownHeaderCount].RawValueLength = (ushort)bytes.Length; HeaderEncoding.GetBytes(headerValue, 0, bytes.Length, bytes, 0); diff --git a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ResponseHeaderTests.cs b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ResponseHeaderTests.cs index bd8bb9cf5f..08b422d815 100644 --- a/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ResponseHeaderTests.cs +++ b/test/Microsoft.AspNet.Server.WebListener.FunctionalTests/ResponseHeaderTests.cs @@ -23,7 +23,7 @@ using System.Net.Http; using System.Text; using System.Threading.Tasks; using Microsoft.AspNet.Http.Features; -using Microsoft.AspNet.Http.Internal; +using Microsoft.Extensions.Primitives; using Xunit; namespace Microsoft.AspNet.Server.WebListener @@ -259,6 +259,58 @@ namespace Microsoft.AspNet.Server.WebListener } } + [Theory, MemberData(nameof(NullHeaderData))] + public async Task Headers_IgnoreNullHeaders(string headerName, StringValues headerValue, StringValues expectedValue) + { + string address; + using (Utilities.CreateHttpServer(out address, httpContext => + { + var responseHeaders = httpContext.Response.Headers; + responseHeaders.Add(headerName, headerValue); + return Task.FromResult(0); + })) + { + HttpResponseMessage response = await SendRequestAsync(address); + response.EnsureSuccessStatusCode(); + var headers = response.Headers; + + if (expectedValue.Count == 0) + { + Assert.False(headers.Contains(headerName)); + } + else + { + Assert.True(headers.Contains(headerName)); + Assert.Equal(headers.GetValues(headerName), expectedValue); + } + } + } + + public static TheoryData NullHeaderData + { + get + { + var dataset = new TheoryData(); + + // Unknown headers + dataset.Add("NullString", (string)null, (string)null); + dataset.Add("EmptyString", "", ""); + dataset.Add("NullStringArray", new string[] { null }, ""); + dataset.Add("EmptyStringArray", new string[] { "" }, ""); + dataset.Add("MixedStringArray", new string[] { null, "" }, new string[] { "", "" }); + // Known headers + dataset.Add("Location", (string)null, (string)null); + dataset.Add("Location", "", (string)null); + dataset.Add("Location", new string[] { null }, (string)null); + dataset.Add("Location", new string[] { "" }, (string)null); + dataset.Add("Location", new string[] { "a" }, "a"); + dataset.Add("Location", new string[] { null, "" }, (string)null); + dataset.Add("Location", new string[] { null, "", "a", "b" }, new string[] { "a", "b" }); + + return dataset; + } + } + private async Task SendRequestAsync(string uri) { using (HttpClient client = new HttpClient())