Fix #1140 by plumbing WebSocketOptions up to HttpOptions (#1174)

This commit is contained in:
Andrew Stanton-Nurse 2017-12-02 17:01:53 -08:00 committed by GitHub
parent 3005337a9c
commit e4671392ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 84 additions and 11 deletions

View File

@ -1,10 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Net.Http;
using System.Net.WebSockets;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Client.Http;
@ -17,6 +18,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public static readonly string HttpMessageHandlerKey = "HttpMessageHandler";
public static readonly string HeadersKey = "Headers";
public static readonly string JwtBearerTokenFactoryKey = "JwtBearerTokenFactory";
public static readonly string WebSocketOptionsKey = "WebSocketOptions";
public static IHubConnectionBuilder WithUrl(this IHubConnectionBuilder hubConnectionBuilder, string url)
{
@ -42,7 +44,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
HttpMessageHandler = hubConnectionBuilder.GetMessageHandler(),
Headers = headers != null ? new ReadOnlyDictionary<string, string>(headers) : null,
JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory()
JwtBearerTokenFactory = hubConnectionBuilder.GetJwtBearerTokenFactory(),
WebSocketOptions = hubConnectionBuilder.GetWebSocketOptions(),
};
return new HttpConnection(url,
@ -96,6 +99,18 @@ namespace Microsoft.AspNetCore.SignalR.Client
return hubConnectionBuilder;
}
public static IHubConnectionBuilder WithWebSocketOptions(this IHubConnectionBuilder hubConnectionBuilder, Action<ClientWebSocketOptions> configureWebSocketOptions)
{
if (configureWebSocketOptions == null)
{
throw new ArgumentNullException(nameof(configureWebSocketOptions));
}
hubConnectionBuilder.AddSetting(WebSocketOptionsKey, configureWebSocketOptions);
return hubConnectionBuilder;
}
public static TransportType GetTransport(this IHubConnectionBuilder hubConnectionBuilder)
{
if (hubConnectionBuilder.TryGetSetting<TransportType>(TransportTypeKey, out var transportType))
@ -131,5 +146,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
return null;
}
public static Action<ClientWebSocketOptions> GetWebSocketOptions(this IHubConnectionBuilder hubConnectionBuilder)
{
hubConnectionBuilder.TryGetSetting<Action<ClientWebSocketOptions>>(WebSocketOptionsKey, out var webSocketOptions);
return webSocketOptions;
}
}
}

View File

@ -1,9 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.WebSockets;
namespace Microsoft.AspNetCore.Sockets.Client.Http
{
@ -12,5 +13,15 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http
public HttpMessageHandler HttpMessageHandler { get; set; }
public IReadOnlyCollection<KeyValuePair<string, string>> Headers { get; set; }
public Func<string> JwtBearerTokenFactory { get; set; }
/// <summary>
/// Gets or sets a delegate that will be invoked with the <see cref="ClientWebSocketOptions"/> object used
/// by the <see cref="WebSocketsTransport"/> to configure the WebSocket.
/// </summary>
/// <remarks>
/// This delegate is invoked after headers from <see cref="Headers"/> and the JWT bearer token from <see cref="JwtBearerTokenFactory"/>
/// has been applied.
/// </remarks>
public Action<ClientWebSocketOptions> WebSocketOptions { get; set; }
}
}

View File

@ -36,6 +36,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
public WebSocketsTransport(HttpOptions httpOptions, ILoggerFactory loggerFactory)
{
_webSocket = new ClientWebSocket();
if (httpOptions?.Headers != null)
{
foreach (var header in httpOptions.Headers)
@ -49,6 +50,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {httpOptions.JwtBearerTokenFactory()}");
}
httpOptions?.WebSocketOptions?.Invoke(_webSocket.Options);
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<WebSocketsTransport>();
}

View File

@ -6,8 +6,8 @@ using System.Collections.Generic;
using System.Linq;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
@ -622,6 +622,39 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Fact]
public async Task WebSocketOptionsAreApplied()
{
using (StartLog(out var loggerFactory))
{
// System.Net has a TransportType type which means we need to fully-qualify this rather than 'use' the namespace
var cookieJar = new System.Net.CookieContainer();
cookieJar.Add(new System.Net.Cookie("Foo", "Bar", "/", new Uri(_serverFixture.Url).Host));
var hubConnection = new HubConnectionBuilder()
.WithUrl(_serverFixture.Url + "/default")
.WithTransport(TransportType.WebSockets)
.WithLoggerFactory(loggerFactory)
.WithWebSocketOptions(options => options.Cookies = cookieJar)
.Build();
try
{
await hubConnection.StartAsync().OrTimeout();
var cookieValue = await hubConnection.InvokeAsync<string>("GetCookieValue", new object[] { "Foo" }).OrTimeout();
Assert.Equal("Bar", cookieValue);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "Exception from test");
throw;
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
}
}
}
public static IEnumerable<object[]> HubProtocolsAndTransportsAndHubPaths
{
get

View File

@ -39,6 +39,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
var headers = Context.Connection.GetHttpContext().Request.Headers;
return headerNames.Select(h => (string)headers[h]);
}
public string GetCookieValue(string cookieName)
{
return Context.Connection.GetHttpContext().Request.Cookies[cookieName];
}
}
public class DynamicTestHub : DynamicHub

View File

@ -6,8 +6,8 @@ using System.Net;
using System.Net.Http;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Client.Tests;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Client.Http;
@ -285,7 +285,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object });
var onReceivedInvoked = false;
connection.OnReceived( _ =>
connection.OnReceived(_ =>
{
onReceivedInvoked = true;
return Task.CompletedTask;
@ -432,10 +432,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null,
httpOptions: new HttpOptions { HttpMessageHandler = mockHttpHandler.Object });
connection.OnReceived( _ =>
{
throw new OperationCanceledException();
});
connection.OnReceived(_ =>
{
throw new OperationCanceledException();
});
await connection.StartAsync();
channel.Writer.TryWrite(Array.Empty<byte>());
@ -960,7 +960,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
[InlineData("http://fakeuri.org/endpoint", "http://fakeuri.org/endpoint/negotiate")]
[InlineData("http://fakeuri.org/endpoint/", "http://fakeuri.org/endpoint/negotiate")]
[InlineData("http://fakeuri.org/endpoint?q=1/0", "http://fakeuri.org/endpoint/negotiate?q=1/0")]
public async Task query(string requested, string expectedNegotiate)
public async Task CorrectlyHandlesQueryStringWhenAppendingNegotiateToUrl(string requested, string expectedNegotiate)
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()