From 3d3ad96206fc504dba954c0c6403139e3438263e Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Wed, 13 Feb 2019 10:27:07 -0800 Subject: [PATCH] SignalR Java Client LongPolling Transport (#6856) --- .../microsoft/signalr/DefaultHttpClient.java | 125 ++++--- .../com/microsoft/signalr/HttpClient.java | 12 +- .../signalr/HttpHubConnectionBuilder.java | 17 +- .../com/microsoft/signalr/HubConnection.java | 84 +++-- .../microsoft/signalr/JsonHubProtocol.java | 5 +- .../signalr/LongPollingTransport.java | 168 +++++++++ .../com/microsoft/signalr/TransportEnum.java | 10 + .../microsoft/signalr/HubConnectionTest.java | 66 +++- .../signalr/LongPollingTransportTest.java | 330 ++++++++++++++++++ .../com/microsoft/signalr/TestHttpClient.java | 12 +- .../java/com/microsoft/signalr/TestUtils.java | 4 +- .../signalr/WebSocketTransportTest.java | 10 + 12 files changed, 733 insertions(+), 110 deletions(-) create mode 100644 src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java create mode 100644 src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/TransportEnum.java create mode 100644 src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/DefaultHttpClient.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/DefaultHttpClient.java index d8eee27c87..a264f5d09f 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/DefaultHttpClient.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/DefaultHttpClient.java @@ -8,79 +8,96 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import io.reactivex.Single; import io.reactivex.subjects.SingleSubject; -import okhttp3.Call; -import okhttp3.Callback; -import okhttp3.Cookie; -import okhttp3.CookieJar; -import okhttp3.HttpUrl; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.RequestBody; -import okhttp3.Response; -import okhttp3.ResponseBody; +import okhttp3.*; final class DefaultHttpClient extends HttpClient { - private final OkHttpClient client; + private OkHttpClient client = null; public DefaultHttpClient() { - this.client = new OkHttpClient.Builder().cookieJar(new CookieJar() { - private List cookieList = new ArrayList<>(); - private Lock cookieLock = new ReentrantLock(); + this(0, null); + } - @Override - public void saveFromResponse(HttpUrl url, List cookies) { - cookieLock.lock(); - try { - for (Cookie cookie : cookies) { - boolean replacedCookie = false; - for (int i = 0; i < cookieList.size(); i++) { - Cookie innerCookie = cookieList.get(i); - if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) { - // We have a new cookie that matches an older one so we replace the older one. - cookieList.set(i, innerCookie); - replacedCookie = true; - break; + public DefaultHttpClient cloneWithTimeOut(int timeoutInMilliseconds) { + OkHttpClient newClient = client.newBuilder().readTimeout(timeoutInMilliseconds, TimeUnit.MILLISECONDS) + .build(); + return new DefaultHttpClient(timeoutInMilliseconds, newClient); + } + + public DefaultHttpClient(int timeoutInMilliseconds, OkHttpClient client) { + if (client != null) { + this.client = client; + } else { + + OkHttpClient.Builder builder = new OkHttpClient.Builder().cookieJar(new CookieJar() { + private List cookieList = new ArrayList<>(); + private Lock cookieLock = new ReentrantLock(); + + @Override + public void saveFromResponse(HttpUrl url, List cookies) { + cookieLock.lock(); + try { + for (Cookie cookie : cookies) { + boolean replacedCookie = false; + for (int i = 0; i < cookieList.size(); i++) { + Cookie innerCookie = cookieList.get(i); + if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) { + // We have a new cookie that matches an older one so we replace the older one. + cookieList.set(i, innerCookie); + replacedCookie = true; + break; + } + } + if (!replacedCookie) { + cookieList.add(cookie); } } - if (!replacedCookie) { - cookieList.add(cookie); - } + } finally { + cookieLock.unlock(); } - } finally { - cookieLock.unlock(); } - } - @Override - public List loadForRequest(HttpUrl url) { - cookieLock.lock(); - try { - List matchedCookies = new ArrayList<>(); - List expiredCookies = new ArrayList<>(); - for (Cookie cookie : cookieList) { - if (cookie.expiresAt() < System.currentTimeMillis()) { - expiredCookies.add(cookie); - } else if (cookie.matches(url)) { - matchedCookies.add(cookie); + @Override + public List loadForRequest(HttpUrl url) { + cookieLock.lock(); + try { + List matchedCookies = new ArrayList<>(); + List expiredCookies = new ArrayList<>(); + for (Cookie cookie : cookieList) { + if (cookie.expiresAt() < System.currentTimeMillis()) { + expiredCookies.add(cookie); + } else if (cookie.matches(url)) { + matchedCookies.add(cookie); + } } - } - cookieList.removeAll(expiredCookies); - return matchedCookies; - } finally { - cookieLock.unlock(); + cookieList.removeAll(expiredCookies); + return matchedCookies; + } finally { + cookieLock.unlock(); + } } + }); + + if (timeoutInMilliseconds > 0) { + builder.readTimeout(timeoutInMilliseconds, TimeUnit.MILLISECONDS); } - }).build(); + this.client = builder.build(); + } } @Override public Single send(HttpRequest httpRequest) { + return send(httpRequest, null); + } + + @Override + public Single send(HttpRequest httpRequest, String bodyContent) { Request.Builder requestBuilder = new Request.Builder().url(httpRequest.getUrl()); switch (httpRequest.getMethod()) { @@ -88,7 +105,13 @@ final class DefaultHttpClient extends HttpClient { requestBuilder.get(); break; case "POST": - RequestBody body = RequestBody.create(null, new byte[]{}); + RequestBody body; + if (bodyContent != null) { + body = RequestBody.create(MediaType.parse("text/plain"), bodyContent); + } else { + body = RequestBody.create(null, new byte[]{}); + } + requestBuilder.post(body); break; case "DELETE": diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpClient.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpClient.java index 9b7a1b4fde..534457367f 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpClient.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpClient.java @@ -95,6 +95,12 @@ abstract class HttpClient { return this.send(request); } + public Single post(String url, String body, HttpRequest options) { + options.setUrl(url); + options.setMethod("POST"); + return this.send(options, body); + } + public Single post(String url, HttpRequest options) { options.setUrl(url); options.setMethod("POST"); @@ -116,5 +122,9 @@ abstract class HttpClient { public abstract Single send(HttpRequest request); + public abstract Single send(HttpRequest request, String body); + public abstract WebSocketWrapper createWebSocket(String url, Map headers); -} \ No newline at end of file + + public abstract HttpClient cloneWithTimeOut(int timeoutInMilliseconds); +} diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java index e2a8cceccc..d91e382ed9 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java @@ -19,19 +19,26 @@ public class HttpHubConnectionBuilder { private Single accessTokenProvider; private long handshakeResponseTimeout = 0; private Map headers; + private TransportEnum transportEnum; HttpHubConnectionBuilder(String url) { this.url = url; } + //For testing purposes. The Transport interface isn't public. + HttpHubConnectionBuilder withTransportImplementation(Transport transport) { + this.transport = transport; + return this; + } + /** - * Sets the transport to be used by the {@link HubConnection}. + * Sets the transport type to indicate which transport to be used by the {@link HubConnection}. * - * @param transport The transport to be used. + * @param transportEnum The type of transport to be used. * @return This instance of the HttpHubConnectionBuilder. */ - HttpHubConnectionBuilder withTransport(Transport transport) { - this.transport = transport; + public HttpHubConnectionBuilder withTransport(TransportEnum transportEnum) { + this.transportEnum = transportEnum; return this; } @@ -112,6 +119,6 @@ public class HttpHubConnectionBuilder { * @return A new instance of {@link HubConnection}. */ public HubConnection build() { - return new HubConnection(url, transport, skipNegotiate, httpClient, accessTokenProvider, handshakeResponseTimeout, headers); + return new HubConnection(url, transport, skipNegotiate, httpClient, accessTokenProvider, handshakeResponseTimeout, headers, transportEnum); } } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index d8c9b931ed..6bac938c72 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -3,14 +3,7 @@ package com.microsoft.signalr; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Date; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Timer; -import java.util.TimerTask; +import java.util.*; import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -46,7 +39,7 @@ public class HubConnection { private Single accessTokenProvider; private final Map headers = new HashMap<>(); private ConnectionState connectionState = null; - private final HttpClient httpClient; + private HttpClient httpClient; private String stopError; private Timer pingTimer = null; private final AtomicLong nextServerTimeout = new AtomicLong(); @@ -56,6 +49,7 @@ public class HubConnection { private long tickRate = 1000; private CompletableSubject handshakeResponseSubject; private long handshakeResponseTimeout = 15*1000; + private TransportEnum transportEnum = TransportEnum.ALL; private final Logger logger = LoggerFactory.getLogger(HubConnection.class); /** @@ -100,7 +94,7 @@ public class HubConnection { } HubConnection(String url, Transport transport, boolean skipNegotiate, HttpClient httpClient, - Single accessTokenProvider, long handshakeResponseTimeout, Map headers) { + Single accessTokenProvider, long handshakeResponseTimeout, Map headers, TransportEnum transportEnum) { if (url == null || url.isEmpty()) { throw new IllegalArgumentException("A valid url is required."); } @@ -122,6 +116,8 @@ public class HubConnection { if (transport != null) { this.transport = transport; + } else if (transportEnum != null) { + this.transportEnum = transportEnum; } if (handshakeResponseTimeout > 0) { @@ -301,7 +297,13 @@ public class HubConnection { negotiate.flatMapCompletable(url -> { logger.debug("Starting HubConnection."); if (transport == null) { - transport = new WebSocketTransport(headers, httpClient); + switch (transportEnum) { + case LONG_POLLING: + transport = new LongPollingTransport(headers, httpClient, accessTokenProvider); + break; + default: + transport = new WebSocketTransport(headers, httpClient); + } } transport.setOnReceive(this.callback); @@ -311,37 +313,20 @@ public class HubConnection { String handshake = HandshakeProtocol.createHandshakeRequestMessage( new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); + connectionState = new ConnectionState(this); + return transport.send(handshake).andThen(Completable.defer(() -> { timeoutHandshakeResponse(handshakeResponseTimeout, TimeUnit.MILLISECONDS); return handshakeResponseSubject.andThen(Completable.defer(() -> { hubConnectionStateLock.lock(); try { - connectionState = new ConnectionState(this); hubConnectionState = HubConnectionState.CONNECTED; logger.info("HubConnection started."); - resetServerTimeout(); - this.pingTimer = new Timer(); - this.pingTimer.schedule(new TimerTask() { - @Override - public void run() { - try { - if (System.currentTimeMillis() > nextServerTimeout.get()) { - stop("Server timeout elapsed without receiving a message from the server."); - return; - } - - if (System.currentTimeMillis() > nextPingActivation.get()) { - sendHubMessage(PingMessage.getInstance()); - } - } catch (Exception e) { - logger.warn("Error sending ping: {}.", e.getMessage()); - // The connection is probably in a bad or closed state now, cleanup the timer so - // it stops triggering - pingTimer.cancel(); - } - } - }, new Date(0), tickRate); + //Don't send pings if we're using long polling. + if (transportEnum != TransportEnum.LONG_POLLING) { + activatePingTimer(); + } } finally { hubConnectionStateLock.unlock(); } @@ -356,6 +341,30 @@ public class HubConnection { return start; } + private void activatePingTimer() { + this.pingTimer = new Timer(); + this.pingTimer.schedule(new TimerTask() { + @Override + public void run() { + try { + if (System.currentTimeMillis() > nextServerTimeout.get()) { + stop("Server timeout elapsed without receiving a message from the server."); + return; + } + + if (System.currentTimeMillis() > nextPingActivation.get()) { + sendHubMessage(PingMessage.getInstance()); + } + } catch (Exception e) { + logger.warn("Error sending ping: {}.", e.getMessage()); + // The connection is probably in a bad or closed state now, cleanup the timer so + // it stops triggering + pingTimer.cancel(); + } + } + }, new Date(0), tickRate); + } + private Single startNegotiate(String url, int negotiateAttempts) { if (hubConnectionState != HubConnectionState.DISCONNECTED) { return Single.just(null); @@ -367,7 +376,10 @@ public class HubConnection { } if (response.getRedirectUrl() == null) { - if (!response.getAvailableTransports().contains("WebSockets")) { + Set transports = response.getAvailableTransports(); + if ((this.transportEnum == TransportEnum.ALL && !(transports.contains("WebSockets") || transports.contains("LongPolling"))) || + (this.transportEnum == TransportEnum.WEBSOCKETS && !transports.contains("WebSockets")) || + (this.transportEnum == TransportEnum.LONG_POLLING && !transports.contains("LongPolling"))) { throw new RuntimeException("There were no compatible transports on the server."); } @@ -563,7 +575,7 @@ public class HubConnection { } else { logger.debug("Sending {} message.", message.getMessageType().name()); } - transport.send(serializedMessage); + transport.send(serializedMessage).subscribeWith(CompletableSubject.create()); resetKeepAlive(); } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/JsonHubProtocol.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/JsonHubProtocol.java index 57bb5d1a7b..fa4d321273 100644 --- a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/JsonHubProtocol.java +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/JsonHubProtocol.java @@ -37,7 +37,10 @@ class JsonHubProtocol implements HubProtocol { @Override public HubMessage[] parseMessages(String payload, InvocationBinder binder) { - if (payload != null && !payload.substring(payload.length() - 1).equals(RECORD_SEPARATOR)) { + if (payload.length() == 0) { + return new HubMessage[]{}; + } + if (!(payload.substring(payload.length() - 1).equals(RECORD_SEPARATOR))) { throw new RuntimeException("Message is incomplete."); } diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java new file mode 100644 index 0000000000..087867bace --- /dev/null +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/LongPollingTransport.java @@ -0,0 +1,168 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.signalr; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import io.reactivex.Completable; +import io.reactivex.Single; +import io.reactivex.subjects.CompletableSubject; + +class LongPollingTransport implements Transport { + private OnReceiveCallBack onReceiveCallBack; + private TransportOnClosedCallback onClose; + private String url; + private final HttpClient client; + private final HttpClient pollingClient; + private final Map headers; + private static final int POLL_TIMEOUT = 100*1000; + private volatile Boolean active = false; + private String pollUrl; + private String closeError; + private Single accessTokenProvider; + private CompletableSubject receiveLoop = CompletableSubject.create(); + private ExecutorService threadPool; + private AtomicBoolean stopCalled = new AtomicBoolean(false); + + private final Logger logger = LoggerFactory.getLogger(LongPollingTransport.class); + + public LongPollingTransport(Map headers, HttpClient client, Single accessTokenProvider) { + this.headers = headers; + this.client = client; + this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT); + this.accessTokenProvider = accessTokenProvider; + } + + //Package private active accessor for testing. + boolean isActive() { + return this.active; + } + + private Single updateHeaderToken() { + return this.accessTokenProvider.flatMap((token) -> { + if (!token.isEmpty()) { + this.headers.put("Authorization", "Bearer " + token); + } + return Single.just(""); + }); + } + + @Override + public Completable start(String url) { + this.active = true; + logger.debug("Starting LongPolling transport."); + this.url = url; + pollUrl = url + "&_=" + System.currentTimeMillis(); + logger.debug("Polling {}.", pollUrl); + return this.updateHeaderToken().flatMapCompletable((r) -> { + HttpRequest request = new HttpRequest(); + request.addHeaders(headers); + return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { + if (response.getStatusCode() != 200) { + logger.error("Unexpected response code {}.", response.getStatusCode()); + this.active = false; + return Completable.error(new Exception("Failed to connect.")); + } else { + this.active = true; + } + this.threadPool = Executors.newCachedThreadPool(); + threadPool.execute(() -> poll(url).subscribeWith(receiveLoop)); + + return Completable.complete(); + }); + }); + } + + private Completable poll(String url) { + if (this.active) { + pollUrl = url + "&_=" + System.currentTimeMillis(); + logger.debug("Polling {}.", pollUrl); + return this.updateHeaderToken().flatMapCompletable((x) -> { + HttpRequest request = new HttpRequest(); + request.addHeaders(headers); + Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> { + if (response.getStatusCode() == 204) { + logger.info("LongPolling transport terminated by server."); + this.active = false; + } else if (response.getStatusCode() != 200) { + logger.error("Unexpected response code {}.", response.getStatusCode()); + this.active = false; + this.closeError = "Unexpected response code " + response.getStatusCode() + "."; + } else { + if (response.getContent() != null) { + logger.debug("Message received."); + threadPool.execute(() -> this.onReceive(response.getContent())); + } else { + logger.debug("Poll timed out, reissuing."); + } + } + return poll(url); + }); + + return pollingCompletable; + }); + } else { + logger.debug("Long Polling transport polling complete."); + receiveLoop.onComplete(); + if (!stopCalled.get()) { + return this.stop(); + } + return Completable.complete(); + } + } + + @Override + public Completable send(String message) { + if (!this.active) { + return Completable.error(new Exception("Cannot send unless the transport is active.")); + } + return this.updateHeaderToken().flatMapCompletable((x) -> { + HttpRequest request = new HttpRequest(); + request.addHeaders(headers); + return Completable.fromSingle(this.client.post(url, message, request)); + }); + } + + @Override + public void setOnReceive(OnReceiveCallBack callback) { + this.onReceiveCallBack = callback; + } + + @Override + public void onReceive(String message) { + this.onReceiveCallBack.invoke(message); + logger.debug("OnReceived callback has been invoked."); + } + + @Override + public void setOnClose(TransportOnClosedCallback onCloseCallback) { + this.onClose = onCloseCallback; + } + + @Override + public Completable stop() { + if (!stopCalled.get()) { + this.stopCalled.set(true); + this.active = false; + return this.updateHeaderToken().flatMapCompletable((x) -> { + HttpRequest request = new HttpRequest(); + request.addHeaders(headers); + this.pollingClient.delete(this.url, request); + CompletableSubject stopCompletableSubject = CompletableSubject.create(); + return this.receiveLoop.andThen(Completable.defer(() -> { + logger.info("LongPolling transport stopped."); + this.onClose.invoke(this.closeError); + return Completable.complete(); + })).subscribeWith(stopCompletableSubject); + }); + } + return Completable.complete(); + } +} diff --git a/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/TransportEnum.java b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/TransportEnum.java new file mode 100644 index 0000000000..129ced21b5 --- /dev/null +++ b/src/SignalR/clients/java/signalr/src/main/java/com/microsoft/signalr/TransportEnum.java @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.signalr; + +public enum TransportEnum { + ALL, + WEBSOCKETS, + LONG_POLLING +} diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index a5590d0381..ea6e4742f7 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -8,7 +8,6 @@ import static org.junit.jupiter.api.Assertions.*; import java.util.Iterator; import java.util.List; import java.util.concurrent.CancellationException; -import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; @@ -66,7 +65,7 @@ class HubConnectionTest { public void checkHubConnectionStateNoHandShakeResponse() { MockTransport mockTransport = new MockTransport(false); HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") - .withTransport(mockTransport) + .withTransportImplementation(mockTransport) .withHttpClient(new TestHttpClient()) .shouldSkipNegotiate(true) .withHandshakeResponseTimeout(100) @@ -1179,7 +1178,7 @@ class HubConnectionTest { } @Test - public void afterSuccessfulNegotiateConnectsWithTransport() { + public void afterSuccessfulNegotiateConnectsWithWebsocketsTransport() { TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", (req) -> Single.just(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" @@ -1188,7 +1187,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .build(); @@ -1199,6 +1198,47 @@ class HubConnectionTest { assertEquals("{\"protocol\":\"json\",\"version\":1}" + RECORD_SEPARATOR, sentMessages[0]); } + @Test + public void afterSuccessfulNegotiateConnectsWithLongPollingTransport() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait(); + + String[] sentMessages = transport.getSentMessages(); + assertEquals(1, sentMessages.length); + assertEquals("{\"protocol\":\"json\",\"version\":1}" + RECORD_SEPARATOR, sentMessages[0]); + } + + @Test + public void receivingServerSentEventsTransportFromNegotiateFails() { + TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", + (req) -> Single.just(new HttpResponse(200, "", + "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + + "availableTransports\":[{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}]}"))); + + MockTransport transport = new MockTransport(true); + HubConnection hubConnection = HubConnectionBuilder + .create("http://example.com") + .withTransportImplementation(transport) + .withHttpClient(client) + .build(); + + RuntimeException exception = assertThrows(RuntimeException.class, + () -> hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait()); + + assertEquals(exception.getMessage(), "There were no compatible transports on the server."); + } + @Test public void negotiateThatReturnsErrorThrowsFromStart() { TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", @@ -1208,7 +1248,7 @@ class HubConnectionTest { HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withHttpClient(client) - .withTransport(transport) + .withTransportImplementation(transport) .build(); RuntimeException exception = assertThrows(RuntimeException.class, @@ -1227,7 +1267,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .build(); @@ -1250,7 +1290,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .withAccessTokenProvider(Single.just("secretToken")) .build(); @@ -1275,7 +1315,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .withAccessTokenProvider(Single.just("secretToken")) .build(); @@ -1335,7 +1375,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(); HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .withHeader("ExampleHeader", "ExampleValue") .build(); @@ -1360,7 +1400,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(); HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .withHeader("ExampleHeader", "ExampleValue") .withHeader("ExampleHeader", "New Value") @@ -1377,7 +1417,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .shouldSkipNegotiate(true) .build(); @@ -1401,7 +1441,7 @@ class HubConnectionTest { HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(mockTransport) + .withTransportImplementation(mockTransport) .withHttpClient(client) .build(); @@ -1424,7 +1464,7 @@ class HubConnectionTest { MockTransport transport = new MockTransport(); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .build(); diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java new file mode 100644 index 0000000000..f2a0076dd6 --- /dev/null +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/LongPollingTransportTest.java @@ -0,0 +1,330 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.signalr; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.Test; + +import io.reactivex.Single; +import io.reactivex.subjects.CompletableSubject; + +public class LongPollingTransportTest { + + @Test + public void LongPollingFailsToConnectWith404Response() { + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> Single.just(new HttpResponse(404, "", ""))); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + Throwable exception = assertThrows(RuntimeException.class, () -> transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait()); + assertEquals(Exception.class, exception.getCause().getClass()); + assertEquals("Failed to connect.", exception.getCause().getMessage()); + assertFalse(transport.isActive()); + } + + @Test + public void LongPollingTransportCantSendBeforeStart() { + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> Single.just(new HttpResponse(404, "", ""))); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + Throwable exception = assertThrows(RuntimeException.class, () -> transport.send("First").timeout(1, TimeUnit.SECONDS).blockingAwait()); + assertEquals(Exception.class, exception.getCause().getClass()); + assertEquals("Cannot send unless the transport is active.", exception.getCause().getMessage()); + assertFalse(transport.isActive()); + } + + @Test + public void StatusCode204StopsLongPollingTriggersOnClosed() { + AtomicBoolean firstPoll = new AtomicBoolean(true); + CompletableSubject block = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (firstPoll.get()) { + firstPoll.set(false); + return Single.just(new HttpResponse(200, "", "")); + } + return Single.just(new HttpResponse(204, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + AtomicBoolean onClosedRan = new AtomicBoolean(false); + transport.setOnClose((error) -> { + onClosedRan.set(true); + block.onComplete(); + }); + + assertFalse(onClosedRan.get()); + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(block.blockingAwait(1, TimeUnit.SECONDS)); + assertTrue(onClosedRan.get()); + assertFalse(transport.isActive()); + } + + @Test + public void LongPollingFailsWhenReceivingUnexpectedErrorCode() { + AtomicBoolean firstPoll = new AtomicBoolean(true); + CompletableSubject blocker = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (firstPoll.get()) { + firstPoll.set(false); + return Single.just(new HttpResponse(200, "", "")); + } + return Single.just(new HttpResponse(999, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + AtomicBoolean onClosedRan = new AtomicBoolean(false); + transport.setOnClose((error) -> { + onClosedRan.set(true); + assertEquals("Unexpected response code 999.", error); + blocker.onComplete(); + }); + + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(blocker.blockingAwait(1, TimeUnit.SECONDS)); + assertFalse(transport.isActive()); + assertTrue(onClosedRan.get()); + } + + @Test + public void CanSetAndTriggerOnReceive() { + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> Single.just(new HttpResponse(200, "", ""))); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + + AtomicBoolean onReceivedRan = new AtomicBoolean(false); + transport.setOnReceive((message) -> { + onReceivedRan.set(true); + assertEquals("TEST", message); + }); + + // The transport doesn't need to be active to trigger onReceive for the case + // when we are handling the last outstanding poll. + transport.onReceive("TEST"); + assertTrue(onReceivedRan.get()); + } + + @Test + public void LongPollingTransportOnReceiveGetsCalled() { + AtomicInteger requestCount = new AtomicInteger(); + CompletableSubject block = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (requestCount.get() == 0) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "")); + } else if (requestCount.get() == 1) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "TEST")); + } + + return Single.just(new HttpResponse(204, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + + AtomicBoolean onReceiveCalled = new AtomicBoolean(false); + AtomicReference message = new AtomicReference<>(); + transport.setOnReceive((msg -> { + onReceiveCalled.set(true); + message.set(msg); + block.onComplete(); + }) ); + + transport.setOnClose((error) -> {}); + + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(block.blockingAwait(1,TimeUnit.SECONDS)); + assertTrue(onReceiveCalled.get()); + assertEquals("TEST", message.get()); + } + + @Test + public void LongPollingTransportOnReceiveGetsCalledMultipleTimes() { + AtomicInteger requestCount = new AtomicInteger(); + CompletableSubject blocker = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (requestCount.get() == 0) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "")); + } else if (requestCount.get() == 1) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "FIRST")); + } else if (requestCount.get() == 2) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "SECOND")); + } + + return Single.just(new HttpResponse(204, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + + AtomicBoolean onReceiveCalled = new AtomicBoolean(false); + AtomicReference message = new AtomicReference<>(""); + AtomicInteger messageCount = new AtomicInteger(); + transport.setOnReceive((msg) -> { + onReceiveCalled.set(true); + message.set(message.get() + msg); + if (messageCount.incrementAndGet() == 2) { + blocker.onComplete(); + } + }); + + transport.setOnClose((error) -> {}); + + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(blocker.blockingAwait(1, TimeUnit.SECONDS)); + assertTrue(onReceiveCalled.get()); + assertEquals("FIRSTSECOND", message.get()); + } + + @Test + public void LongPollingTransportSendsHeaders() { + AtomicInteger requestCount = new AtomicInteger(); + AtomicReference headerValue = new AtomicReference<>(); + CompletableSubject close = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (requestCount.get() == 0) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "")); + } + assertTrue(close.blockingAwait(1, TimeUnit.SECONDS)); + return Single.just(new HttpResponse(204, "", "")); + }).on("POST", (req) -> { + assertFalse(req.getHeaders().isEmpty()); + headerValue.set(req.getHeaders().get("KEY")); + return Single.just(new HttpResponse(200, "", "")); + }); + + Map headers = new HashMap<>(); + headers.put("KEY", "VALUE"); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + transport.setOnClose((error) -> {}); + + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(transport.send("TEST").blockingAwait(1, TimeUnit.SECONDS)); + close.onComplete(); + assertEquals(headerValue.get(), "VALUE"); + } + + @Test + public void LongPollingTransportSetsAuthorizationHeader() { + AtomicInteger requestCount = new AtomicInteger(); + AtomicReference headerValue = new AtomicReference<>(); + CompletableSubject close = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (requestCount.get() == 0) { + requestCount.incrementAndGet(); + return Single.just(new HttpResponse(200, "", "")); + } + assertTrue(close.blockingAwait(1, TimeUnit.SECONDS)); + return Single.just(new HttpResponse(204, "", "")); + }) + .on("POST", (req) -> { + assertFalse(req.getHeaders().isEmpty()); + headerValue.set(req.getHeaders().get("Authorization")); + return Single.just(new HttpResponse(200, "", "")); + }); + + Map headers = new HashMap<>(); + Single tokenProvider = Single.just("TOKEN"); + LongPollingTransport transport = new LongPollingTransport(headers, client, tokenProvider); + transport.setOnClose((error) -> {}); + + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(transport.send("TEST").blockingAwait(1, TimeUnit.SECONDS)); + assertEquals(headerValue.get(), "Bearer TOKEN"); + close.onComplete(); + } + + @Test + public void After204StopDoesNotTriggerOnClose() { + AtomicBoolean firstPoll = new AtomicBoolean(true); + CompletableSubject block = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (firstPoll.get()) { + firstPoll.set(false); + return Single.just(new HttpResponse(200, "", "")); + } + return Single.just(new HttpResponse(204, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + AtomicBoolean onClosedRan = new AtomicBoolean(false); + AtomicInteger onCloseCount = new AtomicInteger(0); + transport.setOnClose((error) -> { + onClosedRan.set(true); + onCloseCount.incrementAndGet(); + block.onComplete(); + }); + + assertFalse(onClosedRan.get()); + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(block.blockingAwait(1, TimeUnit.SECONDS)); + assertEquals(1, onCloseCount.get()); + assertTrue(onClosedRan.get()); + assertFalse(transport.isActive()); + + assertTrue(transport.stop().blockingAwait(1, TimeUnit.SECONDS)); + assertEquals(1, onCloseCount.get()); + } + + @Test + public void StoppingTransportRunsCloseHandlersOnce() { + AtomicBoolean firstPoll = new AtomicBoolean(true); + CompletableSubject block = CompletableSubject.create(); + TestHttpClient client = new TestHttpClient() + .on("GET", (req) -> { + if (firstPoll.get()) { + firstPoll.set(false); + return Single.just(new HttpResponse(200, "", "")); + } else { + assertTrue(block.blockingAwait(1, TimeUnit.SECONDS)); + return Single.just(new HttpResponse(204, "", "")); + } + }) + .on("DELETE", (req) ->{ + //Unblock the last poll when we sent the DELETE request. + block.onComplete(); + return Single.just(new HttpResponse(200, "", "")); + }); + + Map headers = new HashMap<>(); + LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just("")); + AtomicInteger onCloseCount = new AtomicInteger(0); + transport.setOnClose((error) -> { + onCloseCount.incrementAndGet(); + }); + + assertEquals(0, onCloseCount.get()); + transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait(); + assertTrue(transport.stop().blockingAwait(1, TimeUnit.SECONDS)); + assertEquals(1, onCloseCount.get()); + assertFalse(transport.isActive()); + } +} diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestHttpClient.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestHttpClient.java index 7c95f25c00..eea0453504 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestHttpClient.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestHttpClient.java @@ -22,6 +22,11 @@ class TestHttpClient extends HttpClient { @Override public Single send(HttpRequest request) { + return send(request, null); + } + + @Override + public Single send(HttpRequest request, String body) { this.sentRequests.add(request); return this.handler.invoke(request); } @@ -66,7 +71,12 @@ class TestHttpClient extends HttpClient { throw new RuntimeException("WebSockets isn't supported in testing currently."); } + @Override + public HttpClient cloneWithTimeOut(int timeoutInMilliseconds) { + return this; + } + interface TestHttpRequestHandler { Single invoke(HttpRequest request); } -} \ No newline at end of file +} diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java index 9026c392b6..795df060e4 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java @@ -14,10 +14,10 @@ class TestUtils { static HubConnection createHubConnection(String url, Transport transport, boolean skipNegotiate, HttpClient client) { HttpHubConnectionBuilder builder = HubConnectionBuilder.create(url) - .withTransport(transport) + .withTransportImplementation(transport) .withHttpClient(client) .shouldSkipNegotiate(skipNegotiate); return builder.build(); } -} \ No newline at end of file +} diff --git a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportTest.java b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportTest.java index a68e69195d..a3a3595550 100644 --- a/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportTest.java +++ b/src/SignalR/clients/java/signalr/src/test/java/com/microsoft/signalr/WebSocketTransportTest.java @@ -41,10 +41,20 @@ class WebSocketTransportTest { return null; } + @Override + public Single send(HttpRequest request, String body) { + return null; + } + @Override public WebSocketWrapper createWebSocket(String url, Map headers) { return new TestWrapper(); } + + @Override + public HttpClient cloneWithTimeOut(int timeoutInMilliseconds) { + return null; + } } class TestWrapper extends WebSocketWrapper {