diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java index 9d85555040..6c79508b24 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java @@ -33,6 +33,7 @@ public class HubConnection { private Map headers = new HashMap<>(); private ConnectionState connectionState = null; private HttpClient httpClient; + private String stopError; private static ArrayList> emptyArray = new ArrayList<>(); private static int MAX_NEGOTIATE_ATTEMPTS = 100; @@ -194,6 +195,7 @@ public class HubConnection { } }); + stopError = null; CompletableFuture negotiate = null; if (!skipNegotiate) { negotiate = tokenFuture.thenCompose((v) -> startNegotiate(baseUrl, 0)); @@ -208,6 +210,7 @@ public class HubConnection { } transport.setOnReceive(this.callback); + transport.setOnClose((message) -> stopConnection(message)); try { return transport.start(url).thenCompose((future) -> { @@ -278,6 +281,7 @@ public class HubConnection { } if (errorMessage != null) { + stopError = errorMessage; logger.log(LogLevel.Error, "HubConnection disconnected with an error: %s.", errorMessage); } else { logger.log(LogLevel.Debug, "Stopping HubConnection."); @@ -286,30 +290,7 @@ public class HubConnection { hubConnectionStateLock.unlock(); } - return transport.stop().whenComplete((i, t) -> { - HubException hubException = null; - hubConnectionStateLock.lock(); - try { - if (errorMessage != null) { - hubException = new HubException(errorMessage); - } else if (t != null) { - hubException = new HubException(t.getMessage()); - } - connectionState.cancelOutstandingInvocations(hubException); - connectionState = null; - logger.log(LogLevel.Information, "HubConnection stopped."); - hubConnectionState = HubConnectionState.DISCONNECTED; - } finally { - hubConnectionStateLock.unlock(); - } - - // Do not run these callbacks inside the hubConnectionStateLock - if (onClosedCallbackList != null) { - for (Consumer callback : onClosedCallbackList) { - callback.accept(hubException); - } - } - }); + return transport.stop(); } /** @@ -319,6 +300,35 @@ public class HubConnection { return stop(null); } + private void stopConnection(String errorMessage) { + RuntimeException exception = null; + hubConnectionStateLock.lock(); + try { + // errorMessage gets passed in from the transport. An already existing stopError value + // should take precedence. + if (stopError != null) { + errorMessage = stopError; + } + if (errorMessage != null) { + exception = new RuntimeException(errorMessage); + logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage); + } + connectionState.cancelOutstandingInvocations(exception); + connectionState = null; + logger.log(LogLevel.Information, "HubConnection stopped."); + hubConnectionState = HubConnectionState.DISCONNECTED; + } finally { + hubConnectionStateLock.unlock(); + } + + // Do not run these callbacks inside the hubConnectionStateLock + if (onClosedCallbackList != null) { + for (Consumer callback : onClosedCallbackList) { + callback.accept(exception); + } + } + } + /** * Invokes a hub method on the server using the specified method name. * Does not wait for a response from the receiver. diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/OkHttpWebSocketWrapper.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/OkHttpWebSocketWrapper.java index 1fb21fcce1..00b3686bff 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/OkHttpWebSocketWrapper.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/OkHttpWebSocketWrapper.java @@ -97,6 +97,7 @@ class OkHttpWebSocketWrapper extends WebSocketWrapper { public void onFailure(WebSocket webSocket, Throwable t, Response response) { logger.log(LogLevel.Error, "Websocket closed from an error: %s.", t.getMessage()); closeFuture.completeExceptionally(new RuntimeException(t)); + onClose.accept(null, t.getMessage()); checkStartFailure(); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java index 24f4d2a31b..3d870a343e 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java @@ -4,11 +4,13 @@ package com.microsoft.aspnet.signalr; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; interface Transport { CompletableFuture start(String url); CompletableFuture send(String message); void setOnReceive(OnReceiveCallBack callback); void onReceive(String message) throws Exception; + void setOnClose(Consumer onCloseCallback); CompletableFuture stop(); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java index 9aec3f15dd..f338f208c9 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java @@ -5,10 +5,12 @@ package com.microsoft.aspnet.signalr; import java.util.Map; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; class WebSocketTransport implements Transport { private WebSocketWrapper webSocketClient; private OnReceiveCallBack onReceiveCallBack; + private Consumer onClose; private String url; private Logger logger; private HttpClient client; @@ -45,7 +47,12 @@ class WebSocketTransport implements Transport { logger.log(LogLevel.Debug, "Starting Websocket connection."); this.webSocketClient = client.createWebSocket(this.url, this.headers); this.webSocketClient.setOnReceive((message) -> onReceive(message)); - this.webSocketClient.setOnClose((code, reason) -> onClose(code, reason)); + this.webSocketClient.setOnClose((code, reason) -> { + if (onClose != null) { + onClose(code, reason); + } + }); + return webSocketClient.start().thenRun(() -> logger.log(LogLevel.Information, "WebSocket transport connected to: %s.", this.url)); } @@ -65,6 +72,11 @@ class WebSocketTransport implements Transport { this.onReceiveCallBack.invoke(message); } + @Override + public void setOnClose(Consumer onCloseCallback) { + this.onClose = onCloseCallback; + } + @Override public CompletableFuture stop() { return webSocketClient.stop().whenComplete((i, j) -> logger.log(LogLevel.Information, "WebSocket connection stopped.")); @@ -73,5 +85,11 @@ class WebSocketTransport implements Transport { void onClose(int code, String reason) { logger.log(LogLevel.Information, "WebSocket connection stopping with " + "code %d and reason '%s'.", code, reason); + if (code != 1000) { + onClose.accept(reason); + } + else { + onClose.accept(null); + } } } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java index 8988279ae2..979cf6e323 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java @@ -29,6 +29,35 @@ class HubConnectionTest { assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); } + @Test + public void transportCloseTriggersStopInHubConnection() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); + hubConnection.start().get(1000, TimeUnit.MILLISECONDS); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + mockTransport.stop(); + + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + } + + @Test + public void transportCloseWithErrorTriggersStopInHubConnection() throws Exception { + MockTransport mockTransport = new MockTransport(); + AtomicReference message = new AtomicReference<>(); + HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); + String errorMessage = "Example transport error."; + + hubConnection.onClosed((error) -> { + message.set(error.getMessage()); + }); + + hubConnection.start().get(1000, TimeUnit.MILLISECONDS); + assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + mockTransport.stopWithError(errorMessage); + assertEquals(errorMessage, message.get()); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + } + @Test public void constructHubConnectionWithHttpConnectionOptions() throws Exception { Transport mockTransport = new MockTransport(); diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java index 1c7eeb053c..6aafcbaab9 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java @@ -5,11 +5,13 @@ package com.microsoft.aspnet.signalr; import java.util.ArrayList; import java.util.concurrent.CompletableFuture; +import java.util.function.Consumer; class MockTransport implements Transport { private OnReceiveCallBack onReceiveCallBack; private ArrayList sentMessages = new ArrayList<>(); private String url; + private Consumer onClose; @Override public CompletableFuture start(String url) { @@ -33,11 +35,21 @@ class MockTransport implements Transport { this.onReceiveCallBack.invoke(message); } + @Override + public void setOnClose(Consumer onCloseCallback) { + this.onClose = onCloseCallback; + } + @Override public CompletableFuture stop() { + onClose.accept(null); return CompletableFuture.completedFuture(null); } + public void stopWithError(String errorMessage) { + onClose.accept(errorMessage); + } + public void receiveMessage(String message) throws Exception { this.onReceive(message); } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java index 3da2ec49cd..0eda0f30f1 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java @@ -5,7 +5,6 @@ package com.microsoft.aspnet.signalr; import static org.junit.jupiter.api.Assertions.*; -import java.net.URISyntaxException; import java.util.HashMap; import java.util.stream.Stream;