From f27df1d61e39d5c9995d65edd2212fa36e24ea94 Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Wed, 19 Sep 2018 10:14:35 -0700 Subject: [PATCH] Java Async APIs (#2971) --- .../aspnet/signalr/HubConnection.java | 101 +++++++++++------- .../microsoft/aspnet/signalr/Transport.java | 8 +- .../aspnet/signalr/WebSocketTransport.java | 40 ++++--- .../aspnet/signalr/HubConnectionTest.java | 12 ++- .../signalr/WebSocketTransportTest.java | 7 +- 5 files changed, 107 insertions(+), 61 deletions(-) diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java index 417cd3f743..edcbf90558 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java @@ -3,10 +3,14 @@ package com.microsoft.aspnet.signalr; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; public class HubConnection { @@ -18,6 +22,7 @@ public class HubConnection { private Boolean handshakeReceived = false; private static final String RECORD_SEPARATOR = "\u001e"; private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED; + private Lock hubConnectionStateLock = new ReentrantLock(); private Logger logger; private List> onClosedCallbackList; private boolean skipNegotiate = false; @@ -99,6 +104,29 @@ public class HubConnection { } } + private NegotiateResponse handleNegotiate() throws IOException { + accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); + negotiateResponse = Negotiate.processNegotiate(url, accessToken); + + if (negotiateResponse.getConnectionId() != null) { + if (url.contains("?")) { + url = url + "&id=" + negotiateResponse.getConnectionId(); + } else { + url = url + "?id=" + negotiateResponse.getConnectionId(); + } + } + + if (negotiateResponse.getAccessToken() != null) { + this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken()); + } + + if (negotiateResponse.getRedirectUrl() != null) { + this.url = this.negotiateResponse.getRedirectUrl(); + } + + return negotiateResponse; + } + public HubConnection(String url, Transport transport, Logger logger) { this(url, transport, logger, false); } @@ -150,32 +178,15 @@ public class HubConnection { * * @throws Exception An error occurred while connecting. */ - public void start() throws Exception { + public CompletableFuture start() throws Exception { if (hubConnectionState != HubConnectionState.DISCONNECTED) { - return; + return CompletableFuture.completedFuture(null); } if (!skipNegotiate) { int negotiateAttempts = 0; do { accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); - negotiateResponse = Negotiate.processNegotiate(url, accessToken); - - if (negotiateResponse.getConnectionId() != null) { - if (url.contains("?")) { - url = url + "&id=" + negotiateResponse.getConnectionId(); - } else { - url = url + "?id=" + negotiateResponse.getConnectionId(); - } - } - - if (negotiateResponse.getAccessToken() != null) { - this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken()); - } - - if (negotiateResponse.getRedirectUrl() != null) { - url = this.negotiateResponse.getRedirectUrl(); - } - + negotiateResponse = handleNegotiate(); negotiateAttempts++; } while (negotiateResponse.getRedirectUrl() != null && negotiateAttempts < MAX_NEGOTIATE_ATTEMPTS); if (!negotiateResponse.getAvailableTransports().contains("WebSockets")) { @@ -189,32 +200,46 @@ public class HubConnection { } transport.setOnReceive(this.callback); - transport.start(); - String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); - transport.send(handshake); - hubConnectionState = HubConnectionState.CONNECTED; - connectionState = new ConnectionState(this); - logger.log(LogLevel.Information, "HubConnected started."); + return transport.start().thenCompose((future) -> { + String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); + return transport.send(handshake).thenRun(() -> { + hubConnectionStateLock.lock(); + try { + hubConnectionState = HubConnectionState.CONNECTED; + connectionState = new ConnectionState(this); + logger.log(LogLevel.Information, "HubConnected started."); + } finally { + hubConnectionStateLock.unlock(); + } + }); + }); + } /** * Stops a connection to the server. */ private void stop(String errorMessage) { - if (hubConnectionState == HubConnectionState.DISCONNECTED) { - return; + hubConnectionStateLock.lock(); + try { + if (hubConnectionState == HubConnectionState.DISCONNECTED) { + return; + } + + if (errorMessage != null) { + logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage); + } else { + logger.log(LogLevel.Debug, "Stopping HubConnection."); + } + + transport.stop(); + hubConnectionState = HubConnectionState.DISCONNECTED; + connectionState = null; + logger.log(LogLevel.Information, "HubConnection stopped."); + } finally { + hubConnectionStateLock.unlock(); } - if (errorMessage != null) { - logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage); - } else { - logger.log(LogLevel.Debug, "Stopping HubConnection."); - } - - transport.stop(); - hubConnectionState = HubConnectionState.DISCONNECTED; - connectionState = null; - logger.log(LogLevel.Information, "HubConnection stopped."); if (onClosedCallbackList != null) { HubException hubException = new HubException(errorMessage); for (Consumer callback : onClosedCallbackList) { diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java index 79e093682c..eafb865465 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java @@ -3,10 +3,12 @@ package com.microsoft.aspnet.signalr; +import java.util.concurrent.CompletableFuture; + interface Transport { - void start() throws Exception; - void send(String message) throws Exception; + CompletableFuture start() throws Exception; + CompletableFuture send(String message); void setOnReceive(OnReceiveCallBack callback); void onReceive(String message) throws Exception; - void stop(); + CompletableFuture stop(); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java index e20ef50ac8..91cd290898 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java @@ -6,6 +6,7 @@ package com.microsoft.aspnet.signalr; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.java_websocket.client.WebSocketClient; import org.java_websocket.handshake.ServerHandshake; @@ -47,21 +48,28 @@ class WebSocketTransport implements Transport { } @Override - public void start() throws Exception { - logger.log(LogLevel.Debug, "Starting Websocket connection."); - webSocketClient = createWebSocket(headers); - - if (!webSocketClient.connectBlocking()) { - String errorMessage = "There was an error starting the Websockets transport."; - logger.log(LogLevel.Debug, errorMessage); - throw new Exception(errorMessage); - } - logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); + public CompletableFuture start() { + return CompletableFuture.runAsync(() -> { + logger.log(LogLevel.Debug, "Starting Websocket connection."); + webSocketClient = createWebSocket(headers); + try { + if (!webSocketClient.connectBlocking()) { + String errorMessage = "There was an error starting the Websockets transport."; + logger.log(LogLevel.Debug, errorMessage); + throw new RuntimeException(errorMessage); + } + } catch (InterruptedException e) { + String interruptedExMessage = "Connecting the Websockets transport was interrupted."; + logger.log(LogLevel.Debug, interruptedExMessage); + throw new RuntimeException(interruptedExMessage); + } + logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); + }); } @Override - public void send(String message) { - webSocketClient.send(message); + public CompletableFuture send(String message) { + return CompletableFuture.runAsync(() -> webSocketClient.send(message)); } @Override @@ -76,9 +84,11 @@ class WebSocketTransport implements Transport { } @Override - public void stop() { - webSocketClient.closeConnection(0, "HubConnection Stopped"); - logger.log(LogLevel.Information, "WebSocket connection stopped"); + public CompletableFuture stop() { + return CompletableFuture.runAsync(() -> { + webSocketClient.closeConnection(0, "HubConnection Stopped"); + logger.log(LogLevel.Information, "WebSocket connection stopped"); + }); } private WebSocketClient createWebSocket(Map headers) { diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java index 50531c8194..1371770950 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java @@ -6,6 +6,7 @@ package com.microsoft.aspnet.signalr; import static org.junit.jupiter.api.Assertions.*; import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicReference; import org.junit.jupiter.api.Test; @@ -745,11 +746,14 @@ public class HubConnectionTest { private ArrayList sentMessages = new ArrayList<>(); @Override - public void start() {} + public CompletableFuture start() { + return CompletableFuture.completedFuture(null); + } @Override - public void send(String message) { + public CompletableFuture send(String message) { sentMessages.add(message); + return CompletableFuture.completedFuture(null); } @Override @@ -763,7 +767,9 @@ public class HubConnectionTest { } @Override - public void stop() {} + public CompletableFuture stop() { + return CompletableFuture.completedFuture(null); + } public void receiveMessage(String message) throws Exception { this.onReceive(message); diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java index dc16efa4a9..31116d4d65 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java @@ -4,13 +4,16 @@ package com.microsoft.aspnet.signalr; import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.TimeUnit; + import org.junit.jupiter.api.Test; public class WebSocketTransportTest { @Test public void WebsocketThrowsIfItCantConnect() throws Exception { Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger()); - Throwable exception = assertThrows(Exception.class, () -> transport.start()); - assertEquals("There was an error starting the Websockets transport.", exception.getMessage()); + Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS)); + assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage()); } }