SignalR Java Client LongPolling Transport (#6856)

This commit is contained in:
Mikael Mengistu 2019-02-13 10:27:07 -08:00 committed by GitHub
parent 2ac4619635
commit 3d3ad96206
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 733 additions and 110 deletions

View File

@ -8,79 +8,96 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import io.reactivex.Single;
import io.reactivex.subjects.SingleSubject;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okhttp3.*;
final class DefaultHttpClient extends HttpClient {
private final OkHttpClient client;
private OkHttpClient client = null;
public DefaultHttpClient() {
this.client = new OkHttpClient.Builder().cookieJar(new CookieJar() {
private List<Cookie> cookieList = new ArrayList<>();
private Lock cookieLock = new ReentrantLock();
this(0, null);
}
@Override
public void saveFromResponse(HttpUrl url, List<Cookie> cookies) {
cookieLock.lock();
try {
for (Cookie cookie : cookies) {
boolean replacedCookie = false;
for (int i = 0; i < cookieList.size(); i++) {
Cookie innerCookie = cookieList.get(i);
if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) {
// We have a new cookie that matches an older one so we replace the older one.
cookieList.set(i, innerCookie);
replacedCookie = true;
break;
public DefaultHttpClient cloneWithTimeOut(int timeoutInMilliseconds) {
OkHttpClient newClient = client.newBuilder().readTimeout(timeoutInMilliseconds, TimeUnit.MILLISECONDS)
.build();
return new DefaultHttpClient(timeoutInMilliseconds, newClient);
}
public DefaultHttpClient(int timeoutInMilliseconds, OkHttpClient client) {
if (client != null) {
this.client = client;
} else {
OkHttpClient.Builder builder = new OkHttpClient.Builder().cookieJar(new CookieJar() {
private List<Cookie> cookieList = new ArrayList<>();
private Lock cookieLock = new ReentrantLock();
@Override
public void saveFromResponse(HttpUrl url, List<Cookie> cookies) {
cookieLock.lock();
try {
for (Cookie cookie : cookies) {
boolean replacedCookie = false;
for (int i = 0; i < cookieList.size(); i++) {
Cookie innerCookie = cookieList.get(i);
if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) {
// We have a new cookie that matches an older one so we replace the older one.
cookieList.set(i, innerCookie);
replacedCookie = true;
break;
}
}
if (!replacedCookie) {
cookieList.add(cookie);
}
}
if (!replacedCookie) {
cookieList.add(cookie);
}
} finally {
cookieLock.unlock();
}
} finally {
cookieLock.unlock();
}
}
@Override
public List<Cookie> loadForRequest(HttpUrl url) {
cookieLock.lock();
try {
List<Cookie> matchedCookies = new ArrayList<>();
List<Cookie> expiredCookies = new ArrayList<>();
for (Cookie cookie : cookieList) {
if (cookie.expiresAt() < System.currentTimeMillis()) {
expiredCookies.add(cookie);
} else if (cookie.matches(url)) {
matchedCookies.add(cookie);
@Override
public List<Cookie> loadForRequest(HttpUrl url) {
cookieLock.lock();
try {
List<Cookie> matchedCookies = new ArrayList<>();
List<Cookie> expiredCookies = new ArrayList<>();
for (Cookie cookie : cookieList) {
if (cookie.expiresAt() < System.currentTimeMillis()) {
expiredCookies.add(cookie);
} else if (cookie.matches(url)) {
matchedCookies.add(cookie);
}
}
}
cookieList.removeAll(expiredCookies);
return matchedCookies;
} finally {
cookieLock.unlock();
cookieList.removeAll(expiredCookies);
return matchedCookies;
} finally {
cookieLock.unlock();
}
}
});
if (timeoutInMilliseconds > 0) {
builder.readTimeout(timeoutInMilliseconds, TimeUnit.MILLISECONDS);
}
}).build();
this.client = builder.build();
}
}
@Override
public Single<HttpResponse> send(HttpRequest httpRequest) {
return send(httpRequest, null);
}
@Override
public Single<HttpResponse> send(HttpRequest httpRequest, String bodyContent) {
Request.Builder requestBuilder = new Request.Builder().url(httpRequest.getUrl());
switch (httpRequest.getMethod()) {
@ -88,7 +105,13 @@ final class DefaultHttpClient extends HttpClient {
requestBuilder.get();
break;
case "POST":
RequestBody body = RequestBody.create(null, new byte[]{});
RequestBody body;
if (bodyContent != null) {
body = RequestBody.create(MediaType.parse("text/plain"), bodyContent);
} else {
body = RequestBody.create(null, new byte[]{});
}
requestBuilder.post(body);
break;
case "DELETE":

View File

@ -95,6 +95,12 @@ abstract class HttpClient {
return this.send(request);
}
public Single<HttpResponse> post(String url, String body, HttpRequest options) {
options.setUrl(url);
options.setMethod("POST");
return this.send(options, body);
}
public Single<HttpResponse> post(String url, HttpRequest options) {
options.setUrl(url);
options.setMethod("POST");
@ -116,5 +122,9 @@ abstract class HttpClient {
public abstract Single<HttpResponse> send(HttpRequest request);
public abstract Single<HttpResponse> send(HttpRequest request, String body);
public abstract WebSocketWrapper createWebSocket(String url, Map<String, String> headers);
}
public abstract HttpClient cloneWithTimeOut(int timeoutInMilliseconds);
}

View File

@ -19,19 +19,26 @@ public class HttpHubConnectionBuilder {
private Single<String> accessTokenProvider;
private long handshakeResponseTimeout = 0;
private Map<String, String> headers;
private TransportEnum transportEnum;
HttpHubConnectionBuilder(String url) {
this.url = url;
}
//For testing purposes. The Transport interface isn't public.
HttpHubConnectionBuilder withTransportImplementation(Transport transport) {
this.transport = transport;
return this;
}
/**
* Sets the transport to be used by the {@link HubConnection}.
* Sets the transport type to indicate which transport to be used by the {@link HubConnection}.
*
* @param transport The transport to be used.
* @param transportEnum The type of transport to be used.
* @return This instance of the HttpHubConnectionBuilder.
*/
HttpHubConnectionBuilder withTransport(Transport transport) {
this.transport = transport;
public HttpHubConnectionBuilder withTransport(TransportEnum transportEnum) {
this.transportEnum = transportEnum;
return this;
}
@ -112,6 +119,6 @@ public class HttpHubConnectionBuilder {
* @return A new instance of {@link HubConnection}.
*/
public HubConnection build() {
return new HubConnection(url, transport, skipNegotiate, httpClient, accessTokenProvider, handshakeResponseTimeout, headers);
return new HubConnection(url, transport, skipNegotiate, httpClient, accessTokenProvider, handshakeResponseTimeout, headers, transportEnum);
}
}

View File

@ -3,14 +3,7 @@
package com.microsoft.signalr;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Timer;
import java.util.TimerTask;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
@ -46,7 +39,7 @@ public class HubConnection {
private Single<String> accessTokenProvider;
private final Map<String, String> headers = new HashMap<>();
private ConnectionState connectionState = null;
private final HttpClient httpClient;
private HttpClient httpClient;
private String stopError;
private Timer pingTimer = null;
private final AtomicLong nextServerTimeout = new AtomicLong();
@ -56,6 +49,7 @@ public class HubConnection {
private long tickRate = 1000;
private CompletableSubject handshakeResponseSubject;
private long handshakeResponseTimeout = 15*1000;
private TransportEnum transportEnum = TransportEnum.ALL;
private final Logger logger = LoggerFactory.getLogger(HubConnection.class);
/**
@ -100,7 +94,7 @@ public class HubConnection {
}
HubConnection(String url, Transport transport, boolean skipNegotiate, HttpClient httpClient,
Single<String> accessTokenProvider, long handshakeResponseTimeout, Map<String, String> headers) {
Single<String> accessTokenProvider, long handshakeResponseTimeout, Map<String, String> headers, TransportEnum transportEnum) {
if (url == null || url.isEmpty()) {
throw new IllegalArgumentException("A valid url is required.");
}
@ -122,6 +116,8 @@ public class HubConnection {
if (transport != null) {
this.transport = transport;
} else if (transportEnum != null) {
this.transportEnum = transportEnum;
}
if (handshakeResponseTimeout > 0) {
@ -301,7 +297,13 @@ public class HubConnection {
negotiate.flatMapCompletable(url -> {
logger.debug("Starting HubConnection.");
if (transport == null) {
transport = new WebSocketTransport(headers, httpClient);
switch (transportEnum) {
case LONG_POLLING:
transport = new LongPollingTransport(headers, httpClient, accessTokenProvider);
break;
default:
transport = new WebSocketTransport(headers, httpClient);
}
}
transport.setOnReceive(this.callback);
@ -311,37 +313,20 @@ public class HubConnection {
String handshake = HandshakeProtocol.createHandshakeRequestMessage(
new HandshakeRequestMessage(protocol.getName(), protocol.getVersion()));
connectionState = new ConnectionState(this);
return transport.send(handshake).andThen(Completable.defer(() -> {
timeoutHandshakeResponse(handshakeResponseTimeout, TimeUnit.MILLISECONDS);
return handshakeResponseSubject.andThen(Completable.defer(() -> {
hubConnectionStateLock.lock();
try {
connectionState = new ConnectionState(this);
hubConnectionState = HubConnectionState.CONNECTED;
logger.info("HubConnection started.");
resetServerTimeout();
this.pingTimer = new Timer();
this.pingTimer.schedule(new TimerTask() {
@Override
public void run() {
try {
if (System.currentTimeMillis() > nextServerTimeout.get()) {
stop("Server timeout elapsed without receiving a message from the server.");
return;
}
if (System.currentTimeMillis() > nextPingActivation.get()) {
sendHubMessage(PingMessage.getInstance());
}
} catch (Exception e) {
logger.warn("Error sending ping: {}.", e.getMessage());
// The connection is probably in a bad or closed state now, cleanup the timer so
// it stops triggering
pingTimer.cancel();
}
}
}, new Date(0), tickRate);
//Don't send pings if we're using long polling.
if (transportEnum != TransportEnum.LONG_POLLING) {
activatePingTimer();
}
} finally {
hubConnectionStateLock.unlock();
}
@ -356,6 +341,30 @@ public class HubConnection {
return start;
}
private void activatePingTimer() {
this.pingTimer = new Timer();
this.pingTimer.schedule(new TimerTask() {
@Override
public void run() {
try {
if (System.currentTimeMillis() > nextServerTimeout.get()) {
stop("Server timeout elapsed without receiving a message from the server.");
return;
}
if (System.currentTimeMillis() > nextPingActivation.get()) {
sendHubMessage(PingMessage.getInstance());
}
} catch (Exception e) {
logger.warn("Error sending ping: {}.", e.getMessage());
// The connection is probably in a bad or closed state now, cleanup the timer so
// it stops triggering
pingTimer.cancel();
}
}
}, new Date(0), tickRate);
}
private Single<String> startNegotiate(String url, int negotiateAttempts) {
if (hubConnectionState != HubConnectionState.DISCONNECTED) {
return Single.just(null);
@ -367,7 +376,10 @@ public class HubConnection {
}
if (response.getRedirectUrl() == null) {
if (!response.getAvailableTransports().contains("WebSockets")) {
Set<String> transports = response.getAvailableTransports();
if ((this.transportEnum == TransportEnum.ALL && !(transports.contains("WebSockets") || transports.contains("LongPolling"))) ||
(this.transportEnum == TransportEnum.WEBSOCKETS && !transports.contains("WebSockets")) ||
(this.transportEnum == TransportEnum.LONG_POLLING && !transports.contains("LongPolling"))) {
throw new RuntimeException("There were no compatible transports on the server.");
}
@ -563,7 +575,7 @@ public class HubConnection {
} else {
logger.debug("Sending {} message.", message.getMessageType().name());
}
transport.send(serializedMessage);
transport.send(serializedMessage).subscribeWith(CompletableSubject.create());
resetKeepAlive();
}

View File

@ -37,7 +37,10 @@ class JsonHubProtocol implements HubProtocol {
@Override
public HubMessage[] parseMessages(String payload, InvocationBinder binder) {
if (payload != null && !payload.substring(payload.length() - 1).equals(RECORD_SEPARATOR)) {
if (payload.length() == 0) {
return new HubMessage[]{};
}
if (!(payload.substring(payload.length() - 1).equals(RECORD_SEPARATOR))) {
throw new RuntimeException("Message is incomplete.");
}

View File

@ -0,0 +1,168 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.signalr;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import io.reactivex.Completable;
import io.reactivex.Single;
import io.reactivex.subjects.CompletableSubject;
class LongPollingTransport implements Transport {
private OnReceiveCallBack onReceiveCallBack;
private TransportOnClosedCallback onClose;
private String url;
private final HttpClient client;
private final HttpClient pollingClient;
private final Map<String, String> headers;
private static final int POLL_TIMEOUT = 100*1000;
private volatile Boolean active = false;
private String pollUrl;
private String closeError;
private Single<String> accessTokenProvider;
private CompletableSubject receiveLoop = CompletableSubject.create();
private ExecutorService threadPool;
private AtomicBoolean stopCalled = new AtomicBoolean(false);
private final Logger logger = LoggerFactory.getLogger(LongPollingTransport.class);
public LongPollingTransport(Map<String, String> headers, HttpClient client, Single<String> accessTokenProvider) {
this.headers = headers;
this.client = client;
this.pollingClient = client.cloneWithTimeOut(POLL_TIMEOUT);
this.accessTokenProvider = accessTokenProvider;
}
//Package private active accessor for testing.
boolean isActive() {
return this.active;
}
private Single updateHeaderToken() {
return this.accessTokenProvider.flatMap((token) -> {
if (!token.isEmpty()) {
this.headers.put("Authorization", "Bearer " + token);
}
return Single.just("");
});
}
@Override
public Completable start(String url) {
this.active = true;
logger.debug("Starting LongPolling transport.");
this.url = url;
pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((r) -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
return this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
if (response.getStatusCode() != 200) {
logger.error("Unexpected response code {}.", response.getStatusCode());
this.active = false;
return Completable.error(new Exception("Failed to connect."));
} else {
this.active = true;
}
this.threadPool = Executors.newCachedThreadPool();
threadPool.execute(() -> poll(url).subscribeWith(receiveLoop));
return Completable.complete();
});
});
}
private Completable poll(String url) {
if (this.active) {
pollUrl = url + "&_=" + System.currentTimeMillis();
logger.debug("Polling {}.", pollUrl);
return this.updateHeaderToken().flatMapCompletable((x) -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
Completable pollingCompletable = this.pollingClient.get(pollUrl, request).flatMapCompletable(response -> {
if (response.getStatusCode() == 204) {
logger.info("LongPolling transport terminated by server.");
this.active = false;
} else if (response.getStatusCode() != 200) {
logger.error("Unexpected response code {}.", response.getStatusCode());
this.active = false;
this.closeError = "Unexpected response code " + response.getStatusCode() + ".";
} else {
if (response.getContent() != null) {
logger.debug("Message received.");
threadPool.execute(() -> this.onReceive(response.getContent()));
} else {
logger.debug("Poll timed out, reissuing.");
}
}
return poll(url);
});
return pollingCompletable;
});
} else {
logger.debug("Long Polling transport polling complete.");
receiveLoop.onComplete();
if (!stopCalled.get()) {
return this.stop();
}
return Completable.complete();
}
}
@Override
public Completable send(String message) {
if (!this.active) {
return Completable.error(new Exception("Cannot send unless the transport is active."));
}
return this.updateHeaderToken().flatMapCompletable((x) -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
return Completable.fromSingle(this.client.post(url, message, request));
});
}
@Override
public void setOnReceive(OnReceiveCallBack callback) {
this.onReceiveCallBack = callback;
}
@Override
public void onReceive(String message) {
this.onReceiveCallBack.invoke(message);
logger.debug("OnReceived callback has been invoked.");
}
@Override
public void setOnClose(TransportOnClosedCallback onCloseCallback) {
this.onClose = onCloseCallback;
}
@Override
public Completable stop() {
if (!stopCalled.get()) {
this.stopCalled.set(true);
this.active = false;
return this.updateHeaderToken().flatMapCompletable((x) -> {
HttpRequest request = new HttpRequest();
request.addHeaders(headers);
this.pollingClient.delete(this.url, request);
CompletableSubject stopCompletableSubject = CompletableSubject.create();
return this.receiveLoop.andThen(Completable.defer(() -> {
logger.info("LongPolling transport stopped.");
this.onClose.invoke(this.closeError);
return Completable.complete();
})).subscribeWith(stopCompletableSubject);
});
}
return Completable.complete();
}
}

View File

@ -0,0 +1,10 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.signalr;
public enum TransportEnum {
ALL,
WEBSOCKETS,
LONG_POLLING
}

View File

@ -8,7 +8,6 @@ import static org.junit.jupiter.api.Assertions.*;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.CancellationException;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
@ -66,7 +65,7 @@ class HubConnectionTest {
public void checkHubConnectionStateNoHandShakeResponse() {
MockTransport mockTransport = new MockTransport(false);
HubConnection hubConnection = HubConnectionBuilder.create("http://example.com")
.withTransport(mockTransport)
.withTransportImplementation(mockTransport)
.withHttpClient(new TestHttpClient())
.shouldSkipNegotiate(true)
.withHandshakeResponseTimeout(100)
@ -1179,7 +1178,7 @@ class HubConnectionTest {
}
@Test
public void afterSuccessfulNegotiateConnectsWithTransport() {
public void afterSuccessfulNegotiateConnectsWithWebsocketsTransport() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate",
(req) -> Single.just(new HttpResponse(200, "",
"{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
@ -1188,7 +1187,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.build();
@ -1199,6 +1198,47 @@ class HubConnectionTest {
assertEquals("{\"protocol\":\"json\",\"version\":1}" + RECORD_SEPARATOR, sentMessages[0]);
}
@Test
public void afterSuccessfulNegotiateConnectsWithLongPollingTransport() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate",
(req) -> Single.just(new HttpResponse(200, "",
"{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"LongPolling\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")));
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransportImplementation(transport)
.withHttpClient(client)
.build();
hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait();
String[] sentMessages = transport.getSentMessages();
assertEquals(1, sentMessages.length);
assertEquals("{\"protocol\":\"json\",\"version\":1}" + RECORD_SEPARATOR, sentMessages[0]);
}
@Test
public void receivingServerSentEventsTransportFromNegotiateFails() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate",
(req) -> Single.just(new HttpResponse(200, "",
"{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\""
+ "availableTransports\":[{\"transport\":\"ServerSentEvents\",\"transferFormats\":[\"Text\"]}]}")));
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransportImplementation(transport)
.withHttpClient(client)
.build();
RuntimeException exception = assertThrows(RuntimeException.class,
() -> hubConnection.start().timeout(1, TimeUnit.SECONDS).blockingAwait());
assertEquals(exception.getMessage(), "There were no compatible transports on the server.");
}
@Test
public void negotiateThatReturnsErrorThrowsFromStart() {
TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate",
@ -1208,7 +1248,7 @@ class HubConnectionTest {
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withHttpClient(client)
.withTransport(transport)
.withTransportImplementation(transport)
.build();
RuntimeException exception = assertThrows(RuntimeException.class,
@ -1227,7 +1267,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.build();
@ -1250,7 +1290,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.withAccessTokenProvider(Single.just("secretToken"))
.build();
@ -1275,7 +1315,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport(true);
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.withAccessTokenProvider(Single.just("secretToken"))
.build();
@ -1335,7 +1375,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport();
HubConnection hubConnection = HubConnectionBuilder.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.withHeader("ExampleHeader", "ExampleValue")
.build();
@ -1360,7 +1400,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport();
HubConnection hubConnection = HubConnectionBuilder.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.withHeader("ExampleHeader", "ExampleValue")
.withHeader("ExampleHeader", "New Value")
@ -1377,7 +1417,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport();
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.shouldSkipNegotiate(true)
.build();
@ -1401,7 +1441,7 @@ class HubConnectionTest {
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(mockTransport)
.withTransportImplementation(mockTransport)
.withHttpClient(client)
.build();
@ -1424,7 +1464,7 @@ class HubConnectionTest {
MockTransport transport = new MockTransport();
HubConnection hubConnection = HubConnectionBuilder
.create("http://example.com")
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.build();

View File

@ -0,0 +1,330 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.signalr;
import static org.junit.jupiter.api.Assertions.*;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.jupiter.api.Test;
import io.reactivex.Single;
import io.reactivex.subjects.CompletableSubject;
public class LongPollingTransportTest {
@Test
public void LongPollingFailsToConnectWith404Response() {
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> Single.just(new HttpResponse(404, "", "")));
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
Throwable exception = assertThrows(RuntimeException.class, () -> transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait());
assertEquals(Exception.class, exception.getCause().getClass());
assertEquals("Failed to connect.", exception.getCause().getMessage());
assertFalse(transport.isActive());
}
@Test
public void LongPollingTransportCantSendBeforeStart() {
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> Single.just(new HttpResponse(404, "", "")));
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
Throwable exception = assertThrows(RuntimeException.class, () -> transport.send("First").timeout(1, TimeUnit.SECONDS).blockingAwait());
assertEquals(Exception.class, exception.getCause().getClass());
assertEquals("Cannot send unless the transport is active.", exception.getCause().getMessage());
assertFalse(transport.isActive());
}
@Test
public void StatusCode204StopsLongPollingTriggersOnClosed() {
AtomicBoolean firstPoll = new AtomicBoolean(true);
CompletableSubject block = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (firstPoll.get()) {
firstPoll.set(false);
return Single.just(new HttpResponse(200, "", ""));
}
return Single.just(new HttpResponse(204, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onClosedRan = new AtomicBoolean(false);
transport.setOnClose((error) -> {
onClosedRan.set(true);
block.onComplete();
});
assertFalse(onClosedRan.get());
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(block.blockingAwait(1, TimeUnit.SECONDS));
assertTrue(onClosedRan.get());
assertFalse(transport.isActive());
}
@Test
public void LongPollingFailsWhenReceivingUnexpectedErrorCode() {
AtomicBoolean firstPoll = new AtomicBoolean(true);
CompletableSubject blocker = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (firstPoll.get()) {
firstPoll.set(false);
return Single.just(new HttpResponse(200, "", ""));
}
return Single.just(new HttpResponse(999, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onClosedRan = new AtomicBoolean(false);
transport.setOnClose((error) -> {
onClosedRan.set(true);
assertEquals("Unexpected response code 999.", error);
blocker.onComplete();
});
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(blocker.blockingAwait(1, TimeUnit.SECONDS));
assertFalse(transport.isActive());
assertTrue(onClosedRan.get());
}
@Test
public void CanSetAndTriggerOnReceive() {
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> Single.just(new HttpResponse(200, "", "")));
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onReceivedRan = new AtomicBoolean(false);
transport.setOnReceive((message) -> {
onReceivedRan.set(true);
assertEquals("TEST", message);
});
// The transport doesn't need to be active to trigger onReceive for the case
// when we are handling the last outstanding poll.
transport.onReceive("TEST");
assertTrue(onReceivedRan.get());
}
@Test
public void LongPollingTransportOnReceiveGetsCalled() {
AtomicInteger requestCount = new AtomicInteger();
CompletableSubject block = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (requestCount.get() == 0) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", ""));
} else if (requestCount.get() == 1) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", "TEST"));
}
return Single.just(new HttpResponse(204, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onReceiveCalled = new AtomicBoolean(false);
AtomicReference<String> message = new AtomicReference<>();
transport.setOnReceive((msg -> {
onReceiveCalled.set(true);
message.set(msg);
block.onComplete();
}) );
transport.setOnClose((error) -> {});
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(block.blockingAwait(1,TimeUnit.SECONDS));
assertTrue(onReceiveCalled.get());
assertEquals("TEST", message.get());
}
@Test
public void LongPollingTransportOnReceiveGetsCalledMultipleTimes() {
AtomicInteger requestCount = new AtomicInteger();
CompletableSubject blocker = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (requestCount.get() == 0) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", ""));
} else if (requestCount.get() == 1) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", "FIRST"));
} else if (requestCount.get() == 2) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", "SECOND"));
}
return Single.just(new HttpResponse(204, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onReceiveCalled = new AtomicBoolean(false);
AtomicReference<String> message = new AtomicReference<>("");
AtomicInteger messageCount = new AtomicInteger();
transport.setOnReceive((msg) -> {
onReceiveCalled.set(true);
message.set(message.get() + msg);
if (messageCount.incrementAndGet() == 2) {
blocker.onComplete();
}
});
transport.setOnClose((error) -> {});
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(blocker.blockingAwait(1, TimeUnit.SECONDS));
assertTrue(onReceiveCalled.get());
assertEquals("FIRSTSECOND", message.get());
}
@Test
public void LongPollingTransportSendsHeaders() {
AtomicInteger requestCount = new AtomicInteger();
AtomicReference<String> headerValue = new AtomicReference<>();
CompletableSubject close = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (requestCount.get() == 0) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", ""));
}
assertTrue(close.blockingAwait(1, TimeUnit.SECONDS));
return Single.just(new HttpResponse(204, "", ""));
}).on("POST", (req) -> {
assertFalse(req.getHeaders().isEmpty());
headerValue.set(req.getHeaders().get("KEY"));
return Single.just(new HttpResponse(200, "", ""));
});
Map<String, String> headers = new HashMap<>();
headers.put("KEY", "VALUE");
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
transport.setOnClose((error) -> {});
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(transport.send("TEST").blockingAwait(1, TimeUnit.SECONDS));
close.onComplete();
assertEquals(headerValue.get(), "VALUE");
}
@Test
public void LongPollingTransportSetsAuthorizationHeader() {
AtomicInteger requestCount = new AtomicInteger();
AtomicReference<String> headerValue = new AtomicReference<>();
CompletableSubject close = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (requestCount.get() == 0) {
requestCount.incrementAndGet();
return Single.just(new HttpResponse(200, "", ""));
}
assertTrue(close.blockingAwait(1, TimeUnit.SECONDS));
return Single.just(new HttpResponse(204, "", ""));
})
.on("POST", (req) -> {
assertFalse(req.getHeaders().isEmpty());
headerValue.set(req.getHeaders().get("Authorization"));
return Single.just(new HttpResponse(200, "", ""));
});
Map<String, String> headers = new HashMap<>();
Single<String> tokenProvider = Single.just("TOKEN");
LongPollingTransport transport = new LongPollingTransport(headers, client, tokenProvider);
transport.setOnClose((error) -> {});
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(transport.send("TEST").blockingAwait(1, TimeUnit.SECONDS));
assertEquals(headerValue.get(), "Bearer TOKEN");
close.onComplete();
}
@Test
public void After204StopDoesNotTriggerOnClose() {
AtomicBoolean firstPoll = new AtomicBoolean(true);
CompletableSubject block = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (firstPoll.get()) {
firstPoll.set(false);
return Single.just(new HttpResponse(200, "", ""));
}
return Single.just(new HttpResponse(204, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicBoolean onClosedRan = new AtomicBoolean(false);
AtomicInteger onCloseCount = new AtomicInteger(0);
transport.setOnClose((error) -> {
onClosedRan.set(true);
onCloseCount.incrementAndGet();
block.onComplete();
});
assertFalse(onClosedRan.get());
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(block.blockingAwait(1, TimeUnit.SECONDS));
assertEquals(1, onCloseCount.get());
assertTrue(onClosedRan.get());
assertFalse(transport.isActive());
assertTrue(transport.stop().blockingAwait(1, TimeUnit.SECONDS));
assertEquals(1, onCloseCount.get());
}
@Test
public void StoppingTransportRunsCloseHandlersOnce() {
AtomicBoolean firstPoll = new AtomicBoolean(true);
CompletableSubject block = CompletableSubject.create();
TestHttpClient client = new TestHttpClient()
.on("GET", (req) -> {
if (firstPoll.get()) {
firstPoll.set(false);
return Single.just(new HttpResponse(200, "", ""));
} else {
assertTrue(block.blockingAwait(1, TimeUnit.SECONDS));
return Single.just(new HttpResponse(204, "", ""));
}
})
.on("DELETE", (req) ->{
//Unblock the last poll when we sent the DELETE request.
block.onComplete();
return Single.just(new HttpResponse(200, "", ""));
});
Map<String, String> headers = new HashMap<>();
LongPollingTransport transport = new LongPollingTransport(headers, client, Single.just(""));
AtomicInteger onCloseCount = new AtomicInteger(0);
transport.setOnClose((error) -> {
onCloseCount.incrementAndGet();
});
assertEquals(0, onCloseCount.get());
transport.start("http://example.com").timeout(1, TimeUnit.SECONDS).blockingAwait();
assertTrue(transport.stop().blockingAwait(1, TimeUnit.SECONDS));
assertEquals(1, onCloseCount.get());
assertFalse(transport.isActive());
}
}

View File

@ -22,6 +22,11 @@ class TestHttpClient extends HttpClient {
@Override
public Single<HttpResponse> send(HttpRequest request) {
return send(request, null);
}
@Override
public Single<HttpResponse> send(HttpRequest request, String body) {
this.sentRequests.add(request);
return this.handler.invoke(request);
}
@ -66,7 +71,12 @@ class TestHttpClient extends HttpClient {
throw new RuntimeException("WebSockets isn't supported in testing currently.");
}
@Override
public HttpClient cloneWithTimeOut(int timeoutInMilliseconds) {
return this;
}
interface TestHttpRequestHandler {
Single<HttpResponse> invoke(HttpRequest request);
}
}
}

View File

@ -14,10 +14,10 @@ class TestUtils {
static HubConnection createHubConnection(String url, Transport transport, boolean skipNegotiate, HttpClient client) {
HttpHubConnectionBuilder builder = HubConnectionBuilder.create(url)
.withTransport(transport)
.withTransportImplementation(transport)
.withHttpClient(client)
.shouldSkipNegotiate(skipNegotiate);
return builder.build();
}
}
}

View File

@ -41,10 +41,20 @@ class WebSocketTransportTest {
return null;
}
@Override
public Single<HttpResponse> send(HttpRequest request, String body) {
return null;
}
@Override
public WebSocketWrapper createWebSocket(String url, Map<String, String> headers) {
return new TestWrapper();
}
@Override
public HttpClient cloneWithTimeOut(int timeoutInMilliseconds) {
return null;
}
}
class TestWrapper extends WebSocketWrapper {