Change websockets library (#3012)

This commit is contained in:
Mikael Mengistu 2018-09-28 14:20:58 -07:00 committed by GitHub
parent 75ac1a60f7
commit 2a1ba9e4ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 54 deletions

View File

@ -4,16 +4,18 @@
package com.microsoft.aspnet.signalr;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
public class HubConnection {
private String url;
private Transport transport;
@ -31,6 +33,7 @@ public class HubConnection {
private String accessToken;
private Map<String, String> headers = new HashMap<>();
private ConnectionState connectionState = null;
private OkHttpClient httpClient;
private static ArrayList<Class<?>> emptyArray = new ArrayList<>();
private static int MAX_NEGOTIATE_ATTEMPTS = 100;
@ -54,6 +57,59 @@ public class HubConnection {
}
this.skipNegotiate = skipNegotiate;
this.httpClient = new OkHttpClient.Builder()
.cookieJar(new CookieJar() {
private List<Cookie> cookieList = new ArrayList<>();
private Lock cookieLock = new ReentrantLock();
@Override
public void saveFromResponse(HttpUrl url, List<Cookie> cookies) {
cookieLock.lock();
try {
for (Cookie cookie : cookies) {
boolean replacedCookie = false;
for (int i = 0; i < cookieList.size(); i++) {
Cookie innerCookie = cookieList.get(i);
if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) {
// We have a new cookie that matches an older one so we replace the older one.
cookieList.set(i, innerCookie);
replacedCookie = true;
break;
}
}
if (!replacedCookie) {
cookieList.add(cookie);
}
}
} finally {
cookieLock.unlock();
}
}
@Override
public List<Cookie> loadForRequest(HttpUrl url) {
cookieLock.lock();
try {
List<Cookie> matchedCookies = new ArrayList<>();
List<Cookie> expiredCookies = new ArrayList<>();
for (Cookie cookie : cookieList) {
if (cookie.expiresAt() < System.currentTimeMillis()) {
expiredCookies.add(cookie);
} else if (cookie.matches(url)) {
matchedCookies.add(cookie);
}
}
cookieList.removeAll(expiredCookies);
return matchedCookies;
} finally {
cookieLock.unlock();
}
}
})
.build();
this.callback = (payload) -> {
if (!handshakeReceived) {
@ -120,7 +176,7 @@ public class HubConnection {
private NegotiateResponse handleNegotiate() throws IOException, HubException {
accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken();
negotiateResponse = Negotiate.processNegotiate(url, accessToken);
negotiateResponse = Negotiate.processNegotiate(url, httpClient, accessToken);
if (negotiateResponse.getError() != null) {
throw new HubException(negotiateResponse.getError());
@ -176,7 +232,7 @@ public class HubConnection {
logger.log(LogLevel.Debug, "Starting HubConnection");
if (transport == null) {
transport = new WebSocketTransport(url, logger, headers);
transport = new WebSocketTransport(url, logger, headers, httpClient);
}
transport.setOnReceive(this.callback);

View File

@ -12,13 +12,12 @@ import okhttp3.Response;
class Negotiate {
public static NegotiateResponse processNegotiate(String url) throws IOException {
return processNegotiate(url, null);
public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient) throws IOException {
return processNegotiate(url, httpClient, null);
}
public static NegotiateResponse processNegotiate(String url, String accessTokenHeader) throws IOException {
public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient,String accessTokenHeader) throws IOException {
url = resolveNegotiateUrl(url);
OkHttpClient client = new OkHttpClient();
RequestBody body = RequestBody.create(null, new byte[]{});
Request.Builder requestBuilder = new Request.Builder()
.url(url)
@ -30,7 +29,7 @@ class Negotiate {
Request request = requestBuilder.build();
Response response = client.newCall(request).execute();
Response response = httpClient.newCall(request).execute();
String result = response.body().string();
return new NegotiateResponse(result);
}

View File

@ -8,15 +8,17 @@ import java.net.URISyntaxException;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake;
import okhttp3.*;
class WebSocketTransport implements Transport {
private WebSocketClient webSocketClient;
private WebSocket websocketClient;
private SignalRWebSocketListener webSocketListener;
private OnReceiveCallBack onReceiveCallBack;
private URI url;
private Logger logger;
private Map<String, String> headers;
private OkHttpClient httpClient;
private CompletableFuture<Void> startFuture = new CompletableFuture<>();
private static final String HTTP = "http";
private static final String HTTPS = "https";
@ -27,6 +29,14 @@ class WebSocketTransport implements Transport {
this.url = formatUrl(url);
this.logger = logger;
this.headers = headers;
this.httpClient = new OkHttpClient();
}
public WebSocketTransport(String url, Logger logger, Map<String, String> headers, OkHttpClient httpClient) throws URISyntaxException {
this.url = formatUrl(url);
this.logger = logger;
this.headers = headers;
this.httpClient = httpClient;
}
public URI getUrl() {
@ -45,27 +55,15 @@ class WebSocketTransport implements Transport {
@Override
public CompletableFuture start() {
return CompletableFuture.runAsync(() -> {
logger.log(LogLevel.Debug, "Starting Websocket connection.");
webSocketClient = createWebSocket(headers);
try {
if (!webSocketClient.connectBlocking()) {
String errorMessage = "There was an error starting the Websockets transport.";
logger.log(LogLevel.Debug, errorMessage);
throw new RuntimeException(errorMessage);
}
} catch (InterruptedException e) {
String interruptedExMessage = "Connecting the Websockets transport was interrupted.";
logger.log(LogLevel.Debug, interruptedExMessage);
throw new RuntimeException(interruptedExMessage);
}
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI());
});
webSocketListener = new SignalRWebSocketListener();
websocketClient = createUpdatedWebSocket(webSocketListener);
return startFuture;
}
@Override
public CompletableFuture send(String message) {
return CompletableFuture.runAsync(() -> webSocketClient.send(message));
return CompletableFuture.runAsync(() -> websocketClient.send(message));
}
@Override
@ -82,36 +80,62 @@ class WebSocketTransport implements Transport {
@Override
public CompletableFuture stop() {
return CompletableFuture.runAsync(() -> {
webSocketClient.closeConnection(0, "HubConnection Stopped");
websocketClient.close(1000, "HubConnection stopped.");
logger.log(LogLevel.Information, "WebSocket connection stopped");
});
}
private WebSocketClient createWebSocket(Map<String, String> headers) {
return new WebSocketClient(url, headers) {
@Override
public void onOpen(ServerHandshake handshakedata) {
System.out.println("Connected to " + url);
}
private WebSocket createUpdatedWebSocket(WebSocketListener webSocketListener) {
Headers.Builder headerBuilder = new Headers.Builder();
for (String key: headers.keySet()) {
headerBuilder.add(key, headers.get(key));
}
Request request = new Request.Builder().url(url.toString())
.headers(headerBuilder.build())
.build();
@Override
public void onMessage(String message) {
try {
onReceive(message);
} catch (Exception e) {
e.printStackTrace();
}
}
return this.httpClient.newWebSocket(request, webSocketListener);
}
@Override
public void onClose(int code, String reason, boolean remote) {
System.out.println("Connection Closed");
}
@Override
public void onError(Exception ex) {
System.out.println("Error: " + ex.getMessage());
private class SignalRWebSocketListener extends WebSocketListener {
@Override
public void onOpen(WebSocket webSocket, Response response) {
startFuture.complete(null);
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", websocketClient.request().url());
}
@Override
public void onMessage(WebSocket webSocket, String message) {
try {
onReceive(message);
} catch (Exception e) {
e.printStackTrace();
}
};
}
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
logger.log(LogLevel.Information, "WebSocket connection stopping with " +
"code %d and reason %s", code, reason);
// If the start future hasn't completed yet, then we need to complete it exceptionally.
checkStartFailure();
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
logger.log(LogLevel.Error, "Error : %d", t.getMessage());
// If the start future hasn't completed yet, then we need to complete it exceptionally.
checkStartFailure();
}
}
private void checkStartFailure() {
// If the start future hasn't completed yet, then we need to complete it exceptionally.
if (!startFuture.isDone()) {
String errorMessage = "There was an error starting the Websockets transport.";
logger.log(LogLevel.Debug, errorMessage);
startFuture.completeExceptionally(new RuntimeException(errorMessage));
}
}
}

View File

@ -13,7 +13,7 @@ import org.junit.jupiter.api.Test;
class WebSocketTransportTest {
@Test
public void WebsocketThrowsIfItCantConnect() throws Exception {
Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger(), new HashMap<>());
Transport transport = new WebSocketTransport("http://www.notarealurl12345.fake", new NullLogger(), new HashMap<>());
Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS));
assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage());
}