Change websockets library (#3012)
This commit is contained in:
parent
75ac1a60f7
commit
2a1ba9e4ff
|
|
@ -4,16 +4,18 @@
|
|||
package com.microsoft.aspnet.signalr;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.function.Consumer;
|
||||
|
||||
import okhttp3.Cookie;
|
||||
import okhttp3.CookieJar;
|
||||
import okhttp3.HttpUrl;
|
||||
import okhttp3.OkHttpClient;
|
||||
|
||||
public class HubConnection {
|
||||
private String url;
|
||||
private Transport transport;
|
||||
|
|
@ -31,6 +33,7 @@ public class HubConnection {
|
|||
private String accessToken;
|
||||
private Map<String, String> headers = new HashMap<>();
|
||||
private ConnectionState connectionState = null;
|
||||
private OkHttpClient httpClient;
|
||||
|
||||
private static ArrayList<Class<?>> emptyArray = new ArrayList<>();
|
||||
private static int MAX_NEGOTIATE_ATTEMPTS = 100;
|
||||
|
|
@ -54,6 +57,59 @@ public class HubConnection {
|
|||
}
|
||||
|
||||
this.skipNegotiate = skipNegotiate;
|
||||
|
||||
this.httpClient = new OkHttpClient.Builder()
|
||||
.cookieJar(new CookieJar() {
|
||||
private List<Cookie> cookieList = new ArrayList<>();
|
||||
private Lock cookieLock = new ReentrantLock();
|
||||
|
||||
@Override
|
||||
public void saveFromResponse(HttpUrl url, List<Cookie> cookies) {
|
||||
cookieLock.lock();
|
||||
try {
|
||||
for (Cookie cookie : cookies) {
|
||||
boolean replacedCookie = false;
|
||||
for (int i = 0; i < cookieList.size(); i++) {
|
||||
Cookie innerCookie = cookieList.get(i);
|
||||
if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) {
|
||||
// We have a new cookie that matches an older one so we replace the older one.
|
||||
cookieList.set(i, innerCookie);
|
||||
replacedCookie = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!replacedCookie) {
|
||||
cookieList.add(cookie);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
cookieLock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<Cookie> loadForRequest(HttpUrl url) {
|
||||
cookieLock.lock();
|
||||
try {
|
||||
List<Cookie> matchedCookies = new ArrayList<>();
|
||||
List<Cookie> expiredCookies = new ArrayList<>();
|
||||
for (Cookie cookie : cookieList) {
|
||||
if (cookie.expiresAt() < System.currentTimeMillis()) {
|
||||
expiredCookies.add(cookie);
|
||||
} else if (cookie.matches(url)) {
|
||||
matchedCookies.add(cookie);
|
||||
}
|
||||
}
|
||||
|
||||
cookieList.removeAll(expiredCookies);
|
||||
return matchedCookies;
|
||||
} finally {
|
||||
cookieLock.unlock();
|
||||
}
|
||||
}
|
||||
})
|
||||
.build();
|
||||
|
||||
this.callback = (payload) -> {
|
||||
|
||||
if (!handshakeReceived) {
|
||||
|
|
@ -120,7 +176,7 @@ public class HubConnection {
|
|||
|
||||
private NegotiateResponse handleNegotiate() throws IOException, HubException {
|
||||
accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken();
|
||||
negotiateResponse = Negotiate.processNegotiate(url, accessToken);
|
||||
negotiateResponse = Negotiate.processNegotiate(url, httpClient, accessToken);
|
||||
|
||||
if (negotiateResponse.getError() != null) {
|
||||
throw new HubException(negotiateResponse.getError());
|
||||
|
|
@ -176,7 +232,7 @@ public class HubConnection {
|
|||
|
||||
logger.log(LogLevel.Debug, "Starting HubConnection");
|
||||
if (transport == null) {
|
||||
transport = new WebSocketTransport(url, logger, headers);
|
||||
transport = new WebSocketTransport(url, logger, headers, httpClient);
|
||||
}
|
||||
|
||||
transport.setOnReceive(this.callback);
|
||||
|
|
|
|||
|
|
@ -12,13 +12,12 @@ import okhttp3.Response;
|
|||
|
||||
class Negotiate {
|
||||
|
||||
public static NegotiateResponse processNegotiate(String url) throws IOException {
|
||||
return processNegotiate(url, null);
|
||||
public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient) throws IOException {
|
||||
return processNegotiate(url, httpClient, null);
|
||||
}
|
||||
|
||||
public static NegotiateResponse processNegotiate(String url, String accessTokenHeader) throws IOException {
|
||||
public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient,String accessTokenHeader) throws IOException {
|
||||
url = resolveNegotiateUrl(url);
|
||||
OkHttpClient client = new OkHttpClient();
|
||||
RequestBody body = RequestBody.create(null, new byte[]{});
|
||||
Request.Builder requestBuilder = new Request.Builder()
|
||||
.url(url)
|
||||
|
|
@ -30,7 +29,7 @@ class Negotiate {
|
|||
|
||||
Request request = requestBuilder.build();
|
||||
|
||||
Response response = client.newCall(request).execute();
|
||||
Response response = httpClient.newCall(request).execute();
|
||||
String result = response.body().string();
|
||||
return new NegotiateResponse(result);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,15 +8,17 @@ import java.net.URISyntaxException;
|
|||
import java.util.Map;
|
||||
import java.util.concurrent.CompletableFuture;
|
||||
|
||||
import org.java_websocket.client.WebSocketClient;
|
||||
import org.java_websocket.handshake.ServerHandshake;
|
||||
import okhttp3.*;
|
||||
|
||||
class WebSocketTransport implements Transport {
|
||||
private WebSocketClient webSocketClient;
|
||||
private WebSocket websocketClient;
|
||||
private SignalRWebSocketListener webSocketListener;
|
||||
private OnReceiveCallBack onReceiveCallBack;
|
||||
private URI url;
|
||||
private Logger logger;
|
||||
private Map<String, String> headers;
|
||||
private OkHttpClient httpClient;
|
||||
private CompletableFuture<Void> startFuture = new CompletableFuture<>();
|
||||
|
||||
private static final String HTTP = "http";
|
||||
private static final String HTTPS = "https";
|
||||
|
|
@ -27,6 +29,14 @@ class WebSocketTransport implements Transport {
|
|||
this.url = formatUrl(url);
|
||||
this.logger = logger;
|
||||
this.headers = headers;
|
||||
this.httpClient = new OkHttpClient();
|
||||
}
|
||||
|
||||
public WebSocketTransport(String url, Logger logger, Map<String, String> headers, OkHttpClient httpClient) throws URISyntaxException {
|
||||
this.url = formatUrl(url);
|
||||
this.logger = logger;
|
||||
this.headers = headers;
|
||||
this.httpClient = httpClient;
|
||||
}
|
||||
|
||||
public URI getUrl() {
|
||||
|
|
@ -45,27 +55,15 @@ class WebSocketTransport implements Transport {
|
|||
|
||||
@Override
|
||||
public CompletableFuture start() {
|
||||
return CompletableFuture.runAsync(() -> {
|
||||
logger.log(LogLevel.Debug, "Starting Websocket connection.");
|
||||
webSocketClient = createWebSocket(headers);
|
||||
try {
|
||||
if (!webSocketClient.connectBlocking()) {
|
||||
String errorMessage = "There was an error starting the Websockets transport.";
|
||||
logger.log(LogLevel.Debug, errorMessage);
|
||||
throw new RuntimeException(errorMessage);
|
||||
}
|
||||
} catch (InterruptedException e) {
|
||||
String interruptedExMessage = "Connecting the Websockets transport was interrupted.";
|
||||
logger.log(LogLevel.Debug, interruptedExMessage);
|
||||
throw new RuntimeException(interruptedExMessage);
|
||||
}
|
||||
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI());
|
||||
});
|
||||
webSocketListener = new SignalRWebSocketListener();
|
||||
websocketClient = createUpdatedWebSocket(webSocketListener);
|
||||
return startFuture;
|
||||
}
|
||||
|
||||
@Override
|
||||
public CompletableFuture send(String message) {
|
||||
return CompletableFuture.runAsync(() -> webSocketClient.send(message));
|
||||
return CompletableFuture.runAsync(() -> websocketClient.send(message));
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -82,36 +80,62 @@ class WebSocketTransport implements Transport {
|
|||
@Override
|
||||
public CompletableFuture stop() {
|
||||
return CompletableFuture.runAsync(() -> {
|
||||
webSocketClient.closeConnection(0, "HubConnection Stopped");
|
||||
websocketClient.close(1000, "HubConnection stopped.");
|
||||
logger.log(LogLevel.Information, "WebSocket connection stopped");
|
||||
});
|
||||
}
|
||||
|
||||
private WebSocketClient createWebSocket(Map<String, String> headers) {
|
||||
return new WebSocketClient(url, headers) {
|
||||
@Override
|
||||
public void onOpen(ServerHandshake handshakedata) {
|
||||
System.out.println("Connected to " + url);
|
||||
}
|
||||
private WebSocket createUpdatedWebSocket(WebSocketListener webSocketListener) {
|
||||
Headers.Builder headerBuilder = new Headers.Builder();
|
||||
for (String key: headers.keySet()) {
|
||||
headerBuilder.add(key, headers.get(key));
|
||||
}
|
||||
Request request = new Request.Builder().url(url.toString())
|
||||
.headers(headerBuilder.build())
|
||||
.build();
|
||||
|
||||
@Override
|
||||
public void onMessage(String message) {
|
||||
try {
|
||||
onReceive(message);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
return this.httpClient.newWebSocket(request, webSocketListener);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClose(int code, String reason, boolean remote) {
|
||||
System.out.println("Connection Closed");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onError(Exception ex) {
|
||||
System.out.println("Error: " + ex.getMessage());
|
||||
private class SignalRWebSocketListener extends WebSocketListener {
|
||||
@Override
|
||||
public void onOpen(WebSocket webSocket, Response response) {
|
||||
startFuture.complete(null);
|
||||
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", websocketClient.request().url());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onMessage(WebSocket webSocket, String message) {
|
||||
try {
|
||||
onReceive(message);
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onClosing(WebSocket webSocket, int code, String reason) {
|
||||
logger.log(LogLevel.Information, "WebSocket connection stopping with " +
|
||||
"code %d and reason %s", code, reason);
|
||||
// If the start future hasn't completed yet, then we need to complete it exceptionally.
|
||||
checkStartFailure();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void onFailure(WebSocket webSocket, Throwable t, Response response) {
|
||||
logger.log(LogLevel.Error, "Error : %d", t.getMessage());
|
||||
// If the start future hasn't completed yet, then we need to complete it exceptionally.
|
||||
checkStartFailure();
|
||||
}
|
||||
}
|
||||
|
||||
private void checkStartFailure() {
|
||||
// If the start future hasn't completed yet, then we need to complete it exceptionally.
|
||||
if (!startFuture.isDone()) {
|
||||
String errorMessage = "There was an error starting the Websockets transport.";
|
||||
logger.log(LogLevel.Debug, errorMessage);
|
||||
startFuture.completeExceptionally(new RuntimeException(errorMessage));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import org.junit.jupiter.api.Test;
|
|||
class WebSocketTransportTest {
|
||||
@Test
|
||||
public void WebsocketThrowsIfItCantConnect() throws Exception {
|
||||
Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger(), new HashMap<>());
|
||||
Transport transport = new WebSocketTransport("http://www.notarealurl12345.fake", new NullLogger(), new HashMap<>());
|
||||
Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS));
|
||||
assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage());
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue