diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index faf8ea2b27..3449b27620 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -3,7 +3,7 @@ package com.microsoft.signalr; -import java.lang.reflect.Array; +import java.io.StringReader; import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; @@ -11,6 +11,7 @@ import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import com.google.gson.stream.JsonReader; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -38,6 +39,7 @@ public class HubConnection { private List onClosedCallbackList; private final boolean skipNegotiate; private Single accessTokenProvider; + private Single redirectAccessTokenProvider; private final Map headers = new HashMap<>(); private ConnectionState connectionState = null; private HttpClient httpClient; @@ -180,6 +182,7 @@ public class HubConnection { logger.error("Failed to bind arguments received in invocation '{}' of '{}'.", msg.getInvocationId(), msg.getTarget(), msg.getException()); break; case INVOCATION: + InvocationMessage invocationMessage = (InvocationMessage) message; List handlers = this.handlers.get(invocationMessage.getTarget()); if (handlers != null) { @@ -248,17 +251,18 @@ public class HubConnection { throw new RuntimeException(String.format("Unexpected status code returned from negotiate: %d %s.", response.getStatusCode(), response.getStatusText())); } - NegotiateResponse negotiateResponse = new NegotiateResponse(response.getContent()); + JsonReader reader = new JsonReader(new StringReader(response.getContent())); + NegotiateResponse negotiateResponse = new NegotiateResponse(reader); if (negotiateResponse.getError() != null) { throw new RuntimeException(negotiateResponse.getError()); } if (negotiateResponse.getAccessToken() != null) { - this.accessTokenProvider = Single.just(negotiateResponse.getAccessToken()); + this.redirectAccessTokenProvider = Single.just(negotiateResponse.getAccessToken()); // We know the Single is non blocking in this case // It's fine to call blockingGet() on it. - String token = this.accessTokenProvider.blockingGet(); + String token = this.redirectAccessTokenProvider.blockingGet(); this.headers.put("Authorization", "Bearer " + token); } @@ -296,21 +300,22 @@ public class HubConnection { }); stopError = null; - Single negotiate = null; + Single negotiate = null; if (!skipNegotiate) { negotiate = tokenCompletable.andThen(Single.defer(() -> startNegotiate(baseUrl, 0))); } else { - negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(baseUrl))); + negotiate = tokenCompletable.andThen(Single.defer(() -> Single.just(new NegotiateResponse(baseUrl)))); } CompletableSubject start = CompletableSubject.create(); - negotiate.flatMapCompletable(url -> { + negotiate.flatMapCompletable(negotiateResponse -> { logger.debug("Starting HubConnection."); if (transport == null) { + Single tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider; switch (transportEnum) { case LONG_POLLING: - transport = new LongPollingTransport(headers, httpClient, accessTokenProvider); + transport = new LongPollingTransport(headers, httpClient, tokenProvider); break; default: transport = new WebSocketTransport(headers, httpClient); @@ -320,7 +325,7 @@ public class HubConnection { transport.setOnReceive(this.callback); transport.setOnClose((message) -> stopConnection(message)); - return transport.start(url).andThen(Completable.defer(() -> { + return transport.start(negotiateResponse.getFinalUrl()).andThen(Completable.defer(() -> { String handshake = HandshakeProtocol.createHandshakeRequestMessage( new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); @@ -376,7 +381,7 @@ public class HubConnection { }, new Date(0), tickRate); } - private Single startNegotiate(String url, int negotiateAttempts) { + private Single startNegotiate(String url, int negotiateAttempts) { if (hubConnectionState != HubConnectionState.DISCONNECTED) { return Single.just(null); } @@ -409,8 +414,8 @@ public class HubConnection { finalUrl = url + "?id=" + response.getConnectionId(); } } - - return Single.just(finalUrl); + response.setFinalUrl(finalUrl); + return Single.just(response); } return startNegotiate(response.getRedirectUrl(), negotiateAttempts + 1); @@ -473,6 +478,9 @@ public class HubConnection { logger.info("HubConnection stopped."); hubConnectionState = HubConnectionState.DISCONNECTED; handshakeResponseSubject.onComplete(); + redirectAccessTokenProvider = null; + transportEnum = TransportEnum.ALL; + this.headers.remove("Authorization"); } finally { hubConnectionStateLock.unlock(); } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java index 282fe093fc..367fb0d4e6 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/NegotiateResponse.java @@ -16,10 +16,10 @@ class NegotiateResponse { private String redirectUrl; private String accessToken; private String error; + private String finalUrl; - public NegotiateResponse(String negotiatePayload) { + public NegotiateResponse(JsonReader reader) { try { - JsonReader reader = new JsonReader(new StringReader(negotiatePayload)); reader.beginObject(); do { @@ -79,6 +79,10 @@ class NegotiateResponse { } } + public NegotiateResponse(String url) { + this.finalUrl = url; + } + public String getConnectionId() { return connectionId; } @@ -98,4 +102,12 @@ class NegotiateResponse { public String getError() { return error; } + + public String getFinalUrl(){ + return finalUrl; + } + + public void setFinalUrl(String url) { + this.finalUrl = url; + } } diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index 06459feaf0..f68f1f1f48 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -1651,6 +1651,48 @@ class HubConnectionTest { hubConnection.stop(); } + @Test + public void accessTokenProviderReferenceIsKeptAfterNegotiateRedirect() { + AtomicReference token = new AtomicReference<>(); + AtomicReference beforeRedirectToken = new AtomicReference<>(); + + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate", (req) -> { + beforeRedirectToken.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}")); + }) + .on("POST", "http://testexample.com/negotiate", (req) -> { + token.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + }); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .withAccessTokenProvider(Single.just("User Registered Token")) + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals("Bearer User Registered Token", beforeRedirectToken.get()); + assertEquals("Bearer newToken", token.get()); + + // Clear the tokens to see if they get reset to the proper values + beforeRedirectToken.set(""); + token.set(""); + + // Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect. + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop(); + assertEquals("Bearer User Registered Token", beforeRedirectToken.get()); + assertEquals("Bearer newToken", token.get()); + } + @Test public void accessTokenProviderIsUsedForNegotiate() { AtomicReference token = new AtomicReference<>(); @@ -1702,6 +1744,46 @@ class HubConnectionTest { assertEquals("Bearer newToken", token.get()); } + @Test + public void authorizationHeaderFromNegotiateGetsClearedAfterStopping() { + AtomicReference token = new AtomicReference<>(); + AtomicReference beforeRedirectToken = new AtomicReference<>(); + + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate", (req) -> { + beforeRedirectToken.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"newToken\"}")); + }) + .on("POST", "http://testexample.com/negotiate", (req) -> { + token.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + }); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals("Bearer newToken", token.get()); + + // Clear the tokens to see if they get reset to the proper values + beforeRedirectToken.set(""); + token.set(""); + + // Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect. + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop(); + assertNull(beforeRedirectToken.get()); + assertEquals("Bearer newToken", token.get()); + } + @Test public void connectionTimesOutIfServerDoesNotSendMessage() { HubConnection hubConnection = TestUtils.createHubConnection("http://example.com"); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java index 9ba6fc556e..a366e4ab8b 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/NegotiateResponseTest.java @@ -5,8 +5,11 @@ package com.microsoft.signalr; import static org.junit.jupiter.api.Assertions.*; +import com.google.gson.stream.JsonReader; import org.junit.jupiter.api.Test; +import java.io.StringReader; + class NegotiateResponseTest { @Test @@ -15,7 +18,7 @@ class NegotiateResponseTest { "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}," + "{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}," + "{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"; - NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse); + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); assertTrue(negotiateResponse.getAvailableTransports().contains("WebSockets")); assertTrue(negotiateResponse.getAvailableTransports().contains("ServerSentEvents")); assertTrue(negotiateResponse.getAvailableTransports().contains("LongPolling")); @@ -29,7 +32,7 @@ class NegotiateResponseTest { String stringNegotiateResponse = "{\"url\":\"www.example.com\"," + "\"accessToken\":\"some_access_token\"," + "\"availableTransports\":[]}"; - NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse); + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); assertTrue(negotiateResponse.getAvailableTransports().isEmpty()); assertNull(negotiateResponse.getConnectionId()); assertEquals("some_access_token", negotiateResponse.getAccessToken()); @@ -41,7 +44,7 @@ class NegotiateResponseTest { public void NegotiateResponseIgnoresExtraProperties() { String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + "\"extra\":\"something\"}"; - NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse); + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); } @@ -49,7 +52,7 @@ class NegotiateResponseTest { public void NegotiateResponseIgnoresExtraComplexProperties() { String stringNegotiateResponse = "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\"," + "\"extra\":[\"something\"]}"; - NegotiateResponse negotiateResponse = new NegotiateResponse(stringNegotiateResponse); + NegotiateResponse negotiateResponse = new NegotiateResponse(new JsonReader(new StringReader(stringNegotiateResponse))); assertEquals("bVOiRPG8-6YiJ6d7ZcTOVQ", negotiateResponse.getConnectionId()); } }