diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index 54f92c9064..12349152dc 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -358,6 +358,8 @@ public class HubConnection implements AutoCloseable { this.localHeaders.put("Authorization", "Bearer " + token); } tokenCompletable.onComplete(); + }, error -> { + tokenCompletable.onError(error); }); stopError = null; diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java index d17c202148..00dc46f97b 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java @@ -25,10 +25,10 @@ class LongPollingTransport implements Transport { private final HttpClient pollingClient; private final Map headers; private static final int POLL_TIMEOUT = 100*1000; + private final Single accessTokenProvider; private volatile Boolean active = false; private String pollUrl; private String closeError; - private Single accessTokenProvider; private CompletableSubject receiveLoop = CompletableSubject.create(); private ExecutorService threadPool; private ExecutorService onReceiveThread; @@ -41,7 +41,6 @@ class LongPollingTransport implements Transport { this.client = client; this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT); this.accessTokenProvider = accessTokenProvider; - this.onReceiveThread = Executors.newSingleThreadExecutor(); } //Package private active accessor for testing. @@ -49,13 +48,12 @@ class LongPollingTransport implements Transport { return this.active; } - private Single updateHeaderToken() { - return this.accessTokenProvider.flatMap((token) -> { + private Completable updateHeaderToken() { + return this.accessTokenProvider.doOnSuccess((token) -> { if (!token.isEmpty()) { this.headers.put("Authorization", "Bearer " + token); } - return Single.just(""); - }); + }).ignoreElement(); } @Override @@ -65,7 +63,7 @@ class LongPollingTransport implements Transport { this.url = url; pollUrl = url + "&_=" + System.currentTimeMillis(); logger.debug("Polling {}.", pollUrl); - return this.updateHeaderToken().flatMapCompletable((r) -> { + return this.updateHeaderToken().andThen(Completable.defer(() -> { HttpRequest request = new HttpRequest(); request.addHeaders(headers); return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { @@ -77,18 +75,26 @@ class LongPollingTransport implements Transport { this.active = true; } this.threadPool = Executors.newCachedThreadPool(); - threadPool.execute(() -> poll(url).subscribeWith(receiveLoop)); + threadPool.execute(() -> { + this.onReceiveThread = Executors.newSingleThreadExecutor(); + receiveLoop.subscribe(() -> { + this.stop().onErrorComplete().subscribe(); + }, e -> { + this.stop().onErrorComplete().subscribe(); + }); + poll(url).subscribeWith(receiveLoop); + }); return Completable.complete(); }); - }); + })); } private Completable poll(String url) { if (this.active) { pollUrl = url + "&_=" + System.currentTimeMillis(); logger.debug("Polling {}.", pollUrl); - return this.updateHeaderToken().flatMapCompletable((x) -> { + return this.updateHeaderToken().andThen(Completable.defer(() -> { HttpRequest request = new HttpRequest(); request.addHeaders(headers); Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { @@ -111,13 +117,10 @@ class LongPollingTransport implements Transport { }); return pollingCompletable; - }); + })); } else { logger.debug("Long Polling transport polling complete."); receiveLoop.onComplete(); - if (!stopCalled.get()) { - return this.stop(); - } return Completable.complete(); } } @@ -127,11 +130,11 @@ class LongPollingTransport implements Transport { if (!this.active) { return Completable.error(new Exception("Cannot send unless the transport is active.")); } - return this.updateHeaderToken().flatMapCompletable((x) -> { + return this.updateHeaderToken().andThen(Completable.defer(() -> { HttpRequest request = new HttpRequest(); request.addHeaders(headers); - return Completable.fromSingle(this.client.post(url, message, request)); - }); + return this.client.post(url, message, request).ignoreElement(); + })); } @Override @@ -152,23 +155,31 @@ class LongPollingTransport implements Transport { @Override public Completable stop() { - if (!stopCalled.get()) { - this.stopCalled.set(true); + if (stopCalled.compareAndSet(false, true)) { this.active = false; - return this.updateHeaderToken().flatMapCompletable((x) -> { + return this.updateHeaderToken().andThen(Completable.defer(() -> { HttpRequest request = new HttpRequest(); request.addHeaders(headers); - this.pollingClient.delete(this.url, request); - CompletableSubject stopCompletableSubject = CompletableSubject.create(); - return this.receiveLoop.andThen(Completable.defer(() -> { - logger.info("LongPolling transport stopped."); - this.onReceiveThread.shutdown(); - this.threadPool.shutdown(); - this.onClose.invoke(this.closeError); - return Completable.complete(); - })).subscribeWith(stopCompletableSubject); + return this.pollingClient.delete(this.url, request).ignoreElement() + .andThen(receiveLoop) + .doOnComplete(() -> { + cleanup(this.closeError); + }); + })).doOnError(e -> { + cleanup(e.getMessage()); }); } return Completable.complete(); } + + private void cleanup(String error) { + logger.info("LongPolling transport stopped."); + if (this.onReceiveThread != null) { + this.onReceiveThread.shutdown(); + } + if (this.threadPool != null) { + this.threadPool.shutdown(); + } + this.onClose.invoke(error); + } } diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index 8bc833239e..efee55e3ef 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -2913,7 +2913,8 @@ class HubConnectionTest { } assertTrue(close.blockingAwait(5, TimeUnit.SECONDS)); return Single.just(new HttpResponse(204, "", TestUtils.emptyByteBuffer)); - }); + }) + .on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")))); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") @@ -2969,6 +2970,135 @@ class HubConnectionTest { assertEquals(exception.getMessage(), "There were no compatible transports on the server."); } + @Test + public void LongPollingTransportAccessTokenProviderThrowsOnInitialPoll() { + TestHttpClient client = new TestHttpClient() + .on("POST", (req) -> { + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))); + }) + .on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")))) + .on("GET", (req) -> { + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR))); + }); + + AtomicInteger accessTokenCount = new AtomicInteger(0); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransport(TransportEnum.LONG_POLLING) + .withHttpClient(client) + .withAccessTokenProvider(Single.defer(() -> { + if (accessTokenCount.getAndIncrement() < 1) { + return Single.just(""); + } + return Single.error(new RuntimeException("Error from accessTokenProvider")); + })) + .build(); + + try { + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(false); + } catch (RuntimeException ex) { + assertEquals("Error from accessTokenProvider", ex.getMessage()); + } + } + + @Test + public void LongPollingTransportAccessTokenProviderThrowsAfterHandshakeClosesConnection() { + AtomicInteger requestCount = new AtomicInteger(0); + CompletableSubject blockGet = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")))) + .on("GET", (req) -> { + if (requestCount.getAndIncrement() > 1) { + blockGet.blockingAwait(); + } + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR))); + }) + .on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> { + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))); + }); + + AtomicInteger accessTokenCount = new AtomicInteger(0); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransport(TransportEnum.LONG_POLLING) + .withHttpClient(client) + .withAccessTokenProvider(Single.defer(() -> { + if (accessTokenCount.getAndIncrement() < 5) { + return Single.just(""); + } + return Single.error(new RuntimeException("Error from accessTokenProvider")); + })) + .build(); + + CompletableSubject closed = CompletableSubject.create(); + hubConnection.onClosed((e) -> { + closed.onComplete(); + }); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + blockGet.onComplete(); + + closed.timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + } + + @Test + public void LongPollingTransportAccessTokenProviderThrowsDuringStop() { + AtomicInteger requestCount = new AtomicInteger(0); + CompletableSubject blockGet = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("POST", "http://example.com/negotiate?negotiateVersion=1", + (req) -> Single.just(new HttpResponse(200, "", + TestUtils.stringToByteBuffer("{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")))) + .on("GET", (req) -> { + if (requestCount.getAndIncrement() > 1) { + blockGet.blockingAwait(); + } + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("{}" + RECORD_SEPARATOR))); + }) + .on("POST", "http://example.com?id=bVOiRPG8-6YiJ6d7ZcTOVQ", (req) -> { + return Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer(""))); + }); + + AtomicInteger accessTokenCount = new AtomicInteger(0); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransport(TransportEnum.LONG_POLLING) + .withHttpClient(client) + .withAccessTokenProvider(Single.defer(() -> { + if (accessTokenCount.getAndIncrement() < 5) { + return Single.just(""); + } + return Single.error(new RuntimeException("Error from accessTokenProvider")); + })) + .build(); + + CompletableSubject closed = CompletableSubject.create(); + hubConnection.onClosed((e) -> { + closed.onComplete(); + }); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + + try { + hubConnection.stop().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(false); + } catch (Exception ex) { + assertEquals("Error from accessTokenProvider", ex.getMessage()); + } + blockGet.onComplete(); + closed.timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + } + @Test public void receivingServerSentEventsTransportFromNegotiateFails() { TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate?negotiateVersion=1", @@ -3265,6 +3395,21 @@ class HubConnectionTest { assertEquals("Bearer secondRedirectToken", token.get()); } + @Test + public void ErrorInAccessTokenProviderThrowsFromStart() { + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withAccessTokenProvider(Single.defer(() -> Single.error(new RuntimeException("Error from accessTokenProvider")))) + .build(); + + try { + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(false); + } catch (RuntimeException ex) { + assertEquals("Error from accessTokenProvider", ex.getMessage()); + } + } + @Test public void connectionTimesOutIfServerDoesNotSendMessage() { HubConnection hubConnection = TestUtils.createHubConnection("http://example.com"); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java index ae9cb17759..0e65390d39 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java @@ -86,7 +86,8 @@ public class LongPollingTransportTest { return Single.just(new HttpResponse(200, "", TestUtils.emptyByteBuffer)); } return Single.just(new HttpResponse(999, "", TestUtils.emptyByteBuffer)); - }); + }) + .on("DELETE", (req) -> Single.just(new HttpResponse(200, "", TestUtils.stringToByteBuffer("")))); Map headers = new HashMap<>(); LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));