Update acessTokenProvider selecting logic for restarting connection (#8569)

This commit is contained in:
Mikael Mengistu 2019-03-19 10:55:53 -07:00 committed by GitHub
parent 26c487b0c0
commit 6038621630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 123 additions and 18 deletions

View File

@ -3,7 +3,7 @@
package com.microsoft.signalr;
import java.lang.reflect.Array;
import java.io.StringReader;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
@ -11,6 +11,7 @@ import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import com.google.gson.stream.JsonReader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -38,6 +39,7 @@ public class HubConnection {
private List<OnClosedCallback> onClosedCallbackList;
private final boolean skipNegotiate;
private Single<String> accessTokenProvider;
private Single<String> redirectAccessTokenProvider;
private final Map<String, String> headers = new HashMap<>();
private ConnectionState connectionState = null;
private HttpClient httpClient;
@ -180,6 +182,7 @@ public class HubConnection {
logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException());
break;
case INVOCATION:
InvocationMessage invocationMessage = (InvocationMessage) message;
List<InvocationHandler> handlers = this.handlers.get(invocationMessage.getTarget());
if (handlers != null) {
@ -248,17 +251,18 @@ public class HubConnection {
throw new RuntimeException(String.format("Unexpected status code returned from negotiate: %d %s.",
response.getStatusCode(), response.getStatusText()));
}
NegotiateResponse negotiateResponse = new NegotiateResponse(response.getContent());
JsonReader reader = new JsonReader(new StringReader(response.getContent()));
NegotiateResponse negotiateResponse = new NegotiateResponse(reader);
if (negotiateResponse.getError() != null) {
throw new RuntimeException(negotiateResponse.getError());
}
if (negotiateResponse.getAccessToken() != null) {
this.accessTokenProvider = Single.just(negotiateResponse.getAccessToken());
this.redirectAccessTokenProvider = Single.just(negotiateResponse.getAccessToken());
// We know the Single is non blocking in this case
// It's fine to call blockingGet() on it.
String token = this.accessTokenProvider.blockingGet();
String token = this.redirectAccessTokenProvider.blockingGet();
this.headers.put("Authorization", "Bearer " + token);
}
@ -296,21 +300,22 @@ public class HubConnection {
});
stopError = null;
Single<String> negotiate = null;
Single<NegotiateResponse> negotiate = null;
if (!skipNegotiate) {
negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0)));
} else {
negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(baseUrl)));
negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(baseUrl))));
}
CompletableSubject start = CompletableSubject.create();
negotiate.flatMapCompletable(url -> {
negotiate.flatMapCompletable(negotiateResponse -> {
logger.debug("Starting HubConnection.");
if (transport == null) {
Single<String> tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider;
switch (transportEnum) {
case LONG_POLLING:
transport = new LongPollingTransport(headers, httpClient, accessTokenProvider);
transport = new LongPollingTransport(headers, httpClient, tokenProvider);
break;
default:
transport = new WebSocketTransport(headers, httpClient);
@ -320,7 +325,7 @@ public class HubConnection {
transport.setOnReceive(this.callback);
transport.setOnClose((message) -> stopConnection(message));
return transport.start(url).andThen(Completable.defer(() -> {
return transport.start(negotiateResponse.getFinalUrl()).andThen(Completable.defer(() -> {
String handshake = HandshakeProtocol.createHandshakeRequestMessage(
new HandshakeRequestMessage(protocol.getName(), protocol.getVersion()));
@ -376,7 +381,7 @@ public class HubConnection {
}, new Date(0), tickRate);
}
private Single<String> startNegotiate(String url, int negotiateAttempts) {
private Single<NegotiateResponse> startNegotiate(String url, int negotiateAttempts) {
if (hubConnectionState != HubConnectionState.DISCONNECTED) {
return Single.just(null);
}
@ -409,8 +414,8 @@ public class HubConnection {
finalUrl = url + "?id=" + response.getConnectionId();
}
}
return Single.just(finalUrl);
response.setFinalUrl(finalUrl);
return Single.just(response);
}
return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1);
@ -473,6 +478,9 @@ public class HubConnection {
logger.info("HubConnection stopped.");
hubConnectionState = HubConnectionState.DISCONNECTED;
handshakeResponseSubject.onComplete();
redirectAccessTokenProvider = null;
transportEnum = TransportEnum.ALL;
this.headers.remove("Authorization");
} finally {
hubConnectionStateLock.unlock();
}

View File

@ -16,10 +16,10 @@ class NegotiateResponse {
private String redirectUrl;
private String accessToken;
private String error;
private String finalUrl;
public NegotiateResponse(String negotiatePayload) {
public NegotiateResponse(JsonReader reader) {
try {
JsonReader reader = new JsonReader(new StringReader(negotiatePayload));
reader.beginObject();
do {
@ -79,6 +79,10 @@ class NegotiateResponse {
}
}
public NegotiateResponse(String url) {
this.finalUrl = url;
}
public String getConnectionId() {
return connectionId;
}
@ -98,4 +102,12 @@ class NegotiateResponse {
public String getError() {
return error;
}
public String getFinalUrl(){
return finalUrl;
}
public void setFinalUrl(String url) {
this.finalUrl = url;
}
}

View File

@ -1651,6 +1651,48 @@ class HubConnectionTest {
hubConnection.stop();
}
@Test
public void accessTokenProviderReferenceIsKeptAfterNegotiateRedirect() {
AtomicReference<String> token = new AtomicReference<>();
AtomicReference<String> beforeRedirectToken = new AtomicReference<>();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate", (req) -> {
beforeRedirectToken.set(req.getHeaders().get("Authorization"));
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"));
})
.on("POST", "http://testexample.com/negotiate", (req) -> {
token.set(req.getHeaders().get("Authorization"));
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
});
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransportImplementation(transport)
.withHttpClient(client)
.withAccessTokenProvider(Single.just("User Registered Token"))
.build();
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals("Bearer User Registered Token", beforeRedirectToken.get());
assertEquals("Bearer newToken", token.get());
// Clear the tokens to see if they get reset to the proper values
beforeRedirectToken.set("");
token.set("");
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
hubConnection.stop();
assertEquals("Bearer User Registered Token", beforeRedirectToken.get());
assertEquals("Bearer newToken", token.get());
}
@Test
public void accessTokenProviderIsUsedForNegotiate() {
AtomicReference<String> token = new AtomicReference<>();
@ -1702,6 +1744,46 @@ class HubConnectionTest {
assertEquals("Bearer newToken", token.get());
}
@Test
public void authorizationHeaderFromNegotiateGetsClearedAfterStopping() {
AtomicReference<String> token = new AtomicReference<>();
AtomicReference<String> beforeRedirectToken = new AtomicReference<>();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate", (req) -> {
beforeRedirectToken.set(req.getHeaders().get("Authorization"));
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"));
})
.on("POST", "http://testexample.com/negotiate", (req) -> {
token.set(req.getHeaders().get("Authorization"));
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
});
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransportImplementation(transport)
.withHttpClient(client)
.build();
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals("Bearer newToken", token.get());
// Clear the tokens to see if they get reset to the proper values
beforeRedirectToken.set("");
token.set("");
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
hubConnection.stop();
assertNull(beforeRedirectToken.get());
assertEquals("Bearer newToken", token.get());
}
@Test
public void connectionTimesOutIfServerDoesNotSendMessage() {
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");

View File

@ -5,8 +5,11 @@ package com.microsoft.signalr;
import static org.junit.jupiter.api.Assertions.*;
import com.google.gson.stream.JsonReader;
import org.junit.jupiter.api.Test;
import java.io.StringReader;
class NegotiateResponseTest {
@Test
@ -15,7 +18,7 @@ class NegotiateResponseTest {
"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," +
"{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}," +
"{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}";
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
assertTrue(negotiateResponse.getAvailableTransports().contains("WebSockets"));
assertTrue(negotiateResponse.getAvailableTransports().contains("ServerSentEvents"));
assertTrue(negotiateResponse.getAvailableTransports().contains("LongPolling"));
@ -29,7 +32,7 @@ class NegotiateResponseTest {
String stringNegotiateResponse = "{\"url\":\"www.example.com\"," +
"\"accessToken\":\"some_access_token\"," +
"\"availableTransports\":[]}";
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
assertTrue(negotiateResponse.getAvailableTransports().isEmpty());
assertNull(negotiateResponse.getConnectionId());
assertEquals("some_access_token", negotiateResponse.getAccessToken());
@ -41,7 +44,7 @@ class NegotiateResponseTest {
public void NegotiateResponseIgnoresExtraProperties() {
String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," +
"\"extra\":\"something\"}";
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId());
}
@ -49,7 +52,7 @@ class NegotiateResponseTest {
public void NegotiateResponseIgnoresExtraComplexProperties() {
String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," +
"\"extra\":[\"something\"]}";
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId());
}
}