From 2a1ba9e4ff8a31ebf7906b35ae1d0b44bfc2427b Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Fri, 28 Sep 2018 14:20:58 -0700 Subject: [PATCH] Change websockets library (#3012) --- .../aspnet/signalr/HubConnection.java | 68 ++++++++++- .../microsoft/aspnet/signalr/Negotiate.java | 9 +- .../aspnet/signalr/WebSocketTransport.java | 108 +++++++++++------- .../signalr/WebSocketTransportTest.java | 2 +- 4 files changed, 133 insertions(+), 54 deletions(-) diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java index 72abfea42a..b3fc059470 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java @@ -4,16 +4,18 @@ package com.microsoft.aspnet.signalr; import java.io.IOException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; +import okhttp3.Cookie; +import okhttp3.CookieJar; +import okhttp3.HttpUrl; +import okhttp3.OkHttpClient; + public class HubConnection { private String url; private Transport transport; @@ -31,6 +33,7 @@ public class HubConnection { private String accessToken; private Map headers = new HashMap<>(); private ConnectionState connectionState = null; + private OkHttpClient httpClient; private static ArrayList> emptyArray = new ArrayList<>(); private static int MAX_NEGOTIATE_ATTEMPTS = 100; @@ -54,6 +57,59 @@ public class HubConnection { } this.skipNegotiate = skipNegotiate; + + this.httpClient = new OkHttpClient.Builder() + .cookieJar(new CookieJar() { + private List cookieList = new ArrayList<>(); + private Lock cookieLock = new ReentrantLock(); + + @Override + public void saveFromResponse(HttpUrl url, List cookies) { + cookieLock.lock(); + try { + for (Cookie cookie : cookies) { + boolean replacedCookie = false; + for (int i = 0; i < cookieList.size(); i++) { + Cookie innerCookie = cookieList.get(i); + if (cookie.name().equals(innerCookie.name()) && innerCookie.matches(url)) { + // We have a new cookie that matches an older one so we replace the older one. + cookieList.set(i, innerCookie); + replacedCookie = true; + break; + } + } + if (!replacedCookie) { + cookieList.add(cookie); + } + } + } finally { + cookieLock.unlock(); + } + } + + @Override + public List loadForRequest(HttpUrl url) { + cookieLock.lock(); + try { + List matchedCookies = new ArrayList<>(); + List expiredCookies = new ArrayList<>(); + for (Cookie cookie : cookieList) { + if (cookie.expiresAt() < System.currentTimeMillis()) { + expiredCookies.add(cookie); + } else if (cookie.matches(url)) { + matchedCookies.add(cookie); + } + } + + cookieList.removeAll(expiredCookies); + return matchedCookies; + } finally { + cookieLock.unlock(); + } + } + }) + .build(); + this.callback = (payload) -> { if (!handshakeReceived) { @@ -120,7 +176,7 @@ public class HubConnection { private NegotiateResponse handleNegotiate() throws IOException, HubException { accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); - negotiateResponse = Negotiate.processNegotiate(url, accessToken); + negotiateResponse = Negotiate.processNegotiate(url, httpClient, accessToken); if (negotiateResponse.getError() != null) { throw new HubException(negotiateResponse.getError()); @@ -176,7 +232,7 @@ public class HubConnection { logger.log(LogLevel.Debug, "Starting HubConnection"); if (transport == null) { - transport = new WebSocketTransport(url, logger, headers); + transport = new WebSocketTransport(url, logger, headers, httpClient); } transport.setOnReceive(this.callback); diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Negotiate.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Negotiate.java index 7788a06b1c..a90b71c476 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Negotiate.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Negotiate.java @@ -12,13 +12,12 @@ import okhttp3.Response; class Negotiate { - public static NegotiateResponse processNegotiate(String url) throws IOException { - return processNegotiate(url, null); + public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient) throws IOException { + return processNegotiate(url, httpClient, null); } - public static NegotiateResponse processNegotiate(String url, String accessTokenHeader) throws IOException { + public static NegotiateResponse processNegotiate(String url, OkHttpClient httpClient,String accessTokenHeader) throws IOException { url = resolveNegotiateUrl(url); - OkHttpClient client = new OkHttpClient(); RequestBody body = RequestBody.create(null, new byte[]{}); Request.Builder requestBuilder = new Request.Builder() .url(url) @@ -30,7 +29,7 @@ class Negotiate { Request request = requestBuilder.build(); - Response response = client.newCall(request).execute(); + Response response = httpClient.newCall(request).execute(); String result = response.body().string(); return new NegotiateResponse(result); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java index 43281465f8..35270deb04 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java @@ -8,15 +8,17 @@ import java.net.URISyntaxException; import java.util.Map; import java.util.concurrent.CompletableFuture; -import org.java_websocket.client.WebSocketClient; -import org.java_websocket.handshake.ServerHandshake; +import okhttp3.*; class WebSocketTransport implements Transport { - private WebSocketClient webSocketClient; + private WebSocket websocketClient; + private SignalRWebSocketListener webSocketListener; private OnReceiveCallBack onReceiveCallBack; private URI url; private Logger logger; private Map headers; + private OkHttpClient httpClient; + private CompletableFuture startFuture = new CompletableFuture<>(); private static final String HTTP = "http"; private static final String HTTPS = "https"; @@ -27,6 +29,14 @@ class WebSocketTransport implements Transport { this.url = formatUrl(url); this.logger = logger; this.headers = headers; + this.httpClient = new OkHttpClient(); + } + + public WebSocketTransport(String url, Logger logger, Map headers, OkHttpClient httpClient) throws URISyntaxException { + this.url = formatUrl(url); + this.logger = logger; + this.headers = headers; + this.httpClient = httpClient; } public URI getUrl() { @@ -45,27 +55,15 @@ class WebSocketTransport implements Transport { @Override public CompletableFuture start() { - return CompletableFuture.runAsync(() -> { logger.log(LogLevel.Debug, "Starting Websocket connection."); - webSocketClient = createWebSocket(headers); - try { - if (!webSocketClient.connectBlocking()) { - String errorMessage = "There was an error starting the Websockets transport."; - logger.log(LogLevel.Debug, errorMessage); - throw new RuntimeException(errorMessage); - } - } catch (InterruptedException e) { - String interruptedExMessage = "Connecting the Websockets transport was interrupted."; - logger.log(LogLevel.Debug, interruptedExMessage); - throw new RuntimeException(interruptedExMessage); - } - logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); - }); + webSocketListener = new SignalRWebSocketListener(); + websocketClient = createUpdatedWebSocket(webSocketListener); + return startFuture; } @Override public CompletableFuture send(String message) { - return CompletableFuture.runAsync(() -> webSocketClient.send(message)); + return CompletableFuture.runAsync(() -> websocketClient.send(message)); } @Override @@ -82,36 +80,62 @@ class WebSocketTransport implements Transport { @Override public CompletableFuture stop() { return CompletableFuture.runAsync(() -> { - webSocketClient.closeConnection(0, "HubConnection Stopped"); + websocketClient.close(1000, "HubConnection stopped."); logger.log(LogLevel.Information, "WebSocket connection stopped"); }); } - private WebSocketClient createWebSocket(Map headers) { - return new WebSocketClient(url, headers) { - @Override - public void onOpen(ServerHandshake handshakedata) { - System.out.println("Connected to " + url); - } + private WebSocket createUpdatedWebSocket(WebSocketListener webSocketListener) { + Headers.Builder headerBuilder = new Headers.Builder(); + for (String key: headers.keySet()) { + headerBuilder.add(key, headers.get(key)); + } + Request request = new Request.Builder().url(url.toString()) + .headers(headerBuilder.build()) + .build(); - @Override - public void onMessage(String message) { - try { - onReceive(message); - } catch (Exception e) { - e.printStackTrace(); - } - } + return this.httpClient.newWebSocket(request, webSocketListener); + } - @Override - public void onClose(int code, String reason, boolean remote) { - System.out.println("Connection Closed"); - } - @Override - public void onError(Exception ex) { - System.out.println("Error: " + ex.getMessage()); + private class SignalRWebSocketListener extends WebSocketListener { + @Override + public void onOpen(WebSocket webSocket, Response response) { + startFuture.complete(null); + logger.log(LogLevel.Information, "WebSocket transport connected to: %s", websocketClient.request().url()); + } + + @Override + public void onMessage(WebSocket webSocket, String message) { + try { + onReceive(message); + } catch (Exception e) { + e.printStackTrace(); } - }; + } + + @Override + public void onClosing(WebSocket webSocket, int code, String reason) { + logger.log(LogLevel.Information, "WebSocket connection stopping with " + + "code %d and reason %s", code, reason); + // If the start future hasn't completed yet, then we need to complete it exceptionally. + checkStartFailure(); + } + + @Override + public void onFailure(WebSocket webSocket, Throwable t, Response response) { + logger.log(LogLevel.Error, "Error : %d", t.getMessage()); + // If the start future hasn't completed yet, then we need to complete it exceptionally. + checkStartFailure(); + } + } + + private void checkStartFailure() { + // If the start future hasn't completed yet, then we need to complete it exceptionally. + if (!startFuture.isDone()) { + String errorMessage = "There was an error starting the Websockets transport."; + logger.log(LogLevel.Debug, errorMessage); + startFuture.completeExceptionally(new RuntimeException(errorMessage)); + } } } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java index ebed559872..4a40f3bdff 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java @@ -13,7 +13,7 @@ import org.junit.jupiter.api.Test; class WebSocketTransportTest { @Test public void WebsocketThrowsIfItCantConnect() throws Exception { - Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger(), new HashMap<>()); + Transport transport = new WebSocketTransport("http://www.notarealurl12345.fake", new NullLogger(), new HashMap<>()); Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS)); assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage()); }