diff --git a/samples/HelloWorld/Program.cs b/samples/HelloWorld/Program.cs index 3eb30c4f7e..243af9a49d 100644 --- a/samples/HelloWorld/Program.cs +++ b/samples/HelloWorld/Program.cs @@ -48,7 +48,7 @@ namespace HelloWorld // Request // context.Request.ProtocolVersion // context.Request.IsLocal - // context.Request.Headers // TODO: Header helpers? + // context.Request.Headers // context.Request.Method // context.Request.Body // Content-Length - long? diff --git a/src/Microsoft.Net.Server/AuthenticationManager.cs b/src/Microsoft.Net.Server/AuthenticationManager.cs index e7b1ab6e71..458b46261d 100644 --- a/src/Microsoft.Net.Server/AuthenticationManager.cs +++ b/src/Microsoft.Net.Server/AuthenticationManager.cs @@ -24,6 +24,7 @@ using System; using System.Collections.Generic; using System.Diagnostics; +using System.Linq; using System.Runtime.InteropServices; using System.Security.Claims; using System.Security.Principal; @@ -152,22 +153,7 @@ namespace Microsoft.Net.Server if (challenges.Count > 0) { - // TODO: We need a better header API that just lets us append values. - // Append to the existing header, if any. Some clients (IE, Chrome) require each challenges to be sent on their own line/header. - string[] oldValues; - string[] newValues; - if (context.Response.Headers.TryGetValue(HttpKnownHeaderNames.WWWAuthenticate, out oldValues)) - { - newValues = new string[oldValues.Length + challenges.Count]; - Array.Copy(oldValues, newValues, oldValues.Length); - challenges.CopyTo(newValues, oldValues.Length); - } - else - { - newValues = new string[challenges.Count]; - challenges.CopyTo(newValues, 0); - } - context.Response.Headers[HttpKnownHeaderNames.WWWAuthenticate] = newValues; + context.Response.Headers.AppendValues(HttpKnownHeaderNames.WWWAuthenticate, challenges.ToArray()); } } diff --git a/src/Microsoft.Net.Server/DictionaryExtensions.cs b/src/Microsoft.Net.Server/DictionaryExtensions.cs deleted file mode 100644 index 7aa0e65d2c..0000000000 --- a/src/Microsoft.Net.Server/DictionaryExtensions.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) Microsoft Open Technologies, Inc. -// All Rights Reserved -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR -// CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING -// WITHOUT LIMITATION ANY IMPLIED WARRANTIES OR CONDITIONS OF -// TITLE, FITNESS FOR A PARTICULAR PURPOSE, MERCHANTABLITY OR -// NON-INFRINGEMENT. -// See the Apache 2 License for the specific language governing -// permissions and limitations under the License. - -// ----------------------------------------------------------------------- -// -// Copyright (c) Microsoft Corporation. All rights reserved. -// -// ----------------------------------------------------------------------- - -using System; -using System.Linq; -using System.Text; - -namespace System.Collections.Generic -{ - internal static class DictionaryExtensions - { - internal static void Append(this IDictionary dictionary, string key, string value) - { - string[] orriginalValues; - if (dictionary.TryGetValue(key, out orriginalValues)) - { - string[] newValues = new string[orriginalValues.Length + 1]; - orriginalValues.CopyTo(newValues, 0); - newValues[newValues.Length - 1] = value; - dictionary[key] = newValues; - } - else - { - dictionary[key] = new string[] { value }; - } - } - - internal static string Get(this IDictionary dictionary, string key) - { - string[] values; - if (dictionary.TryGetValue(key, out values)) - { - return string.Join(", ", values); - } - return null; - } - - internal static T Get(this IDictionary dictionary, string key, T fallback = default(T)) - { - object values; - if (dictionary.TryGetValue(key, out values)) - { - return (T)values; - } - return fallback; - } - } -} diff --git a/src/Microsoft.Net.Server/Microsoft.Net.Server.kproj b/src/Microsoft.Net.Server/Microsoft.Net.Server.kproj index 7c64fe5b48..7ef3155968 100644 --- a/src/Microsoft.Net.Server/Microsoft.Net.Server.kproj +++ b/src/Microsoft.Net.Server/Microsoft.Net.Server.kproj @@ -22,7 +22,6 @@ - @@ -52,6 +51,8 @@ + + diff --git a/src/Microsoft.Net.Server/NativeInterop/UnsafeNativeMethods.cs b/src/Microsoft.Net.Server/NativeInterop/UnsafeNativeMethods.cs index 096887987f..acf409ed44 100644 --- a/src/Microsoft.Net.Server/NativeInterop/UnsafeNativeMethods.cs +++ b/src/Microsoft.Net.Server/NativeInterop/UnsafeNativeMethods.cs @@ -961,9 +961,9 @@ namespace Microsoft.Net.Server { headerValue = string.Empty; } - // Note that Http.Sys currently collapses all headers of the same name to a single string, so - // append will just set. - unknownHeaders.Append(headerName, headerValue); + // Note that Http.Sys currently collapses all headers of the same name to a single coma seperated string, + // so we can just call Set. + unknownHeaders[headerName] = new[] { headerValue }; } pUnknownHeader++; } diff --git a/src/Microsoft.Net.Server/RequestProcessing/HeaderCollection.cs b/src/Microsoft.Net.Server/RequestProcessing/HeaderCollection.cs new file mode 100644 index 0000000000..62a800cd71 --- /dev/null +++ b/src/Microsoft.Net.Server/RequestProcessing/HeaderCollection.cs @@ -0,0 +1,183 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections; +using System.Collections.Generic; + +namespace Microsoft.Net.Server +{ + public class HeaderCollection : IDictionary + { + public HeaderCollection() + : this(new Dictionary(4, StringComparer.OrdinalIgnoreCase)) + { + } + + public HeaderCollection(IDictionary store) + { + Store = store; + } + + private IDictionary Store { get; set; } + + public string this[string key] + { + get { return Get(key); } + set + { + if (string.IsNullOrEmpty(value)) + { + Remove(key); + } + else + { + Set(key, value); + } + } + } + + string[] IDictionary.this[string key] + { + get { return Store[key]; } + set { Store[key] = value; } + } + + public int Count + { + get { return Store.Count; } + } + + public bool IsReadOnly + { + get { return false; } + } + + public ICollection Keys + { + get { return Store.Keys; } + } + + public ICollection Values + { + get { return Store.Values; } + } + + public void Add(KeyValuePair item) + { + Store.Add(item); + } + + public void Add(string key, string[] value) + { + Store.Add(key, value); + } + + public void Append(string key, string value) + { + string[] values; + if (Store.TryGetValue(key, out values)) + { + var newValues = new string[values.Length + 1]; + Array.Copy(values, newValues, values.Length); + newValues[values.Length] = value; + Store[key] = newValues; + } + else + { + Set(key, value); + } + } + + public void AppendValues(string key, params string[] values) + { + string[] oldValues; + if (Store.TryGetValue(key, out oldValues)) + { + var newValues = new string[oldValues.Length + values.Length]; + Array.Copy(oldValues, newValues, oldValues.Length); + Array.Copy(values, 0, newValues, oldValues.Length, values.Length); + Store[key] = newValues; + } + else + { + SetValues(key, values); + } + } + + public void Clear() + { + Store.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return Store.Contains(item); + } + + public bool ContainsKey(string key) + { + return Store.ContainsKey(key); + } + + public void CopyTo(KeyValuePair[] array, int arrayIndex) + { + Store.CopyTo(array, arrayIndex); + } + + public string Get(string key) + { + string[] values; + if (Store.TryGetValue(key, out values)) + { + return string.Join(", ", values); + } + return null; + } + + public IEnumerator> GetEnumerator() + { + return Store.GetEnumerator(); + } + + public IEnumerable GetValues(string key) + { + string[] values; + if (Store.TryGetValue(key, out values)) + { + return HeaderParser.SplitValues(values); + } + return HeaderParser.Empty; + } + + public bool Remove(KeyValuePair item) + { + return Store.Remove(item); + } + + public bool Remove(string key) + { + return Store.Remove(key); + } + + public void Set(string key, string value) + { + Store[key] = new[] { value }; + } + + public void SetValues(string key, params string[] values) + { + Store[key] = values; + } + + public bool TryGetValue(string key, out string[] value) + { + return Store.TryGetValue(key, out value); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return GetEnumerator(); + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Net.Server/RequestProcessing/HeaderParser.cs b/src/Microsoft.Net.Server/RequestProcessing/HeaderParser.cs new file mode 100644 index 0000000000..fb217f06b1 --- /dev/null +++ b/src/Microsoft.Net.Server/RequestProcessing/HeaderParser.cs @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; + +namespace Microsoft.Net.Server +{ + internal static class HeaderParser + { + internal static IEnumerable Empty = new string[0]; + + // Split on commas, except in quotes + internal static IEnumerable SplitValues(string[] values) + { + foreach (var value in values) + { + int start = 0; + bool inQuotes = false; + int current = 0; + for ( ; current < value.Length; current++) + { + char ch = value[current]; + if (inQuotes) + { + if (ch == '"') + { + inQuotes = false; + } + } + else if (ch == '"') + { + inQuotes = true; + } + else if (ch == ',') + { + var subValue = value.Substring(start, current - start); + if (!string.IsNullOrWhiteSpace(subValue)) + { + yield return subValue.Trim(); + start = current + 1; + } + } + } + + if (start < current) + { + var subValue = value.Substring(start, current - start); + if (!string.IsNullOrWhiteSpace(subValue)) + { + yield return subValue.Trim(); + start = current + 1; + } + } + } + } + } +} \ No newline at end of file diff --git a/src/Microsoft.Net.Server/RequestProcessing/Request.cs b/src/Microsoft.Net.Server/RequestProcessing/Request.cs index 40bb8318bf..4e05202106 100644 --- a/src/Microsoft.Net.Server/RequestProcessing/Request.cs +++ b/src/Microsoft.Net.Server/RequestProcessing/Request.cs @@ -61,7 +61,7 @@ namespace Microsoft.Net.Server private X509Certificate _clientCert; - private IDictionary _headers; + private HeaderCollection _headers; private BoundaryType _contentBoundaryType; private long? _contentLength; private Stream _nativeStream; @@ -140,7 +140,7 @@ namespace Microsoft.Net.Server } _httpMethod = UnsafeNclNativeMethods.HttpApi.GetVerb(RequestBuffer, OriginalBlobAddress); - _headers = new RequestHeaders(_nativeRequestContext); + _headers = new HeaderCollection(new RequestHeaders(_nativeRequestContext)); UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2* requestV2 = (UnsafeNclNativeMethods.HttpApi.HTTP_REQUEST_V2*)memoryBlob.RequestBlob; _user = AuthenticationManager.GetUser(requestV2->pRequestInfo); @@ -250,7 +250,7 @@ namespace Microsoft.Net.Server } } - public IDictionary Headers + public HeaderCollection Headers { get { return _headers; } } @@ -424,17 +424,6 @@ namespace Microsoft.Net.Server return UnsafeNclNativeMethods.HttpApi.GetKnownVerb(RequestBuffer, OriginalBlobAddress); } - // TODO: We need an easier to user header collection that has this built in - internal string GetHeader(string headerName) - { - string[] values; - if (Headers.TryGetValue(headerName, out values)) - { - return string.Join(", ", values); - } - return string.Empty; - } - // Populates the client certificate. The result may be null if there is no client cert. // TODO: Does it make sense for this to be invoked multiple times (e.g. renegotiate)? Client and server code appear to // enable this, but it's unclear what Http.Sys would do. diff --git a/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs b/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs index 3864652cef..98b9c99cb6 100644 --- a/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs +++ b/src/Microsoft.Net.Server/RequestProcessing/RequestContext.cs @@ -179,28 +179,28 @@ namespace Microsoft.Net.Server } // Connection: Upgrade (some odd clients send Upgrade,KeepAlive) - string connection = Request.GetHeader(HttpKnownHeaderNames.Connection); + string connection = Request.Headers[HttpKnownHeaderNames.Connection] ?? string.Empty; if (connection.IndexOf(HttpKnownHeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase) < 0) { return false; } // Upgrade: websocket - string upgrade = Request.GetHeader(HttpKnownHeaderNames.Upgrade); + string upgrade = Request.Headers[HttpKnownHeaderNames.Upgrade]; if (!string.Equals(WebSocketHelpers.WebSocketUpgradeToken, upgrade, StringComparison.OrdinalIgnoreCase)) { return false; } // Sec-WebSocket-Version: 13 - string version = Request.GetHeader(HttpKnownHeaderNames.SecWebSocketVersion); + string version = Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; if (!string.Equals(WebSocketConstants.SupportedProtocolVersion, version, StringComparison.OrdinalIgnoreCase)) { return false; } // Sec-WebSocket-Key: {base64string} - string key = Request.GetHeader(HttpKnownHeaderNames.SecWebSocketKey); + string key = Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; if (!WebSocketHelpers.IsValidWebSocketKey(key)) { return false; @@ -229,28 +229,28 @@ namespace Microsoft.Net.Server } // Connection: Upgrade (some odd clients send Upgrade,KeepAlive) - string connection = Request.GetHeader(HttpKnownHeaderNames.Connection); + string connection = Request.Headers[HttpKnownHeaderNames.Connection] ?? string.Empty; if (connection.IndexOf(HttpKnownHeaderNames.Upgrade, StringComparison.OrdinalIgnoreCase) < 0) { throw new InvalidOperationException("The Connection header is invalid: " + connection); } // Upgrade: websocket - string upgrade = Request.GetHeader(HttpKnownHeaderNames.Upgrade); + string upgrade = Request.Headers[HttpKnownHeaderNames.Upgrade]; if (!string.Equals(WebSocketHelpers.WebSocketUpgradeToken, upgrade, StringComparison.OrdinalIgnoreCase)) { throw new InvalidOperationException("The Upgrade header is invalid: " + upgrade); } // Sec-WebSocket-Version: 13 - string version = Request.GetHeader(HttpKnownHeaderNames.SecWebSocketVersion); + string version = Request.Headers[HttpKnownHeaderNames.SecWebSocketVersion]; if (!string.Equals(WebSocketConstants.SupportedProtocolVersion, version, StringComparison.OrdinalIgnoreCase)) { throw new InvalidOperationException("The Sec-WebSocket-Version header is invalid or not supported: " + version); } // Sec-WebSocket-Key: {base64string} - string key = Request.GetHeader(HttpKnownHeaderNames.SecWebSocketKey); + string key = Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; if (!WebSocketHelpers.IsValidWebSocketKey(key)) { throw new InvalidOperationException("The Sec-WebSocket-Key header is invalid: " + upgrade); @@ -317,29 +317,22 @@ namespace Microsoft.Net.Server { try { - // TODO: We need a better header collection API. ValidateWebSocketRequest(); - string subProtocols = string.Empty; - string[] values; - if (Request.Headers.TryGetValue(HttpKnownHeaderNames.SecWebSocketProtocol, out values)) - { - subProtocols = string.Join(", ", values); - } - + var subProtocols = Request.Headers.GetValues(HttpKnownHeaderNames.SecWebSocketProtocol); bool shouldSendSecWebSocketProtocolHeader = WebSocketHelpers.ProcessWebSocketProtocolHeader(subProtocols, subProtocol); if (shouldSendSecWebSocketProtocolHeader) { - Response.Headers[HttpKnownHeaderNames.SecWebSocketProtocol] = new[] { subProtocol }; + Response.Headers[HttpKnownHeaderNames.SecWebSocketProtocol] = subProtocol; } // negotiate the websocket key return value - string secWebSocketKey = Request.Headers[HttpKnownHeaderNames.SecWebSocketKey].First(); + string secWebSocketKey = Request.Headers[HttpKnownHeaderNames.SecWebSocketKey]; string secWebSocketAccept = WebSocketHelpers.GetSecWebSocketAcceptString(secWebSocketKey); - Response.Headers.Add(HttpKnownHeaderNames.Connection, new[] { HttpKnownHeaderNames.Upgrade }); - Response.Headers.Add(HttpKnownHeaderNames.Upgrade, new[] { WebSocketHelpers.WebSocketUpgradeToken }); - Response.Headers.Add(HttpKnownHeaderNames.SecWebSocketAccept, new[] { secWebSocketAccept }); + Response.Headers.AppendValues(HttpKnownHeaderNames.Connection, HttpKnownHeaderNames.Upgrade); + Response.Headers.AppendValues(HttpKnownHeaderNames.Upgrade, WebSocketHelpers.WebSocketUpgradeToken); + Response.Headers.AppendValues(HttpKnownHeaderNames.SecWebSocketAccept, secWebSocketAccept); Stream opaqueStream = await UpgradeAsync(); diff --git a/src/Microsoft.Net.Server/RequestProcessing/RequestHeaders.cs b/src/Microsoft.Net.Server/RequestProcessing/RequestHeaders.cs index 46f432f10b..113f5bc3fd 100644 --- a/src/Microsoft.Net.Server/RequestProcessing/RequestHeaders.cs +++ b/src/Microsoft.Net.Server/RequestProcessing/RequestHeaders.cs @@ -102,17 +102,27 @@ namespace Microsoft.Net.Server } } - bool IDictionary.ContainsKey(string key) + public bool ContainsKey(string key) { return PropertiesContainsKey(key) || Extra.ContainsKey(key); } - ICollection IDictionary.Keys + public ICollection Keys { get { return PropertiesKeys().Concat(Extra.Keys).ToArray(); } } - bool IDictionary.Remove(string key) + ICollection IDictionary.Values + { + get { return PropertiesValues().Concat(Extra.Values).ToArray(); } + } + + public int Count + { + get { return PropertiesKeys().Count() + Extra.Count; } + } + + public bool Remove(string key) { // Although this is a mutating operation, Extra is used instead of StrongExtra, // because if a real dictionary has not been allocated the default behavior of the @@ -120,16 +130,11 @@ namespace Microsoft.Net.Server return PropertiesTryRemove(key) || Extra.Remove(key); } - bool IDictionary.TryGetValue(string key, out string[] value) + public bool TryGetValue(string key, out string[] value) { return PropertiesTryGetValue(key, out value) || Extra.TryGetValue(key, out value); } - ICollection IDictionary.Values - { - get { return PropertiesValues().Concat(Extra.Values).ToArray(); } - } - void ICollection>.Add(KeyValuePair item) { ((IDictionary)this).Add(item.Key, item.Value); @@ -155,11 +160,6 @@ namespace Microsoft.Net.Server PropertiesEnumerable().Concat(Extra).ToArray().CopyTo(array, arrayIndex); } - int ICollection>.Count - { - get { return PropertiesKeys().Count() + Extra.Count; } - } - bool ICollection>.IsReadOnly { get { return false; } diff --git a/src/Microsoft.Net.Server/RequestProcessing/Response.cs b/src/Microsoft.Net.Server/RequestProcessing/Response.cs index d803bb4a23..ec5eeecbf4 100644 --- a/src/Microsoft.Net.Server/RequestProcessing/Response.cs +++ b/src/Microsoft.Net.Server/RequestProcessing/Response.cs @@ -38,7 +38,7 @@ namespace Microsoft.Net.Server private static readonly string[] ZeroContentLength = new[] { "0" }; private ResponseState _responseState; - private IDictionary _headers; + private HeaderCollection _headers; private string _reasonPhrase; private ResponseStream _nativeStream; private long _contentLength; @@ -53,7 +53,7 @@ namespace Microsoft.Net.Server // TODO: Verbose log _requestContext = httpContext; _nativeResponse = new UnsafeNclNativeMethods.HttpApi.HTTP_RESPONSE_V2(); - _headers = new Dictionary(StringComparer.OrdinalIgnoreCase); + _headers = new HeaderCollection(); _boundaryType = BoundaryType.None; _nativeResponse.Response_V1.StatusCode = (ushort)HttpStatusCode.OK; _nativeResponse.Response_V1.Version.MajorVersion = 1; @@ -156,17 +156,9 @@ namespace Microsoft.Net.Server return true; } - public IDictionary Headers + public HeaderCollection Headers { get { return _headers; } - set - { - if (value == null) - { - throw new ArgumentNullException("value"); - } - _headers = value; - } } internal long CalculatedLength @@ -214,11 +206,11 @@ namespace Microsoft.Net.Server if (value.Value == 0) { - Headers[HttpKnownHeaderNames.ContentLength] = ZeroContentLength; + ((IDictionary)Headers)[HttpKnownHeaderNames.ContentLength] = ZeroContentLength; } else { - Headers[HttpKnownHeaderNames.ContentLength] = new[] { value.Value.ToString(CultureInfo.InvariantCulture) }; + Headers[HttpKnownHeaderNames.ContentLength] = value.Value.ToString(CultureInfo.InvariantCulture); } } } @@ -239,7 +231,7 @@ namespace Microsoft.Net.Server } else { - Headers[HttpKnownHeaderNames.ContentType] = new[] { value }; + Headers[HttpKnownHeaderNames.ContentType] = value; } } } @@ -535,7 +527,7 @@ namespace Microsoft.Net.Server else if (endOfRequest) { // The request is ending without a body, add a Content-Length: 0 header. - Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + Headers[HttpKnownHeaderNames.ContentLength] = "0"; _boundaryType = BoundaryType.ContentLength; _contentLength = 0; flags = UnsafeNclNativeMethods.HttpApi.HTTP_FLAGS.NONE; @@ -551,7 +543,7 @@ namespace Microsoft.Net.Server } else { - Headers[HttpKnownHeaderNames.TransferEncoding] = new string[] { "chunked" }; + Headers[HttpKnownHeaderNames.TransferEncoding] = "chunked"; _boundaryType = BoundaryType.Chunked; } @@ -561,7 +553,7 @@ namespace Microsoft.Net.Server } else { - Headers[HttpKnownHeaderNames.ContentLength] = new string[] { "0" }; + Headers[HttpKnownHeaderNames.ContentLength] = "0"; _contentLength = 0; _boundaryType = BoundaryType.ContentLength; } @@ -583,7 +575,7 @@ namespace Microsoft.Net.Server { if (Request.ProtocolVersion.Minor == 0 && !keepAliveSet) { - Headers[HttpKnownHeaderNames.KeepAlive] = new string[] { "true" }; + Headers[HttpKnownHeaderNames.KeepAlive] = "true"; } } return flags; diff --git a/src/Microsoft.Net.WebSockets/WebSocketHelpers.cs b/src/Microsoft.Net.WebSockets/WebSocketHelpers.cs index 026f295077..fc8a830408 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketHelpers.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketHelpers.cs @@ -22,10 +22,12 @@ //------------------------------------------------------------------------------ using System; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Diagnostics.Contracts; using System.Globalization; using System.IO; +using System.Linq; using System.Net.WebSockets; using System.Runtime.CompilerServices; using System.Security.Cryptography; @@ -132,9 +134,9 @@ namespace Microsoft.Net.WebSockets } // return value here signifies if a Sec-WebSocket-Protocol header should be returned by the server. - public static bool ProcessWebSocketProtocolHeader(string clientSecWebSocketProtocol, string subProtocol) + public static bool ProcessWebSocketProtocolHeader(IEnumerable clientSecWebSocketProtocols, string subProtocol) { - if (string.IsNullOrEmpty(clientSecWebSocketProtocol)) + if (clientSecWebSocketProtocols == null || !clientSecWebSocketProtocols.Any()) { // client hasn't specified any Sec-WebSocket-Protocol header if (subProtocol != null) @@ -158,14 +160,10 @@ namespace Microsoft.Net.WebSockets // here, we know that the client has specified something, it's not empty // and the server has specified exactly one protocol - string[] requestProtocols = clientSecWebSocketProtocol.Split(new char[] { ',' }, - StringSplitOptions.RemoveEmptyEntries); - // client specified protocols, serverOptions has exactly 1 non-empty entry. Check that // this exists in the list the client specified. - for (int i = 0; i < requestProtocols.Length; i++) + foreach (var currentRequestProtocol in clientSecWebSocketProtocols) { - string currentRequestProtocol = requestProtocols[i].Trim(); if (string.Compare(subProtocol, currentRequestProtocol, StringComparison.OrdinalIgnoreCase) == 0) { return true; @@ -174,7 +172,7 @@ namespace Microsoft.Net.WebSockets throw new WebSocketException(WebSocketError.UnsupportedProtocol, SR.GetString(SR.net_WebSockets_AcceptUnsupportedProtocol, - clientSecWebSocketProtocol, + string.Join(", ", clientSecWebSocketProtocols), subProtocol)); } diff --git a/test/Microsoft.Net.Server.FunctionalTests/OpaqueUpgradeTests.cs b/test/Microsoft.Net.Server.FunctionalTests/OpaqueUpgradeTests.cs index 4b4d82bee6..cbf99ec975 100644 --- a/test/Microsoft.Net.Server.FunctionalTests/OpaqueUpgradeTests.cs +++ b/test/Microsoft.Net.Server.FunctionalTests/OpaqueUpgradeTests.cs @@ -25,7 +25,7 @@ namespace Microsoft.Net.Server byte[] body = Encoding.UTF8.GetBytes("Hello World"); context.Response.Body.Write(body, 0, body.Length); - context.Response.Headers["Upgrade"] = new[] { "WebSocket" }; // Win8.1 blocks anything but WebSocket + context.Response.Headers["Upgrade"] = "WebSocket"; // Win8.1 blocks anything but WebSocket Assert.ThrowsAsync(async () => await context.UpgradeAsync()); context.Dispose(); HttpResponseMessage response = await clientTask; @@ -44,7 +44,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); Assert.True(context.IsUpgradableRequest); - context.Response.Headers["Upgrade"] = new[] { "WebSocket" }; // Win8.1 blocks anything but WebSocket + context.Response.Headers["Upgrade"] = "WebSocket"; // Win8.1 blocks anything but WebSocket Stream serverStream = await context.UpgradeAsync(); Assert.True(serverStream.CanRead); Assert.True(serverStream.CanWrite); @@ -86,7 +86,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); Assert.True(context.IsUpgradableRequest); - context.Response.Headers["Upgrade"] = new[] { "WebSocket" }; // Win8.1 blocks anything but WebSocket + context.Response.Headers["Upgrade"] = "WebSocket"; // Win8.1 blocks anything but WebSocket Stream serverStream = await context.UpgradeAsync(); Stream clientStream = await clientTask; diff --git a/test/Microsoft.Net.Server.FunctionalTests/RequestHeaderTests.cs b/test/Microsoft.Net.Server.FunctionalTests/RequestHeaderTests.cs index 7a4122db53..fbc6cf55d5 100644 --- a/test/Microsoft.Net.Server.FunctionalTests/RequestHeaderTests.cs +++ b/test/Microsoft.Net.Server.FunctionalTests/RequestHeaderTests.cs @@ -25,9 +25,11 @@ namespace Microsoft.Net.Server // NOTE: The System.Net client only sends the Connection: keep-alive header on the first connection per service-point. // Assert.Equal(2, requestHeaders.Count); // Assert.Equal("Keep-Alive", requestHeaders.Get("Connection")); - Assert.Equal("localhost:8080", requestHeaders["Host"].First()); + Assert.Equal("localhost:8080", requestHeaders["Host"]); string[] values; Assert.False(requestHeaders.TryGetValue("Accept", out values)); + Assert.False(requestHeaders.ContainsKey("Accept")); + Assert.Null(requestHeaders["Accept"]); context.Dispose(); string response = await responseTask; @@ -46,13 +48,15 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var requestHeaders = context.Request.Headers; Assert.Equal(4, requestHeaders.Count); - Assert.Equal("localhost:8080", requestHeaders["Host"].First()); - Assert.Equal("close", requestHeaders["Connection"].First()); - Assert.Equal(1, requestHeaders["Custom-Header"].Length); + Assert.Equal("localhost:8080", requestHeaders["Host"]); + Assert.Equal(new[] { "localhost:8080" }, requestHeaders.GetValues("Host")); + Assert.Equal("close", requestHeaders["Connection"]); + Assert.Equal(new[] { "close" }, requestHeaders.GetValues("Connection")); // Apparently Http.Sys squashes request headers together. - Assert.Equal("custom1, and custom2, custom3", requestHeaders["Custom-Header"].First()); - Assert.Equal(1, requestHeaders["Spacer-Header"].Length); - Assert.Equal("spacervalue, spacervalue", requestHeaders["Spacer-Header"].First()); + Assert.Equal("custom1, and custom2, custom3", requestHeaders["Custom-Header"]); + Assert.Equal(new[] { "custom1", "and custom2", "custom3" }, requestHeaders.GetValues("Custom-Header")); + Assert.Equal("spacervalue, spacervalue", requestHeaders["Spacer-Header"]); + Assert.Equal(new[] { "spacervalue", "spacervalue" }, requestHeaders.GetValues("Spacer-Header")); context.Dispose(); await responseTask; diff --git a/test/Microsoft.Net.Server.FunctionalTests/ResponseBodyTests.cs b/test/Microsoft.Net.Server.FunctionalTests/ResponseBodyTests.cs index 25a125372f..30b32dcc07 100644 --- a/test/Microsoft.Net.Server.FunctionalTests/ResponseBodyTests.cs +++ b/test/Microsoft.Net.Server.FunctionalTests/ResponseBodyTests.cs @@ -45,7 +45,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Request.Headers["transfeR-Encoding"] = new[] { " CHunked " }; + context.Request.Headers["transfeR-Encoding"] = " CHunked "; Stream stream = context.Response.Body; stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); stream.Write(new byte[10], 0, 10); @@ -70,7 +70,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { " 30 " }; + context.Response.Headers["Content-lenGth"] = " 30 "; Stream stream = context.Response.Body; stream.EndWrite(stream.BeginWrite(new byte[10], 0, 10, null, null)); stream.Write(new byte[10], 0, 10); @@ -132,7 +132,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { " 20 " }; + context.Response.Headers["Content-lenGth"] = " 20 "; context.Response.Body.Write(new byte[5], 0, 5); context.Dispose(); @@ -148,7 +148,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { " 10 " }; + context.Response.Headers["Content-lenGth"] = " 10 "; context.Response.Body.Write(new byte[5], 0, 5); Assert.Throws(() => context.Response.Body.Write(new byte[6], 0, 6)); context.Dispose(); @@ -165,7 +165,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { " 10 " }; + context.Response.Headers["Content-lenGth"] = " 10 "; context.Response.Body.Write(new byte[10], 0, 10); Assert.Throws(() => context.Response.Body.Write(new byte[6], 0, 6)); context.Dispose(); diff --git a/test/Microsoft.Net.Server.FunctionalTests/ResponseHeaderTests.cs b/test/Microsoft.Net.Server.FunctionalTests/ResponseHeaderTests.cs index 4411ba1d74..d0290f0693 100644 --- a/test/Microsoft.Net.Server.FunctionalTests/ResponseHeaderTests.cs +++ b/test/Microsoft.Net.Server.FunctionalTests/ResponseHeaderTests.cs @@ -44,7 +44,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders["WWW-Authenticate"] = new string[] { "custom1" }; + responseHeaders["WWW-Authenticate"] = "custom1"; context.Dispose(); // HttpClient would merge the headers no matter what @@ -68,7 +68,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders["WWW-Authenticate"] = new string[] { "custom1, and custom2", "custom3" }; + responseHeaders.SetValues("WWW-Authenticate", "custom1, and custom2", "custom3"); context.Dispose(); // HttpClient would merge the headers no matter what @@ -92,7 +92,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders["Custom-Header1"] = new string[] { "custom1, and custom2", "custom3" }; + responseHeaders.SetValues("Custom-Header1", "custom1, and custom2", "custom3"); context.Dispose(); // HttpClient would merge the headers no matter what @@ -115,7 +115,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders["Connection"] = new string[] { "Close" }; + responseHeaders["Connection"] = "Close"; context.Dispose(); HttpResponseMessage response = await responseTask; @@ -198,7 +198,7 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders["Transfer-Encoding"] = new string[] { "chunked" }; + responseHeaders["Transfer-Encoding"] = "chunked"; await context.Response.Body.WriteAsync(new byte[10], 0, 10); context.Dispose(); @@ -223,8 +223,8 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); - responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + responseHeaders.SetValues("Custom1", "value1a", "value1b"); + responseHeaders.SetValues("Custom2", "value2a, value2b"); var body = context.Response.Body; body.Flush(); Assert.Throws(() => context.Response.StatusCode = 404); @@ -254,12 +254,12 @@ namespace Microsoft.Net.Server var context = await server.GetContextAsync(); var responseHeaders = context.Response.Headers; - responseHeaders.Add("Custom1", new string[] { "value1a", "value1b" }); - responseHeaders.Add("Custom2", new string[] { "value2a, value2b" }); + responseHeaders.SetValues("Custom1", "value1a", "value1b"); + responseHeaders.SetValues("Custom2", "value2a, value2b"); var body = context.Response.Body; await body.FlushAsync(); Assert.Throws(() => context.Response.StatusCode = 404); - responseHeaders.Add("Custom3", new string[] { "value3a, value3b", "value3c" }); // Ignored + responseHeaders.SetValues("Custom3", "value3a, value3b", "value3c"); // Ignored context.Dispose(); diff --git a/test/Microsoft.Net.Server.FunctionalTests/ResponseSendFileTests.cs b/test/Microsoft.Net.Server.FunctionalTests/ResponseSendFileTests.cs index ff9fb85c9c..3714c56b98 100644 --- a/test/Microsoft.Net.Server.FunctionalTests/ResponseSendFileTests.cs +++ b/test/Microsoft.Net.Server.FunctionalTests/ResponseSendFileTests.cs @@ -4,9 +4,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; -using System.Net; using System.Net.Http; -using System.Text; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -91,7 +89,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Transfer-EncodinG"] = new[] { "CHUNKED" }; + context.Response.Headers["Transfer-EncodinG"] = "CHUNKED"; await context.Response.SendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); context.Dispose(); @@ -112,7 +110,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Transfer-EncodinG"] = new[] { "CHUNKED" }; + context.Response.Headers["Transfer-EncodinG"] = "CHUNKED"; await context.Response.SendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); await context.Response.SendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); context.Dispose(); @@ -206,7 +204,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { FileLength.ToString() }; + context.Response.Headers["Content-lenGth"] = FileLength.ToString(); await context.Response.SendFileAsync(AbsoluteFilePath, 0, null, CancellationToken.None); HttpResponseMessage response = await responseTask; @@ -227,7 +225,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { "10" }; + context.Response.Headers["Content-lenGth"] = "10"; await context.Response.SendFileAsync(AbsoluteFilePath, 0, 10, CancellationToken.None); context.Dispose(); @@ -249,7 +247,7 @@ namespace Microsoft.Net.Server Task responseTask = SendRequestAsync(Address); var context = await server.GetContextAsync(); - context.Response.Headers["Content-lenGth"] = new[] { "0" }; + context.Response.Headers["Content-lenGth"] = "0"; await context.Response.SendFileAsync(AbsoluteFilePath, 0, 0, CancellationToken.None); context.Dispose();