diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index 0dc5183f61..d44cd464c5 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -41,7 +41,8 @@ public class HubConnection { private final boolean skipNegotiate; private Single accessTokenProvider; private Single redirectAccessTokenProvider; - private final Map headers = new HashMap<>(); + private final Map headers; + private final Map localHeaders = new HashMap<>(); private ConnectionState connectionState = null; private HttpClient httpClient; private String stopError; @@ -147,10 +148,7 @@ public class HubConnection { this.handshakeResponseTimeout = handshakeResponseTimeout; } - if (headers != null) { - this.headers.putAll(headers); - } - + this.headers = headers; this.skipNegotiate = skipNegotiate; this.callback = (payload) -> { @@ -255,7 +253,7 @@ public class HubConnection { private Single handleNegotiate(String url) { HttpRequest request = new HttpRequest(); - request.addHeaders(this.headers); + request.addHeaders(this.localHeaders); return httpClient.post(Negotiate.resolveNegotiateUrl(url), request).map((response) -> { if (response.getStatusCode() != 200) { @@ -274,7 +272,7 @@ public class HubConnection { // We know the Single is non blocking in this case // It's fine to call blockingGet() on it. String token = this.redirectAccessTokenProvider.blockingGet(); - this.headers.put("Authorization", "Bearer " + token); + this.localHeaders.put("Authorization", "Bearer " + token); } return negotiateResponse; @@ -303,9 +301,13 @@ public class HubConnection { handshakeResponseSubject = CompletableSubject.create(); handshakeReceived = false; CompletableSubject tokenCompletable = CompletableSubject.create(); + if (headers != null) { + this.localHeaders.putAll(headers); + } + accessTokenProvider.subscribe(token -> { if (token != null && !token.isEmpty()) { - this.headers.put("Authorization", "Bearer " + token); + this.localHeaders.put("Authorization", "Bearer " + token); } tokenCompletable.onComplete(); }); @@ -326,10 +328,10 @@ public class HubConnection { Single tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider; switch (transportEnum) { case LONG_POLLING: - transport = new LongPollingTransport(headers, httpClient, tokenProvider); + transport = new LongPollingTransport(localHeaders, httpClient, tokenProvider); break; default: - transport = new WebSocketTransport(headers, httpClient); + transport = new WebSocketTransport(localHeaders, httpClient); } } @@ -493,7 +495,7 @@ public class HubConnection { redirectAccessTokenProvider = null; connectionId = null; transportEnum = TransportEnum.ALL; - this.headers.remove("Authorization"); + this.localHeaders.clear(); } finally { hubConnectionStateLock.unlock(); } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/WebSocketTransport.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/WebSocketTransport.java index bbdf97d95f..f15ffd4bce 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/WebSocketTransport.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/WebSocketTransport.java @@ -15,9 +15,8 @@ class WebSocketTransport implements Transport { private OnReceiveCallBack onReceiveCallBack; private TransportOnClosedCallback onClose; private String url; - private final HttpClient client; - private final Map headers; - + private HttpClient client; + private Map headers; private final Logger logger = LoggerFactory.getLogger(WebSocketTransport.class); private static final String HTTP = "http"; @@ -25,7 +24,6 @@ class WebSocketTransport implements Transport { private static final String WS = "ws"; private static final String WSS = "wss"; - public WebSocketTransport(Map headers, HttpClient client) { this.client = client; this.headers = headers; diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index 919dffe743..6181ca22b3 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -1727,7 +1727,7 @@ class HubConnectionTest { beforeRedirectToken.set(""); token.set(""); - // Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect. + // Restart the connection to make sure that the original accessTokenProvider that we registered is still registered before the redirect. hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); hubConnection.stop(); @@ -1818,7 +1818,7 @@ class HubConnectionTest { beforeRedirectToken.set(""); token.set(""); - // Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect. + // Restart the connection to make sure that the original accessTokenProvider that we registered is still registered before the redirect. hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); hubConnection.stop(); @@ -1826,6 +1826,52 @@ class HubConnectionTest { assertEquals("Bearer newToken", token.get()); } + @Test + public void authorizationHeaderFromNegotiateGetsSetToNewValue() { + AtomicReference token = new AtomicReference<>(); + AtomicReference redirectToken = new AtomicReference<>(); + AtomicInteger redirectCount = new AtomicInteger(); + + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate", (req) -> { + if(redirectCount.get() == 0){ + redirectCount.incrementAndGet(); + redirectToken.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"firstRedirectToken\"}")); + } else { + redirectToken.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"secondRedirectToken\"}")); + } + }) + .on("POST", "http://testexample.com/negotiate", (req) -> { + token.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + }); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals("Bearer firstRedirectToken", token.get()); + + // Clear the tokens to see if they get reset to the proper values + redirectToken.set(""); + token.set(""); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop(); + assertNull(redirectToken.get()); + assertEquals("Bearer secondRedirectToken", token.get()); + } + @Test public void connectionTimesOutIfServerDoesNotSendMessage() { HubConnection hubConnection = TestUtils.createHubConnection("http://example.com"); @@ -1885,6 +1931,79 @@ class HubConnectionTest { assertEquals("ExampleValue", header.get()); } + @Test + public void headersAreNotClearedWhenConnectionIsRestarted() { + AtomicReference header = new AtomicReference<>(); + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate", + (req) -> { + header.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + }); + + MockTransport transport = new MockTransport(); + HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .withHeader("Authorization", "ExampleValue") + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop(); + assertEquals("ExampleValue", header.get()); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("ExampleValue", header.get()); + } + + @Test + public void userSetAuthHeaderIsNotClearedAfterRedirect() { + AtomicReference beforeRedirectHeader = new AtomicReference<>(); + AtomicReference afterRedirectHeader = new AtomicReference<>(); + + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate", + (req) -> { + beforeRedirectHeader.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"redirectToken\"}\"}")); + }) + .on("POST", "http://testexample.com/negotiate", + (req) -> { + afterRedirectHeader.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); + }); + + MockTransport transport = new MockTransport(); + HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .withHeader("Authorization", "ExampleValue") + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop().blockingAwait(); + assertEquals("ExampleValue", beforeRedirectHeader.get()); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("Bearer redirectToken", afterRedirectHeader.get()); + + // Making sure you can do this after restarting the HubConnection. + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + hubConnection.stop().blockingAwait(); + assertEquals("ExampleValue", beforeRedirectHeader.get()); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("Bearer redirectToken", afterRedirectHeader.get()); + } + @Test public void sameHeaderSetTwiceGetsOverwritten() { AtomicReference header = new AtomicReference<>(); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java index aff1e49598..2631084a34 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportUrlFormatTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.*; import java.util.HashMap; import java.util.stream.Stream; +import io.reactivex.Single; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; @@ -30,4 +31,4 @@ class WebSocketTransportUrlFormatTest { } catch (Exception e) {} assertEquals(expectedUrl, webSocketTransport.getUrl()); } -} \ No newline at end of file +}