Merge branch 'release/2.2'

This commit is contained in:
BrennanConroy 2018-09-20 12:29:13 -07:00
commit e683b81dfe
32 changed files with 849 additions and 296 deletions

2
.vscode/launch.json vendored
View File

@ -8,7 +8,7 @@
"cwd": "${workspaceFolder}/clients/java/", "cwd": "${workspaceFolder}/clients/java/",
"console": "externalTerminal", "console": "externalTerminal",
"stopOnEntry": false, "stopOnEntry": false,
"mainClass": "com.microsoft.aspnet.signalr.Chat", "mainClass": "com.microsoft.aspnet.signalr.sample.Chat",
"args": "" "args": ""
}, },
{ {

View File

@ -16,7 +16,9 @@ repositories {
} }
dependencies { dependencies {
testImplementation group: 'junit', name: 'junit', version: '4.12' testImplementation 'org.junit.jupiter:junit-jupiter-api:5.3.1'
testCompile 'org.junit.jupiter:junit-jupiter-params:5.3.1'
testRuntime 'org.junit.jupiter:junit-jupiter-engine:5.3.1'
implementation "org.java-websocket:Java-WebSocket:1.3.8" implementation "org.java-websocket:Java-WebSocket:1.3.8"
implementation 'com.google.code.gson:gson:2.8.5' implementation 'com.google.code.gson:gson:2.8.5'
implementation 'com.squareup.okhttp3:okhttp:3.11.0' implementation 'com.squareup.okhttp3:okhttp:3.11.0'
@ -41,6 +43,10 @@ spotless {
} }
} }
test {
useJUnitPlatform()
}
task sourceJar(type: Jar) { task sourceJar(type: Jar) {
classifier "sources" classifier "sources"
from sourceSets.main.allJava from sourceSets.main.allJava

View File

@ -5,9 +5,9 @@ package com.microsoft.aspnet.signalr;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentHashMap;
import java.util.Collections;
class CallbackMap { class CallbackMap {
private ConcurrentHashMap<String, List<InvocationHandler>> handlers = new ConcurrentHashMap<>(); private ConcurrentHashMap<String, List<InvocationHandler>> handlers = new ConcurrentHashMap<>();

View File

@ -4,7 +4,7 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
class CloseMessage extends HubMessage { class CloseMessage extends HubMessage {
String error; private String error;
@Override @Override
public HubMessageType getMessageType() { public HubMessageType getMessageType() {

View File

@ -0,0 +1,38 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.aspnet.signalr;
class CompletionMessage extends HubMessage {
private int type = HubMessageType.COMPLETION.value;
private String invocationId;
private Object result;
private String error;
public CompletionMessage(String invocationId, Object result, String error) {
if (error != null && result != null)
{
throw new IllegalArgumentException("Expected either 'error' or 'result' to be provided, but not both");
}
this.invocationId = invocationId;
this.result = result;
this.error = error;
}
public Object getResult() {
return result;
}
public String getError() {
return error;
}
public String getInvocationId() {
return invocationId;
}
@Override
public HubMessageType getMessageType() {
return HubMessageType.values()[type - 1];
}
}

View File

@ -3,10 +3,15 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer; import java.util.function.Consumer;
public class HubConnection { public class HubConnection {
@ -18,6 +23,7 @@ public class HubConnection {
private Boolean handshakeReceived = false; private Boolean handshakeReceived = false;
private static final String RECORD_SEPARATOR = "\u001e"; private static final String RECORD_SEPARATOR = "\u001e";
private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED; private HubConnectionState hubConnectionState = HubConnectionState.DISCONNECTED;
private Lock hubConnectionStateLock = new ReentrantLock();
private Logger logger; private Logger logger;
private List<Consumer<Exception>> onClosedCallbackList; private List<Consumer<Exception>> onClosedCallbackList;
private boolean skipNegotiate = false; private boolean skipNegotiate = false;
@ -66,13 +72,13 @@ public class HubConnection {
switch (message.getMessageType()) { switch (message.getMessageType()) {
case INVOCATION: case INVOCATION:
InvocationMessage invocationMessage = (InvocationMessage) message; InvocationMessage invocationMessage = (InvocationMessage) message;
List<InvocationHandler> handlers = this.handlers.get(invocationMessage.target); List<InvocationHandler> handlers = this.handlers.get(invocationMessage.getTarget());
if (handlers != null) { if (handlers != null) {
for (InvocationHandler handler : handlers) { for (InvocationHandler handler : handlers) {
handler.getAction().invoke(invocationMessage.arguments); handler.getAction().invoke(invocationMessage.getArguments());
} }
} else { } else {
logger.log(LogLevel.Warning, "Failed to find handler for %s method.", invocationMessage.target); logger.log(LogLevel.Warning, "Failed to find handler for %s method.", invocationMessage.getMessageType());
} }
break; break;
case CLOSE: case CLOSE:
@ -83,10 +89,18 @@ public class HubConnection {
case PING: case PING:
// We don't need to do anything in the case of a ping message. // We don't need to do anything in the case of a ping message.
break; break;
case COMPLETION:
CompletionMessage completionMessage = (CompletionMessage)message;
InvocationRequest irq = connectionState.tryRemoveInvocation(completionMessage.getInvocationId());
if (irq == null) {
logger.log(LogLevel.Warning, "Dropped unsolicited Completion message for invocation '%s'.", completionMessage.getInvocationId());
continue;
}
irq.complete(completionMessage);
break;
case STREAM_INVOCATION: case STREAM_INVOCATION:
case STREAM_ITEM: case STREAM_ITEM:
case CANCEL_INVOCATION: case CANCEL_INVOCATION:
case COMPLETION:
logger.log(LogLevel.Error, "This client does not support %s messages.", message.getMessageType()); logger.log(LogLevel.Error, "This client does not support %s messages.", message.getMessageType());
throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", message.getMessageType())); throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", message.getMessageType()));
@ -99,6 +113,29 @@ public class HubConnection {
} }
} }
private NegotiateResponse handleNegotiate() throws IOException {
accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken();
negotiateResponse = Negotiate.processNegotiate(url, accessToken);
if (negotiateResponse.getConnectionId() != null) {
if (url.contains("?")) {
url = url + "&id=" + negotiateResponse.getConnectionId();
} else {
url = url + "?id=" + negotiateResponse.getConnectionId();
}
}
if (negotiateResponse.getAccessToken() != null) {
this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken());
}
if (negotiateResponse.getRedirectUrl() != null) {
this.url = this.negotiateResponse.getRedirectUrl();
}
return negotiateResponse;
}
public HubConnection(String url, Transport transport, Logger logger) { public HubConnection(String url, Transport transport, Logger logger) {
this(url, transport, logger, false); this(url, transport, logger, false);
} }
@ -150,32 +187,15 @@ public class HubConnection {
* *
* @throws Exception An error occurred while connecting. * @throws Exception An error occurred while connecting.
*/ */
public void start() throws Exception { public CompletableFuture start() throws Exception {
if (hubConnectionState != HubConnectionState.DISCONNECTED) { if (hubConnectionState != HubConnectionState.DISCONNECTED) {
return; return CompletableFuture.completedFuture(null);
} }
if (!skipNegotiate) { if (!skipNegotiate) {
int negotiateAttempts = 0; int negotiateAttempts = 0;
do { do {
accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken(); accessToken = (negotiateResponse == null) ? null : negotiateResponse.getAccessToken();
negotiateResponse = Negotiate.processNegotiate(url, accessToken); negotiateResponse = handleNegotiate();
if (negotiateResponse.getConnectionId() != null) {
if (url.contains("?")) {
url = url + "&id=" + negotiateResponse.getConnectionId();
} else {
url = url + "?id=" + negotiateResponse.getConnectionId();
}
}
if (negotiateResponse.getAccessToken() != null) {
this.headers.put("Authorization", "Bearer " + negotiateResponse.getAccessToken());
}
if (negotiateResponse.getRedirectUrl() != null) {
url = this.negotiateResponse.getRedirectUrl();
}
negotiateAttempts++; negotiateAttempts++;
} while (negotiateResponse.getRedirectUrl() != null && negotiateAttempts < MAX_NEGOTIATE_ATTEMPTS); } while (negotiateResponse.getRedirectUrl() != null && negotiateAttempts < MAX_NEGOTIATE_ATTEMPTS);
if (!negotiateResponse.getAvailableTransports().contains("WebSockets")) { if (!negotiateResponse.getAvailableTransports().contains("WebSockets")) {
@ -189,34 +209,53 @@ public class HubConnection {
} }
transport.setOnReceive(this.callback); transport.setOnReceive(this.callback);
transport.start(); return transport.start().thenCompose((future) -> {
String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); String handshake = HandshakeProtocol.createHandshakeRequestMessage(new HandshakeRequestMessage(protocol.getName(), protocol.getVersion()));
transport.send(handshake); return transport.send(handshake).thenRun(() -> {
hubConnectionState = HubConnectionState.CONNECTED; hubConnectionStateLock.lock();
connectionState = new ConnectionState(this); try {
logger.log(LogLevel.Information, "HubConnected started."); hubConnectionState = HubConnectionState.CONNECTED;
connectionState = new ConnectionState(this);
logger.log(LogLevel.Information, "HubConnected started.");
} finally {
hubConnectionStateLock.unlock();
}
});
});
} }
/** /**
* Stops a connection to the server. * Stops a connection to the server.
*/ */
private void stop(String errorMessage) { private void stop(String errorMessage) {
if (hubConnectionState == HubConnectionState.DISCONNECTED) { HubException hubException = null;
return; hubConnectionStateLock.lock();
try {
if (hubConnectionState == HubConnectionState.DISCONNECTED) {
return;
}
if (errorMessage != null) {
logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage);
} else {
logger.log(LogLevel.Debug, "Stopping HubConnection.");
}
transport.stop();
hubConnectionState = HubConnectionState.DISCONNECTED;
if (errorMessage != null) {
hubException = new HubException(errorMessage);
}
connectionState.cancelOutstandingInvocations(hubException);
connectionState = null;
logger.log(LogLevel.Information, "HubConnection stopped.");
} finally {
hubConnectionStateLock.unlock();
} }
if (errorMessage != null) {
logger.log(LogLevel.Error, "HubConnection disconnected with an error %s.", errorMessage);
} else {
logger.log(LogLevel.Debug, "Stopping HubConnection.");
}
transport.stop();
hubConnectionState = HubConnectionState.DISCONNECTED;
connectionState = null;
logger.log(LogLevel.Information, "HubConnection stopped.");
if (onClosedCallbackList != null) { if (onClosedCallbackList != null) {
HubException hubException = new HubException(errorMessage);
for (Consumer<Exception> callback : onClosedCallbackList) { for (Consumer<Exception> callback : onClosedCallbackList) {
callback.accept(hubException); callback.accept(hubException);
} }
@ -243,10 +282,67 @@ public class HubConnection {
throw new HubException("The 'send' method cannot be called if the connection is not active"); throw new HubException("The 'send' method cannot be called if the connection is not active");
} }
InvocationMessage invocationMessage = new InvocationMessage(method, args); InvocationMessage invocationMessage = new InvocationMessage(null, method, args);
String message = protocol.writeMessage(invocationMessage); sendHubMessage(invocationMessage);
logger.log(LogLevel.Debug, "Sending message"); }
transport.send(message);
public <T> CompletableFuture<T> invoke(Class<T> returnType, String method, Object... args) throws Exception {
String id = connectionState.getNextInvocationId();
InvocationMessage invocationMessage = new InvocationMessage(id, method, args);
CompletableFuture<T> future = new CompletableFuture<>();
InvocationRequest irq = new InvocationRequest(returnType, id);
connectionState.addInvocation(irq);
// forward the invocation result or error to the user
// run continuations on a separate thread
CompletableFuture<Object> pendingCall = irq.getPendingCall();
pendingCall.whenCompleteAsync((result, error) -> {
if (error == null) {
// Primitive types can't be cast with the Class cast function
if (returnType.isPrimitive()) {
future.complete((T)result);
} else {
future.complete(returnType.cast(result));
}
} else {
future.completeExceptionally(error);
}
});
// Make sure the actual send is after setting up the future otherwise there is a race
// where the map doesn't have the future yet when the response is returned
sendHubMessage(invocationMessage);
return future;
}
private void sendHubMessage(HubMessage message) throws Exception {
String serializedMessage = protocol.writeMessage(message);
if (message.getMessageType() == HubMessageType.INVOCATION) {
logger.log(LogLevel.Debug, "Sending %d message '%s'.", message.getMessageType().value, ((InvocationMessage)message).getInvocationId());
} else {
logger.log(LogLevel.Debug, "Sending %d message.", message.getMessageType().value);
}
transport.send(serializedMessage);
}
/**
* Removes all handlers associated with the method with the specified method name.
*
* @param name The name of the hub method from which handlers are being removed.
*/
public void remove(String name) {
handlers.remove(name);
logger.log(LogLevel.Trace, "Removing handlers for client method %s", name);
}
public void onClosed(Consumer<Exception> callback) {
if (onClosedCallbackList == null) {
onClosedCallbackList = new ArrayList<>();
}
onClosedCallbackList.add(callback);
} }
/** /**
@ -515,34 +611,80 @@ public class HubConnection {
return new Subscription(handlers, handler, target); return new Subscription(handlers, handler, target);
} }
/**
* Removes all handlers associated with the method with the specified method name.
*
* @param name The name of the hub method from which handlers are being removed.
*/
public void remove(String name) {
handlers.remove(name);
logger.log(LogLevel.Trace, "Removing handlers for client method %s", name);
}
public void onClosed(Consumer<Exception> callback) {
if (onClosedCallbackList == null) {
onClosedCallbackList = new ArrayList<>();
}
onClosedCallbackList.add(callback);
}
private class ConnectionState implements InvocationBinder { private class ConnectionState implements InvocationBinder {
HubConnection connection; private HubConnection connection;
private AtomicInteger nextId = new AtomicInteger(0);
private HashMap<String, InvocationRequest> pendingInvocations = new HashMap<>();
private Lock lock = new ReentrantLock();
public ConnectionState(HubConnection connection) { public ConnectionState(HubConnection connection) {
this.connection = connection; this.connection = connection;
} }
public String getNextInvocationId() {
int i = nextId.incrementAndGet();
return Integer.toString(i);
}
public void cancelOutstandingInvocations(Exception ex) {
lock.lock();
try {
pendingInvocations.forEach((key, irq) -> {
if (ex == null) {
irq.cancel();
} else {
irq.fail(ex);
}
});
pendingInvocations.clear();
} finally {
lock.unlock();
}
}
public void addInvocation(InvocationRequest irq) {
lock.lock();
try {
pendingInvocations.compute(irq.getInvocationId(), (key, value) -> {
if (value != null) {
// This should never happen
throw new IllegalStateException("Invocation Id is already used");
}
return irq;
});
} finally {
lock.unlock();
}
}
public InvocationRequest getInvocation(String id) {
lock.lock();
try {
return pendingInvocations.get(id);
} finally {
lock.unlock();
}
}
public InvocationRequest tryRemoveInvocation(String id) {
lock.lock();
try {
return pendingInvocations.remove(id);
} finally {
lock.unlock();
}
}
@Override @Override
public Class<?> getReturnType(String invocationId) { public Class<?> getReturnType(String invocationId) {
return null; InvocationRequest irq = getInvocation(invocationId);
if (irq == null) {
return null;
}
return irq.getReturnType();
} }
@Override @Override

View File

@ -3,8 +3,6 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import java.io.IOException;
/** /**
* A protocol abstraction for communicating with SignalR hubs. * A protocol abstraction for communicating with SignalR hubs.
*/ */

View File

@ -5,11 +5,12 @@ package com.microsoft.aspnet.signalr;
class InvocationMessage extends HubMessage { class InvocationMessage extends HubMessage {
int type = HubMessageType.INVOCATION.value; int type = HubMessageType.INVOCATION.value;
String invocationId; protected String invocationId;
String target; private String target;
Object[] arguments; private Object[] arguments;
public InvocationMessage(String target, Object[] args) { public InvocationMessage(String invocationId, String target, Object[] args) {
this.invocationId = invocationId;
this.target = target; this.target = target;
this.arguments = args; this.arguments = args;
} }
@ -18,10 +19,6 @@ class InvocationMessage extends HubMessage {
return invocationId; return invocationId;
} }
public void setInvocationId(String invocationId) {
this.invocationId = invocationId;
}
public String getTarget() { public String getTarget() {
return target; return target;
} }

View File

@ -0,0 +1,45 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.aspnet.signalr;
import java.util.concurrent.CompletableFuture;
class InvocationRequest {
private Class<?> returnType;
private CompletableFuture<Object> pendingCall = new CompletableFuture<>();
private String invocationId;
InvocationRequest(Class<?> returnType, String invocationId) {
this.returnType = returnType;
this.invocationId = invocationId;
}
public void complete(CompletionMessage completion) {
if (completion.getResult() != null) {
pendingCall.complete(completion.getResult());
} else {
pendingCall.completeExceptionally(new HubException(completion.getError()));
}
}
public void fail(Exception ex) {
pendingCall.completeExceptionally(ex);
}
public void cancel() {
pendingCall.cancel(false);
}
public CompletableFuture<Object> getPendingCall() {
return pendingCall;
}
public Class<?> getReturnType() {
return returnType;
}
public String getInvocationId() {
return invocationId;
}
}

View File

@ -3,16 +3,15 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import java.io.IOException;
import java.io.StringReader; import java.io.StringReader;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
import com.google.gson.Gson; import com.google.gson.Gson;
import com.google.gson.JsonArray; import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser; import com.google.gson.JsonParser;
import com.google.gson.stream.JsonReader; import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonToken;
class JsonHubProtocol implements HubProtocol { class JsonHubProtocol implements HubProtocol {
private final JsonParser jsonParser = new JsonParser(); private final JsonParser jsonParser = new JsonParser();
@ -31,7 +30,7 @@ class JsonHubProtocol implements HubProtocol {
@Override @Override
public TransferFormat getTransferFormat() { public TransferFormat getTransferFormat() {
return TransferFormat.Text; return TransferFormat.TEXT;
} }
@Override @Override
@ -45,6 +44,8 @@ class JsonHubProtocol implements HubProtocol {
String error = null; String error = null;
ArrayList<Object> arguments = null; ArrayList<Object> arguments = null;
JsonArray argumentsToken = null; JsonArray argumentsToken = null;
Object result = null;
JsonElement resultToken = null;
JsonReader reader = new JsonReader(new StringReader(str)); JsonReader reader = new JsonReader(new StringReader(str));
reader.beginObject(); reader.beginObject();
@ -65,7 +66,11 @@ class JsonHubProtocol implements HubProtocol {
error = reader.nextString(); error = reader.nextString();
break; break;
case "result": case "result":
reader.skipValue(); if (invocationId == null) {
resultToken = jsonParser.parse(reader);
} else {
result = gson.fromJson(reader, binder.getReturnType(invocationId));
}
break; break;
case "item": case "item":
reader.skipValue(); reader.skipValue();
@ -109,18 +114,23 @@ class JsonHubProtocol implements HubProtocol {
} }
} }
if (arguments == null) { if (arguments == null) {
hubMessages.add(new InvocationMessage(target, new Object[0])); hubMessages.add(new InvocationMessage(invocationId, target, new Object[0]));
} else { } else {
hubMessages.add(new InvocationMessage(target, arguments.toArray())); hubMessages.add(new InvocationMessage(invocationId, target, arguments.toArray()));
} }
break; break;
case COMPLETION:
if (resultToken != null) {
result = gson.fromJson(resultToken, binder.getReturnType(invocationId));
}
hubMessages.add(new CompletionMessage(invocationId, result, error));
break;
case STREAM_INVOCATION: case STREAM_INVOCATION:
case STREAM_ITEM: case STREAM_ITEM:
case COMPLETION:
case CANCEL_INVOCATION: case CANCEL_INVOCATION:
throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", messageType)); throw new UnsupportedOperationException(String.format("The message type %s is not supported yet.", messageType));
case PING: case PING:
hubMessages.add(new PingMessage()); hubMessages.add(PingMessage.getInstance());
break; break;
case CLOSE: case CLOSE:
if (error != null) { if (error != null) {

View File

@ -3,12 +3,18 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
class PingMessage extends HubMessage { class PingMessage extends HubMessage
{
private static PingMessage instance = new PingMessage();
int type = HubMessageType.PING.value; private PingMessage()
{
}
public static PingMessage getInstance() {return instance;}
@Override @Override
public HubMessageType getMessageType() { public HubMessageType getMessageType() {
return HubMessageType.PING; return HubMessageType.PING;
} }
} }

View File

@ -8,8 +8,7 @@ class StreamInvocationMessage extends InvocationMessage {
int type = HubMessageType.STREAM_INVOCATION.value; int type = HubMessageType.STREAM_INVOCATION.value;
public StreamInvocationMessage(String invocationId, String target, Object[] arguments) { public StreamInvocationMessage(String invocationId, String target, Object[] arguments) {
super(target, arguments); super(invocationId, target, arguments);
this.invocationId = invocationId;
} }
@Override @Override

View File

@ -4,6 +4,6 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
public enum TransferFormat { public enum TransferFormat {
Text, TEXT,
Binary BINARY
} }

View File

@ -3,10 +3,12 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import java.util.concurrent.CompletableFuture;
interface Transport { interface Transport {
void start() throws Exception; CompletableFuture start() throws Exception;
void send(String message) throws Exception; CompletableFuture send(String message);
void setOnReceive(OnReceiveCallBack callback); void setOnReceive(OnReceiveCallBack callback);
void onReceive(String message) throws Exception; void onReceive(String message) throws Exception;
void stop(); CompletableFuture stop();
} }

View File

@ -6,6 +6,7 @@ package com.microsoft.aspnet.signalr;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture;
import org.java_websocket.client.WebSocketClient; import org.java_websocket.client.WebSocketClient;
import org.java_websocket.handshake.ServerHandshake; import org.java_websocket.handshake.ServerHandshake;
@ -47,21 +48,28 @@ class WebSocketTransport implements Transport {
} }
@Override @Override
public void start() throws Exception { public CompletableFuture start() {
logger.log(LogLevel.Debug, "Starting Websocket connection."); return CompletableFuture.runAsync(() -> {
webSocketClient = createWebSocket(headers); logger.log(LogLevel.Debug, "Starting Websocket connection.");
webSocketClient = createWebSocket(headers);
if (!webSocketClient.connectBlocking()) { try {
String errorMessage = "There was an error starting the Websockets transport."; if (!webSocketClient.connectBlocking()) {
logger.log(LogLevel.Debug, errorMessage); String errorMessage = "There was an error starting the Websockets transport.";
throw new Exception(errorMessage); logger.log(LogLevel.Debug, errorMessage);
} throw new RuntimeException(errorMessage);
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI()); }
} catch (InterruptedException e) {
String interruptedExMessage = "Connecting the Websockets transport was interrupted.";
logger.log(LogLevel.Debug, interruptedExMessage);
throw new RuntimeException(interruptedExMessage);
}
logger.log(LogLevel.Information, "WebSocket transport connected to: %s", webSocketClient.getURI());
});
} }
@Override @Override
public void send(String message) { public CompletableFuture send(String message) {
webSocketClient.send(message); return CompletableFuture.runAsync(() -> webSocketClient.send(message));
} }
@Override @Override
@ -76,9 +84,11 @@ class WebSocketTransport implements Transport {
} }
@Override @Override
public void stop() { public CompletableFuture stop() {
webSocketClient.closeConnection(0, "HubConnection Stopped"); return CompletableFuture.runAsync(() -> {
logger.log(LogLevel.Information, "WebSocket connection stopped"); webSocketClient.closeConnection(0, "HubConnection Stopped");
logger.log(LogLevel.Information, "WebSocket connection stopped");
});
} }
private WebSocketClient createWebSocket(Map<String, String> headers) { private WebSocketClient createWebSocket(Map<String, String> headers) {

View File

@ -3,9 +3,9 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
public class HandshakeProtocolTest { public class HandshakeProtocolTest {

View File

@ -3,22 +3,20 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.CancellationException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import org.junit.Rule; import org.junit.jupiter.api.Test;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class HubConnectionTest { public class HubConnectionTest {
private static final String RECORD_SEPARATOR = "\u001e"; private static final String RECORD_SEPARATOR = "\u001e";
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
@Test @Test
public void checkHubConnectionState() throws Exception { public void checkHubConnectionState() throws Exception {
Transport mockTransport = new MockTransport(); Transport mockTransport = new MockTransport();
@ -47,14 +45,12 @@ public class HubConnectionTest {
@Test @Test
public void hubConnectionReceiveHandshakeResponseWithError() throws Exception { public void hubConnectionReceiveHandshakeResponseWithError() throws Exception {
exceptionRule.expect(HubException.class);
exceptionRule.expectMessage("Requested protocol 'messagepack' is not available.");
MockTransport mockTransport = new MockTransport(); MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start(); hubConnection.start();
mockTransport.receiveMessage("{\"error\":\"Requested protocol 'messagepack' is not available.\"}" + RECORD_SEPARATOR); Throwable exception = assertThrows(HubException.class, () -> mockTransport.receiveMessage("{\"error\":\"Requested protocol 'messagepack' is not available.\"}" + RECORD_SEPARATOR));
assertEquals("Error in handshake Requested protocol 'messagepack' is not available.", exception.getMessage());
} }
@Test @Test
@ -67,7 +63,7 @@ public class HubConnectionTest {
hubConnection.on("inc", action); hubConnection.on("inc", action);
hubConnection.on("inc", action); hubConnection.on("inc", action);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
@ -80,7 +76,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(2, value.get(), 0); assertEquals(Double.valueOf(2), value.get());
} }
@Test @Test
@ -92,7 +88,7 @@ public class HubConnectionTest {
hubConnection.on("inc", action); hubConnection.on("inc", action);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -104,10 +100,10 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
hubConnection.remove("inc"); hubConnection.remove("inc");
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
} }
@Test @Test
@ -120,7 +116,7 @@ public class HubConnectionTest {
hubConnection.on("inc", action); hubConnection.on("inc", action);
hubConnection.remove("inc"); hubConnection.remove("inc");
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -132,7 +128,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that the handler was removed. // Confirming that the handler was removed.
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
} }
@Test @Test
@ -146,7 +142,7 @@ public class HubConnectionTest {
hubConnection.on("inc", action); hubConnection.on("inc", action);
hubConnection.on("inc", secondAction); hubConnection.on("inc", secondAction);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -157,14 +153,14 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
assertEquals(3, value.get(), 0); assertEquals(Double.valueOf(3), value.get());
hubConnection.remove("inc"); hubConnection.remove("inc");
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirm that another invocation doesn't change anything because the handlers have been removed. // Confirm that another invocation doesn't change anything because the handlers have been removed.
assertEquals(3, value.get(), 0); assertEquals(Double.valueOf(3), value.get());
} }
@Test @Test
@ -176,7 +172,7 @@ public class HubConnectionTest {
Subscription subscription = hubConnection.on("inc", action); Subscription subscription = hubConnection.on("inc", action);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -188,7 +184,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
subscription.unsubscribe(); subscription.unsubscribe();
try { try {
@ -197,7 +193,7 @@ public class HubConnectionTest {
assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage()); assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage());
} }
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
} }
@Test @Test
@ -209,7 +205,7 @@ public class HubConnectionTest {
Subscription subscription = hubConnection.on("inc", action); Subscription subscription = hubConnection.on("inc", action);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -221,7 +217,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
subscription.unsubscribe(); subscription.unsubscribe();
subscription.unsubscribe(); subscription.unsubscribe();
@ -231,7 +227,7 @@ public class HubConnectionTest {
assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage()); assertEquals("There are no callbacks registered for the method 'inc'.", ex.getMessage());
} }
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
} }
@Test @Test
@ -245,7 +241,7 @@ public class HubConnectionTest {
Subscription subscription = hubConnection.on("inc", action); Subscription subscription = hubConnection.on("inc", action);
Subscription secondSubscription = hubConnection.on("inc", secondAction); Subscription secondSubscription = hubConnection.on("inc", secondAction);
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
String message = mockTransport.getSentMessages()[0]; String message = mockTransport.getSentMessages()[0];
@ -256,12 +252,12 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(3, value.get(), 0); assertEquals(Double.valueOf(3), value.get());
// This removes the first handler so when "inc" is invoked secondAction should still run. // This removes the first handler so when "inc" is invoked secondAction should still run.
subscription.unsubscribe(); subscription.unsubscribe();
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
assertEquals(5, value.get(), 0); assertEquals(Double.valueOf(5), value.get());
} }
@Test @Test
@ -274,7 +270,7 @@ public class HubConnectionTest {
Subscription sub = hubConnection.on("inc", action); Subscription sub = hubConnection.on("inc", action);
sub.unsubscribe(); sub.unsubscribe();
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
@ -286,7 +282,7 @@ public class HubConnectionTest {
} }
// Confirming that the handler was removed. // Confirming that the handler was removed.
assertEquals(0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
} }
@Test @Test
@ -300,25 +296,129 @@ public class HubConnectionTest {
hubConnection.on("add", action, Double.class); hubConnection.on("add", action, Double.class);
hubConnection.on("add", action, Double.class); hubConnection.on("add", action, Double.class);
assertEquals(0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
hubConnection.start(); hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
mockTransport.receiveMessage("{\"type\":1,\"target\":\"add\",\"arguments\":[12]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"add\",\"arguments\":[12]}" + RECORD_SEPARATOR);
hubConnection.send("add", 12);
// Confirming that our handler was called and the correct message was passed in. // Confirming that our handler was called and the correct message was passed in.
assertEquals(24, value.get(), 0); assertEquals(Double.valueOf(24), value.get());
}
@Test
public void invokeWaitsForCompletionMessage() throws Exception {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
CompletableFuture<Integer> result = hubConnection.invoke(Integer.class, "echo", "message");
assertEquals("{\"type\":1,\"invocationId\":\"1\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(1));
assertFalse(result.isDone());
mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR);
assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS));
}
@Test
public void multipleInvokesWaitForOwnCompletionMessage() throws Exception {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
CompletableFuture<Integer> result = hubConnection.invoke(Integer.class, "echo", "message");
CompletableFuture<String> result2 = hubConnection.invoke(String.class, "echo", "message");
assertEquals("{\"type\":1,\"invocationId\":\"1\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(1));
assertEquals("{\"type\":1,\"invocationId\":\"2\",\"target\":\"echo\",\"arguments\":[\"message\"]}" + RECORD_SEPARATOR, mockTransport.sentMessages.get(2));
assertFalse(result.isDone());
assertFalse(result2.isDone());
mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"2\",\"result\":\"message\"}" + RECORD_SEPARATOR);
assertEquals("message", result2.get(1000L, TimeUnit.MILLISECONDS));
assertFalse(result.isDone());
mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR);
assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS));
}
@Test
public void invokeWorksForPrimitiveTypes() throws Exception {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
// int.class is a primitive type and since we use Class.cast to cast an Object to the expected return type
// which does not work for primitives we have to write special logic for that case.
CompletableFuture<Integer> result = hubConnection.invoke(int.class, "echo", "message");
assertFalse(result.isDone());
mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"result\":42}" + RECORD_SEPARATOR);
assertEquals(Integer.valueOf(42), result.get(1000L, TimeUnit.MILLISECONDS));
}
@Test
public void completionMessageCanHaveError() throws Exception {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
CompletableFuture<Integer> result = hubConnection.invoke(int.class, "echo", "message");
assertFalse(result.isDone());
mockTransport.receiveMessage("{\"type\":3,\"invocationId\":\"1\",\"error\":\"There was an error\"}" + RECORD_SEPARATOR);
String exceptionMessage = null;
try {
result.get(1000L, TimeUnit.MILLISECONDS);
assertFalse(true);
} catch (Exception ex) {
exceptionMessage = ex.getMessage();
}
assertEquals("com.microsoft.aspnet.signalr.HubException: There was an error", exceptionMessage);
}
@Test
public void stopCancelsActiveInvokes() throws Exception {
MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.start();
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR);
CompletableFuture<Integer> result = hubConnection.invoke(int.class, "echo", "message");
assertFalse(result.isDone());
hubConnection.stop();
boolean hasException = false;
try {
result.get(1000L, TimeUnit.MILLISECONDS);
assertFalse(true);
} catch (CancellationException ex) {
hasException = true;
}
assertTrue(hasException);
} }
// We're using AtomicReference<Double> in the send tests instead of int here because Gson has trouble deserializing to Integer
@Test @Test
public void sendWithNoParamsTriggersOnHandler() throws Exception { public void sendWithNoParamsTriggersOnHandler() throws Exception {
AtomicReference<Double> value = new AtomicReference<Double>(0.0); AtomicReference<Integer> value = new AtomicReference<>(0);
MockTransport mockTransport = new MockTransport(); MockTransport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.on("inc", () ->{ hubConnection.on("inc", () ->{
assertEquals(0.0, value.get(), 0); assertEquals(Integer.valueOf(0), value.get());
value.getAndUpdate((val) -> val + 1); value.getAndUpdate((val) -> val + 1);
}); });
@ -327,7 +427,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(1, value.get(), 0); assertEquals(Integer.valueOf(1), value.get());
} }
@Test @Test
@ -373,7 +473,7 @@ public class HubConnectionTest {
// Confirming that our handler was called and the correct message was passed in. // Confirming that our handler was called and the correct message was passed in.
assertEquals("Hello World", value1.get()); assertEquals("Hello World", value1.get());
assertEquals(12, value2.get(), 0); assertEquals(Double.valueOf(12), value2.get());
} }
@Test @Test
@ -473,7 +573,7 @@ public class HubConnectionTest {
assertEquals("B", value2.get()); assertEquals("B", value2.get());
assertEquals("C", value3.get()); assertEquals("C", value3.get());
assertTrue(value4.get()); assertTrue(value4.get());
assertEquals(12, value5.get(), 0); assertEquals(Double.valueOf(12), value5.get());
} }
@Test @Test
@ -513,7 +613,7 @@ public class HubConnectionTest {
assertEquals("B", value2.get()); assertEquals("B", value2.get());
assertEquals("C", value3.get()); assertEquals("C", value3.get());
assertTrue(value4.get()); assertTrue(value4.get());
assertEquals(12, value5.get(), 0); assertEquals(Double.valueOf(12), value5.get());
assertEquals("D", value6.get()); assertEquals("D", value6.get());
} }
@ -557,7 +657,7 @@ public class HubConnectionTest {
assertEquals("B", value2.get()); assertEquals("B", value2.get());
assertEquals("C", value3.get()); assertEquals("C", value3.get());
assertTrue(value4.get()); assertTrue(value4.get());
assertEquals(12, value5.get(), 0); assertEquals(Double.valueOf(12), value5.get());
assertEquals("D", value6.get()); assertEquals("D", value6.get());
assertEquals("E", value7.get()); assertEquals("E", value7.get());
} }
@ -604,7 +704,7 @@ public class HubConnectionTest {
assertEquals("B", value2.get()); assertEquals("B", value2.get());
assertEquals("C", value3.get()); assertEquals("C", value3.get());
assertTrue(value4.get()); assertTrue(value4.get());
assertEquals(12, value5.get(), 0); assertEquals(Double.valueOf(12), value5.get());
assertEquals("D", value6.get()); assertEquals("D", value6.get());
assertEquals("E", value7.get()); assertEquals("E", value7.get());
assertEquals("F", value8.get()); assertEquals("F", value8.get());
@ -649,7 +749,7 @@ public class HubConnectionTest {
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport, true);
hubConnection.on("inc", () ->{ hubConnection.on("inc", () ->{
assertEquals(0.0, value.get(), 0); assertEquals(Double.valueOf(0), value.get());
value.getAndUpdate((val) -> val + 1); value.getAndUpdate((val) -> val + 1);
}); });
@ -661,7 +761,7 @@ public class HubConnectionTest {
mockTransport.receiveMessage("{}" + RECORD_SEPARATOR + "{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR); mockTransport.receiveMessage("{}" + RECORD_SEPARATOR + "{\"type\":1,\"target\":\"inc\",\"arguments\":[]}" + RECORD_SEPARATOR);
// Confirming that our handler was called and that the counter property was incremented. // Confirming that our handler was called and that the counter property was incremented.
assertEquals(1, value.get(), 0); assertEquals(Double.valueOf(1), value.get());
} }
@Test @Test
@ -740,14 +840,12 @@ public class HubConnectionTest {
@Test @Test
public void cannotSendBeforeStart() throws Exception { public void cannotSendBeforeStart() throws Exception {
exceptionRule.expect(HubException.class);
exceptionRule.expectMessage("The 'send' method cannot be called if the connection is not active");
Transport mockTransport = new MockTransport(); Transport mockTransport = new MockTransport();
HubConnection hubConnection = new HubConnection("http://example.com", mockTransport); HubConnection hubConnection = new HubConnection("http://example.com", mockTransport);
assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState());
hubConnection.send("inc"); Throwable exception = assertThrows(HubException.class, () -> hubConnection.send("inc"));
assertEquals("The 'send' method cannot be called if the connection is not active", exception.getMessage());
} }
private class MockTransport implements Transport { private class MockTransport implements Transport {
@ -755,11 +853,14 @@ public class HubConnectionTest {
private ArrayList<String> sentMessages = new ArrayList<>(); private ArrayList<String> sentMessages = new ArrayList<>();
@Override @Override
public void start() {} public CompletableFuture start() {
return CompletableFuture.completedFuture(null);
}
@Override @Override
public void send(String message) { public CompletableFuture send(String message) {
sentMessages.add(message); sentMessages.add(message);
return CompletableFuture.completedFuture(null);
} }
@Override @Override
@ -773,7 +874,9 @@ public class HubConnectionTest {
} }
@Override @Override
public void stop() {} public CompletableFuture stop() {
return CompletableFuture.completedFuture(null);
}
public void receiveMessage(String message) throws Exception { public void receiveMessage(String message) throws Exception {
this.onReceive(message); this.onReceive(message);

View File

@ -3,9 +3,9 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import org.junit.Test; import org.junit.jupiter.api.Test;
public class HubExceptionTest { public class HubExceptionTest {
@Test @Test

View File

@ -3,18 +3,14 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
import java.util.concurrent.PriorityBlockingQueue;
import org.junit.Rule; import org.junit.jupiter.api.Test;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import com.google.gson.JsonArray;
public class JsonHubProtocolTest { public class JsonHubProtocolTest {
private JsonHubProtocol jsonHubProtocol = new JsonHubProtocol(); private JsonHubProtocol jsonHubProtocol = new JsonHubProtocol();
@ -31,12 +27,12 @@ public class JsonHubProtocolTest {
@Test @Test
public void checkTransferFormat() { public void checkTransferFormat() {
assertEquals(TransferFormat.Text, jsonHubProtocol.getTransferFormat()); assertEquals(TransferFormat.TEXT, jsonHubProtocol.getTransferFormat());
} }
@Test @Test
public void verifyWriteMessage() { public void verifyWriteMessage() {
InvocationMessage invocationMessage = new InvocationMessage("test", new Object[] {"42"}); InvocationMessage invocationMessage = new InvocationMessage(null, "test", new Object[] {"42"});
String result = jsonHubProtocol.writeMessage(invocationMessage); String result = jsonHubProtocol.writeMessage(invocationMessage);
String expectedResult = "{\"type\":1,\"target\":\"test\",\"arguments\":[\"42\"]}\u001E"; String expectedResult = "{\"type\":1,\"target\":\"test\",\"arguments\":[\"42\"]}\u001E";
assertEquals(expectedResult, result); assertEquals(expectedResult, result);
@ -45,7 +41,7 @@ public class JsonHubProtocolTest {
@Test @Test
public void parsePingMessage() throws Exception { public void parsePingMessage() throws Exception {
String stringifiedMessage = "{\"type\":6}\u001E"; String stringifiedMessage = "{\"type\":6}\u001E";
TestBinder binder = new TestBinder(new PingMessage()); TestBinder binder = new TestBinder(PingMessage.getInstance());
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
@ -93,7 +89,7 @@ public class JsonHubProtocolTest {
@Test @Test
public void parseSingleMessage() throws Exception { public void parseSingleMessage() throws Exception {
String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42]}\u001E"; String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42]}\u001E";
TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42 })); TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42 }));
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
@ -112,53 +108,37 @@ public class JsonHubProtocolTest {
assertEquals(42, messageResult); assertEquals(42, messageResult);
} }
@Rule
public ExpectedException exceptionRule = ExpectedException.none();
@Test @Test
public void parseSingleUnsupportedStreamItemMessage() throws Exception { public void parseSingleUnsupportedStreamItemMessage() throws Exception {
exceptionRule.expect(UnsupportedOperationException.class);
exceptionRule.expectMessage("The message type STREAM_ITEM is not supported yet.");
String stringifiedMessage = "{\"type\":2,\"Id\":1,\"Item\":42}\u001E"; String stringifiedMessage = "{\"type\":2,\"Id\":1,\"Item\":42}\u001E";
TestBinder binder = new TestBinder(null); TestBinder binder = new TestBinder(null);
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder));
assertEquals("The message type STREAM_ITEM is not supported yet.", exception.getMessage());
} }
@Test @Test
public void parseSingleUnsupportedStreamInvocationMessage() throws Exception { public void parseSingleUnsupportedStreamInvocationMessage() throws Exception {
exceptionRule.expect(UnsupportedOperationException.class);
exceptionRule.expectMessage("The message type STREAM_INVOCATION is not supported yet.");
String stringifiedMessage = "{\"type\":4,\"Id\":1,\"target\":\"test\",\"arguments\":[42]}\u001E"; String stringifiedMessage = "{\"type\":4,\"Id\":1,\"target\":\"test\",\"arguments\":[42]}\u001E";
TestBinder binder = new TestBinder(new StreamInvocationMessage("1", "test", new Object[] { 42 })); TestBinder binder = new TestBinder(new StreamInvocationMessage("1", "test", new Object[] { 42 }));
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder));
assertEquals("The message type STREAM_INVOCATION is not supported yet.", exception.getMessage());
} }
@Test @Test
public void parseSingleUnsupportedCancelInvocationMessage() throws Exception { public void parseSingleUnsupportedCancelInvocationMessage() throws Exception {
exceptionRule.expect(UnsupportedOperationException.class);
exceptionRule.expectMessage("The message type CANCEL_INVOCATION is not supported yet.");
String stringifiedMessage = "{\"type\":5,\"invocationId\":123}\u001E"; String stringifiedMessage = "{\"type\":5,\"invocationId\":123}\u001E";
TestBinder binder = new TestBinder(null); TestBinder binder = new TestBinder(null);
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); Throwable exception = assertThrows(UnsupportedOperationException.class, () -> jsonHubProtocol.parseMessages(stringifiedMessage, binder));
} assertEquals("The message type CANCEL_INVOCATION is not supported yet.", exception.getMessage());
@Test
public void parseSingleUnsupportedCompletionMessage() throws Exception {
exceptionRule.expect(UnsupportedOperationException.class);
exceptionRule.expectMessage("The message type COMPLETION is not supported yet.");
String stringifiedMessage = "{\"type\":3,\"invocationId\":123}\u001E";
TestBinder binder = new TestBinder(null);
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
} }
@Test @Test
public void parseTwoMessages() throws Exception { public void parseTwoMessages() throws Exception {
String twoMessages = "{\"type\":1,\"target\":\"one\",\"arguments\":[42]}\u001E{\"type\":1,\"target\":\"two\",\"arguments\":[43]}\u001E"; String twoMessages = "{\"type\":1,\"target\":\"one\",\"arguments\":[42]}\u001E{\"type\":1,\"target\":\"two\",\"arguments\":[43]}\u001E";
TestBinder binder = new TestBinder(new InvocationMessage("one", new Object[] { 42 })); TestBinder binder = new TestBinder(new InvocationMessage("1", "one", new Object[] { 42 }));
HubMessage[] messages = jsonHubProtocol.parseMessages(twoMessages, binder); HubMessage[] messages = jsonHubProtocol.parseMessages(twoMessages, binder);
assertEquals(2, messages.length); assertEquals(2, messages.length);
@ -189,7 +169,7 @@ public class JsonHubProtocolTest {
@Test @Test
public void parseSingleMessageMutipleArgs() throws Exception { public void parseSingleMessageMutipleArgs() throws Exception {
String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42, 24]}\u001E"; String stringifiedMessage = "{\"type\":1,\"target\":\"test\",\"arguments\":[42, 24]}\u001E";
TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42, 24 })); TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42, 24 }));
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
@ -208,7 +188,7 @@ public class JsonHubProtocolTest {
@Test @Test
public void parseMessageWithOutOfOrderProperties() throws Exception { public void parseMessageWithOutOfOrderProperties() throws Exception {
String stringifiedMessage = "{\"arguments\":[42, 24],\"type\":1,\"target\":\"test\"}\u001E"; String stringifiedMessage = "{\"arguments\":[42, 24],\"type\":1,\"target\":\"test\"}\u001E";
TestBinder binder = new TestBinder(new InvocationMessage("test", new Object[] { 42, 24 })); TestBinder binder = new TestBinder(new InvocationMessage("1", "test", new Object[] { 42, 24 }));
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder); HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
@ -224,8 +204,24 @@ public class JsonHubProtocolTest {
assertEquals(24, messageResult2); assertEquals(24, messageResult2);
} }
@Test
public void parseCompletionMessageWithOutOfOrderProperties() throws Exception {
String stringifiedMessage = "{\"type\":3,\"result\":42,\"invocationId\":\"1\"}\u001E";
TestBinder binder = new TestBinder(new CompletionMessage("1", 42, null));
HubMessage[] messages = jsonHubProtocol.parseMessages(stringifiedMessage, binder);
// We know it's only one message
assertEquals(HubMessageType.COMPLETION, messages[0].getMessageType());
CompletionMessage message = (CompletionMessage) messages[0];
assertEquals(null, message.getError());
assertEquals(42 , message.getResult());
}
private class TestBinder implements InvocationBinder { private class TestBinder implements InvocationBinder {
private Class<?>[] paramTypes = null; private Class<?>[] paramTypes = null;
private Class<?> returnType = null;
public TestBinder(HubMessage expectedMessage) { public TestBinder(HubMessage expectedMessage) {
if (expectedMessage == null) { if (expectedMessage == null) {
@ -249,6 +245,9 @@ public class JsonHubProtocolTest {
break; break;
case STREAM_ITEM: case STREAM_ITEM:
break; break;
case COMPLETION:
returnType = ((CompletionMessage)expectedMessage).getResult().getClass();
break;
default: default:
break; break;
} }
@ -256,7 +255,7 @@ public class JsonHubProtocolTest {
@Override @Override
public Class<?> getReturnType(String invocationId) { public Class<?> getReturnType(String invocationId) {
return null; return returnType;
} }
@Override @Override

View File

@ -3,9 +3,9 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
import org.junit.Test; import org.junit.jupiter.api.Test;
public class NegotiateResponseTest { public class NegotiateResponseTest {

View File

@ -3,39 +3,28 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import java.util.Arrays; import java.util.stream.Stream;
import java.util.Collection;
import org.junit.Test; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runner.RunWith; import org.junit.jupiter.params.provider.Arguments;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
@RunWith(Parameterized.class)
public class ResolveNegotiateUrlTest { public class ResolveNegotiateUrlTest {
private String url; private static Stream<Arguments> protocols() {
private String resolvedUrl; return Stream.of(
Arguments.of("http://example.com/hub/", "http://example.com/hub/negotiate"),
public ResolveNegotiateUrlTest(String url, String resolvedUrl) { Arguments.of("http://example.com/hub", "http://example.com/hub/negotiate"),
this.url = url; Arguments.of("http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"),
this.resolvedUrl = resolvedUrl; Arguments.of("http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"),
Arguments.of("http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data"));
} }
@Parameterized.Parameters @ParameterizedTest
public static Collection protocols() { @MethodSource("protocols")
return Arrays.asList(new String[][]{ public void checkNegotiateUrl(String url, String resolvedUrl) {
{"http://example.com/hub/", "http://example.com/hub/negotiate"}, String urlResult = Negotiate.resolveNegotiateUrl(url);
{"http://example.com/hub", "http://example.com/hub/negotiate"}, assertEquals(resolvedUrl, urlResult);
{"http://example.com/endpoint?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"},
{"http://example.com/endpoint/?q=my/Data", "http://example.com/endpoint/negotiate?q=my/Data"},
{"http://example.com/endpoint/path/more?q=my/Data", "http://example.com/endpoint/path/more/negotiate?q=my/Data"},});
}
@Test
public void checkNegotiateUrl() {
String urlResult = Negotiate.resolveNegotiateUrl(this.url);
assertEquals(this.resolvedUrl, urlResult);
} }
} }

View File

@ -3,20 +3,17 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import org.junit.Rule; import static org.junit.jupiter.api.Assertions.*;
import org.junit.Test;
import org.junit.rules.ExpectedException; import java.util.concurrent.TimeUnit;
import org.junit.jupiter.api.Test;
public class WebSocketTransportTest { public class WebSocketTransportTest {
@Rule
public ExpectedException expectedEx = ExpectedException.none();
@Test @Test
public void WebsocketThrowsIfItCantConnect() throws Exception { public void WebsocketThrowsIfItCantConnect() throws Exception {
expectedEx.expect(Exception.class);
expectedEx.expectMessage("There was an error starting the Websockets transport");
Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger()); Transport transport = new WebSocketTransport("www.notarealurl12345.fake", new NullLogger());
transport.start(); Throwable exception = assertThrows(Exception.class, () -> transport.start().get(1,TimeUnit.SECONDS));
assertEquals("There was an error starting the Websockets transport.", exception.getCause().getMessage());
} }
} }

View File

@ -3,39 +3,28 @@
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr;
import static org.junit.Assert.*; import static org.junit.jupiter.api.Assertions.*;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.Arrays; import java.util.stream.Stream;
import java.util.Collection;
import org.junit.Test; import org.junit.jupiter.params.ParameterizedTest;
import org.junit.runner.RunWith; import org.junit.jupiter.params.provider.Arguments;
import org.junit.runners.Parameterized; import org.junit.jupiter.params.provider.MethodSource;
@RunWith(Parameterized.class)
public class WebSocketTransportUrlFormatTest { public class WebSocketTransportUrlFormatTest {
private String url; private static Stream<Arguments> protocols() {
private String expectedUrl; return Stream.of(
Arguments.of("http://example.com", "ws://example.com"),
public WebSocketTransportUrlFormatTest(String url, String expectedProtocol) { Arguments.of("https://example.com", "wss://example.com"),
this.url = url; Arguments.of("ws://example.com", "ws://example.com"),
this.expectedUrl = expectedProtocol; Arguments.of("wss://example.com", "wss://example.com"));
} }
@Parameterized.Parameters @ParameterizedTest
public static Collection protocols() { @MethodSource("protocols")
return Arrays.asList(new String[][]{ public void checkWebsocketUrlProtocol(String url, String expectedUrl) throws URISyntaxException {
{"http://example.com", "ws://example.com"}, WebSocketTransport webSocketTransport = new WebSocketTransport(url, new NullLogger());
{"https://example.com", "wss://example.com"}, assertEquals(expectedUrl, webSocketTransport.getUrl().toString());
{"ws://example.com", "ws://example.com"},
{"wss://example.com", "wss://example.com"}});
}
@Test
public void checkWebsocketUrlProtocol() throws URISyntaxException {
WebSocketTransport webSocketTransport = new WebSocketTransport(this.url, new NullLogger());
assertEquals(this.expectedUrl, webSocketTransport.getUrl().toString());
} }
} }

View File

@ -1,10 +1,14 @@
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
package com.microsoft.aspnet.signalr; package com.microsoft.aspnet.signalr.sample;
import java.util.Scanner; import java.util.Scanner;
import com.microsoft.aspnet.signalr.HubConnection;
import com.microsoft.aspnet.signalr.HubConnectionBuilder;
import com.microsoft.aspnet.signalr.LogLevel;
public class Chat { public class Chat {
public static void main(String[] args) throws Exception { public static void main(String[] args) throws Exception {
System.out.println("Enter the URL of the SignalR Chat you want to join"); System.out.println("Enter the URL of the SignalR Chat you want to join");

View File

@ -234,11 +234,45 @@ namespace Microsoft.AspNetCore.SignalR.Internal
InitializeHub(hub, connection); InitializeHub(hub, connection);
Task invocation = null; Task invocation = null;
CancellationTokenSource cts = null;
var arguments = hubMethodInvocationMessage.Arguments;
if (descriptor.HasSyntheticArguments)
{
// In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments)
arguments = new object[descriptor.OriginalParameterTypes.Count];
var hubInvocationArgumentPointer = 0;
for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++)
{
if (hubMethodInvocationMessage.Arguments.Length > hubInvocationArgumentPointer &&
hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer].GetType() == descriptor.OriginalParameterTypes[parameterPointer])
{
// The types match so it isn't a synthetic argument, just copy it into the arguments array
arguments[parameterPointer] = hubMethodInvocationMessage.Arguments[hubInvocationArgumentPointer];
hubInvocationArgumentPointer++;
}
else
{
// This is the only synthetic argument type we currently support
if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken))
{
cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
arguments[parameterPointer] = cts.Token;
}
else
{
// This should never happen
Debug.Assert(false, $"Failed to bind argument of type '{descriptor.OriginalParameterTypes[parameterPointer].Name}' for hub method '{methodExecutor.MethodInfo.Name}'.");
}
}
}
}
if (isStreamResponse) if (isStreamResponse)
{ {
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts)) if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, ref cts))
{ {
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
@ -247,13 +281,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal
} }
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts); _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts);
} }
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{ {
// Send Async, no response expected // Send Async, no response expected
invocation = ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); invocation = ExecuteHubMethod(methodExecutor, hub, arguments);
} }
else else
@ -261,7 +295,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// Invoke Async, one reponse expected // Invoke Async, one reponse expected
async Task ExecuteInvocation() async Task ExecuteInvocation()
{ {
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments); var result = await ExecuteHubMethod(methodExecutor, hub, arguments);
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
} }
@ -443,29 +477,24 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true; return true;
} }
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts) private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, ref CancellationTokenSource streamCts)
{ {
if (result != null) if (result != null)
{ {
if (hubMethodDescriptor.IsChannel) if (hubMethodDescriptor.IsChannel)
{ {
streamCts = CreateCancellation(); if (streamCts == null)
{
streamCts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted);
}
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true; return true;
} }
} }
streamCts = null;
enumerator = null; enumerator = null;
return false; return false;
CancellationTokenSource CreateCancellation()
{
var userCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
}
} }
private void DiscoverHubMethods() private void DiscoverHubMethods()

View File

@ -23,8 +23,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies) public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{ {
MethodExecutor = methodExecutor; MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(GetParameterType).ToArray();
Policies = policies.ToArray();
NonAsyncReturnType = (MethodExecutor.IsMethodAsync) NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
? MethodExecutor.AsyncResultType ? MethodExecutor.AsyncResultType
@ -35,6 +33,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal
IsChannel = true; IsChannel = true;
StreamReturnType = channelItemType; StreamReturnType = channelItemType;
} }
// Take out synthetic arguments that will be provided by the server, this list will be given to the protocol parsers
ParameterTypes = methodExecutor.MethodParameters.Where(p =>
{
// Only streams can take CancellationTokens currently
if (IsStreamable && p.ParameterType == typeof(CancellationToken))
{
HasSyntheticArguments = true;
return false;
}
return true;
}).Select(GetParameterType).ToArray();
if (HasSyntheticArguments)
{
OriginalParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
}
Policies = policies.ToArray();
} }
public bool HasStreamingParameters { get; private set; } public bool HasStreamingParameters { get; private set; }
@ -45,6 +62,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IReadOnlyList<Type> ParameterTypes { get; } public IReadOnlyList<Type> ParameterTypes { get; }
public IReadOnlyList<Type> OriginalParameterTypes { get; }
public Type NonAsyncReturnType { get; } public Type NonAsyncReturnType { get; }
public bool IsChannel { get; } public bool IsChannel { get; }
@ -55,6 +74,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IList<IAuthorizeData> Policies { get; } public IList<IAuthorizeData> Policies { get; }
public bool HasSyntheticArguments { get; private set; }
private Type GetParameterType(ParameterInfo p) private Type GetParameterType(ParameterInfo p)
{ {
var type = p.ParameterType; var type = p.ParameterType;

View File

@ -20,7 +20,7 @@ namespace Microsoft.Extensions.DependencyInjection
/// <returns>An <see cref="ISignalRServerBuilder"/> that can be used to further configure the SignalR services.</returns> /// <returns>An <see cref="ISignalRServerBuilder"/> that can be used to further configure the SignalR services.</returns>
public static ISignalRServerBuilder AddSignalRCore(this IServiceCollection services) public static ISignalRServerBuilder AddSignalRCore(this IServiceCollection services)
{ {
services.AddSingleton<SignalRCoreMarkerService>(); services.TryAddSingleton<SignalRCoreMarkerService>();
services.TryAddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>)); services.TryAddSingleton(typeof(HubLifetimeManager<>), typeof(DefaultHubLifetimeManager<>));
services.TryAddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver)); services.TryAddSingleton(typeof(IHubProtocolResolver), typeof(DefaultHubProtocolResolver));
services.TryAddSingleton(typeof(IHubContext<>), typeof(HubContext<>)); services.TryAddSingleton(typeof(IHubContext<>), typeof(HubContext<>));

View File

@ -4,6 +4,7 @@
using System; using System;
using Microsoft.AspNetCore.SignalR; using Microsoft.AspNetCore.SignalR;
using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
namespace Microsoft.Extensions.DependencyInjection namespace Microsoft.Extensions.DependencyInjection
@ -35,8 +36,8 @@ namespace Microsoft.Extensions.DependencyInjection
public static ISignalRServerBuilder AddSignalR(this IServiceCollection services) public static ISignalRServerBuilder AddSignalR(this IServiceCollection services)
{ {
services.AddConnections(); services.AddConnections();
services.AddSingleton<SignalRMarkerService>(); services.TryAddSingleton<SignalRMarkerService>();
services.AddSingleton<IConfigureOptions<HubOptions>, HubOptionsSetup>(); services.TryAddEnumerable(ServiceDescriptor.Singleton<IConfigureOptions<HubOptions>, HubOptionsSetup>());
return services.AddSignalRCore(); return services.AddSignalRCore();
} }

View File

@ -0,0 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Tests
{
public static class CancellationTokenExtensions
{
public static Task WaitForCancellationAsync(this CancellationToken token)
{
var tcs = new TaskCompletionSource<object>(TaskCreationOptions.RunContinuationsAsynchronously);
token.Register((t) =>
{
((TaskCompletionSource<object>)t).SetResult(null);
}, tcs);
return tcs.Task;
}
}
}

View File

@ -4,8 +4,10 @@
using System.Collections.Generic; using System.Collections.Generic;
using System.Threading; using System.Threading;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Options;
using Xunit; using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests namespace Microsoft.AspNetCore.SignalR.Tests
@ -17,12 +19,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{ {
var serviceCollection = new ServiceCollection(); var serviceCollection = new ServiceCollection();
var markerService = new SignalRCoreMarkerService();
serviceCollection.AddSingleton(markerService);
serviceCollection.AddSingleton<IUserIdProvider, CustomIdProvider>(); serviceCollection.AddSingleton<IUserIdProvider, CustomIdProvider>();
serviceCollection.AddSingleton(typeof(HubLifetimeManager<>), typeof(CustomHubLifetimeManager<>)); serviceCollection.AddSingleton(typeof(HubLifetimeManager<>), typeof(CustomHubLifetimeManager<>));
serviceCollection.AddSingleton<IHubProtocolResolver, CustomHubProtocolResolver>(); serviceCollection.AddSingleton<IHubProtocolResolver, CustomHubProtocolResolver>();
serviceCollection.AddScoped(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); serviceCollection.AddScoped(typeof(IHubActivator<>), typeof(CustomHubActivator<>));
serviceCollection.AddSingleton(typeof(IHubContext<>), typeof(CustomHubContext<>)); serviceCollection.AddSingleton(typeof(IHubContext<>), typeof(CustomHubContext<>));
serviceCollection.AddSingleton(typeof(IHubContext<,>), typeof(CustomHubContext<,>)); serviceCollection.AddSingleton(typeof(IHubContext<,>), typeof(CustomHubContext<,>));
var hubOptions = new HubOptionsSetup(new List<IHubProtocol>());
serviceCollection.AddSingleton<IConfigureOptions<HubOptions>>(hubOptions);
serviceCollection.AddSignalR(); serviceCollection.AddSignalR();
var serviceProvider = serviceCollection.BuildServiceProvider(); var serviceProvider = serviceCollection.BuildServiceProvider();
@ -33,6 +39,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.IsType<CustomHubContext<CustomHub>>(serviceProvider.GetRequiredService<IHubContext<CustomHub>>()); Assert.IsType<CustomHubContext<CustomHub>>(serviceProvider.GetRequiredService<IHubContext<CustomHub>>());
Assert.IsType<CustomHubContext<CustomTHub, string>>(serviceProvider.GetRequiredService<IHubContext<CustomTHub, string>>()); Assert.IsType<CustomHubContext<CustomTHub, string>>(serviceProvider.GetRequiredService<IHubContext<CustomTHub, string>>());
Assert.IsType<CustomHubContext<CustomDynamicHub>>(serviceProvider.GetRequiredService<IHubContext<CustomDynamicHub>>()); Assert.IsType<CustomHubContext<CustomDynamicHub>>(serviceProvider.GetRequiredService<IHubContext<CustomDynamicHub>>());
Assert.Equal(hubOptions, serviceProvider.GetRequiredService<IConfigureOptions<HubOptions>>());
Assert.Equal(markerService, serviceProvider.GetRequiredService<SignalRCoreMarkerService>());
} }
[Fact] [Fact]

View File

@ -4,6 +4,7 @@
using System; using System;
using System.Collections.Generic; using System.Collections.Generic;
using System.Text; using System.Text;
using System.Threading;
using System.Threading.Channels; using System.Threading.Channels;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization; using Microsoft.AspNetCore.Authorization;
@ -166,6 +167,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef()); return Clients.Caller.SendAsync("Send", new string('x', 3000), new SelfRef());
} }
public void InvalidArgument(CancellationToken token)
{
}
private class SelfRef private class SelfRef
{ {
public SelfRef() public SelfRef()
@ -620,6 +625,51 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return Channel.CreateUnbounded<string>().Reader; return Channel.CreateUnbounded<string>().Reader;
} }
public ChannelReader<int> CancelableStream(CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public ChannelReader<int> CancelableStream2(int ignore, int ignore2, CancellationToken token)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public ChannelReader<int> CancelableStreamMiddle(int ignore, CancellationToken token, int ignore2)
{
var channel = Channel.CreateBounded<int>(10);
Task.Run(async () =>
{
_tcsService.StartedMethod.SetResult(null);
await token.WaitForCancellationAsync();
channel.Writer.TryComplete();
_tcsService.EndMethod.SetResult(null);
});
return channel.Reader;
}
public int SimpleMethod() public int SimpleMethod()
{ {
return 21; return 21;

View File

@ -2600,6 +2600,95 @@ namespace Microsoft.AspNetCore.SignalR.Tests
} }
} }
[Theory]
[InlineData(nameof(LongRunningHub.CancelableStream))]
[InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]
[InlineData(nameof(LongRunningHub.CancelableStreamMiddle), 1, 2)]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnCancellation(string methodName, params object[] args)
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(methodName, args).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Cancel the stream which should trigger the CancellationToken in the hub method
await client.SendHubMessageAsync(new CancelInvocationMessage(streamInvocationId)).OrTimeout();
var result = await client.ReadAsync().OrTimeout();
var simpleCompletion = Assert.IsType<CompletionMessage>(result);
Assert.Null(simpleCompletion.Result);
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();
// Shut down
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
[Fact]
public async Task StreamHubMethodCanAcceptCancellationTokenAsArgumentAndBeTriggeredOnConnectionAborted()
{
var tcsService = new TcsService();
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder =>
{
builder.AddSingleton(tcsService);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<LongRunningHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.CancelableStream)).OrTimeout();
// Wait for the stream method to start
await tcsService.StartedMethod.Task.OrTimeout();
// Shut down the client which should trigger the CancellationToken in the hub method
client.Dispose();
// CancellationToken passed to hub method will allow EndMethod to be triggered if it is canceled.
await tcsService.EndMethod.Task.OrTimeout();
await connectionHandlerTask.OrTimeout();
}
}
[Fact]
public async Task InvokeHubMethodCannotAcceptCancellationTokenAsArgument()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var invocationId = await client.SendInvocationAsync(nameof(MethodHub.InvalidArgument)).OrTimeout();
var completion = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
Assert.Equal("Failed to invoke 'InvalidArgument' due to an error on the server.", completion.Error);
client.Dispose();
await connectionHandlerTask.OrTimeout();
}
}
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{ {
public int ReleaseCount; public int ReleaseCount;