[Java] Observe accessTokenProvider on error (#24344)

This commit is contained in:
Brennan 2020-08-20 16:27:20 -07:00 committed by GitHub
parent 9b5999f340
commit 04a704c929
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 31 deletions

View File

@ -358,6 +358,8 @@ public class HubConnection implements AutoCloseable {
this.localHeaders.put("Authorization", "Bearer " + token);
}
tokenCompletable.onComplete();
}, error -> {
tokenCompletable.onError(error);
});
stopError = null;

View File

@ -25,10 +25,10 @@ class LongPollingTransport implements Transport {
private final HttpClient pollingClient;
private final Map<String, String> headers;
private static final int POLL_TIMEOUT = 100*1000;
private final Single<String> accessTokenProvider;
private volatile Boolean active = false;
private String pollUrl;
private String closeError;
private Single<String> accessTokenProvider;
private CompletableSubject receiveLoop = CompletableSubject.create();
private ExecutorService threadPool;
private ExecutorService onReceiveThread;
@ -41,7 +41,6 @@ class LongPollingTransport implements Transport {
this.client = client;
this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT);
this.accessTokenProvider = accessTokenProvider;
this.onReceiveThread = Executors.newSingleThreadExecutor();
}
//Package private active accessor for testing.
@ -49,13 +48,12 @@ class LongPollingTransport implements Transport {
return this.active;
}
private Single updateHeaderToken() {
return this.accessTokenProvider.flatMap((token) -> {
private Completable updateHeaderToken() {
return this.accessTokenProvider.doOnSuccess((token) -> {
if (!token.isEmpty()) {
this.headers.put("Authorization", "Bearer " + token);
}
return Single.just("");
});
}).ignoreElement();
}
@Override
@ -65,7 +63,7 @@ class LongPollingTransport implements Transport {
this.url = url;
pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((r) -> {
return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@ -77,18 +75,26 @@ class LongPollingTransport implements Transport {
this.active = true;
}
this.threadPool = Executors.newCachedThreadPool();
threadPool.execute(() -> poll(url).subscribeWith(receiveLoop));
threadPool.execute(() -> {
this.onReceiveThread = Executors.newSingleThreadExecutor();
receiveLoop.subscribe(() -> {
this.stop().onErrorComplete().subscribe();
}, e -> {
this.stop().onErrorComplete().subscribe();
});
poll(url).subscribeWith(receiveLoop);
});
return Completable.complete();
});
});
}));
}
private Completable poll(String url) {
if (this.active) {
pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((x) -> {
return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
@ -111,13 +117,10 @@ class LongPollingTransport implements Transport {
});
return pollingCompletable;
});
}));
} else {
logger.debug("Long Polling transport polling complete.");
receiveLoop.onComplete();
if (!stopCalled.get()) {
return this.stop();
}
return Completable.complete();
}
}
@ -127,11 +130,11 @@ class LongPollingTransport implements Transport {
if (!this.active) {
return Completable.error(new Exception("Cannot send unless the transport is active."));
}
return this.updateHeaderToken().flatMapCompletable((x) -> {
return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
return Completable.fromSingle(this.client.post(url, message, request));
});
return this.client.post(url, message, request).ignoreElement();
}));
}
@Override
@ -152,23 +155,31 @@ class LongPollingTransport implements Transport {
@Override
public Completable stop() {
if (!stopCalled.get()) {
this.stopCalled.set(true);
if (stopCalled.compareAndSet(false, true)) {
this.active = false;
return this.updateHeaderToken().flatMapCompletable((x) -> {
return this.updateHeaderToken().andThen(Completable.defer(() -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
this.pollingClient.delete(this.url, request);
CompletableSubject stopCompletableSubject = CompletableSubject.create();
return this.receiveLoop.andThen(Completable.defer(() -> {
logger.info("LongPolling transport stopped.");
this.onReceiveThread.shutdown();
this.threadPool.shutdown();
this.onClose.invoke(this.closeError);
return Completable.complete();
})).subscribeWith(stopCompletableSubject);
return this.pollingClient.delete(this.url, request).ignoreElement()
.andThen(receiveLoop)
.doOnComplete(() -> {
cleanup(this.closeError);
});
})).doOnError(e -> {
cleanup(e.getMessage());
});
}
return Completable.complete();
}
private void cleanup(String error) {
logger.info("LongPolling transport stopped.");
if (this.onReceiveThread != null) {
this.onReceiveThread.shutdown();
}
if (this.threadPool != null) {
this.threadPool.shutdown();
}
this.onClose.invoke(error);
}
}

View File

@ -2913,7 +2913,8 @@ class HubConnectionTest {
}
assertTrue(close.blockingAwait(5, TimeUnit.SECONDS));
return Single.just(new HttpResponse(204, "", TestUtils.emptyByteBuffer));
});
})
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
@ -2969,6 +2970,135 @@ class HubConnectionTest {
assertEquals(exception.getMessage(), "There were no compatible transports on the server.");
}
@Test
public void LongPollingTransportAccessTokenProviderThrowsOnInitialPoll() {
TestHttpClient client = new TestHttpClient()
.on("POST", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
})
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 1) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
try {
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (RuntimeException ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
}
@Test
public void LongPollingTransportAccessTokenProviderThrowsAfterHandshakeClosesConnection() {
AtomicInteger requestCount = new AtomicInteger(0);
CompletableSubject blockGet = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
if (requestCount.getAndIncrement() > 1) {
blockGet.blockingAwait();
}
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
})
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 5) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
CompletableSubject closed = CompletableSubject.create();
hubConnection.onClosed((e) -> {
closed.onComplete();
});
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
blockGet.onComplete();
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
}
@Test
public void LongPollingTransportAccessTokenProviderThrowsDuringStop() {
AtomicInteger requestCount = new AtomicInteger(0);
CompletableSubject blockGet = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("POST", "http://example.com/negotiate?negotiateVersion=1",
(req) -> Single.just(new HttpResponse(200, "",
TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))))
.on("GET", (req) -> {
if (requestCount.getAndIncrement() > 1) {
blockGet.blockingAwait();
}
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR)));
})
.on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> {
return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")));
});
AtomicInteger accessTokenCount = new AtomicInteger(0);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(TransportEnum.LONG_POLLING)
.withHttpClient(client)
.withAccessTokenProvider(Single.defer(() -> {
if (accessTokenCount.getAndIncrement() < 5) {
return Single.just("");
}
return Single.error(new RuntimeException("Error from accessTokenProvider"));
}))
.build();
CompletableSubject closed = CompletableSubject.create();
hubConnection.onClosed((e) -> {
closed.onComplete();
});
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
try {
hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (Exception ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
blockGet.onComplete();
closed.timeout(1, TimeUnit.SECONDS).blockingAwait();
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
}
@Test
public void receivingServerSentEventsTransportFromNegotiateFails() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1",
@ -3265,6 +3395,21 @@ class HubConnectionTest {
assertEquals("Bearer secondRedirectToken", token.get());
}
@Test
public void ErrorInAccessTokenProviderThrowsFromStart() {
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withAccessTokenProvider(Single.defer(() -> Single.error(new RuntimeException("Error from accessTokenProvider"))))
.build();
try {
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(false);
} catch (RuntimeException ex) {
assertEquals("Error from accessTokenProvider", ex.getMessage());
}
}
@Test
public void connectionTimesOutIfServerDoesNotSendMessage() {
HubConnection hubConnection = TestUtils.createHubConnection("http://example.com");

View File

@ -86,7 +86,8 @@ public class LongPollingTransportTest {
return Single.just(new HttpResponse(200, "", TestUtils.emptyByteBuffer));
}
return Single.just(new HttpResponse(999, "", TestUtils.emptyByteBuffer));
});
})
.on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))));
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));