diff --git a/src/Hosting/TestHost/src/WebSocketClient.cs b/src/Hosting/TestHost/src/WebSocketClient.cs index d312c0ebbd..9f3e0500be 100644 --- a/src/Hosting/TestHost/src/WebSocketClient.cs +++ b/src/Hosting/TestHost/src/WebSocketClient.cs @@ -4,12 +4,14 @@ using System; using System.Collections.Generic; using System.IO; +using System.Linq; using System.Net.WebSockets; using System.Security.Cryptography; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; +using Microsoft.Net.Http.Headers; namespace Microsoft.AspNetCore.TestHost { @@ -72,10 +74,15 @@ namespace Microsoft.AspNetCore.TestHost request.PathBase = _pathBase; } request.QueryString = QueryString.FromUriComponent(uri); - request.Headers.Add("Connection", new string[] { "Upgrade" }); - request.Headers.Add("Upgrade", new string[] { "websocket" }); - request.Headers.Add("Sec-WebSocket-Version", new string[] { "13" }); - request.Headers.Add("Sec-WebSocket-Key", new string[] { CreateRequestKey() }); + request.Headers.Add(HeaderNames.Connection, new string[] { "Upgrade" }); + request.Headers.Add(HeaderNames.Upgrade, new string[] { "websocket" }); + request.Headers.Add(HeaderNames.SecWebSocketVersion, new string[] { "13" }); + request.Headers.Add(HeaderNames.SecWebSocketKey, new string[] { CreateRequestKey() }); + if (SubProtocols.Any()) + { + request.Headers.Add(HeaderNames.SecWebSocketProtocol, SubProtocols.ToArray()); + } + request.Body = Stream.Null; // WebSocket diff --git a/src/Hosting/TestHost/test/TestClientTests.cs b/src/Hosting/TestHost/test/TestClientTests.cs index 322ab54a6f..aa1082b529 100644 --- a/src/Hosting/TestHost/test/TestClientTests.cs +++ b/src/Hosting/TestHost/test/TestClientTests.cs @@ -16,6 +16,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; +using Microsoft.Net.Http.Headers; using Xunit; namespace Microsoft.AspNetCore.TestHost @@ -172,6 +173,7 @@ namespace Microsoft.AspNetCore.TestHost { if (ctx.WebSockets.IsWebSocketRequest) { + Assert.False(ctx.Request.Headers.ContainsKey(HeaderNames.SecWebSocketProtocol)); var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); var receiveArray = new byte[1024]; while (true) @@ -232,6 +234,58 @@ namespace Microsoft.AspNetCore.TestHost clientSocket.Dispose(); } + [Fact] + public async Task WebSocketSubProtocolsWorks() + { + // Arrange + RequestDelegate appDelegate = async ctx => + { + if (ctx.WebSockets.IsWebSocketRequest) + { + if (ctx.WebSockets.WebSocketRequestedProtocols.Contains("alpha") && + ctx.WebSockets.WebSocketRequestedProtocols.Contains("bravo")) + { + // according to rfc6455, the "server needs to include the same field and one of the selected subprotocol values" + // however, this isn't enforced by either our server or client so it's possible to accept an arbitrary protocol. + // Done here to demonstrate not "correct" behaviour, simply to show it's possible. Other clients may not allow this. + var websocket = await ctx.WebSockets.AcceptWebSocketAsync("charlie"); + await websocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Normal Closure", CancellationToken.None); + } + else + { + var subprotocols = ctx.WebSockets.WebSocketRequestedProtocols.Any() + ? string.Join(", ", ctx.WebSockets.WebSocketRequestedProtocols) + : ""; + var closeReason = "Unexpected subprotocols: " + subprotocols; + var websocket = await ctx.WebSockets.AcceptWebSocketAsync(); + await websocket.CloseAsync(WebSocketCloseStatus.InternalServerError, closeReason, CancellationToken.None); + } + } + }; + var builder = new WebHostBuilder() + .Configure(app => + { + app.Run(appDelegate); + }); + var server = new TestServer(builder); + + // Act + var client = server.CreateWebSocketClient(); + client.SubProtocols.Add("alpha"); + client.SubProtocols.Add("bravo"); + var clientSocket = await client.ConnectAsync(new Uri("wss://localhost"), CancellationToken.None); + var buffer = new byte[1024]; + var result = await clientSocket.ReceiveAsync(new ArraySegment(buffer), CancellationToken.None); + + // Assert + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal("Normal Closure", result.CloseStatusDescription); + Assert.Equal(WebSocketState.CloseReceived, clientSocket.State); + Assert.Equal("charlie", clientSocket.SubProtocol); + + clientSocket.Dispose(); + } + [ConditionalFact] public async Task WebSocketAcceptThrowsWhenCancelled() {