Update acessTokenProvider selecting logic for restarting connection (#8569)
This commit is contained in:
parent
26c487b0c0
commit
6038621630
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
package com.microsoft.signalr;
|
||||
|
||||
import java.lang.reflect.Array;
|
||||
import java.io.StringReader;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
|
@ -11,6 +11,7 @@ import java.util.concurrent.atomic.AtomicLong;
|
|||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
import com.google.gson.stream.JsonReader;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
|
@ -38,6 +39,7 @@ public class HubConnection {
|
|||
private List<OnClosedCallback> onClosedCallbackList;
|
||||
private final boolean skipNegotiate;
|
||||
private Single<String> accessTokenProvider;
|
||||
private Single<String> redirectAccessTokenProvider;
|
||||
private final Map<String, String> headers = new HashMap<>();
|
||||
private ConnectionState connectionState = null;
|
||||
private HttpClient httpClient;
|
||||
|
|
@ -180,6 +182,7 @@ public class HubConnection {
|
|||
logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException());
|
||||
break;
|
||||
case INVOCATION:
|
||||
|
||||
InvocationMessage invocationMessage = (InvocationMessage) message;
|
||||
List<InvocationHandler> handlers = this.handlers.get(invocationMessage.getTarget());
|
||||
if (handlers != null) {
|
||||
|
|
@ -248,17 +251,18 @@ public class HubConnection {
|
|||
throw new RuntimeException(String.format("Unexpected status code returned from negotiate: %d %s.",
|
||||
response.getStatusCode(), response.getStatusText()));
|
||||
}
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(response.getContent());
|
||||
JsonReader reader = new JsonReader(new StringReader(response.getContent()));
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(reader);
|
||||
|
||||
if (negotiateResponse.getError() != null) {
|
||||
throw new RuntimeException(negotiateResponse.getError());
|
||||
}
|
||||
|
||||
if (negotiateResponse.getAccessToken() != null) {
|
||||
this.accessTokenProvider = Single.just(negotiateResponse.getAccessToken());
|
||||
this.redirectAccessTokenProvider = Single.just(negotiateResponse.getAccessToken());
|
||||
// We know the Single is non blocking in this case
|
||||
// It's fine to call blockingGet() on it.
|
||||
String token = this.accessTokenProvider.blockingGet();
|
||||
String token = this.redirectAccessTokenProvider.blockingGet();
|
||||
this.headers.put("Authorization", "Bearer " + token);
|
||||
}
|
||||
|
||||
|
|
@ -296,21 +300,22 @@ public class HubConnection {
|
|||
});
|
||||
|
||||
stopError = null;
|
||||
Single<String> negotiate = null;
|
||||
Single<NegotiateResponse> negotiate = null;
|
||||
if (!skipNegotiate) {
|
||||
negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0)));
|
||||
} else {
|
||||
negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(baseUrl)));
|
||||
negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(baseUrl))));
|
||||
}
|
||||
|
||||
CompletableSubject start = CompletableSubject.create();
|
||||
|
||||
negotiate.flatMapCompletable(url -> {
|
||||
negotiate.flatMapCompletable(negotiateResponse -> {
|
||||
logger.debug("Starting HubConnection.");
|
||||
if (transport == null) {
|
||||
Single<String> tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider;
|
||||
switch (transportEnum) {
|
||||
case LONG_POLLING:
|
||||
transport = new LongPollingTransport(headers, httpClient, accessTokenProvider);
|
||||
transport = new LongPollingTransport(headers, httpClient, tokenProvider);
|
||||
break;
|
||||
default:
|
||||
transport = new WebSocketTransport(headers, httpClient);
|
||||
|
|
@ -320,7 +325,7 @@ public class HubConnection {
|
|||
transport.setOnReceive(this.callback);
|
||||
transport.setOnClose((message) -> stopConnection(message));
|
||||
|
||||
return transport.start(url).andThen(Completable.defer(() -> {
|
||||
return transport.start(negotiateResponse.getFinalUrl()).andThen(Completable.defer(() -> {
|
||||
String handshake = HandshakeProtocol.createHandshakeRequestMessage(
|
||||
new HandshakeRequestMessage(protocol.getName(), protocol.getVersion()));
|
||||
|
||||
|
|
@ -376,7 +381,7 @@ public class HubConnection {
|
|||
}, new Date(0), tickRate);
|
||||
}
|
||||
|
||||
private Single<String> startNegotiate(String url, int negotiateAttempts) {
|
||||
private Single<NegotiateResponse> startNegotiate(String url, int negotiateAttempts) {
|
||||
if (hubConnectionState != HubConnectionState.DISCONNECTED) {
|
||||
return Single.just(null);
|
||||
}
|
||||
|
|
@ -409,8 +414,8 @@ public class HubConnection {
|
|||
finalUrl = url + "?id=" + response.getConnectionId();
|
||||
}
|
||||
}
|
||||
|
||||
return Single.just(finalUrl);
|
||||
response.setFinalUrl(finalUrl);
|
||||
return Single.just(response);
|
||||
}
|
||||
|
||||
return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1);
|
||||
|
|
@ -473,6 +478,9 @@ public class HubConnection {
|
|||
logger.info("HubConnection stopped.");
|
||||
hubConnectionState = HubConnectionState.DISCONNECTED;
|
||||
handshakeResponseSubject.onComplete();
|
||||
redirectAccessTokenProvider = null;
|
||||
transportEnum = TransportEnum.ALL;
|
||||
this.headers.remove("Authorization");
|
||||
} finally {
|
||||
hubConnectionStateLock.unlock();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,10 +16,10 @@ class NegotiateResponse {
|
|||
private String redirectUrl;
|
||||
private String accessToken;
|
||||
private String error;
|
||||
private String finalUrl;
|
||||
|
||||
public NegotiateResponse(String negotiatePayload) {
|
||||
public NegotiateResponse(JsonReader reader) {
|
||||
try {
|
||||
JsonReader reader = new JsonReader(new StringReader(negotiatePayload));
|
||||
reader.beginObject();
|
||||
|
||||
do {
|
||||
|
|
@ -79,6 +79,10 @@ class NegotiateResponse {
|
|||
}
|
||||
}
|
||||
|
||||
public NegotiateResponse(String url) {
|
||||
this.finalUrl = url;
|
||||
}
|
||||
|
||||
public String getConnectionId() {
|
||||
return connectionId;
|
||||
}
|
||||
|
|
@ -98,4 +102,12 @@ class NegotiateResponse {
|
|||
public String getError() {
|
||||
return error;
|
||||
}
|
||||
|
||||
public String getFinalUrl(){
|
||||
return finalUrl;
|
||||
}
|
||||
|
||||
public void setFinalUrl(String url) {
|
||||
this.finalUrl = url;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1651,6 +1651,48 @@ class HubConnectionTest {
|
|||
hubConnection.stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void accessTokenProviderReferenceIsKeptAfterNegotiateRedirect() {
|
||||
AtomicReference<String> token = new AtomicReference<>();
|
||||
AtomicReference<String> beforeRedirectToken = new AtomicReference<>();
|
||||
|
||||
TestHttpClient client = new TestHttpClient()
|
||||
.on("POST", "http://example.com/negotiate", (req) -> {
|
||||
beforeRedirectToken.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"));
|
||||
})
|
||||
.on("POST", "http://testexample.com/negotiate", (req) -> {
|
||||
token.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
|
||||
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
|
||||
});
|
||||
|
||||
MockTransport transport = new MockTransport(true);
|
||||
HubConnection hubConnection = HubConnectionBuilder
|
||||
.create("http://example.com")
|
||||
.withTransportImplementation(transport)
|
||||
.withHttpClient(client)
|
||||
.withAccessTokenProvider(Single.just("User Registered Token"))
|
||||
.build();
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals("Bearer User Registered Token", beforeRedirectToken.get());
|
||||
assertEquals("Bearer newToken", token.get());
|
||||
|
||||
// Clear the tokens to see if they get reset to the proper values
|
||||
beforeRedirectToken.set("");
|
||||
token.set("");
|
||||
|
||||
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
assertEquals("Bearer User Registered Token", beforeRedirectToken.get());
|
||||
assertEquals("Bearer newToken", token.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void accessTokenProviderIsUsedForNegotiate() {
|
||||
AtomicReference<String> token = new AtomicReference<>();
|
||||
|
|
@ -1702,6 +1744,46 @@ class HubConnectionTest {
|
|||
assertEquals("Bearer newToken", token.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authorizationHeaderFromNegotiateGetsClearedAfterStopping() {
|
||||
AtomicReference<String> token = new AtomicReference<>();
|
||||
AtomicReference<String> beforeRedirectToken = new AtomicReference<>();
|
||||
|
||||
TestHttpClient client = new TestHttpClient()
|
||||
.on("POST", "http://example.com/negotiate", (req) -> {
|
||||
beforeRedirectToken.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}"));
|
||||
})
|
||||
.on("POST", "http://testexample.com/negotiate", (req) -> {
|
||||
token.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
|
||||
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
|
||||
});
|
||||
|
||||
MockTransport transport = new MockTransport(true);
|
||||
HubConnection hubConnection = HubConnectionBuilder
|
||||
.create("http://example.com")
|
||||
.withTransportImplementation(transport)
|
||||
.withHttpClient(client)
|
||||
.build();
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals("Bearer newToken", token.get());
|
||||
|
||||
// Clear the tokens to see if they get reset to the proper values
|
||||
beforeRedirectToken.set("");
|
||||
token.set("");
|
||||
|
||||
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
assertNull(beforeRedirectToken.get());
|
||||
assertEquals("Bearer newToken", token.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void connectionTimesOutIfServerDoesNotSendMessage() {
|
||||
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ package com.microsoft.signalr;
|
|||
|
||||
import static org.junit.jupiter.api.Assertions.*;
|
||||
|
||||
import com.google.gson.stream.JsonReader;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.io.StringReader;
|
||||
|
||||
|
||||
class NegotiateResponseTest {
|
||||
@Test
|
||||
|
|
@ -15,7 +18,7 @@ class NegotiateResponseTest {
|
|||
"availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," +
|
||||
"{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}," +
|
||||
"{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}";
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
|
||||
assertTrue(negotiateResponse.getAvailableTransports().contains("WebSockets"));
|
||||
assertTrue(negotiateResponse.getAvailableTransports().contains("ServerSentEvents"));
|
||||
assertTrue(negotiateResponse.getAvailableTransports().contains("LongPolling"));
|
||||
|
|
@ -29,7 +32,7 @@ class NegotiateResponseTest {
|
|||
String stringNegotiateResponse = "{\"url\":\"www.example.com\"," +
|
||||
"\"accessToken\":\"some_access_token\"," +
|
||||
"\"availableTransports\":[]}";
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
|
||||
assertTrue(negotiateResponse.getAvailableTransports().isEmpty());
|
||||
assertNull(negotiateResponse.getConnectionId());
|
||||
assertEquals("some_access_token", negotiateResponse.getAccessToken());
|
||||
|
|
@ -41,7 +44,7 @@ class NegotiateResponseTest {
|
|||
public void NegotiateResponseIgnoresExtraProperties() {
|
||||
String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," +
|
||||
"\"extra\":\"something\"}";
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
|
||||
assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId());
|
||||
}
|
||||
|
||||
|
|
@ -49,7 +52,7 @@ class NegotiateResponseTest {
|
|||
public void NegotiateResponseIgnoresExtraComplexProperties() {
|
||||
String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," +
|
||||
"\"extra\":[\"something\"]}";
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse);
|
||||
NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse)));
|
||||
assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue