Don't remove auth header when stopping HubConnection (#10635)
This commit is contained in:
parent
41ce223c1c
commit
f4e3c0a171
|
|
@ -41,7 +41,8 @@ public class HubConnection {
|
|||
private final boolean skipNegotiate;
|
||||
private Single<String> accessTokenProvider;
|
||||
private Single<String> redirectAccessTokenProvider;
|
||||
private final Map<String, String> headers = new HashMap<>();
|
||||
private final Map<String, String> headers;
|
||||
private final Map<String, String> localHeaders = new HashMap<>();
|
||||
private ConnectionState connectionState = null;
|
||||
private HttpClient httpClient;
|
||||
private String stopError;
|
||||
|
|
@ -147,10 +148,7 @@ public class HubConnection {
|
|||
this.handshakeResponseTimeout = handshakeResponseTimeout;
|
||||
}
|
||||
|
||||
if (headers != null) {
|
||||
this.headers.putAll(headers);
|
||||
}
|
||||
|
||||
this.headers = headers;
|
||||
this.skipNegotiate = skipNegotiate;
|
||||
|
||||
this.callback = (payload) -> {
|
||||
|
|
@ -255,7 +253,7 @@ public class HubConnection {
|
|||
|
||||
private Single<NegotiateResponse> handleNegotiate(String url) {
|
||||
HttpRequest request = new HttpRequest();
|
||||
request.addHeaders(this.headers);
|
||||
request.addHeaders(this.localHeaders);
|
||||
|
||||
return httpClient.post(Negotiate.resolveNegotiateUrl(url), request).map((response) -> {
|
||||
if (response.getStatusCode() != 200) {
|
||||
|
|
@ -274,7 +272,7 @@ public class HubConnection {
|
|||
// We know the Single is non blocking in this case
|
||||
// It's fine to call blockingGet() on it.
|
||||
String token = this.redirectAccessTokenProvider.blockingGet();
|
||||
this.headers.put("Authorization", "Bearer " + token);
|
||||
this.localHeaders.put("Authorization", "Bearer " + token);
|
||||
}
|
||||
|
||||
return negotiateResponse;
|
||||
|
|
@ -303,9 +301,13 @@ public class HubConnection {
|
|||
handshakeResponseSubject = CompletableSubject.create();
|
||||
handshakeReceived = false;
|
||||
CompletableSubject tokenCompletable = CompletableSubject.create();
|
||||
if (headers != null) {
|
||||
this.localHeaders.putAll(headers);
|
||||
}
|
||||
|
||||
accessTokenProvider.subscribe(token -> {
|
||||
if (token != null && !token.isEmpty()) {
|
||||
this.headers.put("Authorization", "Bearer " + token);
|
||||
this.localHeaders.put("Authorization", "Bearer " + token);
|
||||
}
|
||||
tokenCompletable.onComplete();
|
||||
});
|
||||
|
|
@ -326,10 +328,10 @@ public class HubConnection {
|
|||
Single<String> tokenProvider = negotiateResponse.getAccessToken() != null ? Single.just(negotiateResponse.getAccessToken()) : accessTokenProvider;
|
||||
switch (transportEnum) {
|
||||
case LONG_POLLING:
|
||||
transport = new LongPollingTransport(headers, httpClient, tokenProvider);
|
||||
transport = new LongPollingTransport(localHeaders, httpClient, tokenProvider);
|
||||
break;
|
||||
default:
|
||||
transport = new WebSocketTransport(headers, httpClient);
|
||||
transport = new WebSocketTransport(localHeaders, httpClient);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -493,7 +495,7 @@ public class HubConnection {
|
|||
redirectAccessTokenProvider = null;
|
||||
connectionId = null;
|
||||
transportEnum = TransportEnum.ALL;
|
||||
this.headers.remove("Authorization");
|
||||
this.localHeaders.clear();
|
||||
} finally {
|
||||
hubConnectionStateLock.unlock();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,9 +15,8 @@ class WebSocketTransport implements Transport {
|
|||
private OnReceiveCallBack onReceiveCallBack;
|
||||
private TransportOnClosedCallback onClose;
|
||||
private String url;
|
||||
private final HttpClient client;
|
||||
private final Map<String, String> headers;
|
||||
|
||||
private HttpClient client;
|
||||
private Map<String, String> headers;
|
||||
private final Logger logger = LoggerFactory.getLogger(WebSocketTransport.class);
|
||||
|
||||
private static final String HTTP = "http";
|
||||
|
|
@ -25,7 +24,6 @@ class WebSocketTransport implements Transport {
|
|||
private static final String WS = "ws";
|
||||
private static final String WSS = "wss";
|
||||
|
||||
|
||||
public WebSocketTransport(Map<String, String> headers, HttpClient client) {
|
||||
this.client = client;
|
||||
this.headers = headers;
|
||||
|
|
|
|||
|
|
@ -1727,7 +1727,7 @@ class HubConnectionTest {
|
|||
beforeRedirectToken.set("");
|
||||
token.set("");
|
||||
|
||||
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
|
||||
// Restart the connection to make sure that the original accessTokenProvider that we registered is still registered before the redirect.
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
|
|
@ -1818,7 +1818,7 @@ class HubConnectionTest {
|
|||
beforeRedirectToken.set("");
|
||||
token.set("");
|
||||
|
||||
// Restart the connection to make sure that the orignal accessTokenProvider that we registered is still registered before the redirect.
|
||||
// Restart the connection to make sure that the original accessTokenProvider that we registered is still registered before the redirect.
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
|
|
@ -1826,6 +1826,52 @@ class HubConnectionTest {
|
|||
assertEquals("Bearer newToken", token.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void authorizationHeaderFromNegotiateGetsSetToNewValue() {
|
||||
AtomicReference<String> token = new AtomicReference<>();
|
||||
AtomicReference<String> redirectToken = new AtomicReference<>();
|
||||
AtomicInteger redirectCount = new AtomicInteger();
|
||||
|
||||
TestHttpClient client = new TestHttpClient()
|
||||
.on("POST", "http://example.com/negotiate", (req) -> {
|
||||
if(redirectCount.get() == 0){
|
||||
redirectCount.incrementAndGet();
|
||||
redirectToken.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"firstRedirectToken\"}"));
|
||||
} else {
|
||||
redirectToken.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"secondRedirectToken\"}"));
|
||||
}
|
||||
})
|
||||
.on("POST", "http://testexample.com/negotiate", (req) -> {
|
||||
token.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
|
||||
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
|
||||
});
|
||||
|
||||
MockTransport transport = new MockTransport(true);
|
||||
HubConnection hubConnection = HubConnectionBuilder
|
||||
.create("http://example.com")
|
||||
.withTransportImplementation(transport)
|
||||
.withHttpClient(client)
|
||||
.build();
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals("Bearer firstRedirectToken", token.get());
|
||||
|
||||
// Clear the tokens to see if they get reset to the proper values
|
||||
redirectToken.set("");
|
||||
token.set("");
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
assertNull(redirectToken.get());
|
||||
assertEquals("Bearer secondRedirectToken", token.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void connectionTimesOutIfServerDoesNotSendMessage() {
|
||||
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");
|
||||
|
|
@ -1885,6 +1931,79 @@ class HubConnectionTest {
|
|||
assertEquals("ExampleValue", header.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void headersAreNotClearedWhenConnectionIsRestarted() {
|
||||
AtomicReference<String> header = new AtomicReference<>();
|
||||
TestHttpClient client = new TestHttpClient()
|
||||
.on("POST", "http://example.com/negotiate",
|
||||
(req) -> {
|
||||
header.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
|
||||
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
|
||||
});
|
||||
|
||||
MockTransport transport = new MockTransport();
|
||||
HubConnection hubConnection = HubConnectionBuilder.create("http://example.com")
|
||||
.withTransportImplementation(transport)
|
||||
.withHttpClient(client)
|
||||
.withHeader("Authorization", "ExampleValue")
|
||||
.build();
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop();
|
||||
assertEquals("ExampleValue", header.get());
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
assertEquals("ExampleValue", header.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void userSetAuthHeaderIsNotClearedAfterRedirect() {
|
||||
AtomicReference<String> beforeRedirectHeader = new AtomicReference<>();
|
||||
AtomicReference<String> afterRedirectHeader = new AtomicReference<>();
|
||||
|
||||
TestHttpClient client = new TestHttpClient()
|
||||
.on("POST", "http://example.com/negotiate",
|
||||
(req) -> {
|
||||
beforeRedirectHeader.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\",\"accessToken\":\"redirectToken\"}\"}"));
|
||||
})
|
||||
.on("POST", "http://testexample.com/negotiate",
|
||||
(req) -> {
|
||||
afterRedirectHeader.set(req.getHeaders().get("Authorization"));
|
||||
return Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
|
||||
+ "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"));
|
||||
});
|
||||
|
||||
MockTransport transport = new MockTransport();
|
||||
HubConnection hubConnection = HubConnectionBuilder.create("http://example.com")
|
||||
.withTransportImplementation(transport)
|
||||
.withHttpClient(client)
|
||||
.withHeader("Authorization", "ExampleValue")
|
||||
.build();
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop().blockingAwait();
|
||||
assertEquals("ExampleValue", beforeRedirectHeader.get());
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
assertEquals("Bearer redirectToken", afterRedirectHeader.get());
|
||||
|
||||
// Making sure you can do this after restarting the HubConnection.
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
hubConnection.stop().blockingAwait();
|
||||
assertEquals("ExampleValue", beforeRedirectHeader.get());
|
||||
|
||||
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
|
||||
assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState());
|
||||
assertEquals("Bearer redirectToken", afterRedirectHeader.get());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sameHeaderSetTwiceGetsOverwritten() {
|
||||
AtomicReference<String> header = new AtomicReference<>();
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.*;
|
|||
import java.util.HashMap;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
import io.reactivex.Single;
|
||||
import org.junit.jupiter.params.ParameterizedTest;
|
||||
import org.junit.jupiter.params.provider.Arguments;
|
||||
import org.junit.jupiter.params.provider.MethodSource;
|
||||
|
|
@ -30,4 +31,4 @@ class WebSocketTransportUrlFormatTest {
|
|||
} catch (Exception e) {}
|
||||
assertEquals(expectedUrl, webSocketTransport.getUrl());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue