diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java index 97dcc5d412..9d85555040 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java @@ -186,6 +186,7 @@ public class HubConnection { return CompletableFuture.completedFuture(null); } + handshakeReceived = false; CompletableFuture tokenFuture = accessTokenProvider.get() .thenAccept((token) -> { if (token != null) { @@ -203,13 +204,13 @@ public class HubConnection { return negotiate.thenCompose((url) -> { logger.log(LogLevel.Debug, "Starting HubConnection."); if (transport == null) { - transport = new WebSocketTransport(url, headers, httpClient, logger); + transport = new WebSocketTransport(headers, httpClient, logger); } transport.setOnReceive(this.callback); try { - return transport.start().thenCompose((future) -> { + return transport.start(url).thenCompose((future) -> { String handshake = HandshakeProtocol.createHandshakeRequestMessage( new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); return transport.send(handshake).thenRun(() -> { @@ -289,8 +290,6 @@ public class HubConnection { HubException hubException = null; hubConnectionStateLock.lock(); try { - hubConnectionState = HubConnectionState.DISCONNECTED; - if (errorMessage != null) { hubException = new HubException(errorMessage); } else if (t != null) { @@ -299,6 +298,7 @@ public class HubConnection { connectionState.cancelOutstandingInvocations(hubException); connectionState = null; logger.log(LogLevel.Information, "HubConnection stopped."); + hubConnectionState = HubConnectionState.DISCONNECTED; } finally { hubConnectionStateLock.unlock(); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java index 37a09acc92..9ad9bdf390 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java @@ -37,6 +37,10 @@ class JsonHubProtocol implements HubProtocol { @Override public HubMessage[] parseMessages(String payload, InvocationBinder binder) throws Exception { + if (payload != null && !payload.substring(payload.length() - 1).equals(RECORD_SEPARATOR)) { + throw new RuntimeException("Message is incomplete."); + } + String[] messages = payload.split(RECORD_SEPARATOR); List hubMessages = new ArrayList<>(); for (String str : messages) { @@ -48,7 +52,6 @@ class JsonHubProtocol implements HubProtocol { JsonArray argumentsToken = null; Object result = null; JsonElement resultToken = null; - JsonReader reader = new JsonReader(new StringReader(str)); reader.beginObject(); diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java index 29efaf4050..24f4d2a31b 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java @@ -6,7 +6,7 @@ package com.microsoft.aspnet.signalr; import java.util.concurrent.CompletableFuture; interface Transport { - CompletableFuture start() throws Exception; + CompletableFuture start(String url); CompletableFuture send(String message); void setOnReceive(OnReceiveCallBack callback); void onReceive(String message) throws Exception; diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java index dcd8801537..9aec3f15dd 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java @@ -19,8 +19,7 @@ class WebSocketTransport implements Transport { private static final String WS = "ws"; private static final String WSS = "wss"; - public WebSocketTransport(String url, Map headers, HttpClient client, Logger logger) { - this.url = formatUrl(url); + public WebSocketTransport(Map headers, HttpClient client, Logger logger) { this.logger = logger; this.client = client; this.headers = headers; @@ -41,7 +40,8 @@ class WebSocketTransport implements Transport { } @Override - public CompletableFuture start() { + public CompletableFuture start(String url) { + this.url = formatUrl(url); logger.log(LogLevel.Debug, "Starting Websocket connection."); this.webSocketClient = client.createWebSocket(this.url, this.headers); this.webSocketClient.setOnReceive((message) -> onReceive(message)); diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java index 56cfa0e978..8988279ae2 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java @@ -1019,6 +1019,7 @@ class HubConnectionTest { hubConnection.start().get(1000, TimeUnit.MILLISECONDS); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); + assertEquals("http://testexample.com/?id=bVOiRPG8-6YiJ6d7ZcTOVQ", transport.getUrl()); hubConnection.stop(); assertEquals("Bearer newToken", token.get()); } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java index b637058f15..18bfc9393b 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java @@ -246,6 +246,16 @@ class JsonHubProtocolTest { assertEquals("Invocation provides 1 argument(s) but target expects 2.", exception.getMessage()); } + @Test + public void errorWhileParsingIncompleteMessage() throws Exception { + String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":"; + TestBinder binder = new TestBinder(new InvocationMessage(null, "test", new Object[] { 42, 24 })); + + RuntimeException exception = assertThrows(RuntimeException.class, + () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder)); + assertEquals("Message is incomplete.", exception.getMessage()); + } + private class TestBinder implements InvocationBinder { private Class[] paramTypes = null; private Class returnType = null; diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java index 56099005a6..1c7eeb053c 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/MockTransport.java @@ -9,9 +9,11 @@ import java.util.concurrent.CompletableFuture; class MockTransport implements Transport { private OnReceiveCallBack onReceiveCallBack; private ArrayList sentMessages = new ArrayList<>(); + private String url; @Override - public CompletableFuture start() { + public CompletableFuture start(String url) { + this.url = url; return CompletableFuture.completedFuture(null); } @@ -43,4 +45,8 @@ class MockTransport implements Transport { public String[] getSentMessages() { return sentMessages.toArray(new String[sentMessages.size()]); } + + public String getUrl() { + return this.url; + } } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java index 9be145f536..76e89cfd24 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java @@ -14,8 +14,8 @@ import org.junit.jupiter.api.Test; class WebSocketTransportTest { @Test public void WebsocketThrowsIfItCantConnect() throws Exception { - Transport transport = new WebSocketTransport("http://www.notarealurl12345.fake", new HashMap<>(), new DefaultHttpClient(new NullLogger()), new NullLogger()); - ExecutionException exception = assertThrows(ExecutionException.class, () -> transport.start().get(1, TimeUnit.SECONDS)); + Transport transport = new WebSocketTransport(new HashMap<>(), new DefaultHttpClient(new NullLogger()), new NullLogger()); + ExecutionException exception = assertThrows(ExecutionException.class, () -> transport.start("http://www.example.com").get(1, TimeUnit.SECONDS)); assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage()); } } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java index 9e215af183..3da2ec49cd 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java @@ -24,8 +24,11 @@ class WebSocketTransportUrlFormatTest { @ParameterizedTest @MethodSource("protocols") - public void checkWebsocketUrlProtocol(String url, String expectedUrl) throws URISyntaxException { - WebSocketTransport webSocketTransport = new WebSocketTransport(url, new HashMap<>(), new TestHttpClient(), new NullLogger()); + public void checkWebsocketUrlProtocol(String url, String expectedUrl) { + WebSocketTransport webSocketTransport = new WebSocketTransport(new HashMap<>(), new TestHttpClient(), new NullLogger()); + try { + webSocketTransport.start(url); + } catch (Exception e) {} assertEquals(expectedUrl, webSocketTransport.getUrl()); } } \ No newline at end of file