diff --git a/.vscode/launch.json b/.vscode/launch.json index 06a988616d..8a35acc3ed 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -8,7 +8,7 @@ "cwd": "${workspaceFolder}/clients/java/", "console": "externalTerminal", "stopOnEntry": false, - "mainClass": "com.microsoft.aspnet.signalr.Chat", + "mainClass": "com.microsoft.aspnet.signalr.sample.Chat", "args": "" }, { diff --git a/clients/java/signalr/build.gradle b/clients/java/signalr/build.gradle index b688ee869e..244695cedc 100644 --- a/clients/java/signalr/build.gradle +++ b/clients/java/signalr/build.gradle @@ -16,7 +16,9 @@ repositories { } dependencies { - testImplementation group: 'junit', name: 'junit', version: '4.12' + testImplementation 'org.junit.jupiter:junit-jupiter-api:5.3.1' + testCompile 'org.junit.jupiter:junit-jupiter-params:5.3.1' + testRuntime 'org.junit.jupiter:junit-jupiter-engine:5.3.1' implementation "org.java-websocket:Java-WebSocket:1.3.8" implementation 'com.google.code.gson:gson:2.8.5' implementation 'com.squareup.okhttp3:okhttp:3.11.0' @@ -41,6 +43,10 @@ spotless { } } +test { + useJUnitPlatform() +} + task sourceJar(type: Jar) { classifier "sources" from sourceSets.main.allJava diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CallbackMap.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CallbackMap.java index 528aa8d327..8dbd13488c 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CallbackMap.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CallbackMap.java @@ -5,9 +5,9 @@ package com.microsoft.aspnet.signalr; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.concurrent.ConcurrentHashMap; -import java.util.Collections; class CallbackMap { private ConcurrentHashMap> handlers = new ConcurrentHashMap<>(); diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CloseMessage.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CloseMessage.java index 920eb2eabc..1931f49ad4 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CloseMessage.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CloseMessage.java @@ -4,7 +4,7 @@ package com.microsoft.aspnet.signalr; class CloseMessage extends HubMessage { - String error; + private String error; @Override public HubMessageType getMessageType() { diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CompletionMessage.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CompletionMessage.java new file mode 100644 index 0000000000..50a8e7f78a --- /dev/null +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/CompletionMessage.java @@ -0,0 +1,38 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.aspnet.signalr; + +class CompletionMessage extends HubMessage { + private int type = HubMessageType.COMPLETION.value; + private String invocationId; + private Object result; + private String error; + + public CompletionMessage(String invocationId, Object result, String error) { + if (error != null && result != null) + { + throw new IllegalArgumentException("Expected either 'error' or 'result' to be provided, but not both"); + } + this.invocationId = invocationId; + this.result = result; + this.error = error; + } + + public Object getResult() { + return result; + } + + public String getError() { + return error; + } + + public String getInvocationId() { + return invocationId; + } + + @Override + public HubMessageType getMessageType() { + return HubMessageType.values()[type - 1]; + } +} \ No newline at end of file diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java index 361920280c..06c3b34b55 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubConnection.java @@ -3,10 +3,15 @@ package com.microsoft.aspnet.signalr; +import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.function.Consumer; public class HubConnection { @@ -18,6 +23,7 @@ public class HubConnection { private Boolean handshakeReceived = false; private static final String RECORD_SEPARATOR = "\u001e"; private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED; + private Lock hubConnectionStateLock = new ReentrantLock(); private Logger logger; private List> onClosedCallbackList; private boolean skipNegotiate = false; @@ -66,13 +72,13 @@ public class HubConnection { switch (message.getMessageType()) { case INVOCATION: InvocationMessage invocationMessage = (InvocationMessage) message; - List handlers = this.handlers.get(invocationMessage.target); + List handlers = this.handlers.get(invocationMessage.getTarget()); if (handlers != null) { for (InvocationHandler handler : handlers) { - handler.getAction().invoke(invocationMessage.arguments); + handler.getAction().invoke(invocationMessage.getArguments()); } } else { - logger.log(LogLevel.Warning, "Failed to find handler for %s method.", invocationMessage.target); + logger.log(LogLevel.Warning, "Failed to find handler for %s method.", invocationMessage.getMessageType()); } break; case CLOSE: @@ -83,10 +89,18 @@ public class HubConnection { case PING: // We don't need to do anything in the case of a ping message. break; + case COMPLETION: + CompletionMessage completionMessage = (CompletionMessage)message; + InvocationRequest irq = connectionState.tryRemoveInvocation(completionMessage.getInvocationId()); + if (irq == null) { + logger.log(LogLevel.Warning, "Dropped unsolicited Completion message for invocation '%s'.", completionMessage.getInvocationId()); + continue; + } + irq.complete(completionMessage); + break; case STREAM_INVOCATION: case STREAM_ITEM: case CANCEL_INVOCATION: - case COMPLETION: logger.log(LogLevel.Error, "This client does not support %s messages.", message.getMessageType()); throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", message.getMessageType())); @@ -99,6 +113,29 @@ public class HubConnection { } } + private NegotiateResponse handleNegotiate() throws IOException { + accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); + negotiateResponse = Negotiate.processNegotiate(url, accessToken); + + if (negotiateResponse.getConnectionId() != null) { + if (url.contains("?")) { + url = url + "&id=" + negotiateResponse.getConnectionId(); + } else { + url = url + "?id=" + negotiateResponse.getConnectionId(); + } + } + + if (negotiateResponse.getAccessToken() != null) { + this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken()); + } + + if (negotiateResponse.getRedirectUrl() != null) { + this.url = this.negotiateResponse.getRedirectUrl(); + } + + return negotiateResponse; + } + public HubConnection(String url, Transport transport, Logger logger) { this(url, transport, logger, false); } @@ -150,32 +187,15 @@ public class HubConnection { * * @throws Exception An error occurred while connecting. */ - public void start() throws Exception { + public CompletableFuture start() throws Exception { if (hubConnectionState != HubConnectionState.DISCONNECTED) { - return; + return CompletableFuture.completedFuture(null); } if (!skipNegotiate) { int negotiateAttempts = 0; do { accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); - negotiateResponse = Negotiate.processNegotiate(url, accessToken); - - if (negotiateResponse.getConnectionId() != null) { - if (url.contains("?")) { - url = url + "&id=" + negotiateResponse.getConnectionId(); - } else { - url = url + "?id=" + negotiateResponse.getConnectionId(); - } - } - - if (negotiateResponse.getAccessToken() != null) { - this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken()); - } - - if (negotiateResponse.getRedirectUrl() != null) { - url = this.negotiateResponse.getRedirectUrl(); - } - + negotiateResponse = handleNegotiate(); negotiateAttempts++; } while (negotiateResponse.getRedirectUrl() != null && negotiateAttempts < MAX_NEGOTIATE_ATTEMPTS); if (!negotiateResponse.getAvailableTransports().contains("WebSockets")) { @@ -189,34 +209,53 @@ public class HubConnection { } transport.setOnReceive(this.callback); - transport.start(); - String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); - transport.send(handshake); - hubConnectionState = HubConnectionState.CONNECTED; - connectionState = new ConnectionState(this); - logger.log(LogLevel.Information, "HubConnected started."); + return transport.start().thenCompose((future) -> { + String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); + return transport.send(handshake).thenRun(() -> { + hubConnectionStateLock.lock(); + try { + hubConnectionState = HubConnectionState.CONNECTED; + connectionState = new ConnectionState(this); + logger.log(LogLevel.Information, "HubConnected started."); + } finally { + hubConnectionStateLock.unlock(); + } + }); + }); + } /** * Stops a connection to the server. */ private void stop(String errorMessage) { - if (hubConnectionState == HubConnectionState.DISCONNECTED) { - return; + HubException hubException = null; + hubConnectionStateLock.lock(); + try { + if (hubConnectionState == HubConnectionState.DISCONNECTED) { + return; + } + + if (errorMessage != null) { + logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage); + } else { + logger.log(LogLevel.Debug, "Stopping HubConnection."); + } + + transport.stop(); + hubConnectionState = HubConnectionState.DISCONNECTED; + + if (errorMessage != null) { + hubException = new HubException(errorMessage); + } + connectionState.cancelOutstandingInvocations(hubException); + connectionState = null; + logger.log(LogLevel.Information, "HubConnection stopped."); + } finally { + hubConnectionStateLock.unlock(); } - if (errorMessage != null) { - logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage); - } else { - logger.log(LogLevel.Debug, "Stopping HubConnection."); - } - - transport.stop(); - hubConnectionState = HubConnectionState.DISCONNECTED; - connectionState = null; - logger.log(LogLevel.Information, "HubConnection stopped."); if (onClosedCallbackList != null) { - HubException hubException = new HubException(errorMessage); for (Consumer callback : onClosedCallbackList) { callback.accept(hubException); } @@ -243,10 +282,67 @@ public class HubConnection { throw new HubException("The 'send' method cannot be called if the connection is not active"); } - InvocationMessage invocationMessage = new InvocationMessage(method, args); - String message = protocol.writeMessage(invocationMessage); - logger.log(LogLevel.Debug, "Sending message"); - transport.send(message); + InvocationMessage invocationMessage = new InvocationMessage(null, method, args); + sendHubMessage(invocationMessage); + } + + public CompletableFuture invoke(Class returnType, String method, Object... args) throws Exception { + String id = connectionState.getNextInvocationId(); + InvocationMessage invocationMessage = new InvocationMessage(id, method, args); + + CompletableFuture future = new CompletableFuture<>(); + InvocationRequest irq = new InvocationRequest(returnType, id); + connectionState.addInvocation(irq); + + // forward the invocation result or error to the user + // run continuations on a separate thread + CompletableFuture pendingCall = irq.getPendingCall(); + pendingCall.whenCompleteAsync((result, error) -> { + if (error == null) { + // Primitive types can't be cast with the Class cast function + if (returnType.isPrimitive()) { + future.complete((T)result); + } else { + future.complete(returnType.cast(result)); + } + } else { + future.completeExceptionally(error); + } + }); + + // Make sure the actual send is after setting up the future otherwise there is a race + // where the map doesn't have the future yet when the response is returned + sendHubMessage(invocationMessage); + + return future; + } + + private void sendHubMessage(HubMessage message) throws Exception { + String serializedMessage = protocol.writeMessage(message); + if (message.getMessageType() == HubMessageType.INVOCATION) { + logger.log(LogLevel.Debug, "Sending %d message '%s'.", message.getMessageType().value, ((InvocationMessage)message).getInvocationId()); + } else { + logger.log(LogLevel.Debug, "Sending %d message.", message.getMessageType().value); + } + transport.send(serializedMessage); + } + + /** + * Removes all handlers associated with the method with the specified method name. + * + * @param name The name of the hub method from which handlers are being removed. + */ + public void remove(String name) { + handlers.remove(name); + logger.log(LogLevel.Trace, "Removing handlers for client method %s", name); + } + + public void onClosed(Consumer callback) { + if (onClosedCallbackList == null) { + onClosedCallbackList = new ArrayList<>(); + } + + onClosedCallbackList.add(callback); } /** @@ -515,34 +611,80 @@ public class HubConnection { return new Subscription(handlers, handler, target); } - /** - * Removes all handlers associated with the method with the specified method name. - * - * @param name The name of the hub method from which handlers are being removed. - */ - public void remove(String name) { - handlers.remove(name); - logger.log(LogLevel.Trace, "Removing handlers for client method %s", name); - } - - public void onClosed(Consumer callback) { - if (onClosedCallbackList == null) { - onClosedCallbackList = new ArrayList<>(); - } - - onClosedCallbackList.add(callback); - } - private class ConnectionState implements InvocationBinder { - HubConnection connection; + private HubConnection connection; + private AtomicInteger nextId = new AtomicInteger(0); + private HashMap pendingInvocations = new HashMap<>(); + private Lock lock = new ReentrantLock(); public ConnectionState(HubConnection connection) { this.connection = connection; } + public String getNextInvocationId() { + int i = nextId.incrementAndGet(); + return Integer.toString(i); + } + + public void cancelOutstandingInvocations(Exception ex) { + lock.lock(); + try { + pendingInvocations.forEach((key, irq) -> { + if (ex == null) { + irq.cancel(); + } else { + irq.fail(ex); + } + }); + + pendingInvocations.clear(); + } finally { + lock.unlock(); + } + } + + public void addInvocation(InvocationRequest irq) { + lock.lock(); + try { + pendingInvocations.compute(irq.getInvocationId(), (key, value) -> { + if (value != null) { + // This should never happen + throw new IllegalStateException("Invocation Id is already used"); + } + + return irq; + }); + } finally { + lock.unlock(); + } + } + + public InvocationRequest getInvocation(String id) { + lock.lock(); + try { + return pendingInvocations.get(id); + } finally { + lock.unlock(); + } + } + + public InvocationRequest tryRemoveInvocation(String id) { + lock.lock(); + try { + return pendingInvocations.remove(id); + } finally { + lock.unlock(); + } + } + @Override public Class getReturnType(String invocationId) { - return null; + InvocationRequest irq = getInvocation(invocationId); + if (irq == null) { + return null; + } + + return irq.getReturnType(); } @Override diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubProtocol.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubProtocol.java index 82244730fd..a1d2ee4d92 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubProtocol.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/HubProtocol.java @@ -3,8 +3,6 @@ package com.microsoft.aspnet.signalr; -import java.io.IOException; - /** * A protocol abstraction for communicating with SignalR hubs. */ diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationMessage.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationMessage.java index cde7f9feac..d7fe2e0f04 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationMessage.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationMessage.java @@ -5,11 +5,12 @@ package com.microsoft.aspnet.signalr; class InvocationMessage extends HubMessage { int type = HubMessageType.INVOCATION.value; - String invocationId; - String target; - Object[] arguments; + protected String invocationId; + private String target; + private Object[] arguments; - public InvocationMessage(String target, Object[] args) { + public InvocationMessage(String invocationId, String target, Object[] args) { + this.invocationId = invocationId; this.target = target; this.arguments = args; } @@ -18,10 +19,6 @@ class InvocationMessage extends HubMessage { return invocationId; } - public void setInvocationId(String invocationId) { - this.invocationId = invocationId; - } - public String getTarget() { return target; } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationRequest.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationRequest.java new file mode 100644 index 0000000000..5eae9374a2 --- /dev/null +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/InvocationRequest.java @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +package com.microsoft.aspnet.signalr; + +import java.util.concurrent.CompletableFuture; + +class InvocationRequest { + private Class returnType; + private CompletableFuture pendingCall = new CompletableFuture<>(); + private String invocationId; + + InvocationRequest(Class returnType, String invocationId) { + this.returnType = returnType; + this.invocationId = invocationId; + } + + public void complete(CompletionMessage completion) { + if (completion.getResult() != null) { + pendingCall.complete(completion.getResult()); + } else { + pendingCall.completeExceptionally(new HubException(completion.getError())); + } + } + + public void fail(Exception ex) { + pendingCall.completeExceptionally(ex); + } + + public void cancel() { + pendingCall.cancel(false); + } + + public CompletableFuture getPendingCall() { + return pendingCall; + } + + public Class getReturnType() { + return returnType; + } + + public String getInvocationId() { + return invocationId; + } +} \ No newline at end of file diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java index 29ed3a5a89..4acc831226 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/JsonHubProtocol.java @@ -3,16 +3,15 @@ package com.microsoft.aspnet.signalr; -import java.io.IOException; import java.io.StringReader; import java.util.ArrayList; import java.util.List; import com.google.gson.Gson; import com.google.gson.JsonArray; +import com.google.gson.JsonElement; import com.google.gson.JsonParser; import com.google.gson.stream.JsonReader; -import com.google.gson.stream.JsonToken; class JsonHubProtocol implements HubProtocol { private final JsonParser jsonParser = new JsonParser(); @@ -31,7 +30,7 @@ class JsonHubProtocol implements HubProtocol { @Override public TransferFormat getTransferFormat() { - return TransferFormat.Text; + return TransferFormat.TEXT; } @Override @@ -45,6 +44,8 @@ class JsonHubProtocol implements HubProtocol { String error = null; ArrayList arguments = null; JsonArray argumentsToken = null; + Object result = null; + JsonElement resultToken = null; JsonReader reader = new JsonReader(new StringReader(str)); reader.beginObject(); @@ -65,7 +66,11 @@ class JsonHubProtocol implements HubProtocol { error = reader.nextString(); break; case "result": - reader.skipValue(); + if (invocationId == null) { + resultToken = jsonParser.parse(reader); + } else { + result = gson.fromJson(reader, binder.getReturnType(invocationId)); + } break; case "item": reader.skipValue(); @@ -109,18 +114,23 @@ class JsonHubProtocol implements HubProtocol { } } if (arguments == null) { - hubMessages.add(new InvocationMessage(target, new Object[0])); + hubMessages.add(new InvocationMessage(invocationId, target, new Object[0])); } else { - hubMessages.add(new InvocationMessage(target, arguments.toArray())); + hubMessages.add(new InvocationMessage(invocationId, target, arguments.toArray())); } break; + case COMPLETION: + if (resultToken != null) { + result = gson.fromJson(resultToken, binder.getReturnType(invocationId)); + } + hubMessages.add(new CompletionMessage(invocationId, result, error)); + break; case STREAM_INVOCATION: case STREAM_ITEM: - case COMPLETION: case CANCEL_INVOCATION: throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", messageType)); case PING: - hubMessages.add(new PingMessage()); + hubMessages.add(PingMessage.getInstance()); break; case CLOSE: if (error != null) { diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/PingMessage.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/PingMessage.java index 755f6fba88..e4e00fd365 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/PingMessage.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/PingMessage.java @@ -3,12 +3,18 @@ package com.microsoft.aspnet.signalr; -class PingMessage extends HubMessage { +class PingMessage extends HubMessage +{ + private static PingMessage instance = new PingMessage(); - int type = HubMessageType.PING.value; + private PingMessage() + { + } + + public static PingMessage getInstance() {return instance;} @Override public HubMessageType getMessageType() { return HubMessageType.PING; } -} +} \ No newline at end of file diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/StreamInvocationMessage.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/StreamInvocationMessage.java index ab4f07983c..cf2d111a30 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/StreamInvocationMessage.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/StreamInvocationMessage.java @@ -8,8 +8,7 @@ class StreamInvocationMessage extends InvocationMessage { int type = HubMessageType.STREAM_INVOCATION.value; public StreamInvocationMessage(String invocationId, String target, Object[] arguments) { - super(target, arguments); - this.invocationId = invocationId; + super(invocationId, target, arguments); } @Override diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/TransferFormat.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/TransferFormat.java index 32b8956bfb..cbda56e06f 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/TransferFormat.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/TransferFormat.java @@ -4,6 +4,6 @@ package com.microsoft.aspnet.signalr; public enum TransferFormat { - Text, - Binary + TEXT, + BINARY } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java index 79e093682c..eafb865465 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Transport.java @@ -3,10 +3,12 @@ package com.microsoft.aspnet.signalr; +import java.util.concurrent.CompletableFuture; + interface Transport { - void start() throws Exception; - void send(String message) throws Exception; + CompletableFuture start() throws Exception; + CompletableFuture send(String message); void setOnReceive(OnReceiveCallBack callback); void onReceive(String message) throws Exception; - void stop(); + CompletableFuture stop(); } diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java index e20ef50ac8..91cd290898 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java +++ b/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/WebSocketTransport.java @@ -6,6 +6,7 @@ package com.microsoft.aspnet.signalr; import java.net.URI; import java.net.URISyntaxException; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.java_websocket.client.WebSocketClient; import org.java_websocket.handshake.ServerHandshake; @@ -47,21 +48,28 @@ class WebSocketTransport implements Transport { } @Override - public void start() throws Exception { - logger.log(LogLevel.Debug, "Starting Websocket connection."); - webSocketClient = createWebSocket(headers); - - if (!webSocketClient.connectBlocking()) { - String errorMessage = "There was an error starting the Websockets transport."; - logger.log(LogLevel.Debug, errorMessage); - throw new Exception(errorMessage); - } - logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); + public CompletableFuture start() { + return CompletableFuture.runAsync(() -> { + logger.log(LogLevel.Debug, "Starting Websocket connection."); + webSocketClient = createWebSocket(headers); + try { + if (!webSocketClient.connectBlocking()) { + String errorMessage = "There was an error starting the Websockets transport."; + logger.log(LogLevel.Debug, errorMessage); + throw new RuntimeException(errorMessage); + } + } catch (InterruptedException e) { + String interruptedExMessage = "Connecting the Websockets transport was interrupted."; + logger.log(LogLevel.Debug, interruptedExMessage); + throw new RuntimeException(interruptedExMessage); + } + logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); + }); } @Override - public void send(String message) { - webSocketClient.send(message); + public CompletableFuture send(String message) { + return CompletableFuture.runAsync(() -> webSocketClient.send(message)); } @Override @@ -76,9 +84,11 @@ class WebSocketTransport implements Transport { } @Override - public void stop() { - webSocketClient.closeConnection(0, "HubConnection Stopped"); - logger.log(LogLevel.Information, "WebSocket connection stopped"); + public CompletableFuture stop() { + return CompletableFuture.runAsync(() -> { + webSocketClient.closeConnection(0, "HubConnection Stopped"); + logger.log(LogLevel.Information, "WebSocket connection stopped"); + }); } private WebSocketClient createWebSocket(Map headers) { diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HandshakeProtocolTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HandshakeProtocolTest.java index df168fefc8..ff2ea82184 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HandshakeProtocolTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HandshakeProtocolTest.java @@ -3,9 +3,9 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class HandshakeProtocolTest { diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java index 655cb11506..47acc2178b 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubConnectionTest.java @@ -3,22 +3,20 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import java.util.ArrayList; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import org.junit.jupiter.api.Test; public class HubConnectionTest { private static final String RECORD_SEPARATOR = "\u001e"; - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - @Test public void checkHubConnectionState() throws Exception { Transport mockTransport = new MockTransport(); @@ -47,14 +45,12 @@ public class HubConnectionTest { @Test public void hubConnectionReceiveHandshakeResponseWithError() throws Exception { - exceptionRule.expect(HubException.class); - exceptionRule.expectMessage("Requested protocol 'messagepack' is not available."); - MockTransport mockTransport = new MockTransport(); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); hubConnection.start(); - mockTransport.receiveMessage("{\"error\":\"Requested protocol 'messagepack' is not available.\"}" + RECORD_SEPARATOR); + Throwable exception = assertThrows(HubException.class, () -> mockTransport.receiveMessage("{\"error\":\"Requested protocol 'messagepack' is not available.\"}" + RECORD_SEPARATOR)); + assertEquals("Error in handshake Requested protocol 'messagepack' is not available.", exception.getMessage()); } @Test @@ -67,7 +63,7 @@ public class HubConnectionTest { hubConnection.on("inc", action); hubConnection.on("inc", action); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); @@ -80,7 +76,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(2, value.get(), 0); + assertEquals(Double.valueOf(2), value.get()); } @Test @@ -92,7 +88,7 @@ public class HubConnectionTest { hubConnection.on("inc", action); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -104,10 +100,10 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); hubConnection.remove("inc"); - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); } @Test @@ -120,7 +116,7 @@ public class HubConnectionTest { hubConnection.on("inc", action); hubConnection.remove("inc"); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -132,7 +128,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that the handler was removed. - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); } @Test @@ -146,7 +142,7 @@ public class HubConnectionTest { hubConnection.on("inc", action); hubConnection.on("inc", secondAction); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -157,14 +153,14 @@ public class HubConnectionTest { mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); - assertEquals(3, value.get(), 0); + assertEquals(Double.valueOf(3), value.get()); hubConnection.remove("inc"); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirm that another invocation doesn't change anything because the handlers have been removed. - assertEquals(3, value.get(), 0); + assertEquals(Double.valueOf(3), value.get()); } @Test @@ -176,7 +172,7 @@ public class HubConnectionTest { Subscription subscription = hubConnection.on("inc", action); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -188,7 +184,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); subscription.unsubscribe(); try { @@ -197,7 +193,7 @@ public class HubConnectionTest { assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage()); } - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); } @Test @@ -209,7 +205,7 @@ public class HubConnectionTest { Subscription subscription = hubConnection.on("inc", action); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -221,7 +217,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); subscription.unsubscribe(); subscription.unsubscribe(); @@ -231,7 +227,7 @@ public class HubConnectionTest { assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage()); } - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); } @Test @@ -245,7 +241,7 @@ public class HubConnectionTest { Subscription subscription = hubConnection.on("inc", action); Subscription secondSubscription = hubConnection.on("inc", secondAction); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); String message = mockTransport.getSentMessages()[0]; @@ -256,12 +252,12 @@ public class HubConnectionTest { mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(3, value.get(), 0); + assertEquals(Double.valueOf(3), value.get()); // This removes the first handler so when "inc" is invoked secondAction should still run. subscription.unsubscribe(); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); - assertEquals(5, value.get(), 0); + assertEquals(Double.valueOf(5), value.get()); } @Test @@ -274,7 +270,7 @@ public class HubConnectionTest { Subscription sub = hubConnection.on("inc", action); sub.unsubscribe(); - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); @@ -286,7 +282,7 @@ public class HubConnectionTest { } // Confirming that the handler was removed. - assertEquals(0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); } @Test @@ -300,25 +296,129 @@ public class HubConnectionTest { hubConnection.on("add", action, Double.class); hubConnection.on("add", action, Double.class); - assertEquals(0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); hubConnection.start(); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"add\",\"arguments\":[12]}" + RECORD_SEPARATOR); - hubConnection.send("add", 12); // Confirming that our handler was called and the correct message was passed in. - assertEquals(24, value.get(), 0); + assertEquals(Double.valueOf(24), value.get()); + } + + @Test + public void invokeWaitsForCompletionMessage() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); + + hubConnection.start(); + mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); + + CompletableFuture result = hubConnection.invoke(Integer.class, "echo", "message"); + assertEquals("{\"type\":1,\"invocationId\":\"1\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(1)); + assertFalse(result.isDone()); + + mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR); + + assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS)); + } + + @Test + public void multipleInvokesWaitForOwnCompletionMessage() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); + + hubConnection.start(); + mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); + + CompletableFuture result = hubConnection.invoke(Integer.class, "echo", "message"); + CompletableFuture result2 = hubConnection.invoke(String.class, "echo", "message"); + assertEquals("{\"type\":1,\"invocationId\":\"1\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(1)); + assertEquals("{\"type\":1,\"invocationId\":\"2\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(2)); + assertFalse(result.isDone()); + assertFalse(result2.isDone()); + + mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"2\",\"result\":\"message\"}" + RECORD_SEPARATOR); + assertEquals("message", result2.get(1000L, TimeUnit.MILLISECONDS)); + assertFalse(result.isDone()); + + mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR); + assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS)); + } + + @Test + public void invokeWorksForPrimitiveTypes() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); + + hubConnection.start(); + mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); + + // int.class is a primitive type and since we use Class.cast to cast an Object to the expected return type + // which does not work for primitives we have to write special logic for that case. + CompletableFuture result = hubConnection.invoke(int.class, "echo", "message"); + assertFalse(result.isDone()); + + mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR); + + assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS)); + } + + @Test + public void completionMessageCanHaveError() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); + + hubConnection.start(); + mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); + + CompletableFuture result = hubConnection.invoke(int.class, "echo", "message"); + assertFalse(result.isDone()); + + mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"error\":\"There was an error\"}" + RECORD_SEPARATOR); + + String exceptionMessage = null; + try { + result.get(1000L, TimeUnit.MILLISECONDS); + assertFalse(true); + } catch (Exception ex) { + exceptionMessage = ex.getMessage(); + } + + assertEquals("com.microsoft.aspnet.signalr.HubException: There was an error", exceptionMessage); + } + + @Test + public void stopCancelsActiveInvokes() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); + + hubConnection.start(); + mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); + + CompletableFuture result = hubConnection.invoke(int.class, "echo", "message"); + assertFalse(result.isDone()); + + hubConnection.stop(); + + boolean hasException = false; + try { + result.get(1000L, TimeUnit.MILLISECONDS); + assertFalse(true); + } catch (CancellationException ex) { + hasException = true; + } + + assertTrue(hasException); } - // We're using AtomicReference in the send tests instead of int here because Gson has trouble deserializing to Integer @Test public void sendWithNoParamsTriggersOnHandler() throws Exception { - AtomicReference value = new AtomicReference(0.0); + AtomicReference value = new AtomicReference<>(0); MockTransport mockTransport = new MockTransport(); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); hubConnection.on("inc", () ->{ - assertEquals(0.0, value.get(), 0); + assertEquals(Integer.valueOf(0), value.get()); value.getAndUpdate((val) -> val + 1); }); @@ -327,7 +427,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(1, value.get(), 0); + assertEquals(Integer.valueOf(1), value.get()); } @Test @@ -373,7 +473,7 @@ public class HubConnectionTest { // Confirming that our handler was called and the correct message was passed in. assertEquals("Hello World", value1.get()); - assertEquals(12, value2.get(), 0); + assertEquals(Double.valueOf(12), value2.get()); } @Test @@ -473,7 +573,7 @@ public class HubConnectionTest { assertEquals("B", value2.get()); assertEquals("C", value3.get()); assertTrue(value4.get()); - assertEquals(12, value5.get(), 0); + assertEquals(Double.valueOf(12), value5.get()); } @Test @@ -513,7 +613,7 @@ public class HubConnectionTest { assertEquals("B", value2.get()); assertEquals("C", value3.get()); assertTrue(value4.get()); - assertEquals(12, value5.get(), 0); + assertEquals(Double.valueOf(12), value5.get()); assertEquals("D", value6.get()); } @@ -557,7 +657,7 @@ public class HubConnectionTest { assertEquals("B", value2.get()); assertEquals("C", value3.get()); assertTrue(value4.get()); - assertEquals(12, value5.get(), 0); + assertEquals(Double.valueOf(12), value5.get()); assertEquals("D", value6.get()); assertEquals("E", value7.get()); } @@ -604,7 +704,7 @@ public class HubConnectionTest { assertEquals("B", value2.get()); assertEquals("C", value3.get()); assertTrue(value4.get()); - assertEquals(12, value5.get(), 0); + assertEquals(Double.valueOf(12), value5.get()); assertEquals("D", value6.get()); assertEquals("E", value7.get()); assertEquals("F", value8.get()); @@ -649,7 +749,7 @@ public class HubConnectionTest { HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); hubConnection.on("inc", () ->{ - assertEquals(0.0, value.get(), 0); + assertEquals(Double.valueOf(0), value.get()); value.getAndUpdate((val) -> val + 1); }); @@ -661,7 +761,7 @@ public class HubConnectionTest { mockTransport.receiveMessage("{}" + RECORD_SEPARATOR + "{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); // Confirming that our handler was called and that the counter property was incremented. - assertEquals(1, value.get(), 0); + assertEquals(Double.valueOf(1), value.get()); } @Test @@ -740,14 +840,12 @@ public class HubConnectionTest { @Test public void cannotSendBeforeStart() throws Exception { - exceptionRule.expect(HubException.class); - exceptionRule.expectMessage("The 'send' method cannot be called if the connection is not active"); - Transport mockTransport = new MockTransport(); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); - hubConnection.send("inc"); + Throwable exception = assertThrows(HubException.class, () -> hubConnection.send("inc")); + assertEquals("The 'send' method cannot be called if the connection is not active", exception.getMessage()); } private class MockTransport implements Transport { @@ -755,11 +853,14 @@ public class HubConnectionTest { private ArrayList sentMessages = new ArrayList<>(); @Override - public void start() {} + public CompletableFuture start() { + return CompletableFuture.completedFuture(null); + } @Override - public void send(String message) { + public CompletableFuture send(String message) { sentMessages.add(message); + return CompletableFuture.completedFuture(null); } @Override @@ -773,7 +874,9 @@ public class HubConnectionTest { } @Override - public void stop() {} + public CompletableFuture stop() { + return CompletableFuture.completedFuture(null); + } public void receiveMessage(String message) throws Exception { this.onReceive(message); diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubExceptionTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubExceptionTest.java index 9914fb4cd1..0e306e5890 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubExceptionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/HubExceptionTest.java @@ -3,9 +3,9 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class HubExceptionTest { @Test diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java index 2c8f6e8d67..08fe5c1065 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/JsonHubProtocolTest.java @@ -3,18 +3,14 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import java.util.ArrayList; import java.util.Arrays; import java.util.List; -import java.util.concurrent.PriorityBlockingQueue; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import org.junit.jupiter.api.Test; -import com.google.gson.JsonArray; public class JsonHubProtocolTest { private JsonHubProtocol jsonHubProtocol = new JsonHubProtocol(); @@ -31,12 +27,12 @@ public class JsonHubProtocolTest { @Test public void checkTransferFormat() { - assertEquals(TransferFormat.Text, jsonHubProtocol.getTransferFormat()); + assertEquals(TransferFormat.TEXT, jsonHubProtocol.getTransferFormat()); } @Test public void verifyWriteMessage() { - InvocationMessage invocationMessage = new InvocationMessage("test", new Object[] {"42"}); + InvocationMessage invocationMessage = new InvocationMessage(null, "test", new Object[] {"42"}); String result = jsonHubProtocol.writeMessage(invocationMessage); String expectedResult = "{\"type\":1,\"target\":\"test\",\"arguments\":[\"42\"]}\u001E"; assertEquals(expectedResult, result); @@ -45,7 +41,7 @@ public class JsonHubProtocolTest { @Test public void parsePingMessage() throws Exception { String stringifiedMessage = "{\"type\":6}\u001E"; - TestBinder binder = new TestBinder(new PingMessage()); + TestBinder binder = new TestBinder(PingMessage.getInstance()); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); @@ -93,7 +89,7 @@ public class JsonHubProtocolTest { @Test public void parseSingleMessage() throws Exception { String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42]}\u001E"; - TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42 })); + TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42 })); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); @@ -112,53 +108,37 @@ public class JsonHubProtocolTest { assertEquals(42, messageResult); } - @Rule - public ExpectedException exceptionRule = ExpectedException.none(); - @Test public void parseSingleUnsupportedStreamItemMessage() throws Exception { - exceptionRule.expect(UnsupportedOperationException.class); - exceptionRule.expectMessage("The message type STREAM_ITEM is not supported yet."); String stringifiedMessage = "{\"type\":2,\"Id\":1,\"Item\":42}\u001E"; TestBinder binder = new TestBinder(null); - HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); + Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder)); + assertEquals("The message type STREAM_ITEM is not supported yet.", exception.getMessage()); } @Test public void parseSingleUnsupportedStreamInvocationMessage() throws Exception { - exceptionRule.expect(UnsupportedOperationException.class); - exceptionRule.expectMessage("The message type STREAM_INVOCATION is not supported yet."); String stringifiedMessage = "{\"type\":4,\"Id\":1,\"target\":\"test\",\"arguments\":[42]}\u001E"; TestBinder binder = new TestBinder(new StreamInvocationMessage("1", "test", new Object[] { 42 })); - HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); + Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder)); + assertEquals("The message type STREAM_INVOCATION is not supported yet.", exception.getMessage()); } @Test public void parseSingleUnsupportedCancelInvocationMessage() throws Exception { - exceptionRule.expect(UnsupportedOperationException.class); - exceptionRule.expectMessage("The message type CANCEL_INVOCATION is not supported yet."); String stringifiedMessage = "{\"type\":5,\"invocationId\":123}\u001E"; TestBinder binder = new TestBinder(null); - HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); - } - - @Test - public void parseSingleUnsupportedCompletionMessage() throws Exception { - exceptionRule.expect(UnsupportedOperationException.class); - exceptionRule.expectMessage("The message type COMPLETION is not supported yet."); - String stringifiedMessage = "{\"type\":3,\"invocationId\":123}\u001E"; - TestBinder binder = new TestBinder(null); - - HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); + Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder)); + assertEquals("The message type CANCEL_INVOCATION is not supported yet.", exception.getMessage()); } @Test public void parseTwoMessages() throws Exception { String twoMessages = "{\"type\":1,\"target\":\"one\",\"arguments\":[42]}\u001E{\"type\":1,\"target\":\"two\",\"arguments\":[43]}\u001E"; - TestBinder binder = new TestBinder(new InvocationMessage("one", new Object[] { 42 })); + TestBinder binder = new TestBinder(new InvocationMessage("1", "one", new Object[] { 42 })); HubMessage[] messages = jsonHubProtocol.parseMessages(twoMessages, binder); assertEquals(2, messages.length); @@ -189,7 +169,7 @@ public class JsonHubProtocolTest { @Test public void parseSingleMessageMutipleArgs() throws Exception { String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42, 24]}\u001E"; - TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42, 24 })); + TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42, 24 })); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); @@ -208,7 +188,7 @@ public class JsonHubProtocolTest { @Test public void parseMessageWithOutOfOrderProperties() throws Exception { String stringifiedMessage = "{\"arguments\":[42, 24],\"type\":1,\"target\":\"test\"}\u001E"; - TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42, 24 })); + TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42, 24 })); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); @@ -224,8 +204,24 @@ public class JsonHubProtocolTest { assertEquals(24, messageResult2); } + @Test + public void parseCompletionMessageWithOutOfOrderProperties() throws Exception { + String stringifiedMessage = "{\"type\":3,\"result\":42,\"invocationId\":\"1\"}\u001E"; + TestBinder binder = new TestBinder(new CompletionMessage("1", 42, null)); + + HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); + + // We know it's only one message + assertEquals(HubMessageType.COMPLETION, messages[0].getMessageType()); + + CompletionMessage message = (CompletionMessage) messages[0]; + assertEquals(null, message.getError()); + assertEquals(42 , message.getResult()); + } + private class TestBinder implements InvocationBinder { private Class[] paramTypes = null; + private Class returnType = null; public TestBinder(HubMessage expectedMessage) { if (expectedMessage == null) { @@ -249,6 +245,9 @@ public class JsonHubProtocolTest { break; case STREAM_ITEM: break; + case COMPLETION: + returnType = ((CompletionMessage)expectedMessage).getResult().getClass(); + break; default: break; } @@ -256,7 +255,7 @@ public class JsonHubProtocolTest { @Override public Class getReturnType(String invocationId) { - return null; + return returnType; } @Override diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/NegotiateResponseTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/NegotiateResponseTest.java index c83e19eb01..7c2be3b9fc 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/NegotiateResponseTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/NegotiateResponseTest.java @@ -3,9 +3,9 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; -import org.junit.Test; +import org.junit.jupiter.api.Test; public class NegotiateResponseTest { diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/ResolveNegotiateUrlTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/ResolveNegotiateUrlTest.java index 21daaae8bb..9f852c7725 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/ResolveNegotiateUrlTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/ResolveNegotiateUrlTest.java @@ -3,39 +3,28 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.assertEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; -import java.util.Arrays; -import java.util.Collection; +import java.util.stream.Stream; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; - -@RunWith(Parameterized.class) public class ResolveNegotiateUrlTest { - private String url; - private String resolvedUrl; - - public ResolveNegotiateUrlTest(String url, String resolvedUrl) { - this.url = url; - this.resolvedUrl = resolvedUrl; + private static Stream protocols() { + return Stream.of( + Arguments.of("http://example.com/hub/", "http://example.com/hub/negotiate"), + Arguments.of("http://example.com/hub", "http://example.com/hub/negotiate"), + Arguments.of("http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), + Arguments.of("http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"), + Arguments.of("http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data")); } - @Parameterized.Parameters - public static Collection protocols() { - return Arrays.asList(new String[][]{ - {"http://example.com/hub/", "http://example.com/hub/negotiate"}, - {"http://example.com/hub", "http://example.com/hub/negotiate"}, - {"http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"}, - {"http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"}, - {"http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data"},}); - } - - @Test - public void checkNegotiateUrl() { - String urlResult = Negotiate.resolveNegotiateUrl(this.url); - assertEquals(this.resolvedUrl, urlResult); + @ParameterizedTest + @MethodSource("protocols") + public void checkNegotiateUrl(String url, String resolvedUrl) { + String urlResult = Negotiate.resolveNegotiateUrl(url); + assertEquals(resolvedUrl, urlResult); } } \ No newline at end of file diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java index a29246d812..31116d4d65 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportTest.java @@ -3,20 +3,17 @@ package com.microsoft.aspnet.signalr; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.ExpectedException; +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; public class WebSocketTransportTest { - - @Rule - public ExpectedException expectedEx = ExpectedException.none(); - @Test public void WebsocketThrowsIfItCantConnect() throws Exception { - expectedEx.expect(Exception.class); - expectedEx.expectMessage("There was an error starting the Websockets transport"); Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger()); - transport.start(); + Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS)); + assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage()); } } diff --git a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java index 0e2a93284d..f7df0d956f 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/WebSocketTransportUrlFormatTest.java @@ -3,39 +3,28 @@ package com.microsoft.aspnet.signalr; -import static org.junit.Assert.*; +import static org.junit.jupiter.api.Assertions.*; import java.net.URISyntaxException; -import java.util.Arrays; -import java.util.Collection; +import java.util.stream.Stream; -import org.junit.Test; -import org.junit.runner.RunWith; -import org.junit.runners.Parameterized; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; - -@RunWith(Parameterized.class) public class WebSocketTransportUrlFormatTest { - private String url; - private String expectedUrl; - - public WebSocketTransportUrlFormatTest(String url, String expectedProtocol) { - this.url = url; - this.expectedUrl = expectedProtocol; + private static Stream protocols() { + return Stream.of( + Arguments.of("http://example.com", "ws://example.com"), + Arguments.of("https://example.com", "wss://example.com"), + Arguments.of("ws://example.com", "ws://example.com"), + Arguments.of("wss://example.com", "wss://example.com")); } - @Parameterized.Parameters - public static Collection protocols() { - return Arrays.asList(new String[][]{ - {"http://example.com", "ws://example.com"}, - {"https://example.com", "wss://example.com"}, - {"ws://example.com", "ws://example.com"}, - {"wss://example.com", "wss://example.com"}}); - } - - @Test - public void checkWebsocketUrlProtocol() throws URISyntaxException { - WebSocketTransport webSocketTransport = new WebSocketTransport(this.url, new NullLogger()); - assertEquals(this.expectedUrl, webSocketTransport.getUrl().toString()); + @ParameterizedTest + @MethodSource("protocols") + public void checkWebsocketUrlProtocol(String url, String expectedUrl) throws URISyntaxException { + WebSocketTransport webSocketTransport = new WebSocketTransport(url, new NullLogger()); + assertEquals(expectedUrl, webSocketTransport.getUrl().toString()); } } \ No newline at end of file diff --git a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Chat.java b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/sample/Chat.java similarity index 87% rename from clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Chat.java rename to clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/sample/Chat.java index c306c26eef..cec5fe913c 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/aspnet/signalr/Chat.java +++ b/clients/java/signalr/src/test/java/com/microsoft/aspnet/signalr/sample/Chat.java @@ -1,10 +1,14 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -package com.microsoft.aspnet.signalr; +package com.microsoft.aspnet.signalr.sample; import java.util.Scanner; +import com.microsoft.aspnet.signalr.HubConnection; +import com.microsoft.aspnet.signalr.HubConnectionBuilder; +import com.microsoft.aspnet.signalr.LogLevel; + public class Chat { public static void main(String[] args) throws Exception { System.out.println("Enter the URL of the SignalR Chat you want to join"); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index de775e8f8a..6d78877e09 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -234,11 +234,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal InitializeHub(hub, connection); Task invocation = null; + CancellationTokenSource cts = null; + var arguments = hubMethodInvocationMessage.Arguments; + if (descriptor.HasSyntheticArguments) + { + // In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments) + arguments = new object[descriptor.OriginalParameterTypes.Count]; + + var hubInvocationArgumentPointer = 0; + for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) + { + if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer && + hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer]) + { + // The types match so it isn't a synthetic argument, just copy it into the arguments array + arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer]; + hubInvocationArgumentPointer++; + } + else + { + // This is the only synthetic argument type we currently support + if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) + { + cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + arguments[parameterPointer] = cts.Token; + } + else + { + // This should never happen + Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'."); + } + } + } + } + if (isStreamResponse) { - var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); + var result = await ExecuteHubMethod(methodExecutor, hub, arguments); - if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts)) + if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts)) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, @@ -247,13 +281,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal } Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts); + _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts); } else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) { // Send Async, no response expected - invocation = ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); + invocation = ExecuteHubMethod(methodExecutor, hub, arguments); } else @@ -261,7 +295,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal // Invoke Async, one reponse expected async Task ExecuteInvocation() { - var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); + var result = await ExecuteHubMethod(methodExecutor, hub, arguments); Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } @@ -443,29 +477,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } - private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, out CancellationTokenSource streamCts) + private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, ref CancellationTokenSource streamCts) { if (result != null) { if (hubMethodDescriptor.IsChannel) { - streamCts = CreateCancellation(); + if (streamCts == null) + { + streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); + } + connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); return true; } } - streamCts = null; enumerator = null; return false; - - CancellationTokenSource CreateCancellation() - { - var userCts = new CancellationTokenSource(); - connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts); - - return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token); - } } private void DiscoverHubMethods() diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs index b942279e46..fe22be662a 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs @@ -23,8 +23,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable policies) { MethodExecutor = methodExecutor; - ParameterTypes = methodExecutor.MethodParameters.Select(GetParameterType).ToArray(); - Policies = policies.ToArray(); NonAsyncReturnType = (MethodExecutor.IsMethodAsync) ? MethodExecutor.AsyncResultType @@ -35,6 +33,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal IsChannel = true; StreamReturnType = channelItemType; } + + // Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers + ParameterTypes = methodExecutor.MethodParameters.Where(p => + { + // Only streams can take CancellationTokens currently + if (IsStreamable && p.ParameterType == typeof(CancellationToken)) + { + HasSyntheticArguments = true; + return false; + } + return true; + }).Select(GetParameterType).ToArray(); + + if (HasSyntheticArguments) + { + OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray(); + } + + Policies = policies.ToArray(); } public bool HasStreamingParameters { get; private set; } @@ -45,6 +62,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal public IReadOnlyList ParameterTypes { get; } + public IReadOnlyList OriginalParameterTypes { get; } + public Type NonAsyncReturnType { get; } public bool IsChannel { get; } @@ -55,6 +74,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal public IList Policies { get; } + public bool HasSyntheticArguments { get; private set; } + private Type GetParameterType(ParameterInfo p) { var type = p.ParameterType; diff --git a/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs index 3271883a0a..efdeda8d18 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/SignalRDependencyInjectionExtensions.cs @@ -20,7 +20,7 @@ namespace Microsoft.Extensions.DependencyInjection /// An that can be used to further configure the SignalR services. public static ISignalRServerBuilder AddSignalRCore(this IServiceCollection services) { - services.AddSingleton(); + services.TryAddSingleton(); services.TryAddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>)); services.TryAddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver)); services.TryAddSingleton(typeof(IHubContext<>), typeof(HubContext<>)); diff --git a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs index dfc3c9e644..8974bd094f 100644 --- a/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR/SignalRDependencyInjectionExtensions.cs @@ -4,6 +4,7 @@ using System; using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.Extensions.DependencyInjection.Extensions; using Microsoft.Extensions.Options; namespace Microsoft.Extensions.DependencyInjection @@ -35,8 +36,8 @@ namespace Microsoft.Extensions.DependencyInjection public static ISignalRServerBuilder AddSignalR(this IServiceCollection services) { services.AddConnections(); - services.AddSingleton(); - services.AddSingleton, HubOptionsSetup>(); + services.TryAddSingleton(); + services.TryAddEnumerable(ServiceDescriptor.Singleton, HubOptionsSetup>()); return services.AddSignalRCore(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests.Utils/CancellationTokenExtensions.cs b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/CancellationTokenExtensions.cs new file mode 100644 index 0000000000..fadb9626bf --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests.Utils/CancellationTokenExtensions.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + public static class CancellationTokenExtensions + { + public static Task WaitForCancellationAsync(this CancellationToken token) + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + token.Register((t) => + { + ((TaskCompletionSource)t).SetResult(null); + }, tcs); + return tcs.Task; + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/AddSignalRTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/AddSignalRTests.cs index 6d8360d745..f5e15fb0b0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/AddSignalRTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/AddSignalRTests.cs @@ -4,8 +4,10 @@ using System.Collections.Generic; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Options; using Xunit; namespace Microsoft.AspNetCore.SignalR.Tests @@ -17,12 +19,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var serviceCollection = new ServiceCollection(); + var markerService = new SignalRCoreMarkerService(); + serviceCollection.AddSingleton(markerService); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(typeof(HubLifetimeManager<>), typeof(CustomHubLifetimeManager<>)); serviceCollection.AddSingleton(); serviceCollection.AddScoped(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); serviceCollection.AddSingleton(typeof(IHubContext<>), typeof(CustomHubContext<>)); serviceCollection.AddSingleton(typeof(IHubContext<,>), typeof(CustomHubContext<,>)); + var hubOptions = new HubOptionsSetup(new List()); + serviceCollection.AddSingleton>(hubOptions); serviceCollection.AddSignalR(); var serviceProvider = serviceCollection.BuildServiceProvider(); @@ -33,6 +39,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.IsType>(serviceProvider.GetRequiredService>()); Assert.IsType>(serviceProvider.GetRequiredService>()); Assert.IsType>(serviceProvider.GetRequiredService>()); + Assert.Equal(hubOptions, serviceProvider.GetRequiredService>()); + Assert.Equal(markerService, serviceProvider.GetRequiredService()); } [Fact] diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index c1d30c5cbc..5e08a5f501 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Text; +using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authorization; @@ -166,6 +167,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef()); } + public void InvalidArgument(CancellationToken token) + { + } + private class SelfRef { public SelfRef() @@ -620,6 +625,51 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Channel.CreateUnbounded().Reader; } + public ChannelReader CancelableStream(CancellationToken token) + { + var channel = Channel.CreateBounded(10); + + Task.Run(async () => + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + channel.Writer.TryComplete(); + _tcsService.EndMethod.SetResult(null); + }); + + return channel.Reader; + } + + public ChannelReader CancelableStream2(int ignore, int ignore2, CancellationToken token) + { + var channel = Channel.CreateBounded(10); + + Task.Run(async () => + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + channel.Writer.TryComplete(); + _tcsService.EndMethod.SetResult(null); + }); + + return channel.Reader; + } + + public ChannelReader CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2) + { + var channel = Channel.CreateBounded(10); + + Task.Run(async () => + { + _tcsService.StartedMethod.SetResult(null); + await token.WaitForCancellationAsync(); + channel.Writer.TryComplete(); + _tcsService.EndMethod.SetResult(null); + }); + + return channel.Reader; + } + public int SimpleMethod() { return 21; diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index dffc255428..b27a73414b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -2600,6 +2600,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Theory] + [InlineData(nameof(LongRunningHub.CancelableStream))] + [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)] + [InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)] + public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args) + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout(); + // Wait for the stream method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + // Cancel the stream which should trigger the CancellationToken in the hub method + await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout(); + + var result = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(result); + Assert.Null(simpleCompletion.Result); + + // CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled. + await tcsService.EndMethod.Task.OrTimeout(); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted() + { + var tcsService = new TcsService(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(tcsService); + }); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout(); + // Wait for the stream method to start + await tcsService.StartedMethod.Task.OrTimeout(); + + // Shut down the client which should trigger the CancellationToken in the hub method + client.Dispose(); + + // CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled. + await tcsService.EndMethod.Task.OrTimeout(); + + await connectionHandlerTask.OrTimeout(); + } + } + + [Fact] + public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument() + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout(); + + var completion = Assert.IsType(await client.ReadAsync().OrTimeout()); + + Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error); + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + private class CustomHubActivator : IHubActivator where THub : Hub { public int ReleaseCount;