diff --git a/src/Microsoft.Net.WebSockets/Constants.cs b/src/Microsoft.Net.WebSockets/Constants.cs index 8fb002913e..6877c70d0b 100644 --- a/src/Microsoft.Net.WebSockets/Constants.cs +++ b/src/Microsoft.Net.WebSockets/Constants.cs @@ -10,7 +10,8 @@ namespace Microsoft.Net.WebSockets { public static class Headers { - public const string WebSocketVersion = "Sec-WebSocket-Version"; + public const string SecWebSocketVersion = "Sec-WebSocket-Version"; + public const string SecWebSocketProtocol = "Sec-WebSocket-Protocol"; public const string SupportedVersion = "13"; } diff --git a/src/Microsoft.Net.WebSockets/WebSocketClient.cs b/src/Microsoft.Net.WebSockets/WebSocketClient.cs index 60636a05f0..ca77944110 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketClient.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketClient.cs @@ -1,5 +1,7 @@ using System; +using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net; using System.Net.WebSockets; using System.Threading; @@ -26,6 +28,13 @@ namespace Microsoft.Net.WebSockets.Client { ReceiveBufferSize = 1024 * 16; KeepAliveInterval = TimeSpan.FromMinutes(2); + SubProtocols = new List(); + } + + public IList SubProtocols + { + get; + private set; } public TimeSpan KeepAliveInterval @@ -64,8 +73,11 @@ namespace Microsoft.Net.WebSockets.Client CancellationTokenRegistration cancellation = cancellationToken.Register(() => request.Abort()); - request.Headers[Constants.Headers.WebSocketVersion] = Constants.Headers.SupportedVersion; - // TODO: Sub-protocols + request.Headers[Constants.Headers.SecWebSocketVersion] = Constants.Headers.SupportedVersion; + if (SubProtocols.Count > 0) + { + request.Headers[Constants.Headers.SecWebSocketProtocol] = string.Join(", ", SubProtocols); + } if (ConfigureRequest != null) { @@ -85,14 +97,19 @@ namespace Microsoft.Net.WebSockets.Client if (response.StatusCode != HttpStatusCode.SwitchingProtocols) { response.Dispose(); - throw new InvalidOperationException("Incomplete handshake"); + throw new InvalidOperationException("Incomplete handshake, invalid status code: " + response.StatusCode); } + // TODO: Validate Sec-WebSocket-Key/Sec-WebSocket-Accept - // TODO: Sub protocol + string subProtocol = response.Headers[Constants.Headers.SecWebSocketProtocol]; + if (!string.IsNullOrEmpty(subProtocol) && !SubProtocols.Contains(subProtocol, StringComparer.OrdinalIgnoreCase)) + { + throw new InvalidOperationException("Incomplete handshake, the server specified an unknown sub-protocol: " + subProtocol); + } Stream stream = response.GetResponseStream(); - return CommonWebSocket.CreateClientWebSocket(stream, null, KeepAliveInterval, ReceiveBufferSize, useZeroMask: UseZeroMask); + return CommonWebSocket.CreateClientWebSocket(stream, subProtocol, KeepAliveInterval, ReceiveBufferSize, useZeroMask: UseZeroMask); } } } diff --git a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs index 91d1045e4c..f4c6b207fe 100644 --- a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs +++ b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs @@ -35,6 +35,32 @@ namespace Microsoft.Net.WebSockets.Test } } + [Fact] + public async Task NegotiateSubProtocol_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + client.SubProtocols.Add("alpha"); + client.SubProtocols.Add("bravo"); + client.SubProtocols.Add("charlie"); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + Assert.Equal("alpha, bravo, charlie", serverContext.Request.Headers["Sec-WebSocket-Protocol"]); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync("Bravo"); + + WebSocket clientSocket = await clientConnect; + Assert.Equal("Bravo", clientSocket.SubProtocol); + clientSocket.Dispose(); + } + } + [Fact] public async Task SendShortData_Success() {