Change websockets library (#3012)

This commit is contained in:
Mikael Mengistu 2018-09-28 14:20:58 -07:00 committed by GitHub
parent 75ac1a60f7
commit 2a1ba9e4ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 133 additions and 54 deletions

View File

@ -4,16 +4,18 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.*;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock; import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer; import java.util.function.Consumer;
import okhttp3.Cookie;
import okhttp3.CookieJar;
import okhttp3.HttpUrl;
import okhttp3.OkHttpClient;
public class HubConnection { public class HubConnection {
private String url; private String url;
private Transport transport; private Transport transport;
@ -31,6 +33,7 @@ public class HubConnection {
private String accessToken; private String accessToken;
private Map<String, String> headers = new HashMap<>(); private Map<String, String> headers = new HashMap<>();
private ConnectionState connectionState = null; private ConnectionState connectionState = null;
private OkHttpClient httpClient;
private static ArrayList<Class<?>> emptyArray = new ArrayList<>(); private static ArrayList<Class<?>> emptyArray = new ArrayList<>();
private static int MAX_NEGOTIATE_ATTEMPTS = 100; private static int MAX_NEGOTIATE_ATTEMPTS = 100;
@ -54,6 +57,59 @@ public class HubConnection {
} }
this.skipNegotiate = skipNegotiate; this.skipNegotiate = skipNegotiate;
this.httpClient = new OkHttpClient.Builder()
.cookieJar(new CookieJar() {
private List<Cookie> cookieList = new ArrayList<>();
private Lock cookieLock = new ReentrantLock();
@Override
public void saveFromResponse(HttpUrl url, List<Cookie> cookies) {
cookieLock.lock();
try {
for (Cookie cookie : cookies) {
boolean replacedCookie = false;
for (int i = 0; i < cookieList.size(); i++) {
Cookie innerCookie = cookieList.get(i);
if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) {
// We have a new cookie that matches an older one so we replace the older one.
cookieList.set(i, innerCookie);
replacedCookie = true;
break;
}
}
if (!replacedCookie) {
cookieList.add(cookie);
}
}
} finally {
cookieLock.unlock();
}
}
@Override
public List<Cookie> loadForRequest(HttpUrl url) {
cookieLock.lock();
try {
List<Cookie> matchedCookies = new ArrayList<>();
List<Cookie> expiredCookies = new ArrayList<>();
for (Cookie cookie : cookieList) {
if (cookie.expiresAt() < System.currentTimeMillis()) {
expiredCookies.add(cookie);
} else if (cookie.matches(url)) {
matchedCookies.add(cookie);
}
}
cookieList.removeAll(expiredCookies);
return matchedCookies;
} finally {
cookieLock.unlock();
}
}
})
.build();
this.callback = (payload) -> { this.callback = (payload) -> {
if (!handshakeReceived) { if (!handshakeReceived) {
@ -120,7 +176,7 @@ public class HubConnection {
private NegotiateResponse handleNegotiate() throws IOException, HubException { private NegotiateResponse handleNegotiate() throws IOException, HubException {
accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken();
negotiateResponse = Negotiate.processNegotiate(url, accessToken); negotiateResponse = Negotiate.processNegotiate(url, httpClient, accessToken);
if (negotiateResponse.getError() != null) { if (negotiateResponse.getError() != null) {
throw new HubException(negotiateResponse.getError()); throw new HubException(negotiateResponse.getError());
@ -176,7 +232,7 @@ public class HubConnection {
logger.log(LogLevel.Debug, "Starting HubConnection"); logger.log(LogLevel.Debug, "Starting HubConnection");
if (transport == null) { if (transport == null) {
transport = new WebSocketTransport(url, logger, headers); transport = new WebSocketTransport(url, logger, headers, httpClient);
} }
transport.setOnReceive(this.callback); transport.setOnReceive(this.callback);

View File

@ -12,13 +12,12 @@ import okhttp3.Response;
class Negotiate { class Negotiate {
public static NegotiateResponse processNegotiate(String url) throws IOException { public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient) throws IOException {
return processNegotiate(url, null); return processNegotiate(url, httpClient, null);
} }
public static NegotiateResponse processNegotiate(String url, String accessTokenHeader) throws IOException { public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient,String accessTokenHeader) throws IOException {
url = resolveNegotiateUrl(url); url = resolveNegotiateUrl(url);
OkHttpClient client = new OkHttpClient();
RequestBody body = RequestBody.create(null, new byte[]{}); RequestBody body = RequestBody.create(null, new byte[]{});
Request.Builder requestBuilder = new Request.Builder() Request.Builder requestBuilder = new Request.Builder()
.url(url) .url(url)
@ -30,7 +29,7 @@ class Negotiate {
Request request = requestBuilder.build(); Request request = requestBuilder.build();
Response response = client.newCall(request).execute(); Response response = httpClient.newCall(request).execute();
String result = response.body().string(); String result = response.body().string();
return new NegotiateResponse(result); return new NegotiateResponse(result);
} }

View File

@ -8,15 +8,17 @@ import java.net.URISyntaxException;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import org.java_websocket.client.WebSocketClient; import okhttp3.*;
import org.java_websocket.handshake.ServerHandshake;
class WebSocketTransport implements Transport { class WebSocketTransport implements Transport {
private WebSocketClient webSocketClient; private WebSocket websocketClient;
private SignalRWebSocketListener webSocketListener;
private OnReceiveCallBack onReceiveCallBack; private OnReceiveCallBack onReceiveCallBack;
private URI url; private URI url;
private Logger logger; private Logger logger;
private Map<String, String> headers; private Map<String, String> headers;
private OkHttpClient httpClient;
private CompletableFuture<Void> startFuture = new CompletableFuture<>();
private static final String HTTP = "http"; private static final String HTTP = "http";
private static final String HTTPS = "https"; private static final String HTTPS = "https";
@ -27,6 +29,14 @@ class WebSocketTransport implements Transport {
this.url = formatUrl(url); this.url = formatUrl(url);
this.logger = logger; this.logger = logger;
this.headers = headers; this.headers = headers;
this.httpClient = new OkHttpClient();
}
public WebSocketTransport(String url, Logger logger, Map<String, String> headers, OkHttpClient httpClient) throws URISyntaxException {
this.url = formatUrl(url);
this.logger = logger;
this.headers = headers;
this.httpClient = httpClient;
} }
public URI getUrl() { public URI getUrl() {
@ -45,27 +55,15 @@ class WebSocketTransport implements Transport {
@Override @Override
public CompletableFuture start() { public CompletableFuture start() {
return CompletableFuture.runAsync(() -> {
logger.log(LogLevel.Debug, "Starting Websocket connection."); logger.log(LogLevel.Debug, "Starting Websocket connection.");
webSocketClient = createWebSocket(headers); webSocketListener = new SignalRWebSocketListener();
try { websocketClient = createUpdatedWebSocket(webSocketListener);
if (!webSocketClient.connectBlocking()) { return startFuture;
String errorMessage = "There was an error starting the Websockets transport.";
logger.log(LogLevel.Debug, errorMessage);
throw new RuntimeException(errorMessage);
}
} catch (InterruptedException e) {
String interruptedExMessage = "Connecting the Websockets transport was interrupted.";
logger.log(LogLevel.Debug, interruptedExMessage);
throw new RuntimeException(interruptedExMessage);
}
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI());
});
} }
@Override @Override
public CompletableFuture send(String message) { public CompletableFuture send(String message) {
return CompletableFuture.runAsync(() -> webSocketClient.send(message)); return CompletableFuture.runAsync(() -> websocketClient.send(message));
} }
@Override @Override
@ -82,36 +80,62 @@ class WebSocketTransport implements Transport {
@Override @Override
public CompletableFuture stop() { public CompletableFuture stop() {
return CompletableFuture.runAsync(() -> { return CompletableFuture.runAsync(() -> {
webSocketClient.closeConnection(0, "HubConnection Stopped"); websocketClient.close(1000, "HubConnection stopped.");
logger.log(LogLevel.Information, "WebSocket connection stopped"); logger.log(LogLevel.Information, "WebSocket connection stopped");
}); });
} }
private WebSocketClient createWebSocket(Map<String, String> headers) { private WebSocket createUpdatedWebSocket(WebSocketListener webSocketListener) {
return new WebSocketClient(url, headers) { Headers.Builder headerBuilder = new Headers.Builder();
@Override for (String key: headers.keySet()) {
public void onOpen(ServerHandshake handshakedata) { headerBuilder.add(key, headers.get(key));
System.out.println("Connected to " + url); }
} Request request = new Request.Builder().url(url.toString())
.headers(headerBuilder.build())
.build();
@Override return this.httpClient.newWebSocket(request, webSocketListener);
public void onMessage(String message) { }
try {
onReceive(message);
} catch (Exception e) {
e.printStackTrace();
}
}
@Override
public void onClose(int code, String reason, boolean remote) {
System.out.println("Connection Closed");
}
@Override private class SignalRWebSocketListener extends WebSocketListener {
public void onError(Exception ex) { @Override
System.out.println("Error: " + ex.getMessage()); public void onOpen(WebSocket webSocket, Response response) {
startFuture.complete(null);
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", websocketClient.request().url());
}
@Override
public void onMessage(WebSocket webSocket, String message) {
try {
onReceive(message);
} catch (Exception e) {
e.printStackTrace();
} }
}; }
@Override
public void onClosing(WebSocket webSocket, int code, String reason) {
logger.log(LogLevel.Information, "WebSocket connection stopping with " +
"code %d and reason %s", code, reason);
// If the start future hasn't completed yet, then we need to complete it exceptionally.
checkStartFailure();
}
@Override
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
logger.log(LogLevel.Error, "Error : %d", t.getMessage());
// If the start future hasn't completed yet, then we need to complete it exceptionally.
checkStartFailure();
}
}
private void checkStartFailure() {
// If the start future hasn't completed yet, then we need to complete it exceptionally.
if (!startFuture.isDone()) {
String errorMessage = "There was an error starting the Websockets transport.";
logger.log(LogLevel.Debug, errorMessage);
startFuture.completeExceptionally(new RuntimeException(errorMessage));
}
} }
} }

View File

@ -13,7 +13,7 @@ import org.junit.jupiter.api.Test;
class WebSocketTransportTest { class WebSocketTransportTest {
@Test @Test
public void WebsocketThrowsIfItCantConnect() throws Exception { public void WebsocketThrowsIfItCantConnect() throws Exception {
Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger(), new HashMap<>()); Transport transport = new WebSocketTransport("http://www.notarealurl12345.fake", new NullLogger(), new HashMap<>());
Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS)); Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS));
assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage()); assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage());
} }