fix #2140 by ensuring the access token flows to WebSocketTransport (#2173)

This commit is contained in:
Andrew Stanton-Nurse 2018-05-01 16:14:24 -07:00 committed by GitHub
parent 295801ac50
commit d711916ad6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 70 additions and 35 deletions

View File

@ -137,7 +137,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
_httpClient = CreateHttpClient();
}
_transportFactory = new DefaultTransportFactory(httpConnectionOptions.Transports, _loggerFactory, _httpClient, httpConnectionOptions);
_transportFactory = new DefaultTransportFactory(httpConnectionOptions.Transports, _loggerFactory, _httpClient, httpConnectionOptions, GetAccessTokenAsync);
_logScope = new ConnectionLogScope();
_scopeDisposable = _logger.BeginScope(_logScope);

View File

@ -1,7 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Net.Http;
using System.Net.Http.Headers;
using System.Threading;

View File

@ -3,8 +3,7 @@
using System;
using System.Net.Http;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.Http.Connections.Client.Internal;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
@ -13,11 +12,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
private readonly HttpClient _httpClient;
private readonly HttpConnectionOptions _httpConnectionOptions;
private readonly Func<Task<string>> _accessTokenProvider;
private readonly HttpTransportType _requestedTransportType;
private readonly ILoggerFactory _loggerFactory;
private static volatile bool _websocketsSupported = true;
public DefaultTransportFactory(HttpTransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpConnectionOptions httpConnectionOptions)
public DefaultTransportFactory(HttpTransportType requestedTransportType, ILoggerFactory loggerFactory, HttpClient httpClient, HttpConnectionOptions httpConnectionOptions, Func<Task<string>> accessTokenProvider)
{
if (httpClient == null && requestedTransportType != HttpTransportType.WebSockets)
{
@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
_loggerFactory = loggerFactory;
_httpClient = httpClient;
_httpConnectionOptions = httpConnectionOptions;
_accessTokenProvider = accessTokenProvider;
}
public ITransport CreateTransport(HttpTransportType availableServerTransports)
@ -36,7 +37,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
try
{
return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory);
return new WebSocketsTransport(_httpConnectionOptions, _loggerFactory, _accessTokenProvider);
}
catch (PlatformNotSupportedException)
{
@ -46,11 +47,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
if ((availableServerTransports & HttpTransportType.ServerSentEvents & _requestedTransportType) == HttpTransportType.ServerSentEvents)
{
// We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us.
return new ServerSentEventsTransport(_httpClient, _loggerFactory);
}
if ((availableServerTransports & HttpTransportType.LongPolling & _requestedTransportType) == HttpTransportType.LongPolling)
{
// We don't need to give the transport the accessTokenProvider because the HttpClient has a message handler that does the work for us.
return new LongPollingTransport(_httpClient, _loggerFactory);
}

View File

@ -32,12 +32,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
public PipeWriter Output => _transport.Output;
public WebSocketsTransport()
: this(null, null)
{
}
public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory)
public WebSocketsTransport(HttpConnectionOptions httpConnectionOptions, ILoggerFactory loggerFactory, Func<Task<string>> accessTokenProvider)
{
_webSocket = new ClientWebSocket();
@ -79,11 +74,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
_webSocket.Options.UseDefaultCredentials = httpConnectionOptions.UseDefaultCredentials.Value;
}
if (httpConnectionOptions.AccessTokenProvider != null)
{
_accessTokenProvider = httpConnectionOptions.AccessTokenProvider;
}
httpConnectionOptions.WebSocketConfiguration?.Invoke(_webSocket.Options);
_closeTimeout = httpConnectionOptions.CloseTimeout;
@ -94,6 +84,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
_webSocket.Options.SetRequestHeader("X-Requested-With", "XMLHttpRequest");
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<WebSocketsTransport>();
// Ignore the HttpConnectionOptions access token provider. We were given an updated delegate from the HttpConnection.
_accessTokenProvider = accessTokenProvider;
}
public async Task StartAsync(Uri url, TransferFormat transferFormat)
@ -116,10 +109,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
Log.StartTransport(_logger, transferFormat, resolvedUrl);
// We don't need to capture to a local because we never change this delegate.
if (_accessTokenProvider != null)
{
var accessToken = await _accessTokenProvider();
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
if (!string.IsNullOrEmpty(accessToken))
{
_webSocket.Options.SetRequestHeader("Authorization", $"Bearer {accessToken}");
}
}
await _webSocket.ConnectAsync(resolvedUrl, CancellationToken.None);

View File

@ -797,6 +797,34 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task ClientCanUseJwtBearerTokenForAuthenticationWhenRedirected(HttpTransportType transportType)
{
using (StartVerifableLog(out var loggerFactory, $"{nameof(ClientCanUseJwtBearerTokenForAuthenticationWhenRedirected)}_{transportType}"))
{
var hubConnection = new HubConnectionBuilder()
.WithLoggerFactory(loggerFactory)
.WithUrl(ServerFixture.Url + "/redirect", transportType)
.Build();
try
{
await hubConnection.StartAsync().OrTimeout();
var message = await hubConnection.InvokeAsync<string>(nameof(TestHub.Echo), "Hello, World!").OrTimeout();
Assert.Equal("Hello, World!", message);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task ClientCanSendHeaders(HttpTransportType transportType)

View File

@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.IdentityModel.Tokens;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
@ -69,6 +70,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await context.Response.WriteAsync(GenerateJwtToken());
return;
}
else if (context.Request.Path.StartsWithSegments("/redirect"))
{
await context.Response.WriteAsync(JsonConvert.SerializeObject(new
{
url = $"{context.Request.Scheme}://{context.Request.Host}/authorizedHub",
accessToken = GenerateJwtToken()
}));
}
});
}

View File

@ -8,7 +8,6 @@ using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.Http.Connections.Client.Internal;
using Microsoft.AspNetCore.SignalR.Tests;
using Moq;

View File

@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[InlineData((HttpTransportType)int.MaxValue)]
public void DefaultTransportFactoryCanBeCreatedWithNoOrUnknownTransportTypeFlags(HttpTransportType transportType)
{
Assert.NotNull(new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient(), httpConnectionOptions: null));
Assert.NotNull(new DefaultTransportFactory(transportType, new LoggerFactory(), new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null));
}
[Theory]
@ -33,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void DefaultTransportFactoryCannotBeCreatedWithoutHttpClient(HttpTransportType transportType)
{
var exception = Assert.Throws<ArgumentNullException>(
() => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null, httpConnectionOptions: null));
() => new DefaultTransportFactory(transportType, new LoggerFactory(), httpClient: null, httpConnectionOptions: null, accessTokenProvider: null));
Assert.Equal("httpClient", exception.ParamName);
}
@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[Fact]
public void DefaultTransportFactoryCanBeCreatedWithoutHttpClientIfWebSocketsTransportRequestedExplicitly()
{
new DefaultTransportFactory(HttpTransportType.WebSockets, new LoggerFactory(), httpClient: null, httpConnectionOptions: null);
new DefaultTransportFactory(HttpTransportType.WebSockets, new LoggerFactory(), httpClient: null, httpConnectionOptions: null, accessTokenProvider: null);
}
[ConditionalTheory]
@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[WebSocketsSupportedCondition]
public void DefaultTransportFactoryCreatesRequestedTransportIfAvailable(HttpTransportType requestedTransport, Type expectedTransportType)
{
var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null);
var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null);
Assert.IsType(expectedTransportType,
transportFactory.CreateTransport(AllTransportTypes));
}
@ -64,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void DefaultTransportFactoryThrowsIfItCannotCreateRequestedTransport(HttpTransportType requestedTransport)
{
var transportFactory =
new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null);
new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null);
var ex = Assert.Throws<InvalidOperationException>(
() => transportFactory.CreateTransport(~requestedTransport));
@ -76,7 +76,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public void DefaultTransportFactoryCreatesWebSocketsTransportIfAvailable()
{
Assert.IsType<WebSocketsTransport>(
new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null)
new DefaultTransportFactory(AllTransportTypes, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null)
.CreateTransport(AllTransportTypes));
}
@ -88,7 +88,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
if (!TestHelpers.IsWebSocketsSupported())
{
var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null);
var transportFactory = new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null);
Assert.IsType(expectedTransportType,
transportFactory.CreateTransport(AllTransportTypes));
}
@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
if (!TestHelpers.IsWebSocketsSupported())
{
var transportFactory =
new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null);
new DefaultTransportFactory(requestedTransport, loggerFactory: null, httpClient: new HttpClient(), httpConnectionOptions: null, accessTokenProvider: null);
var ex = Assert.Throws<InvalidOperationException>(
() => transportFactory.CreateTransport(AllTransportTypes));

View File

@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
httpOptions.Proxy = Mock.Of<IWebProxy>();
httpOptions.WebSocketConfiguration = options => webSocketsOptions = options;
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: httpOptions, loggerFactory: null);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: httpOptions, loggerFactory: null, accessTokenProvider: null);
Assert.NotNull(webSocketsTransport);
Assert.NotNull(webSocketsOptions);
@ -59,7 +59,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"),
TransferFormat.Binary).OrTimeout();
await webSocketsTransport.StopAsync().OrTimeout();
@ -73,7 +73,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/httpheader"),
TransferFormat.Binary).OrTimeout();
@ -101,7 +101,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/httpheader"),
TransferFormat.Binary).OrTimeout();
@ -124,7 +124,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"),
TransferFormat.Binary);
webSocketsTransport.Output.Complete();
@ -140,7 +140,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echoAndClose"), transferFormat);
await webSocketsTransport.Output.WriteAsync(new byte[] { 0x42 });
@ -162,7 +162,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
await webSocketsTransport.StartAsync(new Uri(ServerFixture.WebSocketsUrl + "/echo"),
transferFormat).OrTimeout();
@ -180,7 +180,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
using (StartVerifableLog(out var loggerFactory))
{
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory);
var webSocketsTransport = new WebSocketsTransport(httpConnectionOptions: null, loggerFactory: loggerFactory, accessTokenProvider: null);
var exception = await Assert.ThrowsAsync<ArgumentException>(() =>
webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), transferFormat));