Replacing StreamCompletion with StreamInvocation

This commit is contained in:
Pawel Kadluczka 2017-10-25 16:46:19 -07:00 committed by Pawel Kadluczka
parent 1a21fd49b1
commit ff12b9b20c
26 changed files with 597 additions and 484 deletions

View File

@ -123,32 +123,6 @@ describe("HubConnection", () => {
let ex = await captureException(async () => await invokePromise);
expect(ex.message).toBe("Connection lost");
});
it("rejects streaming results made using 'invoke'", async () => {
let connection = new TestConnection();
let hubConnection = new HubConnection(connection);
let invokePromise = hubConnection.invoke("testMethod");
connection.receive({ type: MessageType.Result, invocationId: connection.lastInvocationId, item: null });
connection.onclose();
let ex = await captureException(async () => await invokePromise);
expect(ex.message).toBe("Streaming methods must be invoked using the 'HubConnection.stream()' method.");
});
it("rejects streaming completions made using 'invoke'", async () => {
let connection = new TestConnection();
let hubConnection = new HubConnection(connection);
let invokePromise = hubConnection.invoke("testMethod");
connection.receive({ type: MessageType.StreamCompletion, invocationId: connection.lastInvocationId });
connection.onclose();
let ex = await captureException(async () => await invokePromise);
expect(ex.message).toBe("Streaming methods must be invoked using the 'HubConnection.stream()' method.");
});
});
describe("on", () => {
@ -301,10 +275,9 @@ describe("HubConnection", () => {
// Verify the message is sent
expect(connection.sentData.length).toBe(1);
expect(JSON.parse(connection.sentData[0])).toEqual({
type: MessageType.Invocation,
type: MessageType.StreamInvocation,
invocationId: connection.lastInvocationId,
target: "testStream",
nonblocking: false,
arguments: [
"arg",
42
@ -323,7 +296,7 @@ describe("HubConnection", () => {
hubConnection.stream<any>("testMethod", "arg", 42)
.subscribe(observer);
connection.receive({ type: MessageType.StreamCompletion, invocationId: connection.lastInvocationId, error: "foo" });
connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId, error: "foo" });
let ex = await captureException(async () => await observer.completed);
expect(ex.message).toEqual("Error: foo");
@ -337,7 +310,7 @@ describe("HubConnection", () => {
hubConnection.stream<any>("testMethod", "arg", 42)
.subscribe(observer);
connection.receive({ type: MessageType.StreamCompletion, invocationId: connection.lastInvocationId });
connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId });
expect(await observer.completed).toEqual([]);
});
@ -370,20 +343,6 @@ describe("HubConnection", () => {
expect(ex.message).toEqual("Error: Connection lost");
});
it("rejects completion responses", async () => {
let connection = new TestConnection();
let hubConnection = new HubConnection(connection);
let observer = new TestObserver();
hubConnection.stream<any>("testMethod")
.subscribe(observer);
connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId, result: "foo" });
let ex = await captureException(async () => await observer.completed);
expect(ex.message).toEqual("Error: Hub methods must be invoked using the 'HubConnection.invoke()' method.");
});
it("yields items as they arrive", async () => {
let connection = new TestConnection();
@ -401,7 +360,7 @@ describe("HubConnection", () => {
connection.receive({ type: MessageType.Result, invocationId: connection.lastInvocationId, item: 3 });
expect(observer.itemsReceived).toEqual([1, 2, 3]);
connection.receive({ type: MessageType.StreamCompletion, invocationId: connection.lastInvocationId });
connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId });
expect(await observer.completed).toEqual([1, 2, 3]);
});

View File

@ -6,7 +6,7 @@ import { IConnection } from "./IConnection"
import { HttpConnection} from "./HttpConnection"
import { TransportType, TransferMode } from "./Transports"
import { Subject, Observable } from "./Observable"
import { IHubProtocol, ProtocolType, MessageType, HubMessage, CompletionMessage, StreamCompletionMessage, ResultMessage, InvocationMessage, NegotiationMessage } from "./IHubProtocol";
import { IHubProtocol, ProtocolType, MessageType, HubMessage, CompletionMessage, ResultMessage, InvocationMessage, StreamInvocationMessage, NegotiationMessage } from "./IHubProtocol";
import { JsonHubProtocol } from "./JsonHubProtocol";
import { TextMessageFormat } from "./Formatters"
import { Base64EncodedHubProtocol } from "./Base64EncodedHubProtocol"
@ -63,10 +63,9 @@ export class HubConnection {
break;
case MessageType.Result:
case MessageType.Completion:
case MessageType.StreamCompletion:
let callback = this.callbacks.get(message.invocationId);
if (callback != null) {
if (message.type == MessageType.Completion || message.type == MessageType.StreamCompletion) {
if (message.type === MessageType.Completion) {
this.callbacks.delete(message.invocationId);
}
callback(message);
@ -127,7 +126,7 @@ export class HubConnection {
}
stream<T>(methodName: string, ...args: any[]): Observable<T> {
let invocationDescriptor = this.createInvocation(methodName, args, false);
let invocationDescriptor = this.createStreamInvocation(methodName, args);
let subject = new Subject<T>();
@ -137,22 +136,17 @@ export class HubConnection {
return;
}
switch (invocationEvent.type) {
case MessageType.StreamCompletion:
let completionMessage = <StreamCompletionMessage>invocationEvent;
if (completionMessage.error) {
subject.error(new Error(completionMessage.error));
}
else {
subject.complete();
}
break;
case MessageType.Result:
subject.next(<T>(<ResultMessage>invocationEvent).item);
break;
default:
subject.error(new Error("Hub methods must be invoked using the 'HubConnection.invoke()' method."));
break;
if (invocationEvent.type === MessageType.Completion) {
let completionMessage = <CompletionMessage>invocationEvent;
if (completionMessage.error) {
subject.error(new Error(completionMessage.error));
}
else {
subject.complete();
}
}
else {
subject.next(<T>(<ResultMessage>invocationEvent).item);
}
});
@ -194,7 +188,7 @@ export class HubConnection {
}
}
else {
reject(new Error("Streaming methods must be invoked using the 'HubConnection.stream()' method."));
reject(new Error(`Unexpected message type: ${invocationEvent.type}`));
}
});
@ -257,4 +251,16 @@ export class HubConnection {
nonblocking: nonblocking
};
}
private createStreamInvocation(methodName: string, args: any[]): StreamInvocationMessage {
let id = this.id;
this.id++;
return <StreamInvocationMessage>{
type: MessageType.StreamInvocation,
invocationId: id.toString(),
target: methodName,
arguments: args,
};
}
}

View File

@ -5,7 +5,7 @@ export const enum MessageType {
Invocation = 1,
Result,
Completion,
StreamCompletion
StreamInvocation
}
export interface HubMessage {
@ -19,12 +19,13 @@ export interface InvocationMessage extends HubMessage {
readonly nonblocking?: boolean;
}
export interface ResultMessage extends HubMessage {
readonly item?: any;
export interface StreamInvocationMessage extends HubMessage {
readonly target: string;
readonly arguments: Array<any>
}
export interface StreamCompletionMessage extends HubMessage {
readonly error?: string;
export interface ResultMessage extends HubMessage {
readonly item?: any;
}
export interface CompletionMessage extends HubMessage {

View File

@ -1,7 +1,7 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
import { IHubProtocol, ProtocolType, MessageType, HubMessage, InvocationMessage, ResultMessage, CompletionMessage, StreamCompletionMessage } from "./IHubProtocol";
import { IHubProtocol, ProtocolType, MessageType, HubMessage, InvocationMessage, ResultMessage, CompletionMessage, StreamInvocationMessage } from "./IHubProtocol";
import { BinaryMessageFormat } from "./Formatters"
import * as msgpack5 from "msgpack5"
@ -34,8 +34,6 @@ export class MessagePackHubProtocol implements IHubProtocol {
return this.createStreamItemMessage(properties);
case MessageType.Completion:
return this.createCompletionMessage(properties);
case MessageType.StreamCompletion:
return this.createStreamCompletionMessage(properties);
default:
throw new Error("Invalid message type.");
}
@ -102,22 +100,12 @@ export class MessagePackHubProtocol implements IHubProtocol {
return completionMessage as CompletionMessage;
}
private createStreamCompletionMessage(properties: any[]): StreamCompletionMessage {
if (properties.length < 2) {
throw new Error("Invalid payload for Completion message.");
}
return <StreamCompletionMessage>{
type: MessageType.StreamCompletion,
invocationId: properties[1],
error: properties.length == 3 ? properties[2] : null,
};
}
writeMessage(message: HubMessage): ArrayBuffer {
switch (message.type) {
case MessageType.Invocation:
return this.writeInvocation(message as InvocationMessage);
case MessageType.StreamInvocation:
return this.writeStreamInvocation(message as StreamInvocationMessage);
case MessageType.Result:
case MessageType.Completion:
throw new Error(`Writing messages of type '${message.type}' is not supported.`);
@ -133,4 +121,12 @@ export class MessagePackHubProtocol implements IHubProtocol {
return BinaryMessageFormat.write(payload.slice());
}
private writeStreamInvocation(streamInvocationMessage: StreamInvocationMessage): ArrayBuffer {
let msgpack = msgpack5();
let payload = msgpack.encode([ MessageType.StreamInvocation, streamInvocationMessage.invocationId,
streamInvocationMessage.target, streamInvocationMessage.arguments]);
return BinaryMessageFormat.write(payload.slice());
}
}

View File

@ -163,7 +163,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
// The stream invocation will be canceled by the CancelInvocationMessage, connection closing, or channel finishing.
using (cancellationToken.Register(token => ((CancellationTokenSource)token).Cancel(), invokeCts))
{
await InvokeCore(methodName, irq, args);
await InvokeStreamCore(methodName, irq, args);
}
if (cancellationToken.CanBeCanceled)
@ -178,7 +178,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
if (invocationReq.HubConnection.TryRemoveInvocation(invocationReq.InvocationId, out _))
{
invocationReq.Complete(new StreamCompletionMessage(irq.InvocationId, error: null));
invocationReq.Complete(CompletionMessage.Empty(irq.InvocationId));
}
invocationReq.Dispose();
@ -224,6 +224,27 @@ namespace Microsoft.AspNetCore.SignalR.Client
return SendHubMessage(invocationMessage, irq);
}
private Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args)
{
ThrowIfConnectionTerminated(irq.InvocationId);
_logger.PreparingStreamingInvocation(irq.InvocationId, methodName, irq.ResultType.FullName, args.Length);
var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName,
argumentBindingException: null, arguments: args);
// I just want an excuse to use 'irq' as a variable name...
_logger.RegisterInvocation(invocationMessage.InvocationId);
AddInvocation(irq);
// Trace the full invocation
_logger.IssueInvocation(invocationMessage.InvocationId, irq.ResultType.FullName, methodName, args);
// We don't need to wait for this to complete. It will signal back to the invocation request.
return SendHubMessage(invocationMessage, irq);
}
private async Task SendHubMessage(HubMessage hubMessage, InvocationRequest irq)
{
try
@ -306,17 +327,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
DispatchInvocationStreamItemAsync(streamItem, irq);
break;
case StreamCompletionMessage streamCompletion:
if (!TryRemoveInvocation(streamCompletion.InvocationId, out irq))
{
_logger.DropStreamCompletionMessage(streamCompletion.InvocationId);
return;
}
DispatchStreamCompletion(streamCompletion, irq);
irq.Dispose();
break;
default:
throw new InvalidOperationException($"Unknown message type: {message.GetType().FullName}");
throw new InvalidOperationException($"Unexpected message type: {message.GetType().FullName}");
}
}
}
@ -412,20 +424,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private void DispatchStreamCompletion(StreamCompletionMessage completion, InvocationRequest irq)
{
_logger.ReceivedStreamCompletion(completion.InvocationId);
if (irq.CancellationToken.IsCancellationRequested)
{
_logger.CancelingStreamCompletion(irq.InvocationId);
}
else
{
irq.Complete(completion);
}
}
private void ThrowIfConnectionTerminated(string invocationId)
{
if (_connectionActive.Token.IsCancellationRequested)

View File

@ -114,6 +114,5 @@ namespace Microsoft.AspNetCore.SignalR.Client
return outputChannel.In;
}
}
}

View File

@ -40,53 +40,50 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
private static readonly Action<ILogger, string, Exception> _dropStreamMessage =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(9, nameof(DropStreamMessage)), "Dropped unsolicited StreamItem message for invocation '{invocationId}'.");
private static readonly Action<ILogger, string, Exception> _dropStreamCompletionMessage =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(10, nameof(DropStreamCompletionMessage)), "Dropped unsolicited Stream Completion message for invocation '{invocationId}'.");
private static readonly Action<ILogger, Exception> _shutdownConnection =
LoggerMessage.Define(LogLevel.Trace, new EventId(11, nameof(ShutdownConnection)), "Shutting down connection.");
LoggerMessage.Define(LogLevel.Trace, new EventId(10, nameof(ShutdownConnection)), "Shutting down connection.");
private static readonly Action<ILogger, Exception> _shutdownWithError =
LoggerMessage.Define(LogLevel.Error, new EventId(12, nameof(ShutdownWithError)), "Connection is shutting down due to an error.");
LoggerMessage.Define(LogLevel.Error, new EventId(11, nameof(ShutdownWithError)), "Connection is shutting down due to an error.");
private static readonly Action<ILogger, string, Exception> _removeInvocation =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(13, nameof(RemoveInvocation)), "Removing pending invocation {invocationId}.");
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(12, nameof(RemoveInvocation)), "Removing pending invocation {invocationId}.");
private static readonly Action<ILogger, string, Exception> _missingHandler =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(14, nameof(MissingHandler)), "Failed to find handler for '{target}' method.");
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(13, nameof(MissingHandler)), "Failed to find handler for '{target}' method.");
private static readonly Action<ILogger, string, Exception> _receivedStreamItem =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(15, nameof(ReceivedStreamItem)), "Received StreamItem for Invocation {invocationId}.");
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(14, nameof(ReceivedStreamItem)), "Received StreamItem for Invocation {invocationId}.");
private static readonly Action<ILogger, string, Exception> _cancelingStreamItem =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(16, nameof(CancelingStreamItem)), "Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was canceled.");
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(15, nameof(CancelingStreamItem)), "Canceling dispatch of StreamItem message for Invocation {invocationId}. The invocation was canceled.");
private static readonly Action<ILogger, string, Exception> _receivedStreamItemAfterClose =
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(17, nameof(ReceivedStreamItemAfterClose)), "Invocation {invocationId} received stream item after channel was closed.");
LoggerMessage.Define<string>(LogLevel.Warning, new EventId(16, nameof(ReceivedStreamItemAfterClose)), "Invocation {invocationId} received stream item after channel was closed.");
private static readonly Action<ILogger, string, Exception> _receivedInvocationCompletion =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(18, nameof(ReceivedInvocationCompletion)), "Received Completion for Invocation {invocationId}.");
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(17, nameof(ReceivedInvocationCompletion)), "Received Completion for Invocation {invocationId}.");
private static readonly Action<ILogger, string, Exception> _cancelingInvocationCompletion =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(19, nameof(CancelingInvocationCompletion)), "Canceling dispatch of Completion message for Invocation {invocationId}. The invocation was canceled.");
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(18, nameof(CancelingInvocationCompletion)), "Canceling dispatch of Completion message for Invocation {invocationId}. The invocation was canceled.");
private static readonly Action<ILogger, string, Exception> _receivedStreamCompletion =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(20, nameof(ReceivedStreamCompletion)), "Received StreamCompletion for Invocation {invocationId}.");
private static readonly Action<ILogger, string, Exception> _cancelingStreamCompletion =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(21, nameof(CancelingStreamCompletion)), "Canceling dispatch of StreamCompletion message for Invocation {invocationId}. The invocation was canceled.");
private static readonly Action<ILogger, string, Exception> _cancelingCompletion =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(19, nameof(CancelingCompletion)), "Canceling dispatch of Completion message for Invocation {invocationId}. The invocation was canceled.");
private static readonly Action<ILogger, string, Exception> _invokeAfterTermination =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(22, nameof(InvokeAfterTermination)), "Invoke for Invocation '{invocationId}' was called after the connection was terminated.");
LoggerMessage.Define<string>(LogLevel.Error, new EventId(20, nameof(InvokeAfterTermination)), "Invoke for Invocation '{invocationId}' was called after the connection was terminated.");
private static readonly Action<ILogger, string, Exception> _invocationAlreadyInUse =
LoggerMessage.Define<string>(LogLevel.Critical, new EventId(23, nameof(InvocationAlreadyInUse)), "Invocation ID '{invocationId}' is already in use.");
LoggerMessage.Define<string>(LogLevel.Critical, new EventId(21, nameof(InvocationAlreadyInUse)), "Invocation ID '{invocationId}' is already in use.");
private static readonly Action<ILogger, string, Exception> _receivedUnexpectedResponse =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(24, nameof(ReceivedUnexpectedResponse)), "Unsolicited response received for invocation '{invocationId}'.");
LoggerMessage.Define<string>(LogLevel.Error, new EventId(22, nameof(ReceivedUnexpectedResponse)), "Unsolicited response received for invocation '{invocationId}'.");
private static readonly Action<ILogger, string, Exception> _hubProtocol =
LoggerMessage.Define<string>(LogLevel.Information, new EventId(25, nameof(HubProtocol)), "Using HubProtocol '{protocol}'.");
LoggerMessage.Define<string>(LogLevel.Information, new EventId(23, nameof(HubProtocol)), "Using HubProtocol '{protocol}'.");
private static readonly Action<ILogger, string, string, string, int, Exception> _preparingStreamingInvocation =
LoggerMessage.Define<string, string, string, int>(LogLevel.Trace, new EventId(24, nameof(PreparingStreamingInvocation)), "Preparing streaming invocation '{invocationId}' of '{target}', with return type '{returnType}' and {argumentCount} argument(s).");
// Category: Streaming and NonStreaming
private static readonly Action<ILogger, string, Exception> _invocationCreated =
@ -105,8 +102,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
private static readonly Action<ILogger, string, Exception> _errorWritingStreamItem =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(4, nameof(ErrorWritingStreamItem)), "Invocation {invocationId} caused an error trying to write a stream item.");
private static readonly Action<ILogger, string, string, Exception> _receivedUnexpectedMessageTypeForStreamCompletion =
LoggerMessage.Define<string, string>(LogLevel.Error, new EventId(5, nameof(ReceivedUnexpectedMessageTypeForStreamCompletion)), "Invocation {invocationId} was invoked as a streaming hub method but completed with '{messageType}' message.");
private static readonly Action<ILogger, string, Exception> _receivedUnexpectedComplete =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(5, nameof(ReceivedUnexpectedComplete)), "Invocation {invocationId} received a completion result, but was invoked as a streaming invocation.");
// Category: NonStreaming
private static readonly Action<ILogger, string, Exception> _streamItemOnNonStreamInvocation =
@ -115,9 +112,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
private static readonly Action<ILogger, string, Exception> _errorInvokingClientSideMethod =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(5, nameof(ErrorInvokingClientSideMethod)), "Invoking client side method '{methodName}' failed.");
private static readonly Action<ILogger, string, string, Exception> _receivedUnexpectedMessageTypeForInvokeCompletion =
LoggerMessage.Define<string, string>(LogLevel.Error, new EventId(6, nameof(ReceivedUnexpectedMessageTypeForInvokeCompletion)), "Invocation {invocationId} was invoked as a non-streaming hub method but completed with '{messageType}' message.");
public static void PreparingNonBlockingInvocation(this ILogger logger, string invocationId, string target, int count)
{
_preparingNonBlockingInvocation(logger, invocationId, target, count, null);
@ -128,6 +122,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
_preparingBlockingInvocation(logger, invocationId, target, returnType, count, null);
}
public static void PreparingStreamingInvocation(this ILogger logger, string invocationId, string target, string returnType, int count)
{
_preparingStreamingInvocation(logger, invocationId, target, returnType, count, null);
}
public static void RegisterInvocation(this ILogger logger, string invocationId)
{
_registerInvocation(logger, invocationId, null);
@ -176,11 +175,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
_dropStreamMessage(logger, invocationId, null);
}
public static void DropStreamCompletionMessage(this ILogger logger, string invocationId)
{
_dropStreamCompletionMessage(logger, invocationId, null);
}
public static void ShutdownConnection(this ILogger logger)
{
_shutdownConnection(logger, null);
@ -226,14 +220,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
_cancelingInvocationCompletion(logger, invocationId, null);
}
public static void ReceivedStreamCompletion(this ILogger logger, string invocationId)
public static void CancelingCompletion(this ILogger logger, string invocationId)
{
_receivedStreamCompletion(logger, invocationId, null);
}
public static void CancelingStreamCompletion(this ILogger logger, string invocationId)
{
_cancelingStreamCompletion(logger, invocationId, null);
_cancelingCompletion(logger, invocationId, null);
}
public static void InvokeAfterTermination(this ILogger logger, string invocationId)
@ -281,6 +270,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
_errorWritingStreamItem(logger, invocationId, exception);
}
public static void ReceivedUnexpectedComplete(this ILogger logger, string invocationId)
{
_receivedUnexpectedComplete(logger, invocationId, null);
}
public static void StreamItemOnNonStreamInvocation(this ILogger logger, string invocationId)
{
_streamItemOnNonStreamInvocation(logger, invocationId, null);
@ -290,15 +284,5 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
{
_errorInvokingClientSideMethod(logger, methodName, exception);
}
public static void ReceivedUnexpectedMessageTypeForStreamCompletion(this ILogger logger, string invocationId, string messageType)
{
_receivedUnexpectedMessageTypeForStreamCompletion(logger, invocationId, messageType, null);
}
public static void ReceivedUnexpectedMessageTypeForInvokeCompletion(this ILogger logger, string invocationId, string messageType)
{
_receivedUnexpectedMessageTypeForStreamCompletion(logger, invocationId, messageType, null);
}
}
}

View File

@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
@ -52,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
public abstract void Fail(Exception exception);
public abstract void Complete(HubMessage message);
public abstract void Complete(CompletionMessage message);
public abstract ValueTask<bool> StreamItem(object item);
protected abstract void Cancel();
@ -78,26 +77,21 @@ namespace Microsoft.AspNetCore.SignalR.Client
public ReadableChannel<object> Result => _channel.In;
public override void Complete(HubMessage message)
public override void Complete(CompletionMessage completionMessage)
{
Debug.Assert(message != null, "message is null");
if (!(message is StreamCompletionMessage streamCompletionMessage))
{
Logger.ReceivedUnexpectedMessageTypeForStreamCompletion(InvocationId, message.GetType().Name);
// This is not 100% accurate but it is the only case that can be encountered today when running end-to-end
// and this is the most useful message to show to the user.
Fail(new InvalidOperationException($"Streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnection.StreamAsync)}' method."));
return;
}
if (!string.IsNullOrEmpty(streamCompletionMessage.Error))
{
Fail(new HubException(streamCompletionMessage.Error));
return;
}
Logger.InvocationCompleted(InvocationId);
if (completionMessage.Result != null)
{
Logger.ReceivedUnexpectedComplete(InvocationId);
_channel.Out.TryComplete(new InvalidOperationException("Server provided a result in a completion response to a streamed invocation."));
}
if (!string.IsNullOrEmpty(completionMessage.Error))
{
Fail(new HubException(completionMessage.Error));
return;
}
_channel.Out.TryComplete();
}
@ -143,20 +137,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Task<object> Result => _completionSource.Task;
public override void Complete(HubMessage message)
public override void Complete(CompletionMessage completionMessage)
{
Debug.Assert(message != null, "message is null");
if (!(message is CompletionMessage completionMessage))
{
Logger.ReceivedUnexpectedMessageTypeForStreamCompletion(InvocationId, message.GetType().Name);
// This is not 100% accurate but it is the only case that can be encountered today when running end-to-end
// and this is the most useful message to show to the user.
Fail(new InvalidOperationException(
$"Non-streaming hub methods must be invoked with the '{nameof(HubConnection)}.{nameof(HubConnection.InvokeAsync)}' method."));
return;
}
if (!string.IsNullOrEmpty(completionMessage.Error))
{
Fail(new HubException(completionMessage.Error));

View File

@ -7,7 +7,7 @@ using System.Runtime.ExceptionServices;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class InvocationMessage : HubMessage
public abstract class HubMethodInvocationMessage : HubMessage
{
private readonly ExceptionDispatchInfo _argumentBindingException;
private readonly object[] _arguments;
@ -35,9 +35,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
public bool NonBlocking { get; }
public bool NonBlocking { get; protected set; }
public InvocationMessage(string invocationId, bool nonBlocking, string target, ExceptionDispatchInfo argumentBindingException, params object[] arguments)
public HubMethodInvocationMessage(string invocationId, string target, ExceptionDispatchInfo argumentBindingException, object[] arguments)
: base(invocationId)
{
if (string.IsNullOrEmpty(invocationId))
@ -58,12 +58,32 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
Target = target;
_arguments = arguments;
_argumentBindingException = argumentBindingException;
}
}
public class InvocationMessage : HubMethodInvocationMessage
{
public InvocationMessage(string invocationId, bool nonBlocking, string target, ExceptionDispatchInfo argumentBindingException, params object[] arguments)
: base(invocationId, target, argumentBindingException, arguments)
{
NonBlocking = nonBlocking;
}
public override string ToString()
{
return $"Invocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(NonBlocking)}: {NonBlocking}, {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments.Select(a => a?.ToString()))} ] }}";
return $"InvocationMessage {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(NonBlocking)}: {NonBlocking}, {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments?.Select(a => a?.ToString())) ?? string.Empty } ] }}";
}
}
public class StreamInvocationMessage : HubMethodInvocationMessage
{
public StreamInvocationMessage(string invocationId, string target, ExceptionDispatchInfo argumentBindingException, params object[] arguments)
: base(invocationId, target, argumentBindingException, arguments)
{ }
public override string ToString()
{
return $"StreamInvocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {string.Join(", ", Arguments?.Select(a => a?.ToString())) ?? string.Empty} ] }}";
}
}
}

View File

@ -26,7 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private const int InvocationMessageType = 1;
private const int ResultMessageType = 2;
private const int CompletionMessageType = 3;
private const int StreamCompletionMessageType = 4;
private const int StreamInvocationMessageType = 4;
private const int CancelInvocationMessageType = 5;
// ONLY to be used for application payloads (args, return values, etc.)
@ -110,12 +110,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
case InvocationMessageType:
return BindInvocationMessage(json, binder);
case StreamInvocationMessageType:
return BindStreamInvocationMessage(json, binder);
case ResultMessageType:
return BindResultMessage(json, binder);
case CompletionMessageType:
return BindCompletionMessage(json, binder);
case StreamCompletionMessageType:
return BindStreamCompletionMessage(json);
case CancelInvocationMessageType:
return BindCancelInvocationMessage(json);
default:
@ -138,15 +138,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
case InvocationMessage m:
WriteInvocationMessage(m, writer);
break;
case StreamInvocationMessage m:
WriteStreamInvocationMessage(m, writer);
break;
case StreamItemMessage m:
WriteStreamItemMessage(m, writer);
break;
case CompletionMessage m:
WriteCompletionMessage(m, writer);
break;
case StreamCompletionMessage m:
WriteStreamCompletionMessage(m, writer);
break;
case CancelInvocationMessage m:
WriteCancelInvocationMessage(m, writer);
break;
@ -173,18 +173,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
writer.WriteEndObject();
}
private void WriteStreamCompletionMessage(StreamCompletionMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
WriteHubMessageCommon(message, writer, StreamCompletionMessageType);
if (!string.IsNullOrEmpty(message.Error))
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(message.Error);
}
writer.WriteEndObject();
}
private void WriteCancelInvocationMessage(CancelInvocationMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
@ -214,15 +202,32 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
writer.WriteValue(message.NonBlocking);
}
WriteArguments(message.Arguments, writer);
writer.WriteEndObject();
}
private void WriteStreamInvocationMessage(StreamInvocationMessage message, JsonTextWriter writer)
{
writer.WriteStartObject();
WriteHubMessageCommon(message, writer, StreamInvocationMessageType);
writer.WritePropertyName(TargetPropertyName);
writer.WriteValue(message.Target);
WriteArguments(message.Arguments, writer);
writer.WriteEndObject();
}
private void WriteArguments(object[] arguments, JsonTextWriter writer)
{
writer.WritePropertyName(ArgumentsPropertyName);
writer.WriteStartArray();
foreach (var argument in message.Arguments)
foreach (var argument in arguments)
{
_payloadSerializer.Serialize(writer, argument);
}
writer.WriteEndArray();
writer.WriteEndObject();
}
private static void WriteHubMessageCommon(HubMessage message, JsonTextWriter writer, int type)
@ -254,6 +259,26 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
private StreamInvocationMessage BindStreamInvocationMessage(JObject json, IInvocationBinder binder)
{
var invocationId = JsonUtils.GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
var target = JsonUtils.GetRequiredProperty<string>(json, TargetPropertyName, JTokenType.String);
var args = JsonUtils.GetRequiredProperty<JArray>(json, ArgumentsPropertyName, JTokenType.Array);
var paramTypes = binder.GetParameterTypes(target);
try
{
var arguments = BindArguments(args, paramTypes);
return new StreamInvocationMessage(invocationId, target, argumentBindingException: null, arguments: arguments);
}
catch (Exception ex)
{
return new StreamInvocationMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
}
}
private object[] BindArguments(JArray args, Type[] paramTypes)
{
var arguments = new object[args.Count];
@ -308,13 +333,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return new CompletionMessage(invocationId, error, result: payload, hasResult: true);
}
private StreamCompletionMessage BindStreamCompletionMessage(JObject json)
{
var invocationId = JsonUtils.GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);
var error = JsonUtils.GetOptionalProperty<string>(json, ErrorPropertyName, JTokenType.String);
return new StreamCompletionMessage(invocationId, error);
}
private CancelInvocationMessage BindCancelInvocationMessage(JObject json)
{
var invocationId = JsonUtils.GetRequiredProperty<string>(json, InvocationIdPropertyName, JTokenType.String);

View File

@ -3,7 +3,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Runtime.ExceptionServices;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
@ -17,7 +16,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private const int InvocationMessageType = 1;
private const int StreamItemMessageType = 2;
private const int CompletionMessageType = 3;
private const int StreamCompletionMessageType = 4;
private const int StreamInvocationMessageType = 4;
private const int CancelInvocationMessageType = 5;
private const int ErrorResult = 1;
@ -58,19 +57,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
using (var unpacker = Unpacker.Create(input))
{
var arraySize = ReadArrayLength(unpacker, "elementCount");
_ = ReadArrayLength(unpacker, "elementCount");
var messageType = ReadInt32(unpacker, "messageType");
switch (messageType)
{
case InvocationMessageType:
return CreateInvocationMessage(unpacker, binder);
case StreamInvocationMessageType:
return CreateStreamInvocationMessage(unpacker, binder);
case StreamItemMessageType:
return CreateStreamItemMessage(unpacker, binder);
case CompletionMessageType:
return CreateCompletionMessage(unpacker, binder);
case StreamCompletionMessageType:
return CreateStreamCompletionMessage(unpacker, arraySize, binder);
case CancelInvocationMessageType:
return CreateCancelInvocationMessage(unpacker);
default:
@ -97,6 +96,22 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
private static StreamInvocationMessage CreateStreamInvocationMessage(Unpacker unpacker, IInvocationBinder binder)
{
var invocationId = ReadInvocationId(unpacker);
var target = ReadString(unpacker, "target");
var parameterTypes = binder.GetParameterTypes(target);
try
{
var arguments = BindArguments(unpacker, parameterTypes);
return new StreamInvocationMessage(invocationId, target, argumentBindingException: null, arguments: arguments);
}
catch (Exception ex)
{
return new StreamInvocationMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex));
}
}
private static object[] BindArguments(Unpacker unpacker, Type[] parameterTypes)
{
var argumentCount = ReadArrayLength(unpacker, "arguments");
@ -160,16 +175,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
return new CompletionMessage(invocationId, error, result, hasResult);
}
private static StreamCompletionMessage CreateStreamCompletionMessage(Unpacker unpacker, long arraySize, IInvocationBinder binder)
{
Debug.Assert(arraySize == 2 || arraySize == 3, "Unexpected item count");
var invocationId = ReadInvocationId(unpacker);
// Error is optional so StreamCompletion without error has 2 items, StreamCompletion with error has 3 items
var error = arraySize == 3 ? ReadString(unpacker, "error") : null;
return new StreamCompletionMessage(invocationId, error);
}
private static CancelInvocationMessage CreateCancelInvocationMessage(Unpacker unpacker)
{
var invocationId = ReadInvocationId(unpacker);
@ -195,15 +200,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
case InvocationMessage invocationMessage:
WriteInvocationMessage(invocationMessage, packer);
break;
case StreamInvocationMessage streamInvocationMessage:
WriteStreamInvocationMessage(streamInvocationMessage, packer);
break;
case StreamItemMessage streamItemMessage:
WriteStreamingItemMessage(streamItemMessage, packer);
break;
case CompletionMessage completionMessage:
WriteCompletionMessage(completionMessage, packer);
break;
case StreamCompletionMessage streamCompletionMessage:
WriteStreamCompletionMessage(streamCompletionMessage, packer);
break;
case CancelInvocationMessage cancelInvocationMessage:
WriteCancelInvocationMessage(cancelInvocationMessage, packer);
break;
@ -222,6 +227,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
packer.PackObject(invocationMessage.Arguments, _serializationContext);
}
private void WriteStreamInvocationMessage(StreamInvocationMessage streamInvocationMessage, Packer packer)
{
packer.PackArrayHeader(4);
packer.Pack(StreamInvocationMessageType);
packer.PackString(streamInvocationMessage.InvocationId);
packer.PackString(streamInvocationMessage.Target);
packer.PackObject(streamInvocationMessage.Arguments, _serializationContext);
}
private void WriteStreamingItemMessage(StreamItemMessage streamItemMessage, Packer packer)
{
packer.PackArrayHeader(3);
@ -253,19 +267,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
}
}
private void WriteStreamCompletionMessage(StreamCompletionMessage streamCompletionMessage, Packer packer)
{
var hasError = !string.IsNullOrEmpty(streamCompletionMessage.Error);
packer.PackArrayHeader(2 + (hasError ? 1 : 0));
packer.Pack(StreamCompletionMessageType);
packer.PackString(streamCompletionMessage.InvocationId);
if (hasError)
{
packer.PackString(streamCompletionMessage.Error);
}
}
private void WriteCancelInvocationMessage(CancelInvocationMessage cancelInvocationMessage, Packer packer)
{
packer.PackArrayHeader(2);

View File

@ -1,22 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public class StreamCompletionMessage : HubMessage
{
public string Error { get; }
public StreamCompletionMessage(string invocationId, string error)
: base(invocationId)
{
Error = error;
}
public override string ToString()
{
var errorStr = Error == null ? "<<null>>" : $"\"{Error}\"";
return $"StreamCompletion {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Error)}: {errorStr} }}";
}
}
}

View File

@ -277,7 +277,15 @@ namespace Microsoft.AspNetCore.SignalR
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, invocationMessage);
_ = ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
break;
case StreamInvocationMessage streamInvocationMessage:
_logger.ReceivedStreamHubInvocation(streamInvocationMessage);
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
_ = ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
break;
case CancelInvocationMessage cancelInvocationMessage:
@ -312,13 +320,24 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task ProcessInvocation(HubConnectionContext connection, InvocationMessage invocationMessage)
private async Task ProcessInvocation(HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation)
{
try
{
// If an unexpected exception occurs then we want to kill the entire connection
// by ending the processing loop
await Execute(connection, invocationMessage);
if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor))
{
// Send an error to the client. Then let the normal completion process occur
_logger.UnknownHubMethod(hubMethodInvocationMessage.Target);
await SendMessageAsync(connection, CompletionMessage.WithError(
hubMethodInvocationMessage.InvocationId, $"Unknown hub method '{hubMethodInvocationMessage.Target}'"));
}
else
{
await Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation);
}
}
catch (Exception ex)
{
@ -327,20 +346,6 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task Execute(HubConnectionContext connection, InvocationMessage invocationMessage)
{
if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor))
{
// Send an error to the client. Then let the normal completion process occur
_logger.UnknownHubMethod(invocationMessage.Target);
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
}
else
{
await Invoke(descriptor, connection, invocationMessage);
}
}
private async Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
while (await connection.Output.WaitToWriteAsync())
@ -356,7 +361,8 @@ namespace Microsoft.AspNetCore.SignalR
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, InvocationMessage invocationMessage)
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation)
{
var methodExecutor = descriptor.MethodExecutor;
@ -364,11 +370,14 @@ namespace Microsoft.AspNetCore.SignalR
{
if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies))
{
_logger.HubMethodNotAuthorized(invocationMessage.Target);
if (!invocationMessage.NonBlocking)
{
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized"));
}
_logger.HubMethodNotAuthorized(hubMethodInvocationMessage.Target);
await SendInvocationError(hubMethodInvocationMessage, connection,
$"Failed to invoke '{hubMethodInvocationMessage.Target}' because user is unauthorized");
return;
}
if (!await ValidateInvocationMode(methodExecutor.MethodReturnType, isStreamedInvocation, hubMethodInvocationMessage, connection))
{
return;
}
@ -379,45 +388,29 @@ namespace Microsoft.AspNetCore.SignalR
{
InitializeHub(hub, connection);
object result = null;
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
// ReadableChannel is awaitable but we don't want to await it.
if (methodExecutor.IsMethodAsync && !IsChannel(methodExecutor.MethodReturnType, out _))
if (isStreamedInvocation)
{
if (methodExecutor.MethodReturnType == typeof(Task))
{
await (Task)methodExecutor.Execute(hub, invocationMessage.Arguments);
}
else
{
result = await methodExecutor.ExecuteAsync(hub, invocationMessage.Arguments);
}
var enumerator = GetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, methodExecutor, result, methodExecutor.MethodReturnType);
_logger.StreamingResult(hubMethodInvocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator);
}
else
else if (!hubMethodInvocationMessage.NonBlocking)
{
result = methodExecutor.Execute(hub, invocationMessage.Arguments);
}
if (IsStreamed(connection, invocationMessage.InvocationId, methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator))
{
_logger.StreamingResult(invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await StreamResultsAsync(invocationMessage.InvocationId, connection, enumerator);
}
else if (!invocationMessage.NonBlocking)
{
_logger.SendingResult(invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await SendMessageAsync(connection, CompletionMessage.WithResult(invocationMessage.InvocationId, result));
_logger.SendingResult(hubMethodInvocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await SendMessageAsync(connection, CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
}
}
catch (TargetInvocationException ex)
{
_logger.FailedInvokingHubMethod(invocationMessage.Target, ex);
await SendInvocationError(invocationMessage, connection, methodExecutor.MethodReturnType, ex.InnerException);
_logger.FailedInvokingHubMethod(hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage, connection, ex.InnerException.Message);
}
catch (Exception ex)
{
_logger.FailedInvokingHubMethod(invocationMessage.Target, ex);
await SendInvocationError(invocationMessage, connection, methodExecutor.MethodReturnType, ex);
_logger.FailedInvokingHubMethod(hubMethodInvocationMessage.Target, ex);
await SendInvocationError(hubMethodInvocationMessage, connection, ex.Message);
}
finally
{
@ -426,19 +419,37 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task SendInvocationError(InvocationMessage invocationMessage, HubConnectionContext connection, Type returnType, Exception ex)
private static async Task<object> ExecuteHubMethod(ObjectMethodExecutor methodExecutor, THub hub, object[] arguments)
{
if (!invocationMessage.NonBlocking)
// ReadableChannel is awaitable but we don't want to await it.
if (methodExecutor.IsMethodAsync && !IsChannel(methodExecutor.MethodReturnType, out _))
{
if (IsIObservable(returnType) || IsChannel(returnType, out _))
if (methodExecutor.MethodReturnType == typeof(Task))
{
await SendMessageAsync(connection, new StreamCompletionMessage(invocationMessage.InvocationId, ex.Message));
await (Task)methodExecutor.Execute(hub, arguments);
}
else
{
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message));
return await methodExecutor.ExecuteAsync(hub, arguments);
}
}
else
{
return methodExecutor.Execute(hub, arguments);
}
return null;
}
private async Task SendInvocationError(HubMethodInvocationMessage hubMethodInvocationMessage,
HubConnectionContext connection, string errorMessage)
{
if (hubMethodInvocationMessage.NonBlocking)
{
return;
}
await SendMessageAsync(connection, CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, errorMessage));
}
private void InitializeHub(THub hub, HubConnectionContext connection)
@ -448,7 +459,7 @@ namespace Microsoft.AspNetCore.SignalR
hub.Groups = _hubContext.Groups;
}
private bool IsChannel(Type type, out Type payloadType)
private static bool IsChannel(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ReadableChannel<>));
if (channelType == null)
@ -486,7 +497,7 @@ namespace Microsoft.AspNetCore.SignalR
}
finally
{
await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: error));
await SendMessageAsync(connection, new CompletionMessage(invocationId, error: error, result: null, hasResult: false));
if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts))
{
@ -495,34 +506,74 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private bool IsStreamed(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, Type resultType, out IAsyncEnumerator<object> enumerator)
private async Task<bool> ValidateInvocationMode(Type resultType, bool isStreamedInvocation,
HubMethodInvocationMessage hubMethodInvocationMessage, HubConnectionContext connection)
{
if (result == null)
var isStreamedResult = IsStreamed(resultType);
if (isStreamedResult && !isStreamedInvocation)
{
enumerator = null;
if (!hubMethodInvocationMessage.NonBlocking)
{
_logger.StreamingMethodCalledWithInvoke(hubMethodInvocationMessage);
await SendMessageAsync(connection, CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method in a non-streaming fashion."));
}
return false;
}
if (!isStreamedResult && isStreamedInvocation)
{
_logger.NonStreamingMethodCalledWithStream(hubMethodInvocationMessage);
await SendMessageAsync(connection, CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method in a streaming fashion."));
return false;
}
return true;
}
private static bool IsStreamed(Type resultType)
{
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface != null)
{
enumerator = AsyncEnumeratorAdapters.FromObservable(result, observableInterface, CreateCancellation());
return true;
}
else if (IsChannel(resultType, out var payloadType))
if (IsChannel(resultType, out _))
{
enumerator = AsyncEnumeratorAdapters.FromChannel(result, payloadType, CreateCancellation());
return true;
}
else
return false;
}
private IAsyncEnumerator<object> GetStreamingEnumerator(HubConnectionContext connection, string invocationId, ObjectMethodExecutor methodExecutor, object result, Type resultType)
{
if (result != null)
{
// Not streamed
enumerator = null;
return false;
var observableInterface = IsIObservable(resultType) ?
resultType :
resultType.GetInterfaces().FirstOrDefault(IsIObservable);
if (observableInterface != null)
{
return AsyncEnumeratorAdapters.FromObservable(result, observableInterface, CreateCancellation());
}
if (IsChannel(resultType, out var payloadType))
{
return AsyncEnumeratorAdapters.FromChannel(result, payloadType, CreateCancellation());
}
}
_logger.InvalidReturnValueFromStreamingMethod(methodExecutor.MethodInfo.Name);
throw new InvalidOperationException($"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is null, does not implement the IObservable<> interface or is not a ReadableChannel<>.");
CancellationToken CreateCancellation()
{
var streamCts = new CancellationTokenSource();

View File

@ -58,6 +58,18 @@ namespace Microsoft.AspNetCore.SignalR.Core.Internal
private static readonly Action<ILogger, Exception> _abortFailed =
LoggerMessage.Define(LogLevel.Trace, new EventId(15, nameof(AbortFailed)), "Abort callback failed.");
private static readonly Action<ILogger, StreamInvocationMessage, Exception> _receivedStreamHubInvocation =
LoggerMessage.Define<StreamInvocationMessage>(LogLevel.Debug, new EventId(16, nameof(ReceivedStreamHubInvocation)), "Received stream hub invocation: {invocationMessage}.");
private static readonly Action<ILogger, HubMethodInvocationMessage, Exception> _streamingMethodCalledWithInvoke =
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(17, nameof(StreamingMethodCalledWithInvoke)), "A streaming method was invoked in the non-streaming fashion : {invocationMessage}.");
private static readonly Action<ILogger, HubMethodInvocationMessage, Exception> _nonStreamingMethodCalledWithStream =
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(18, nameof(NonStreamingMethodCalledWithStream)), "A non-streaming method was invoked in the streaming fashion : {invocationMessage}.");
private static readonly Action<ILogger, string, Exception> _invalidReturnValueFromStreamingMethod =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(19, nameof(InvalidReturnValueFromStreamingMethod)), "A streaming method returned a value that cannot be used to build enumerator {hubMethod}.");
public static void UsingHubProtocol(this ILogger logger, string hubProtocol)
{
_usingHubProtocol(logger, hubProtocol, null);
@ -137,5 +149,25 @@ namespace Microsoft.AspNetCore.SignalR.Core.Internal
{
_abortFailed(logger, exception);
}
public static void ReceivedStreamHubInvocation(this ILogger logger, StreamInvocationMessage invocationMessage)
{
_receivedStreamHubInvocation(logger, invocationMessage, null);
}
public static void StreamingMethodCalledWithInvoke(this ILogger logger, HubMethodInvocationMessage invocationMessage)
{
_streamingMethodCalledWithInvoke(logger, invocationMessage, null);
}
public static void NonStreamingMethodCalledWithStream(this ILogger logger, HubMethodInvocationMessage invocationMessage)
{
_nonStreamingMethodCalledWithStream(logger, invocationMessage, null);
}
public static void InvalidReturnValueFromStreamingMethod(this ILogger logger, string hubMethod)
{
_invalidReturnValueFromStreamingMethod(logger, hubMethod, null);
}
}
}

View File

@ -597,7 +597,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public class RedisExcludeClientsMessage : InvocationMessage
private class RedisExcludeClientsMessage : InvocationMessage
{
public IReadOnlyList<string> ExcludedIds;

View File

@ -66,7 +66,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public async Task<IList<HubMessage>> StreamAsync(string methodName, params object[] args)
{
var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args);
var invocationId = await SendStreamInvocationAsync(methodName, args);
var messages = new List<HubMessage>();
while (true)
@ -89,7 +89,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
messages.Add(message);
break;
case CompletionMessage _:
case StreamCompletionMessage _:
messages.Add(message);
return messages;
default:
@ -140,6 +139,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
argumentBindingException: null, arguments: args));
}
public Task<string> SendStreamInvocationAsync(string methodName, params object[] args)
{
var invocationId = GetInvocationId();
return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName,
argumentBindingException: null, arguments: args));
}
public async Task<string> SendHubMessageAsync(HubMessage message)
{
var payload = _protocolReaderWriter.WriteMessage(message);

View File

@ -389,7 +389,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory(Skip="https://github.com/aspnet/SignalR/issues/1053")]
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task ServerThrowsHubExceptionIfStreamingHubMethodCannotBeResolved(IHubProtocol hubProtocol, TransportType transportType, string hubPath)
{
@ -474,6 +474,88 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task ServerThrowsHubExceptionIfNonStreamMethodInvokedWithStreamAsync(IHubProtocol hubProtocol, TransportType transportType, string hubPath)
{
using (StartLog(out var loggerFactory))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + hubPath), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory);
try
{
await connection.StartAsync().OrTimeout();
var channel = await connection.StreamAsync<int>("HelloWorld").OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAllAsync()).OrTimeout();
Assert.Equal("The client attempted to invoke the non-streaming 'HelloWorld' method in a streaming fashion.", ex.Message);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "Exception from test");
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task ServerThrowsHubExceptionIfStreamMethodInvokedWithInvoke(IHubProtocol hubProtocol, TransportType transportType, string hubPath)
{
using (StartLog(out var loggerFactory))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + hubPath), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory);
try
{
await connection.StartAsync().OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => connection.InvokeAsync("Stream", 3)).OrTimeout();
Assert.Equal("The client attempted to invoke the streaming 'Stream' method in a non-streaming fashion.", ex.Message);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "Exception from test");
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task ServerThrowsHubExceptionIfBuildingAsyncEnumeratorIsNotPossible(IHubProtocol hubProtocol, TransportType transportType, string hubPath)
{
using (StartLog(out var loggerFactory))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + hubPath), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, hubProtocol, loggerFactory);
try
{
await connection.StartAsync().OrTimeout();
var channel = await connection.StreamAsync<int>("StreamBroken").OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAllAsync()).OrTimeout();
Assert.Equal("The value returned by the streaming method 'StreamBroken' is null, does not implement the IObservable<> interface or is not a ReadableChannel<>.", ex.Message);
}
catch (Exception ex)
{
loggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "Exception from test");
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
public static IEnumerable<object[]> HubProtocolsAndTransportsAndHubPaths
{
get

View File

@ -19,6 +19,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public ReadableChannel<int> StreamException() => TestHubMethodsImpl.StreamException();
public ReadableChannel<string> StreamBroken() => TestHubMethodsImpl.StreamBroken();
public async Task CallEcho(string message)
{
await Clients.Client(Context.ConnectionId).InvokeAsync("Echo", message);
@ -40,6 +42,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public ReadableChannel<int> StreamException() => TestHubMethodsImpl.StreamException();
public ReadableChannel<string> StreamBroken() => TestHubMethodsImpl.StreamBroken();
public async Task CallEcho(string message)
{
await Clients.Client(Context.ConnectionId).Echo(message);
@ -61,6 +65,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public ReadableChannel<int> StreamException() => TestHubMethodsImpl.StreamException();
public ReadableChannel<string> StreamBroken() => TestHubMethodsImpl.StreamBroken();
public async Task CallEcho(string message)
{
await Clients.Client(Context.ConnectionId).Echo(message);
@ -95,6 +101,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
throw new InvalidOperationException("Error occurred while streaming.");
}
public static ReadableChannel<string> StreamBroken() => null;
}
public interface ITestHub

View File

@ -104,10 +104,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await connection.ReadSentTextMessageAsync().OrTimeout();
var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout();
Assert.Equal("{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage);
Assert.Equal("{\"invocationId\":\"1\",\"type\":4,\"target\":\"Foo\",\"arguments\":[]}\u001e", invokeMessage);
// Complete the channel
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
await channel.Completion;
}
finally
@ -150,7 +150,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var channel = await hubConnection.StreamAsync<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
Assert.Empty(await channel.ReadAllAsync());
}
@ -183,52 +183,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
[Fact]
public async Task StreamFailsIfCompletionMessageIsNotStreamCompletionMessage()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var channel = await hubConnection.StreamAsync<string>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await channel.ReadAllAsync().OrTimeout());
Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsync' method.", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task StreamFailsIfErrorCompletionMessageIsNotStreamCompletionMessage()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var channel = await hubConnection.StreamAsync<string>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "error" }).OrTimeout();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await channel.ReadAllAsync().OrTimeout());
Assert.Equal("Streaming hub methods must be invoked with the 'HubConnection.StreamAsync' method.", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task InvokeFailsWithExceptionWhenCompletionWithErrorReceived()
{
@ -252,6 +206,29 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
[Fact]
public async Task StreamFailsIfCompletionMessageHasPayload()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var channel = await hubConnection.StreamAsync<string>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, result = "Oops" }).OrTimeout();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(async () => await channel.ReadAllAsync().OrTimeout());
Assert.Equal("Server provided a result in a completion response to a streamed invocation.", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task StreamFailsWithExceptionWhenCompletionWithErrorReceived()
{
@ -263,7 +240,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var channel = await hubConnection.StreamAsync<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4, error = "An error occurred" }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3, error = "An error occurred" }).OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(async () => await channel.ReadAllAsync().OrTimeout());
Assert.Equal("An error occurred", ex.Message);
@ -298,52 +275,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
[Fact]
public async Task InvokeFailsWithErrorWhenStreamCompletionReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.InvokeAsync<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => invokeTask).OrTimeout();
Assert.Equal("Non-streaming hub methods must be invoked with the 'HubConnection.InvokeAsync' method.", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task InvokeFailsWithErrorWhenErrorStreamCompletionReceived()
{
var connection = new TestConnection();
var hubConnection = new HubConnection(connection, new JsonHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync();
var invokeTask = hubConnection.InvokeAsync<int>("Foo");
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4, error = "error" }).OrTimeout();
var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => invokeTask).OrTimeout();
Assert.Equal("Non-streaming hub methods must be invoked with the 'HubConnection.InvokeAsync' method.", ex.Message);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task StreamYieldsItemsAsTheyArrive()
{
@ -358,7 +289,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "1" }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "2" }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 2, item = "3" }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 4 }).OrTimeout();
await connection.ReceiveJsonMessage(new { invocationId = "1", type = 3 }).OrTimeout();
var notifications = await channel.ReadAllAsync().OrTimeout();

View File

@ -46,6 +46,17 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"byteArrProp\":\"AQID\"}}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null,\"ByteArrProp\":\"AQID\"}}" },
new object[] { CompletionMessage.WithResult("123", new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":3,\"result\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}}" },
new object[] { new StreamInvocationMessage("123", "Target", null, 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, 1, "Foo", 2.0f), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, true), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[true]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, new object[] { null }), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[null]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, new CustomObject()), false, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"ByteArrProp\":\"AQID\"}]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, new CustomObject()), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"byteArrProp\":\"AQID\"}]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, new CustomObject()), false, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00\",\"NullProp\":null,\"ByteArrProp\":\"AQID\"}]}" },
new object[] { new StreamInvocationMessage("123", "Target", null, new CustomObject()), true, NullValueHandling.Include, "{\"invocationId\":\"123\",\"type\":4,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}]}" },
new object[] { new CancelInvocationMessage("123"), true, NullValueHandling.Ignore, "{\"invocationId\":\"123\",\"type\":5}" }
};
[Theory]
@ -112,6 +123,12 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[InlineData("{'type':3,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
[InlineData("{'type':3,'invocationId':'42','error':[]}", "Expected 'error' to be of type String.")]
[InlineData("{'type':4}", "Missing required property 'invocationId'.")]
[InlineData("{'type':4,'invocationId':42}", "Expected 'invocationId' to be of type String.")]
[InlineData("{'type':4,'invocationId':'42','target':42}", "Expected 'target' to be of type String.")]
[InlineData("{'type':4,'invocationId':'42','target':'foo'}", "Missing required property 'arguments'.")]
[InlineData("{'type':4,'invocationId':'42','target':'foo','arguments':{}}", "Expected 'arguments' to be of type Array.")]
[InlineData("{'type':9}", "Unknown message type: 9")]
[InlineData("{'type':'foo'}", "Expected 'type' to be of type Integer.")]
@ -130,6 +147,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
[Theory]
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")]
[InlineData("{'type':1,'invocationId':'42','target':'foo','arguments':[ 'abc', 'xyz']}", "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.")]
[InlineData("{'type':4,'invocationId':'42','target':'foo','arguments':[]}", "Invocation provides 0 argument(s) but target expects 2.")]
[InlineData("{'type':4,'invocationId':'42','target':'foo','arguments':[ 'abc', 'xyz']}", "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.")]
public void ArgumentBindingErrors(string input, string expectedMessage)
{
input = Frame(input);
@ -137,7 +156,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(paramTypes: new[] { typeof(int), typeof(string) }, returnType: typeof(bool));
var protocol = new JsonHubProtocol();
protocol.TryParseMessages(Encoding.UTF8.GetBytes(input), binder, out var messages);
var ex = Assert.Throws<FormatException>(() => ((InvocationMessage)messages[0]).Arguments);
var ex = Assert.Throws<FormatException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments);
Assert.Equal(expectedMessage, ex.Message);
}

View File

@ -35,9 +35,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true) } },
new object[] { new[] { new CompletionMessage("xyz", error: null, result: new[] { new CustomObject(), new CustomObject() }, hasResult: true) } },
new object[] { new[] { new StreamCompletionMessage("xyz", error: null) } },
new object[] { new[] { new StreamCompletionMessage("xyz", error: "Error not found!") } },
new object[] { new[] { new StreamItemMessage("xyz", null) } },
new object[] { new[] { new StreamItemMessage("xyz", 42) } },
new object[] { new[] { new StreamItemMessage("xyz", 42.0f) } },
@ -46,6 +43,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new object[] { new[] { new StreamItemMessage("xyz", new CustomObject()) } },
new object[] { new[] { new StreamItemMessage("xyz", new[] { new CustomObject(), new CustomObject() }) } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null) } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null, new object[] { null }) } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null, 42) } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null, 42, "string") } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null, 42, "string", new CustomObject()) } },
new object[] { new[] { new StreamInvocationMessage("xyz", "method", null, new[] { new CustomObject(), new CustomObject() }) } },
new object[] { new[] { new CancelInvocationMessage("xyz") } },
new object[]
@ -55,8 +59,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new InvocationMessage("xyz", /*nonBlocking*/ true, "method", null, 42, "string", new CustomObject()),
new CompletionMessage("xyz", error: null, result: 42, hasResult: true),
new StreamItemMessage("xyz", null),
new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true),
new StreamCompletionMessage("xyz", error: null),
new StreamInvocationMessage("xyz", "method", null, 42, "string", new CustomObject()),
new CompletionMessage("xyz", error: null, result: new CustomObject(), hasResult: true)
}
}
};
@ -110,6 +114,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new object[] { new byte[] { 0x93, 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03 }, "Deserializing object of the `String` type for 'argument' failed." }, // non void result but result missing
new object[] { new byte[] { 0x93, 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03, 0xa9 }, "Deserializing object of the `String` type for 'argument' failed." }, // result is cut
new object[] { new byte[] { 0x93, 0x03, 0xa3, 0x78, 0x79, 0x7a, 0x03, 0x00 }, "Deserializing object of the `String` type for 'argument' failed." }, // return type mismatch
// StreamInvocationMessage
new object[] { new byte[] { 0x95, 0x04 }, "Reading 'invocationId' as String failed." }, // invocationId missing
new object[] { new byte[] { 0x95, 0x04, 0xc2 }, "Reading 'invocationId' as String failed." }, // 0xc2 is Bool false
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a }, "Reading 'target' as String failed." }, // target missing
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0x00 }, "Reading 'target' as String failed." }, // 0x00 is Int
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1 }, "Reading 'target' as String failed." }, // string is cut
};
[Theory]
@ -132,12 +143,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
public static IEnumerable<object[]> ArgumentBindingErrors => new[]
{
// InvocationMessage
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78 }, "Reading array length for 'arguments' failed." }, // array is missing
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78, 0x00 }, "Reading array length for 'arguments' failed." }, // 0x00 is not array marker
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78, 0x91 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // array is missing elements
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78, 0x91, 0xa2, 0x78 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // array element is cut
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78, 0x92, 0xa0, 0x00 }, "Invocation provides 2 argument(s) but target expects 1." }, // argument count does not match binder argument count
new object[] { new byte[] { 0x95, 0x01, 0xa3, 0x78, 0x79, 0x7a, 0xc2, 0xa1, 0x78, 0x91, 0x00 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // argument type mismatch
// StreamInvocationMessage
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78 }, "Reading array length for 'arguments' failed." }, // array is missing
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78, 0x00 }, "Reading array length for 'arguments' failed." }, // 0x00 is not array marker
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78, 0x91 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // array is missing elements
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78, 0x91, 0xa2, 0x78 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // array element is cut
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78, 0x92, 0xa0, 0x00 }, "Invocation provides 2 argument(s) but target expects 1." }, // argument count does not match binder argument count
new object[] { new byte[] { 0x95, 0x04, 0xa3, 0x78, 0x79, 0x7a, 0xa1, 0x78, 0x91, 0x00 }, "Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked." }, // argument type mismatch
};
[Theory]
@ -154,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var binder = new TestBinder(new[] { typeof(string) }, typeof(string));
_hubProtocol.TryParseMessages(buffer, binder, out var messages);
var exception = Assert.Throws<FormatException>(() => ((InvocationMessage)messages[0]).Arguments);
var exception = Assert.Throws<FormatException>(() => ((HubMethodInvocationMessage)messages[0]).Arguments);
Assert.Equal(expectedExceptionMessage, exception.Message);
}
@ -227,6 +247,32 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x50, 0x72, 0x6f, 0x70, 0xa8, 0x53, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x52, 0x21
}
},
new object[]
{
new StreamInvocationMessage("0", "A", null, 1, new CustomObject()),
new byte[]
{
0x6b, 0x94, 0x04, 0xa1, 0x30, 0xa1, 0x41,
0x92, // argument array
0x01, // 1 - first argument
// 0x86 - a map of 6 items (properties)
0x86, 0xab, 0x42, 0x79, 0x74, 0x65, 0x41, 0x72, 0x72, 0x50, 0x72, 0x6f, 0x70, 0xc4, 0x03, 0x01,
0x02, 0x03, 0xac, 0x44, 0x61, 0x74, 0x65, 0x54, 0x69, 0x6d, 0x65, 0x50, 0x72, 0x6f, 0x70, 0xd3,
0x08, 0xd4, 0x80, 0x6d, 0xb2, 0x76, 0xc0, 0x00, 0xaa, 0x44, 0x6f, 0x75, 0x62, 0x6c, 0x65, 0x50,
0x72, 0x6f, 0x70, 0xcb, 0x40, 0x19, 0x21, 0xfb, 0x54, 0x42, 0xcf, 0x12, 0xa7, 0x49, 0x6e, 0x74,
0x50, 0x72, 0x6f, 0x70, 0x2a, 0xa8, 0x4e, 0x75, 0x6c, 0x6c, 0x50, 0x72, 0x6f, 0x70, 0xc0, 0xaa,
0x53, 0x74, 0x72, 0x69, 0x6e, 0x67, 0x50, 0x72, 0x6f, 0x70, 0xa8, 0x53, 0x69, 0x67, 0x6e, 0x61,
0x6c, 0x52, 0x21
}
},
new object[]
{
new CancelInvocationMessage("0"),
new byte[]
{
0x04, 0x92, 0x05, 0xa1, 0x30
}
}
};

View File

@ -17,6 +17,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
switch (expectedMessage)
{
case StreamInvocationMessage i:
_paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray();
break;
case InvocationMessage i:
_paramTypes = i.Arguments?.Select(a => a?.GetType() ?? typeof(object))?.ToArray();
break;

View File

@ -22,7 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
}
return InvocationMessagesEqual(x, y) || StreamItemMessagesEqual(x, y) || CompletionMessagesEqual(x, y)
|| StreamCompletionMessagesEqual(x, y) || CancelInvocationMessagesEqual(x, y);
|| StreamInvocationMessagesEqual(x, y) || CancelInvocationMessagesEqual(x, y);
}
private bool CompletionMessagesEqual(HubMessage x, HubMessage y)
@ -33,12 +33,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
(Equals(left.Result, right.Result) || SequenceEqual(left.Result, right.Result));
}
private bool StreamCompletionMessagesEqual(HubMessage x, HubMessage y)
{
return x is StreamCompletionMessage left && y is StreamCompletionMessage right &&
string.Equals(left.Error, right.Error, StringComparison.Ordinal);
}
private bool StreamItemMessagesEqual(HubMessage x, HubMessage y)
{
return x is StreamItemMessage left && y is StreamItemMessage right &&
@ -53,6 +47,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
left.NonBlocking == right.NonBlocking;
}
private bool StreamInvocationMessagesEqual(HubMessage x, HubMessage y)
{
return x is StreamInvocationMessage left && y is StreamInvocationMessage right &&
string.Equals(left.Target, right.Target, StringComparison.Ordinal) &&
ArgumentListsEqual(left.Arguments, right.Arguments) &&
left.NonBlocking == right.NonBlocking;
}
private bool CancelInvocationMessagesEqual(HubMessage x, HubMessage y)
{
return x is CancelInvocationMessage && y is CancelInvocationMessage;

View File

@ -574,8 +574,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
private void AssertMessage(Channel<HubMessage> channel)
{
Assert.True(channel.In.TryRead(out var item));
var message = item as InvocationMessage;
Assert.NotNull(message);
var message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);
@ -583,7 +582,6 @@ namespace Microsoft.AspNetCore.SignalR.Redis.Tests
private class MyHub : Hub
{
}
}
}

View File

@ -30,15 +30,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
Assert.NotNull(message);
var message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);
Assert.True(output2.In.TryRead(out item));
message = item as InvocationMessage;
Assert.NotNull(message);
message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);
@ -66,8 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await manager.InvokeAllAsync("Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
Assert.NotNull(message);
var message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);
@ -97,8 +94,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await manager.InvokeGroupAsync("gunit", "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output1.In.TryRead(out var item));
var message = item as InvocationMessage;
Assert.NotNull(message);
var message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);
@ -121,8 +117,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await manager.InvokeConnectionAsync(connection.ConnectionId, "Hello", new object[] { "World" }).OrTimeout();
Assert.True(output.In.TryRead(out var item));
var message = item as InvocationMessage;
Assert.NotNull(message);
var message = Assert.IsType<InvocationMessage>(item);
Assert.Equal("Hello", message.Target);
Assert.Single(message.Arguments);
Assert.Equal("World", (string)message.Arguments[0]);

View File

@ -232,7 +232,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var endPointTask = endPoint.OnConnectedAsync(client.Connection);
var invocationId = await client.SendInvocationAsync(nameof(ObservableHub.Subscribe), nonBlocking: false).OrTimeout();
var invocationId = await client.SendStreamInvocationAsync(nameof(ObservableHub.Subscribe)).OrTimeout();
await waitForSubscribe.Task.OrTimeout();
@ -1015,7 +1015,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
AssertHubMessage(new StreamItemMessage(string.Empty, "1"), messages[1]);
AssertHubMessage(new StreamItemMessage(string.Empty, "2"), messages[2]);
AssertHubMessage(new StreamItemMessage(string.Empty, "3"), messages[3]);
AssertHubMessage(new StreamCompletionMessage(string.Empty, error: null), messages[4]);
AssertHubMessage(CompletionMessage.Empty(string.Empty), messages[4]);
client.Dispose();
@ -1035,11 +1035,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await client.Connected.OrTimeout();
var invocationId = await client.SendInvocationAsync(nameof(StreamingHub.BlockingStream)).OrTimeout();
var invocationId = Guid.NewGuid().ToString("N");
await client.SendHubMessageAsync(new StreamInvocationMessage(invocationId, nameof(StreamingHub.BlockingStream),
argumentBindingException: null));
// cancel the Streaming method
await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout();
var hubMessage = Assert.IsType<StreamCompletionMessage>(await client.ReadAsync().OrTimeout());
var hubMessage = Assert.IsType<CompletionMessage>(await client.ReadAsync().OrTimeout());
Assert.Equal(invocationId, hubMessage.InvocationId);
Assert.Null(hubMessage.Error);
@ -1221,7 +1224,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
await client.SendInvocationAsync(nameof(MethodHub.BroadcastItem)).OrTimeout();
var message = await client.ReadAsync().OrTimeout() as InvocationMessage;
var message = Assert.IsType<InvocationMessage>(await client.ReadAsync().OrTimeout());
var msgPackObject = Assert.IsType<MessagePackObject>(message.Arguments[0]);
// Custom serialization - object was serialized as an array and not a map
@ -1317,10 +1320,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Assert.Equal(expectedCompletion.HasResult, actualCompletion.HasResult);
Assert.Equal(expectedCompletion.Result, actualCompletion.Result);
break;
case StreamCompletionMessage expectedStreamCompletion:
var actualStreamCompletion = Assert.IsType<StreamCompletionMessage>(actual);
Assert.Equal(expectedStreamCompletion.Error, actualStreamCompletion.Error);
break;
case StreamItemMessage expectedStreamItem:
var actualStreamItem = Assert.IsType<StreamItemMessage>(actual);
Assert.Equal(expectedStreamItem.Item, actualStreamItem.Item);