[Java] Observe accessTokenProvider on error (#24344)

This commit is contained in:
Brennan 2020-08-20 16:27:20 -07:00 committed by GitHub
parent 9b5999f340
commit 04a704c929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 31 deletions

View File

@ -358,6 +358,8 @@ public class HubConnection implements AutoCloseable {
this.localHeaders.put("Authorization", "Bearer " + token); this.localHeaders.put("Authorization", "Bearer " + token);
} }
tokenCompletable.onComplete(); tokenCompletable.onComplete();
}, error -> {
tokenCompletable.onError(error);
}); });
stopError = null; stopError = null;

View File

@ -25,10 +25,10 @@ class LongPollingTransport implements Transport {
private final HttpClient pollingClient; private final HttpClient pollingClient;
private final Map<String, String> headers; private final Map<String, String> headers;
private static final int POLL_TIMEOUT = 100*1000; private static final int POLL_TIMEOUT = 100*1000;
private final Single<String> accessTokenProvider;
private volatile Boolean active = false; private volatile Boolean active = false;
private String pollUrl; private String pollUrl;
private String closeError; private String closeError;
private Single<String> accessTokenProvider;
private CompletableSubject receiveLoop = CompletableSubject.create(); private CompletableSubject receiveLoop = CompletableSubject.create();
private ExecutorService threadPool; private ExecutorService threadPool;
private ExecutorService onReceiveThread; private ExecutorService onReceiveThread;
@ -41,7 +41,6 @@ class LongPollingTransport implements Transport {
this.client = client; this.client = client;
this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT); this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT);
this.accessTokenProvider = accessTokenProvider; this.accessTokenProvider = accessTokenProvider;
this.onReceiveThread = Executors.newSingleThreadExecutor();
} }
//Package private active accessor for testing. //Package private active accessor for testing.
@ -49,13 +48,12 @@ class LongPollingTransport implements Transport {
return this.active; return this.active;
} }
private Single updateHeaderToken() { private Completable updateHeaderToken() {
return this.accessTokenProvider.flatMap((token) -> { return this.accessTokenProvider.doOnSuccess((token) -> {
if (!token.isEmpty()) { if (!token.isEmpty()) {
this.headers.put("Authorization", "Bearer " + token); this.headers.put("Authorization", "Bearer " + token);
} }
return Single.just(""); }).ignoreElement();
});
} }
@Override @Override
@ -65,7 +63,7 @@ class LongPollingTransport implements Transport {
this.url = url; this.url = url;
pollUrl = url + "&_=" + System.currentTimeMillis(); pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl); logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((r) -> { return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest(); HttpRequest request = new HttpRequest();
request.addHeaders(headers); request.addHeaders(headers);
return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@ -77,18 +75,26 @@ class LongPollingTransport implements Transport {
this.active = true; this.active = true;
} }
this.threadPool = Executors.newCachedThreadPool(); this.threadPool = Executors.newCachedThreadPool();
threadPool.execute(() -> poll(url).subscribeWith(receiveLoop)); threadPool.execute(() -> {
this.onReceiveThread = Executors.newSingleThreadExecutor();
receiveLoop.subscribe(() -> {
this.stop().onErrorComplete().subscribe();
}, e -> {
this.stop().onErrorComplete().subscribe();
});
poll(url).subscribeWith(receiveLoop);
});
return Completable.complete(); return Completable.complete();
}); });
}); }));
} }
private Completable poll(String url) { private Completable poll(String url) {
if (this.active) { if (this.active) {
pollUrl = url + "&_=" + System.currentTimeMillis(); pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl); logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((x) -> { return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest(); HttpRequest request = new HttpRequest();
request.addHeaders(headers); request.addHeaders(headers);
Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@ -111,13 +117,10 @@ class LongPollingTransport implements Transport {
}); });
return pollingCompletable; return pollingCompletable;
}); }));
} else { } else {
logger.debug("Long Polling transport polling complete."); logger.debug("Long Polling transport polling complete.");
receiveLoop.onComplete(); receiveLoop.onComplete();
if (!stopCalled.get()) {
return this.stop();
}
return Completable.complete(); return Completable.complete();
} }
} }
@ -127,11 +130,11 @@ class LongPollingTransport implements Transport {
if (!this.active) { if (!this.active) {
return Completable.error(new Exception("Cannot send unless the transport is active.")); return Completable.error(new Exception("Cannot send unless the transport is active."));
} }
return this.updateHeaderToken().flatMapCompletable((x) -> { return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest(); HttpRequest request = new HttpRequest();
request.addHeaders(headers); request.addHeaders(headers);
return Completable.fromSingle(this.client.post(url, message, request)); return this.client.post(url, message, request).ignoreElement();
}); }));
} }
@Override @Override
@ -152,23 +155,31 @@ class LongPollingTransport implements Transport {
@Override @Override
public Completable stop() { public Completable stop() {
if (!stopCalled.get()) { if (stopCalled.compareAndSet(false, true)) {
this.stopCalled.set(true);
this.active = false; this.active = false;
return this.updateHeaderToken().flatMapCompletable((x) -> { return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest(); HttpRequest request = new HttpRequest();
request.addHeaders(headers); request.addHeaders(headers);
this.pollingClient.delete(this.url, request); return this.pollingClient.delete(this.url, request).ignoreElement()
CompletableSubject stopCompletableSubject = CompletableSubject.create(); .andThen(receiveLoop)
return this.receiveLoop.andThen(Completable.defer(() -> { .doOnComplete(() -> {
logger.info("LongPolling transport stopped."); cleanup(this.closeError);
this.onReceiveThread.shutdown(); });
this.threadPool.shutdown(); })).doOnError(e -> {
this.onClose.invoke(this.closeError); cleanup(e.getMessage());
return Completable.complete();
})).subscribeWith(stopCompletableSubject);
}); });
} }
return Completable.complete(); return Completable.complete();
} }
private void cleanup(String error) {
logger.info("LongPolling transport stopped.");
if (this.onReceiveThread != null) {
this.onReceiveThread.shutdown();
}
if (this.threadPool != null) {
this.threadPool.shutdown();
}
this.onClose.invoke(error);
}
} }

View File

@ -2913,7 +2913,8 @@ class HubConnectionTest {
} }
assertTrue(close.blockingAwait(5, TimeUnit.SECONDS)); assertTrue(close.blockingAwait(5, TimeUnit.SECONDS));
return Single.just(new HttpResponse(204, "", TestUtils.emptyByteBuffer)); return Single.just(new HttpResponse(204, "", TestUtils.emptyByteBuffer));
}); })
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
HubConnection hubConnection = HubConnectionBuilder HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com") .create("http://example.com")
@ -2969,6 +2970,135 @@ class HubConnectionTest {
assertEquals(exception.getMessage(), "There were no compatible transports on the server."); assertEquals(exception.getMessage(), "There were no compatible transports on the server.");
} }
@Test
public void LongPollingTransportAccessTokenProviderThrowsOnInitialPoll() {
TestHttpClient client = new TestHttpClient()
.on("POST", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
})
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 1) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
try {
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (RuntimeException ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
}
@Test
public void LongPollingTransportAccessTokenProviderThrowsAfterHandshakeClosesConnection() {
AtomicInteger requestCount = new AtomicInteger(0);
CompletableSubject blockGet = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
if (requestCount.getAndIncrement() > 1) {
blockGet.blockingAwait();
}
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
})
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 5) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
CompletableSubject closed = CompletableSubject.create();
hubConnection.onClosed((e) -> {
closed.onComplete();
});
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
blockGet.onComplete();
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
}
@Test
public void LongPollingTransportAccessTokenProviderThrowsDuringStop() {
AtomicInteger requestCount = new AtomicInteger(0);
CompletableSubject blockGet = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
if (requestCount.getAndIncrement() > 1) {
blockGet.blockingAwait();
}
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
})
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 5) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
CompletableSubject closed = CompletableSubject.create();
hubConnection.onClosed((e) -> {
closed.onComplete();
});
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
try {
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (Exception ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
blockGet.onComplete();
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
}
@Test @Test
public void receivingServerSentEventsTransportFromNegotiateFails() { public void receivingServerSentEventsTransportFromNegotiateFails() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1",
@ -3265,6 +3395,21 @@ class HubConnectionTest {
assertEquals("Bearer secondRedirectToken", token.get()); assertEquals("Bearer secondRedirectToken", token.get());
} }
@Test
public void ErrorInAccessTokenProviderThrowsFromStart() {
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withAccessTokenProvider(Single.defer(() -> Single.error(new RuntimeException("Error from accessTokenProvider"))))
.build();
try {
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (RuntimeException ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
}
@Test @Test
public void connectionTimesOutIfServerDoesNotSendMessage() { public void connectionTimesOutIfServerDoesNotSendMessage() {
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com"); HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");

View File

@ -86,7 +86,8 @@ public class LongPollingTransportTest {
return Single.just(new HttpResponse(200, "", TestUtils.emptyByteBuffer)); return Single.just(new HttpResponse(200, "", TestUtils.emptyByteBuffer));
} }
return Single.just(new HttpResponse(999, "", TestUtils.emptyByteBuffer)); return Single.just(new HttpResponse(999, "", TestUtils.emptyByteBuffer));
}); })
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
Map<String, String> headers = new HashMap<>(); Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));