From e4671392ecd9c4c23c2d65492c29f322f9610efe Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Sat, 2 Dec 2017 17:01:53 -0800 Subject: [PATCH] Fix #1140 by plumbing WebSocketOptions up to HttpOptions (#1174) --- .../HubConnectionBuilderHttpExtensions.cs | 25 +++++++++++-- .../HttpOptions.cs | 13 ++++++- .../WebSocketsTransport.cs | 3 ++ .../HubConnectionTests.cs | 35 ++++++++++++++++++- .../Hubs.cs | 5 +++ .../HttpConnectionTests.cs | 14 ++++---- 6 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs index cdab76ef41..1333f83380 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionBuilderHttpExtensions.cs @@ -1,10 +1,11 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Collections.ObjectModel; using System.Net.Http; +using System.Net.WebSockets; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Client.Http; @@ -17,6 +18,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public static readonly string HttpMessageHandlerKey = "HttpMessageHandler"; public static readonly string HeadersKey = "Headers"; public static readonly string JwtBearerTokenFactoryKey = "JwtBearerTokenFactory"; + public static readonly string WebSocketOptionsKey = "WebSocketOptions"; public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnectionBuilder, string url) { @@ -42,7 +44,8 @@ namespace Microsoft.AspNetCore.SignalR.Client { HttpMessageHandler = hubConnectionBuilder.GetMessageHandler(), Headers = headers != null ? new ReadOnlyDictionary(headers) : null, - JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory() + JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory(), + WebSocketOptions = hubConnectionBuilder.GetWebSocketOptions(), }; return new HttpConnection(url, @@ -96,6 +99,18 @@ namespace Microsoft.AspNetCore.SignalR.Client return hubConnectionBuilder; } + public static IHubConnectionBuilder WithWebSocketOptions(this IHubConnectionBuilder hubConnectionBuilder, Action configureWebSocketOptions) + { + if (configureWebSocketOptions == null) + { + throw new ArgumentNullException(nameof(configureWebSocketOptions)); + } + + hubConnectionBuilder.AddSetting(WebSocketOptionsKey, configureWebSocketOptions); + + return hubConnectionBuilder; + } + public static TransportType GetTransport(this IHubConnectionBuilder hubConnectionBuilder) { if (hubConnectionBuilder.TryGetSetting(TransportTypeKey, out var transportType)) @@ -131,5 +146,11 @@ namespace Microsoft.AspNetCore.SignalR.Client return null; } + + public static Action GetWebSocketOptions(this IHubConnectionBuilder hubConnectionBuilder) + { + hubConnectionBuilder.TryGetSetting>(WebSocketOptionsKey, out var webSocketOptions); + return webSocketOptions; + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs index 7ccfacef2c..9dc47da2a2 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs @@ -1,9 +1,10 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Collections.Generic; using System.Net.Http; +using System.Net.WebSockets; namespace Microsoft.AspNetCore.Sockets.Client.Http { @@ -12,5 +13,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http public HttpMessageHandler HttpMessageHandler { get; set; } public IReadOnlyCollection> Headers { get; set; } public Func JwtBearerTokenFactory { get; set; } + + /// + /// Gets or sets a delegate that will be invoked with the object used + /// by the to configure the WebSocket. + /// + /// + /// This delegate is invoked after headers from and the JWT bearer token from + /// has been applied. + /// + public Action WebSocketOptions { get; set; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index e4ee9a1fb1..fdaa2b64c0 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -36,6 +36,7 @@ namespace Microsoft.AspNetCore.Sockets.Client public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory) { _webSocket = new ClientWebSocket(); + if (httpOptions?.Headers != null) { foreach (var header in httpOptions.Headers) @@ -49,6 +50,8 @@ namespace Microsoft.AspNetCore.Sockets.Client _webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.JwtBearerTokenFactory()}"); } + httpOptions?.WebSocketOptions?.Invoke(_webSocket.Options); + _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 594d674bc3..e8b7ef1525 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -6,8 +6,8 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; @@ -622,6 +622,39 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task WebSocketOptionsAreApplied() + { + using (StartLog(out var loggerFactory)) + { + // System.Net has a TransportType type which means we need to fully-qualify this rather than 'use' the namespace + var cookieJar = new System.Net.CookieContainer(); + cookieJar.Add(new System.Net.Cookie("Foo", "Bar", "/", new Uri(_serverFixture.Url).Host)); + + var hubConnection = new HubConnectionBuilder() + .WithUrl(_serverFixture.Url + "/default") + .WithTransport(TransportType.WebSockets) + .WithLoggerFactory(loggerFactory) + .WithWebSocketOptions(options => options.Cookies = cookieJar) + .Build(); + try + { + await hubConnection.StartAsync().OrTimeout(); + var cookieValue = await hubConnection.InvokeAsync("GetCookieValue", new object[] { "Foo" }).OrTimeout(); + Assert.Equal("Bar", cookieValue); + } + catch (Exception ex) + { + loggerFactory.CreateLogger().LogError(ex, "Exception from test"); + throw; + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + } + } + } + public static IEnumerable HubProtocolsAndTransportsAndHubPaths { get diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 0fc7da25cf..c17318a31c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -39,6 +39,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests var headers = Context.Connection.GetHttpContext().Request.Headers; return headerNames.Select(h => (string)headers[h]); } + + public string GetCookieValue(string cookieName) + { + return Context.Connection.GetHttpContext().Request.Cookies[cookieName]; + } } public class DynamicTestHub : DynamicHub diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 588b504484..89cbb5c8e2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -6,8 +6,8 @@ using System.Net; using System.Net.Http; using System.Text; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Client.Http; @@ -285,7 +285,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); var onReceivedInvoked = false; - connection.OnReceived( _ => + connection.OnReceived(_ => { onReceivedInvoked = true; return Task.CompletedTask; @@ -432,10 +432,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object }); - connection.OnReceived( _ => - { - throw new OperationCanceledException(); - }); + connection.OnReceived(_ => + { + throw new OperationCanceledException(); + }); await connection.StartAsync(); channel.Writer.TryWrite(Array.Empty()); @@ -960,7 +960,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests [InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate")] [InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate")] [InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0")] - public async Task query(string requested, string expectedNegotiate) + public async Task CorrectlyHandlesQueryStringWhenAppendingNegotiateToUrl(string requested, string expectedNegotiate) { var mockHttpHandler = new Mock(); mockHttpHandler.Protected()