From 364018238ad8f2846132c0b84c0733bdcd82c32a Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Thu, 10 Jan 2019 21:52:28 -0800 Subject: [PATCH] Refactor streaming from client to server (#4559) --- .../DefaultHubDispatcherBenchmark.cs | 76 ++++++-- .../src/MessagePackHubProtocol.ts | 53 +++--- .../tests/MessagePackHubProtocol.test.ts | 4 + .../clients/ts/signalr/src/HubConnection.ts | 67 ++++--- .../clients/ts/signalr/src/IHubProtocol.ts | 34 +--- src/SignalR/clients/ts/signalr/src/index.ts | 2 +- .../ts/signalr/tests/HubConnection.test.ts | 32 ++-- .../ts/signalr/tests/JsonHubProtocol.test.ts | 1 + src/SignalR/specs/HubProtocol.md | 77 ++++++-- .../HubConnection.cs | 44 +++-- .../Protocol/HubMethodInvocationMessage.cs | 74 +++++++- .../Protocol/HubProtocolConstants.cs | 10 - .../Protocol/StreamCompleteMessage.cs | 41 ---- .../Protocol/StreamDataMessage.cs | 33 ---- .../Protocol/StreamPlaceholder.cs | 25 --- .../Internal/DefaultHubDispatcher.Log.cs | 12 +- .../Internal/DefaultHubDispatcher.cs | 63 ++++--- .../Internal/HubMethodDescriptor.cs | 27 ++- .../StreamTracker.cs | 16 +- .../Protocol/MessagePackHubProtocol.cs | 147 +++++++-------- .../Protocol/NewtonsoftJsonHubProtocol.cs | 133 +++++-------- .../HubConnectionTests.cs | 2 +- .../Hubs.cs | 14 +- .../HubConnectionTests.cs | 20 +- .../Internal/Protocol/JsonHubProtocolTests.cs | 8 +- .../Protocol/MessagePackHubProtocolTests.cs | 76 ++++---- .../Internal/Protocol/TestBinder.cs | 3 - .../TestHubMessageEqualityComparer.cs | 67 ++++--- .../TestClient.cs | 22 ++- .../HubConnectionHandlerTestUtils/Hubs.cs | 6 + .../HubConnectionHandlerTests.cs | 177 +++++++++++++++--- 31 files changed, 772 insertions(+), 594 deletions(-) delete mode 100644 src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamCompleteMessage.cs delete mode 100644 src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamDataMessage.cs delete mode 100644 src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamPlaceholder.cs diff --git a/src/SignalR/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/src/SignalR/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index d7431ed28b..c28babcd19 100644 --- a/src/SignalR/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/src/SignalR/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -80,6 +80,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks public class NoErrorHubConnectionContext : HubConnectionContext { + public TaskCompletionSource ReceivedCompleted = new TaskCompletionSource(); + public NoErrorHubConnectionContext(ConnectionContext connectionContext, TimeSpan keepAliveInterval, ILoggerFactory loggerFactory) : base(connectionContext, keepAliveInterval, loggerFactory) { } @@ -88,6 +90,8 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { if (message is CompletionMessage completionMessage) { + ReceivedCompleted.TrySetResult(null); + if (!string.IsNullOrEmpty(completionMessage.Error)) { throw new Exception("Error invoking hub method: " + completionMessage.Error); @@ -163,72 +167,116 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return channel.Reader; } + + public async Task UploadStream(ChannelReader channelReader) + { + while (await channelReader.WaitToReadAsync()) + { + while (channelReader.TryRead(out var item)) + { + } + } + } } [Benchmark] public Task Invocation() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "Invocation", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "Invocation", Array.Empty())); } [Benchmark] public Task InvocationAsync() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationAsync", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationAsync", Array.Empty())); } [Benchmark] public Task InvocationReturnValue() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnValue", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnValue", Array.Empty())); } [Benchmark] public Task InvocationReturnAsync() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnAsync", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationReturnAsync", Array.Empty())); } [Benchmark] public Task InvocationValueTaskAsync() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationValueTaskAsync", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationValueTaskAsync", Array.Empty())); } [Benchmark] public Task StreamChannelReader() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReader", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReader", Array.Empty())); } [Benchmark] public Task StreamChannelReaderAsync() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderAsync", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderAsync", Array.Empty())); } [Benchmark] public Task StreamChannelReaderValueTaskAsync() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderValueTaskAsync", Array.Empty())); + return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderValueTaskAsync", Array.Empty())); } [Benchmark] - public Task StreamChannelReaderCount_Zero() + public async Task StreamChannelReaderCount_Zero() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 0 })); + await _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 0 })); + + await (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted.Task; + (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted = new TaskCompletionSource(); } [Benchmark] - public Task StreamChannelReaderCount_One() + public async Task StreamChannelReaderCount_One() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 1 })); + await _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 1 })); + + await (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted.Task; + (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted = new TaskCompletionSource(); } [Benchmark] - public Task StreamChannelReaderCount_Thousand() + public async Task StreamChannelReaderCount_Thousand() { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 1000 })); + await _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", new object[] { 1000 })); + + await (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted.Task; + (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted = new TaskCompletionSource(); + } + + [Benchmark] + public async Task UploadStream_One() + { + await _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", nameof(TestHub.UploadStream), Array.Empty(), streamIds: new string[] { "1" })); + await _dispatcher.DispatchMessageAsync(_connectionContext, new StreamItemMessage("1", "test")); + await _dispatcher.DispatchMessageAsync(_connectionContext, CompletionMessage.Empty("1")); + + await (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted.Task; + (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted = new TaskCompletionSource(); + } + + [Benchmark] + public async Task UploadStream_Thousand() + { + await _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", nameof(TestHub.UploadStream), Array.Empty(), streamIds: new string[] { "1" })); + for (var i = 0; i < 1000; ++i) + { + await _dispatcher.DispatchMessageAsync(_connectionContext, new StreamItemMessage("1", "test")); + } + await _dispatcher.DispatchMessageAsync(_connectionContext, CompletionMessage.Empty("1")); + + await (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted.Task; + (_connectionContext as NoErrorHubConnectionContext).ReceivedCompleted = new TaskCompletionSource(); } } } diff --git a/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts b/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts index 78609f89b7..50d3c19d5b 100644 --- a/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts +++ b/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts @@ -4,7 +4,7 @@ import { Buffer } from "buffer"; import * as msgpack5 from "msgpack5"; -import { CompletionMessage, HubMessage, IHubProtocol, ILogger, InvocationMessage, LogLevel, MessageHeaders, MessageType, NullLogger, StreamCompleteMessage, StreamDataMessage, StreamInvocationMessage, StreamItemMessage, TransferFormat } from "@aspnet/signalr"; +import { CompletionMessage, HubMessage, IHubProtocol, ILogger, InvocationMessage, LogLevel, MessageHeaders, MessageType, NullLogger, StreamInvocationMessage, StreamItemMessage, TransferFormat } from "@aspnet/signalr"; import { BinaryMessageFormat } from "./BinaryMessageFormat"; import { isArrayBuffer } from "./Utils"; @@ -25,6 +25,10 @@ export class MessagePackHubProtocol implements IHubProtocol { /** The TransferFormat of the protocol. */ public readonly transferFormat: TransferFormat = TransferFormat.Binary; + private readonly errorResult = 1; + private readonly voidResult = 2; + private readonly nonVoidResult = 3; + /** Creates an array of HubMessage objects from the specified serialized representation. * * @param {ArrayBuffer | Buffer} input An ArrayBuffer containing the serialized representation. @@ -65,15 +69,12 @@ export class MessagePackHubProtocol implements IHubProtocol { return this.writeInvocation(message as InvocationMessage); case MessageType.StreamInvocation: return this.writeStreamInvocation(message as StreamInvocationMessage); - case MessageType.StreamData: - return this.writeStreamData(message as StreamDataMessage); case MessageType.StreamItem: + return this.writeStreamItem(message as StreamItemMessage); case MessageType.Completion: - throw new Error(`Writing messages of type '${message.type}' is not supported.`); + return this.writeCompletion(message as CompletionMessage); case MessageType.Ping: return BinaryMessageFormat.write(SERIALIZED_PING_MESSAGE); - case MessageType.StreamComplete: - return this.writeStreamComplete(message as StreamCompleteMessage); default: throw new Error("Invalid message type."); } @@ -147,6 +148,7 @@ export class MessagePackHubProtocol implements IHubProtocol { arguments: properties[4], headers, invocationId, + streamIds: [], target: properties[3] as string, type: MessageType.Invocation, }; @@ -154,6 +156,7 @@ export class MessagePackHubProtocol implements IHubProtocol { return { arguments: properties[4], headers, + streamIds: [], target: properties[3], type: MessageType.Invocation, }; @@ -181,13 +184,9 @@ export class MessagePackHubProtocol implements IHubProtocol { throw new Error("Invalid payload for Completion message."); } - const errorResult = 1; - const voidResult = 2; - const nonVoidResult = 3; - const resultKind = properties[3]; - if (resultKind !== voidResult && properties.length < 5) { + if (resultKind !== this.voidResult && properties.length < 5) { throw new Error("Invalid payload for Completion message."); } @@ -195,10 +194,10 @@ export class MessagePackHubProtocol implements IHubProtocol { let result: any; switch (resultKind) { - case errorResult: + case this.errorResult: error = properties[4]; break; - case nonVoidResult: + case this.nonVoidResult: result = properties[4]; break; } @@ -217,7 +216,7 @@ export class MessagePackHubProtocol implements IHubProtocol { private writeInvocation(invocationMessage: InvocationMessage): ArrayBuffer { const msgpack = msgpack5(); const payload = msgpack.encode([MessageType.Invocation, invocationMessage.headers || {}, invocationMessage.invocationId || null, - invocationMessage.target, invocationMessage.arguments]); + invocationMessage.target, invocationMessage.arguments, invocationMessage.streamIds]); return BinaryMessageFormat.write(payload.slice()); } @@ -225,23 +224,35 @@ export class MessagePackHubProtocol implements IHubProtocol { private writeStreamInvocation(streamInvocationMessage: StreamInvocationMessage): ArrayBuffer { const msgpack = msgpack5(); const payload = msgpack.encode([MessageType.StreamInvocation, streamInvocationMessage.headers || {}, streamInvocationMessage.invocationId, - streamInvocationMessage.target, streamInvocationMessage.arguments]); + streamInvocationMessage.target, streamInvocationMessage.arguments, streamInvocationMessage.streamIds]); return BinaryMessageFormat.write(payload.slice()); } - private writeStreamData(streamDataMessage: StreamDataMessage): ArrayBuffer { + private writeStreamItem(streamItemMessage: StreamItemMessage): ArrayBuffer { const msgpack = msgpack5(); - const payload = msgpack.encode([MessageType.StreamData, streamDataMessage.streamId, - streamDataMessage.item]); + const payload = msgpack.encode([MessageType.StreamItem, streamItemMessage.headers || {}, streamItemMessage.invocationId, + streamItemMessage.item]); return BinaryMessageFormat.write(payload.slice()); } - private writeStreamComplete(streamCompleteMessage: StreamCompleteMessage): ArrayBuffer { + private writeCompletion(completionMessage: CompletionMessage): ArrayBuffer { const msgpack = msgpack5(); - const payload = msgpack.encode([MessageType.StreamComplete, streamCompleteMessage.streamId, - streamCompleteMessage.error || null]); + const resultKind = completionMessage.error ? this.errorResult : completionMessage.result ? this.nonVoidResult : this.voidResult; + + let payload: any; + switch (resultKind) { + case this.errorResult: + payload = msgpack.encode([MessageType.Completion, completionMessage.headers || {}, completionMessage.invocationId, resultKind, completionMessage.error]); + break; + case this.voidResult: + payload = msgpack.encode([MessageType.Completion, completionMessage.headers || {}, completionMessage.invocationId, resultKind]); + break; + case this.nonVoidResult: + payload = msgpack.encode([MessageType.Completion, completionMessage.headers || {}, completionMessage.invocationId, resultKind, completionMessage.result]); + break; + } return BinaryMessageFormat.write(payload.slice()); } diff --git a/src/SignalR/clients/ts/signalr-protocol-msgpack/tests/MessagePackHubProtocol.test.ts b/src/SignalR/clients/ts/signalr-protocol-msgpack/tests/MessagePackHubProtocol.test.ts index e22dc33fe5..e406614fa0 100644 --- a/src/SignalR/clients/ts/signalr-protocol-msgpack/tests/MessagePackHubProtocol.test.ts +++ b/src/SignalR/clients/ts/signalr-protocol-msgpack/tests/MessagePackHubProtocol.test.ts @@ -9,6 +9,7 @@ describe("MessagePackHubProtocol", () => { const invocation = { arguments: [42, true, "test", ["x1", "y2"], null], headers: {}, + streamIds: [], target: "myMethod", type: MessageType.Invocation, } as InvocationMessage; @@ -22,6 +23,7 @@ describe("MessagePackHubProtocol", () => { const invocation = { arguments: [new Date(Date.UTC(2018, 1, 1, 12, 34, 56))], headers: {}, + streamIds: [], target: "mymethod", type: MessageType.Invocation, } as InvocationMessage; @@ -37,6 +39,7 @@ describe("MessagePackHubProtocol", () => { headers: { foo: "bar", }, + streamIds: [], target: "myMethod", type: MessageType.Invocation, } as InvocationMessage; @@ -51,6 +54,7 @@ describe("MessagePackHubProtocol", () => { arguments: [42, true, "test", ["x1", "y2"], null], headers: {}, invocationId: "123", + streamIds: [], target: "myMethod", type: MessageType.Invocation, } as InvocationMessage; diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index 05d7731beb..9ad3e7032b 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -3,7 +3,7 @@ import { HandshakeProtocol, HandshakeRequestMessage, HandshakeResponseMessage } from "./HandshakeProtocol"; import { IConnection } from "./IConnection"; -import { CancelInvocationMessage, CompletionMessage, IHubProtocol, InvocationMessage, MessageType, StreamCompleteMessage, StreamDataMessage, StreamInvocationMessage, StreamItemMessage } from "./IHubProtocol"; +import { CancelInvocationMessage, CompletionMessage, IHubProtocol, InvocationMessage, MessageType, StreamInvocationMessage, StreamItemMessage } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { IStreamResult } from "./Stream"; import { Subject } from "./Subject"; @@ -30,7 +30,6 @@ export class HubConnection { private callbacks: { [invocationId: string]: (invocationEvent: StreamItemMessage | CompletionMessage | null, error?: Error) => void }; private methods: { [name: string]: Array<(...args: any[]) => void> }; private invocationId: number; - private streamId: number; private closedCallbacks: Array<(error?: Error) => void>; private receivedHandshakeResponse: boolean; private handshakeResolver!: (value?: PromiseLike<{}>) => void; @@ -86,7 +85,6 @@ export class HubConnection { this.methods = {}; this.closedCallbacks = []; this.invocationId = 0; - this.streamId = 0; this.receivedHandshakeResponse = false; this.connectionState = HubConnectionState.Disconnected; @@ -155,8 +153,8 @@ export class HubConnection { * @returns {IStreamResult} An object that yields results from the server as they are received. */ public stream(methodName: string, ...args: any[]): IStreamResult { - const streams = this.replaceStreamingParams(args); - const invocationDescriptor = this.createStreamInvocation(methodName, args); + const [streams, streamIds] = this.replaceStreamingParams(args); + const invocationDescriptor = this.createStreamInvocation(methodName, args, streamIds); const subject = new Subject(); subject.cancelCallback = () => { @@ -219,19 +217,14 @@ export class HubConnection { * @returns {Promise} A Promise that resolves when the invocation has been successfully sent, or rejects with an error. */ public send(methodName: string, ...args: any[]): Promise { - const streams = this.replaceStreamingParams(args); - const sendPromise = this.sendWithProtocol(this.createInvocation(methodName, args, true)); + const [streams, streamIds] = this.replaceStreamingParams(args); + const sendPromise = this.sendWithProtocol(this.createInvocation(methodName, args, true, streamIds)); this.launchStreams(streams, sendPromise); return sendPromise; } - private nextStreamId(): string { - this.streamId += 1; - return this.streamId.toString(); - } - /** Invokes a hub method on the server using the specified name and arguments. * * The Promise returned by this method resolves when the server indicates it has finished invoking the method. When the promise @@ -244,8 +237,8 @@ export class HubConnection { * @returns {Promise} A Promise that resolves with the result of the server method (if any), or rejects with an error. */ public invoke(methodName: string, ...args: any[]): Promise { - const streams = this.replaceStreamingParams(args); - const invocationDescriptor = this.createInvocation(methodName, args, false); + const [streams, streamIds] = this.replaceStreamingParams(args); + const invocationDescriptor = this.createInvocation(methodName, args, false, streamIds); const p = new Promise((resolve, reject) => { // invocationId will always have a value for a non-blocking invocation @@ -519,10 +512,11 @@ export class HubConnection { } } - private createInvocation(methodName: string, args: any[], nonblocking: boolean): InvocationMessage { + private createInvocation(methodName: string, args: any[], nonblocking: boolean, streamIds: string[]): InvocationMessage { if (nonblocking) { return { arguments: args, + streamIds, target: methodName, type: MessageType.Invocation, }; @@ -533,6 +527,7 @@ export class HubConnection { return { arguments: args, invocationId: invocationId.toString(), + streamIds, target: methodName, type: MessageType.Invocation, }; @@ -554,7 +549,7 @@ export class HubConnection { for (const streamId in streams) { streams[streamId].subscribe({ complete: () => { - promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamCompleteMessage(streamId))); + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createCompletionMessage(streamId))); }, error: (err) => { let message: string; @@ -566,31 +561,33 @@ export class HubConnection { message = "Unknown error"; } - promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamCompleteMessage(streamId, message))); + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createCompletionMessage(streamId, message))); }, next: (item) => { - promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamDataMessage(streamId, item))); + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamItemMessage(streamId, item))); }, }); } } - private replaceStreamingParams(args: any[]): Array> { + private replaceStreamingParams(args: any[]): [Array>, string[]] { const streams: Array> = []; + const streamIds: string[] = []; for (let i = 0; i < args.length; i++) { const argument = args[i]; if (this.isObservable(argument)) { - const streamId = this.nextStreamId(); + const streamId = this.invocationId; + this.invocationId++; // Store the stream for later use streams[streamId] = argument; - // Replace the stream with a placeholder - // Use capitalized StreamId because the MessagePack-CSharp library expects exact case for arguments - // Json allows case-insensitive arguments by default - args[i] = { StreamId: streamId }; + streamIds.push(streamId.toString()); + + // remove stream from args + args.splice(i, 1); } } - return streams; + return [streams, streamIds]; } private isObservable(arg: any): arg is IStreamResult { @@ -598,13 +595,14 @@ export class HubConnection { return arg.subscribe && typeof arg.subscribe === "function"; } - private createStreamInvocation(methodName: string, args: any[]): StreamInvocationMessage { + private createStreamInvocation(methodName: string, args: any[], streamIds: string[]): StreamInvocationMessage { const invocationId = this.invocationId; this.invocationId++; return { arguments: args, invocationId: invocationId.toString(), + streamIds, target: methodName, type: MessageType.StreamInvocation, }; @@ -617,26 +615,27 @@ export class HubConnection { }; } - private createStreamDataMessage(id: string, item: any): StreamDataMessage { + private createStreamItemMessage(id: string, item: any): StreamItemMessage { return { + invocationId: id, item, - streamId: id, - type: MessageType.StreamData, + type: MessageType.StreamItem, }; } - private createStreamCompleteMessage(id: string, error?: string): StreamCompleteMessage { + private createCompletionMessage(id: string, error?: any, result?: any): CompletionMessage { if (error) { return { error, - streamId: id, - type: MessageType.StreamComplete, + invocationId: id, + type: MessageType.Completion, }; } return { - streamId: id, - type: MessageType.StreamComplete, + invocationId: id, + result, + type: MessageType.Completion, }; } } diff --git a/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts b/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts index d5c39e5e90..7ef743c3d8 100644 --- a/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts +++ b/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts @@ -20,10 +20,6 @@ export enum MessageType { Ping = 6, /** Indicates the message is a Close message and implements the {@link @aspnet/signalr.CloseMessage} interface. */ Close = 7, - /** Indicates the message is a StreamComplete message and implements the {@link StreamCompleteMessage} interface */ - StreamComplete = 8, - /** Indicates the message is a ParamterStreaming message and implements the {@link StreamDataMessage} interface */ - StreamData = 9, } /** Defines a dictionary of string keys and string values representing headers attached to a Hub message. */ @@ -40,9 +36,7 @@ export type HubMessage = CompletionMessage | CancelInvocationMessage | PingMessage | - CloseMessage | - StreamCompleteMessage | - StreamDataMessage; + CloseMessage; /** Defines properties common to all Hub messages. */ export interface HubMessageBase { @@ -70,6 +64,8 @@ export interface InvocationMessage extends HubInvocationMessage { readonly target: string; /** The target method arguments. */ readonly arguments: any[]; + /** The target methods stream IDs. */ + readonly streamIds: string[]; } /** A hub message representing a streaming invocation. */ @@ -83,6 +79,8 @@ export interface StreamInvocationMessage extends HubInvocationMessage { readonly target: string; /** The target method arguments. */ readonly arguments: any[]; + /** The target methods stream IDs. */ + readonly streamIds: string[]; } /** A hub message representing a single item produced as part of a result stream. */ @@ -97,18 +95,6 @@ export interface StreamItemMessage extends HubInvocationMessage { readonly item?: any; } -/** A hub message representing a single stream item, transferred through a streaming parameter. */ -export interface StreamDataMessage extends HubMessageBase { - /** @inheritDoc */ - readonly type: MessageType.StreamData; - - /** The streamId. */ - readonly streamId: string; - - /** The item produced by the client. */ - readonly item?: any; -} - /** A hub message representing the result of an invocation. */ export interface CompletionMessage extends HubInvocationMessage { /** @inheritDoc */ @@ -155,16 +141,6 @@ export interface CancelInvocationMessage extends HubInvocationMessage { readonly invocationId: string; } -/** A hub message sent to indicate the end of stream items for a streaming parameter. */ -export interface StreamCompleteMessage extends HubMessageBase { - /** @inheritDoc */ - readonly type: MessageType.StreamComplete; - /** The stream ID of the stream to be completed. */ - readonly streamId: string; - /** The error that triggered completion, if any. */ - readonly error?: string; -} - /** A protocol abstraction for communicating with SignalR Hubs. */ export interface IHubProtocol { /** The name of the protocol. This is used by SignalR to resolve the protocol between the client and server. */ diff --git a/src/SignalR/clients/ts/signalr/src/index.ts b/src/SignalR/clients/ts/signalr/src/index.ts index f770e9049e..e1479e72e5 100644 --- a/src/SignalR/clients/ts/signalr/src/index.ts +++ b/src/SignalR/clients/ts/signalr/src/index.ts @@ -14,7 +14,7 @@ export { IHttpConnectionOptions } from "./IHttpConnectionOptions"; export { HubConnection, HubConnectionState } from "./HubConnection"; export { HubConnectionBuilder } from "./HubConnectionBuilder"; export { MessageType, MessageHeaders, HubMessage, HubMessageBase, HubInvocationMessage, InvocationMessage, StreamInvocationMessage, StreamItemMessage, CompletionMessage, - PingMessage, CloseMessage, CancelInvocationMessage, IHubProtocol, StreamDataMessage, StreamCompleteMessage } from "./IHubProtocol"; + PingMessage, CloseMessage, CancelInvocationMessage, IHubProtocol } from "./IHubProtocol"; export { ILogger, LogLevel } from "./ILogger"; export { HttpTransportType, TransferFormat, ITransport } from "./ITransport"; export { IStreamSubscriber, IStreamResult, ISubscription } from "./Stream"; diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index 6f08420c10..c70d51baad 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -113,6 +113,7 @@ describe("HubConnection", () => { "arg", 42, ], + streamIds: [], target: "testMethod", type: MessageType.Invocation, }); @@ -144,6 +145,7 @@ describe("HubConnection", () => { 42, ], invocationId: connection.lastInvocationId, + streamIds: [], target: "testMethod", type: MessageType.Invocation, }); @@ -342,8 +344,9 @@ describe("HubConnection", () => { const invokePromise = hubConnection.invoke("testMethod", "arg", subject); expect(JSON.parse(connection.sentData[1])).toEqual({ - arguments: ["arg", {StreamId: "1"}], - invocationId: "0", + arguments: ["arg"], + invocationId: "1", + streamIds: ["0"], target: "testMethod", type: MessageType.Invocation, }); @@ -353,12 +356,12 @@ describe("HubConnection", () => { setTimeout(resolve, 50); }); expect(JSON.parse(connection.sentData[2])).toEqual({ + invocationId: "0", item: "item numero uno", - streamId: "1", - type: MessageType.StreamData, + type: MessageType.StreamItem, }); - connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId, result: "foo" }); + connection.receive({ type: MessageType.Completion, invocationId: "1", result: "foo" }); expect(await invokePromise).toBe("foo"); } finally { @@ -378,7 +381,8 @@ describe("HubConnection", () => { await hubConnection.send("testMethod", "arg", subject); expect(JSON.parse(connection.sentData[1])).toEqual({ - arguments: ["arg", { StreamId: "1" }], + arguments: ["arg"], + streamIds: ["0"], target: "testMethod", type: MessageType.Invocation, }); @@ -388,9 +392,9 @@ describe("HubConnection", () => { setTimeout(resolve, 50); }); expect(JSON.parse(connection.sentData[2])).toEqual({ + invocationId: "0", item: "item numero uno", - streamId: "1", - type: MessageType.StreamData, + type: MessageType.StreamItem, }); } finally { await hubConnection.stop(); @@ -420,8 +424,9 @@ describe("HubConnection", () => { }); expect(JSON.parse(connection.sentData[1])).toEqual({ - arguments: ["arg", { StreamId: "1" }], - invocationId: "0", + arguments: ["arg"], + invocationId: "1", + streamIds: ["0"], target: "testMethod", type: MessageType.StreamInvocation, }); @@ -431,12 +436,12 @@ describe("HubConnection", () => { setTimeout(resolve, 50); }); expect(JSON.parse(connection.sentData[2])).toEqual({ + invocationId: "0", item: "item numero uno", - streamId: "1", - type: MessageType.StreamData, + type: MessageType.StreamItem, }); - connection.receive({ type: MessageType.StreamItem, invocationId: connection.lastInvocationId, item: "foo" }); + connection.receive({ type: MessageType.StreamItem, invocationId: "1", item: "foo" }); expect(streamItem).toEqual("foo"); expect(streamError).toBe(null); @@ -891,6 +896,7 @@ describe("HubConnection", () => { 42, ], invocationId: connection.lastInvocationId, + streamIds: [], target: "testStream", type: MessageType.StreamInvocation, }); diff --git a/src/SignalR/clients/ts/signalr/tests/JsonHubProtocol.test.ts b/src/SignalR/clients/ts/signalr/tests/JsonHubProtocol.test.ts index 2c3dc16828..e5fd7aaab5 100644 --- a/src/SignalR/clients/ts/signalr/tests/JsonHubProtocol.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/JsonHubProtocol.test.ts @@ -47,6 +47,7 @@ describe("JsonHubProtocol", () => { headers: { foo: "bar", }, + streamIds: [], target: "myMethod", type: MessageType.Invocation, } as InvocationMessage; diff --git a/src/SignalR/specs/HubProtocol.md b/src/SignalR/specs/HubProtocol.md index 448a938740..d0a94886e4 100644 --- a/src/SignalR/specs/HubProtocol.md +++ b/src/SignalR/specs/HubProtocol.md @@ -27,8 +27,8 @@ In the SignalR protocol, the following types of messages can be sent: | `Close` | Callee, Caller | Sent by the server when a connection is closed. Contains an error if the connection was closed because of an error. | | `Invocation` | Caller | Indicates a request to invoke a particular method (the Target) with provided Arguments on the remote endpoint. | | `StreamInvocation` | Caller | Indicates a request to invoke a streaming method (the Target) with provided Arguments on the remote endpoint. | -| `StreamItem` | Callee | Indicates individual items of streamed response data from a previous `StreamInvocation` message. | -| `Completion` | Callee | Indicates a previous `Invocation` or `StreamInvocation` has completed. Contains an error if the invocation concluded with an error or the result of a non-streaming method invocation. The result will be absent for `void` methods. In case of streaming invocations no further `StreamItem` messages will be received. | +| `StreamItem` | Callee, Caller | Indicates individual items of streamed response data from a previous `StreamInvocation` message or streamed uploads from an invocation with streamIds. | +| `Completion` | Callee, Caller | Indicates a previous `Invocation` or `StreamInvocation` has completed or a stream in an `Invocation` or `StreamInvocation` has completed. Contains an error if the invocation concluded with an error or the result of a non-streaming method invocation. The result will be absent for `void` methods. In case of streaming invocations no further `StreamItem` messages will be received. | | `CancelInvocation` | Caller | Sent by the client to cancel a streaming invocation on the server. | | `Ping` | Caller, Callee | Sent by either party to check if the connection is active. | @@ -101,6 +101,10 @@ On the Callee side, it is up to the Callee's Binder to determine if a method cal On the Caller side, the user code which performs the invocation indicates how it would like to receive the results and it is up the Caller's Binder to handle the result. If the Caller expects only a single result, but multiple results are returned, or if the caller expects multiple results but only one result is returned, the Caller's Binder should yield an error. If the Caller wants to stop receiving `StreamItem` messages before the Callee sends a `Completion` message, the Caller can send a `CancelInvocation` message with the same `Invocation ID` used for the `StreamInvocation` message that started the stream. When the Callee receives a `CancelInvocation` message it will stop sending `StreamItem` messages and will send a `Completion` message. The Caller is free to ignore any `StreamItem` messages as well as the `Completion` message after sending `CancelInvocation`. +## Upload streaming + +The Caller can send streaming data to the Callee, they can begin such a process by making an `Invocation` or `StreamInvocation` and adding a "StreamIds" property with an array of IDs that will represent the stream(s) associated with the invocation. The IDs must be unique from any other stream IDs used by the same Caller. The Caller then sends `StreamItem` messages with the "InvocationId" property set to the ID for the stream they are sending over. The Caller can end the stream by sending a `Completion` message with the ID of the stream they are completing. If the Callee sends a `Completion` the Caller should stop sending `StreamItem` and `Completion` messages, and the Callee is free to ignore any `StreamItem` and `Completion` messages that are sent after the invocation has completed. + ## Completion and results An Invocation is only considered completed when the `Completion` message is received. Receiving **any** message using the same `Invocation ID` after a `Completion` message has been received for that invocation is considered a protocol error and the recipient may immediately terminate the connection. @@ -180,6 +184,20 @@ public void NonBlocking(string caller) { _callers.Add(caller); } + +public async Task AddStream(ChannelReader stream) +{ + int sum = 0; + while (await stream.WaitToReadAsync()) + { + while (stream.TryRead(out var item)) + { + sum += item; + } + } + + return sum; +} ``` In each of the below examples, lines starting `C->S` indicate messages sent from the Caller ("Client") to the Callee ("Server"), and lines starting `S->C` indicate messages sent from the Callee ("Server") back to the Caller ("Client"). Message syntax is just a pseudo-code and is not intended to match any particular encoding. @@ -269,6 +287,17 @@ S->C: Completion { Id = 42 } // This can be ignored C->S: Invocation { Target = "NonBlocking", Arguments = [ "foo" ] } ``` +### Stream from Client to Server (`AddStream` example above) + +``` +C->S: Invocation { Id = 42, Target = "AddStream", Arguments = [ ], StreamIds = [ 1 ] } +C->S: StreamItem { Id = 1, Item = 1 } +C->S: StreamItem { Id = 1, Item = 2 } +C->S: StreamItem { Id = 1, Item = 3 } +C->S: Completion { Id = 1 } +S->C: Completion { Id = 42, Result = 6 } +``` + ### Ping ``` @@ -289,6 +318,7 @@ An `Invocation` message is a JSON object with the following properties: * `invocationId` - An optional `String` encoding the `Invocation ID` for a message. * `target` - A `String` encoding the `Target` name, as expected by the Callee's Binder * `arguments` - An `Array` containing arguments to apply to the method referred to in Target. This is a sequence of JSON `Token`s, encoded as indicated below in the "JSON Payload Encoding" section +* `streamIds` - An optional `Array` of strings representing unique ids for streams coming from the Caller to the Callee and being consumed by the method referred to in Target. Example: @@ -316,6 +346,22 @@ Example (Non-Blocking): } ``` +Example (Invocation with stream from Caller): + +```json +{ + "type": 1, + "invocationId": "123", + "target": "Send", + "arguments": [ + 42 + ], + "streamIds": [ + "1" + ] +} +``` + ### StreamInvocation Message Encoding A `StreamInvocation` message is a JSON object with the following properties: @@ -324,6 +370,7 @@ A `StreamInvocation` message is a JSON object with the following properties: * `invocationId` - A `String` encoding the `Invocation ID` for a message. * `target` - A `String` encoding the `Target` name, as expected by the Callee's Binder. * `arguments` - An `Array` containing arguments to apply to the method referred to in Target. This is a sequence of JSON `Token`s, encoded as indicated below in the "JSON Payload Encoding" section. +* `streamIds` - An optional `Array` of strings representing unique ids for streams coming from the Caller to the Callee and being consumed by the method referred to in Target. Example: @@ -490,7 +537,7 @@ MessagePack uses different formats to encode values. Refer to the [MsgPack forma `Invocation` messages have the following structure: ``` -[1, Headers, InvocationId, NonBlocking, Target, [Arguments]] +[1, Headers, InvocationId, NonBlocking, Target, [Arguments], [StreamIds]] ``` * `1` - Message Type - `1` indicates this is an `Invocation` message. @@ -500,18 +547,19 @@ MessagePack uses different formats to encode values. Refer to the [MsgPack forma * A `String` encoding the Invocation ID for the message. * Target - A `String` encoding the Target name, as expected by the Callee's Binder. * Arguments - An Array containing arguments to apply to the method referred to in Target. +* StreamIds - An `Array` of strings representing unique ids for streams coming from the Caller to the Callee and being consumed by the method referred to in Target. #### Example: The following payload ``` -0x94 0x01 0x80 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a +0x96 0x01 0x80 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a 0x90 ``` is decoded as follows: -* `0x95` - 5-element array +* `0x96` - 6-element array * `0x01` - `1` (Message Type - `Invocation` message) * `0x80` - Map of length 0 (Headers) * `0xa3` - string of length 3 (InvocationId) @@ -527,17 +575,18 @@ is decoded as follows: * `0x64` - `d` * `0x91` - 1-element array (Arguments) * `0x2a` - `42` (Argument value) +* `0x90` - 0-element array (StreamIds) #### Non-Blocking Example: The following payload ``` -0x95 0x01 0x80 0xc0 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a +0x96 0x01 0x80 0xc0 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a 0x90 ``` is decoded as follows: -* `0x95` - 5-element array +* `0x96` - 6-element array * `0x01` - `1` (Message Type - `Invocation` message) * `0x80` - Map of length 0 (Headers) * `0xc0` - `nil` (Invocation ID) @@ -550,13 +599,14 @@ is decoded as follows: * `0x64` - `d` * `0x91` - 1-element array (Arguments) * `0x2a` - `42` (Argument value) +* `0x90` - 0-element array (StreamIds) ### StreamInvocation Message Encoding `StreamInvocation` messages have the following structure: ``` -[4, Headers, InvocationId, Target, [Arguments]] +[4, Headers, InvocationId, Target, [Arguments], [StreamIds]] ``` * `4` - Message Type - `4` indicates this is a `StreamInvocation` message. @@ -564,18 +614,19 @@ is decoded as follows: * InvocationId - A `String` encoding the Invocation ID for the message. * Target - A `String` encoding the Target name, as expected by the Callee's Binder. * Arguments - An Array containing arguments to apply to the method referred to in Target. +* StreamIds - An `Array` of strings representing unique ids for streams coming from the Caller to the Callee and being consumed by the method referred to in Target. Example: The following payload ``` -0x95 0x04 0x80 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a +0x96 0x04 0x80 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a 0x90 ``` is decoded as follows: -* `0x95` - 5-element array +* `0x96` - 6-element array * `0x04` - `4` (Message Type - `StreamInvocation` message) * `0x80` - Map of length 0 (Headers) * `0xa3` - string of length 3 (InvocationId) @@ -591,6 +642,7 @@ is decoded as follows: * `0x64` - `d` * `0x91` - 1-element array (Arguments) * `0x2a` - `42` (Argument value) +* `0x90` - 0-element array (StreamIds) ### StreamItem Message Encoding @@ -795,12 +847,12 @@ Headers are not valid in a Ping message. The Ping message is **always exactly en Below shows an example encoding of a message containing headers: ``` -0x95 0x01 0x82 0xa1 0x78 0xa1 0x79 0xa1 0x7a 0xa1 0x7a 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a +0x96 0x01 0x82 0xa1 0x78 0xa1 0x79 0xa1 0x7a 0xa1 0x7a 0xa3 0x78 0x79 0x7a 0xa6 0x6d 0x65 0x74 0x68 0x6f 0x64 0x91 0x2a 0x90 ``` and is decoded as follows: -* `0x95` - 5-element array +* `0x96` - 6-element array * `0x01` - `1` (Message Type - `Invocation` message) * `0x82` - Map of length 2 * `0xa1` - string of length 1 (Key) @@ -824,6 +876,7 @@ and is decoded as follows: * `0x64` - `d` * `0x91` - 1-element array (Arguments) * `0x2a` - `42` (Argument value) +* `0x90` - 0-element array (StreamIds) and interpreted as an Invocation message with headers: `'x' = 'y'` and `'z' = 'z'`. diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index 46ac5c4ca6..ce525f204f 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -420,7 +420,7 @@ namespace Microsoft.AspNetCore.SignalR.Client irq.Dispose(); } - var readers = PackageStreamingParams(args); + var readers = PackageStreamingParams(ref args, out var streamIds); CheckDisposed(); await WaitConnectionLockAsync(); @@ -434,7 +434,7 @@ namespace Microsoft.AspNetCore.SignalR.Client // I just want an excuse to use 'irq' as a variable name... var irq = InvocationRequest.Stream(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out channel); - await InvokeStreamCore(methodName, irq, args, cancellationToken); + await InvokeStreamCore(methodName, irq, args, streamIds?.ToArray(), cancellationToken); if (cancellationToken.CanBeCanceled) { @@ -451,10 +451,12 @@ namespace Microsoft.AspNetCore.SignalR.Client return channel; } - private Dictionary PackageStreamingParams(object[] args) + private Dictionary PackageStreamingParams(ref object[] args, out List streamIds) { // lazy initialized, to avoid allocating unecessary dictionaries Dictionary readers = null; + streamIds = null; + var newArgs = new List(args.Length); for (var i = 0; i < args.Length; i++) { @@ -465,14 +467,26 @@ namespace Microsoft.AspNetCore.SignalR.Client readers = new Dictionary(); } - var id = _connectionState.GetNextStreamId(); + var id = _connectionState.GetNextId(); readers[id] = args[i]; - args[i] = new StreamPlaceholder(id); + + if (streamIds == null) + { + streamIds = new List(); + } + + streamIds.Add(id); Log.StartingStream(_logger, id); } + else + { + newArgs.Add(args[i]); + } } + args = newArgs.ToArray(); + return readers; } @@ -510,7 +524,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item)) { - await SendWithLock(new StreamDataMessage(streamId, item)); + await SendWithLock(new StreamItemMessage(streamId, item)); Log.SendingStreamItem(_logger, streamId); } } @@ -522,12 +536,12 @@ namespace Microsoft.AspNetCore.SignalR.Client } Log.CompletingStream(_logger, streamId); - await SendWithLock(new StreamCompleteMessage(streamId, responseError)); + await SendWithLock(CompletionMessage.WithError(streamId, responseError)); } private async Task InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken) { - var readers = PackageStreamingParams(args); + var readers = PackageStreamingParams(ref args, out var streamIds); CheckDisposed(); await WaitConnectionLockAsync(); @@ -539,7 +553,7 @@ namespace Microsoft.AspNetCore.SignalR.Client CheckConnectionActive(nameof(InvokeCoreAsync)); var irq = InvocationRequest.Invoke(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out invocationTask); - await InvokeCore(methodName, irq, args, cancellationToken); + await InvokeCore(methodName, irq, args, streamIds?.ToArray(), cancellationToken); } finally { @@ -552,12 +566,12 @@ namespace Microsoft.AspNetCore.SignalR.Client return await invocationTask; } - private async Task InvokeCore(string methodName, InvocationRequest irq, object[] args, CancellationToken cancellationToken) + private async Task InvokeCore(string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) { Log.PreparingBlockingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); // Client invocations are always blocking - var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args); + var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args, streams); Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); _connectionState.AddInvocation(irq); @@ -577,13 +591,13 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private async Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args, CancellationToken cancellationToken) + private async Task InvokeStreamCore(string methodName, InvocationRequest irq, object[] args, string[] streams, CancellationToken cancellationToken) { AssertConnectionValid(); Log.PreparingStreamingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length); - var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName, args); + var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName, args, streams); Log.RegisteringInvocation(_logger, invocationMessage.InvocationId); @@ -622,10 +636,10 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task SendCoreAsyncCore(string methodName, object[] args, CancellationToken cancellationToken) { - var readers = PackageStreamingParams(args); + var readers = PackageStreamingParams(ref args, out var streamIds); Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length); - var invocationMessage = new InvocationMessage(null, methodName, args); + var invocationMessage = new InvocationMessage(null, methodName, args, streamIds?.ToArray()); await SendWithLock(invocationMessage, callerName: nameof(SendCoreAsync)); LaunchStreams(readers, cancellationToken); diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubMethodInvocationMessage.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubMethodInvocationMessage.cs index 4b0b0ad079..a2cb0c58de 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubMethodInvocationMessage.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubMethodInvocationMessage.cs @@ -21,6 +21,24 @@ namespace Microsoft.AspNetCore.SignalR.Protocol /// public object[] Arguments { get; } + /// + /// The target methods stream IDs. + /// + public string[] StreamIds { get; } + + /// + /// Initializes a new instance of the class. + /// + /// The invocation ID. + /// The target method name. + /// The target method arguments. + /// The target methods stream IDs. + protected HubMethodInvocationMessage(string invocationId, string target, object[] arguments, string[] streamIds) + : this(invocationId, target, arguments) + { + StreamIds = streamIds; + } + /// /// Initializes a new instance of the class. /// @@ -32,7 +50,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { if (string.IsNullOrEmpty(target)) { - throw new ArgumentNullException(nameof(target)); + throw new ArgumentException(nameof(target)); } Target = target; @@ -66,10 +84,23 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { } + /// + /// Initializes a new instance of the class. + /// + /// The invocation ID. + /// The target method name. + /// The target method arguments. + /// The target methods stream IDs. + public InvocationMessage(string invocationId, string target, object[] arguments, string[] streamIds) + : base(invocationId, target, arguments, streamIds) + { + } + /// public override string ToString() { string args; + string streamIds; try { args = string.Join(", ", Arguments?.Select(a => a?.ToString())); @@ -78,7 +109,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { args = $"Error: {ex.Message}"; } - return $"InvocationMessage {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {args} ] }}"; + + try + { + streamIds = string.Join(", ", StreamIds != null ? StreamIds.Select(id => id?.ToString()) : Array.Empty()); + } + catch (Exception ex) + { + streamIds = $"Error: {ex.Message}"; + } + + return $"InvocationMessage {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {args} ], {nameof(StreamIds)}: [ {streamIds} ] }}"; } } @@ -96,16 +137,25 @@ namespace Microsoft.AspNetCore.SignalR.Protocol public StreamInvocationMessage(string invocationId, string target, object[] arguments) : base(invocationId, target, arguments) { - if (string.IsNullOrEmpty(invocationId)) - { - throw new ArgumentNullException(nameof(invocationId)); - } + } + + /// + /// Initializes a new instance of the class. + /// + /// The invocation ID. + /// The target method name. + /// The target method arguments. + /// The target methods stream IDs. + public StreamInvocationMessage(string invocationId, string target, object[] arguments, string[] streamIds) + : base(invocationId, target, arguments, streamIds) + { } /// public override string ToString() { string args; + string streamIds; try { args = string.Join(", ", Arguments?.Select(a => a?.ToString())); @@ -114,7 +164,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { args = $"Error: {ex.Message}"; } - return $"StreamInvocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {args} ] }}"; + + try + { + streamIds = string.Join(", ", StreamIds != null ? StreamIds.Select(id => id?.ToString()) : Array.Empty()); + } + catch (Exception ex) + { + streamIds = $"Error: {ex.Message}"; + } + + return $"StreamInvocation {{ {nameof(InvocationId)}: \"{InvocationId}\", {nameof(Target)}: \"{Target}\", {nameof(Arguments)}: [ {args} ], {nameof(StreamIds)}: [ {streamIds} ] }}"; } } } diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubProtocolConstants.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubProtocolConstants.cs index 25fbf6dbbc..ce1e3cbfd5 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubProtocolConstants.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/HubProtocolConstants.cs @@ -42,15 +42,5 @@ namespace Microsoft.AspNetCore.SignalR.Protocol /// Represents the close message type. /// public const int CloseMessageType = 7; - - /// - /// Represents the stream complete message type. - /// - public const int StreamCompleteMessageType = 8; - - /// - /// Same as StreamItemMessage, except - /// - public const int StreamDataMessageType = 9; } } diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamCompleteMessage.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamCompleteMessage.cs deleted file mode 100644 index 587764aa72..0000000000 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamCompleteMessage.cs +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.AspNetCore.SignalR.Protocol -{ - /// - /// A message for indicating that a particular stream has ended. - /// - public class StreamCompleteMessage : HubMessage - { - /// - /// Gets the stream id. - /// - public string StreamId { get; } - - /// - /// Gets the error. Will be null if there is no error. - /// - public string Error { get; } - - /// - /// Whether the message has an error. - /// - public bool HasError { get => Error != null; } - - /// - /// Initializes a new instance of - /// - /// The streamId of the stream to complete. - /// An optional error field. - public StreamCompleteMessage(string streamId, string error = null) - { - StreamId = streamId; - Error = error; - } - } -} diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamDataMessage.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamDataMessage.cs deleted file mode 100644 index 6862ed96a2..0000000000 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamDataMessage.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -namespace Microsoft.AspNetCore.SignalR.Protocol -{ - /// - /// Sent to parameter streams. - /// Similar to , except the data is sent to a parameter stream, rather than in response to an invocation. - /// - public class StreamDataMessage : HubMessage - { - /// - /// The piece of data this message carries. - /// - public object Item { get; } - - /// - /// The stream to which to deliver data. - /// - public string StreamId { get; } - - public StreamDataMessage(string streamId, object item) - { - StreamId = streamId; - Item = item; - } - - public override string ToString() - { - return $"StreamDataMessage {{ {nameof(StreamId)}: \"{StreamId}\", {nameof(Item)}: {Item ?? "<>"} }}"; - } - } -} diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamPlaceholder.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamPlaceholder.cs deleted file mode 100644 index f111e90cba..0000000000 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Common/Protocol/StreamPlaceholder.cs +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.Text; - -namespace Microsoft.AspNetCore.SignalR.Protocol -{ - /// - /// Used by protocol serializers/deserializers to transfer information about streaming parameters. - /// Is packed as an argument in the form `{"streamId": "42"}`, and sent over wire. - /// Is then unpacked on the other side, and a new channel is created and saved under the streamId. - /// Then, each is routed to the appropiate channel based on streamId. - /// - public class StreamPlaceholder - { - public string StreamId { get; private set; } - - public StreamPlaceholder(string streamId) - { - StreamId = streamId; - } - } -} diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs index 10fb4da955..f532a94e27 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs @@ -149,9 +149,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal _invalidReturnValueFromStreamingMethod(logger, hubMethod, null); } - public static void ReceivedStreamItem(ILogger logger, StreamDataMessage message) + public static void ReceivedStreamItem(ILogger logger, StreamItemMessage message) { - _receivedStreamItem(logger, message.StreamId, null); + _receivedStreamItem(logger, message.InvocationId, null); } public static void StartingParameterStream(ILogger logger, string streamId) @@ -159,14 +159,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal _startingParameterStream(logger, streamId, null); } - public static void CompletingStream(ILogger logger, StreamCompleteMessage message) + public static void CompletingStream(ILogger logger, CompletionMessage message) { - _completingStream(logger, message.StreamId, null); + _completingStream(logger, message.InvocationId, null); } - public static void ClosingStreamWithBindingError(ILogger logger, StreamCompleteMessage message) + public static void ClosingStreamWithBindingError(ILogger logger, CompletionMessage message) { - _closingStreamWithBindingError(logger, message.StreamId, message.Error, null); + _closingStreamWithBindingError(logger, message.InvocationId, message.Error, null); } } } diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index d65b33ad07..7f8706bdd1 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -113,11 +113,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal connection.StartClientTimeout(); break; - case StreamDataMessage streamItem: + case StreamItemMessage streamItem: Log.ReceivedStreamItem(_logger, streamItem); return ProcessStreamItem(connection, streamItem); - case StreamCompleteMessage streamCompleteMessage: + case CompletionMessage streamCompleteMessage: // closes channels, removes from Lookup dict // user's method can see the channel is complete and begin wrapping up Log.CompletingStream(_logger, streamCompleteMessage); @@ -149,14 +149,14 @@ namespace Microsoft.AspNetCore.SignalR.Internal "Failed to bind Stream message.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); - var message = new StreamCompleteMessage(bindingFailureMessage.Id, errorString); + var message = CompletionMessage.WithError(bindingFailureMessage.Id, errorString); Log.ClosingStreamWithBindingError(_logger, message); connection.StreamTracker.Complete(message); return Task.CompletedTask; } - private Task ProcessStreamItem(HubConnectionContext connection, StreamDataMessage message) + private Task ProcessStreamItem(HubConnectionContext connection, StreamItemMessage message) { Log.ReceivedStreamItem(_logger, message); return connection.StreamTracker.ProcessItem(message); @@ -174,7 +174,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal } else { - bool isStreamCall = descriptor.HasStreamingParameters; + bool isStreamCall = descriptor.StreamingParameters != null; return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall); } } @@ -206,26 +206,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal hubActivator = scope.ServiceProvider.GetRequiredService>(); hub = hubActivator.Create(); - if (isStreamCall) - { - // swap out placeholders for channels - var args = hubMethodInvocationMessage.Arguments; - for (int i = 0; i < args.Length; i++) - { - var placeholder = args[i] as StreamPlaceholder; - if (placeholder == null) - { - continue; - } - - Log.StartingParameterStream(_logger, placeholder.StreamId); - var itemType = methodExecutor.MethodParameters[i].ParameterType.GetGenericArguments()[0]; - args[i] = connection.StreamTracker.AddStream(placeholder.StreamId, itemType); - } - } - try { + var clientStreamLength = hubMethodInvocationMessage.StreamIds?.Length ?? 0; + var serverStreamLength = descriptor.StreamingParameters?.Count ?? 0; + if (clientStreamLength != serverStreamLength) + { + throw new HubException($"Client sent {clientStreamLength} stream(s), Hub method expects {serverStreamLength}."); + } + InitializeHub(hub, connection); Task invocation = null; @@ -236,6 +225,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal // In order to add the synthetic arguments we need a new array because the invocation array is too small (it doesn't know about synthetic arguments) arguments = new object[descriptor.OriginalParameterTypes.Count]; + var streamPointer = 0; var hubInvocationArgumentPointer = 0; for (var parameterPointer = 0; parameterPointer < arguments.Length; parameterPointer++) { @@ -248,12 +238,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal } else { - // This is the only synthetic argument type we currently support if (descriptor.OriginalParameterTypes[parameterPointer] == typeof(CancellationToken)) { cts = CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted); arguments[parameterPointer] = cts.Token; } + else if (isStreamCall && ReflectionHelper.IsStreamingType(descriptor.OriginalParameterTypes[parameterPointer], mustBeDirectType: true)) + { + Log.StartingParameterStream(_logger, hubMethodInvocationMessage.StreamIds[streamPointer]); + var itemType = descriptor.StreamingParameters[streamPointer]; + arguments[parameterPointer] = connection.StreamTracker.AddStream(hubMethodInvocationMessage.StreamIds[streamPointer], itemType); + streamPointer++; + } else { // This should never happen @@ -302,6 +298,25 @@ namespace Microsoft.AspNetCore.SignalR.Internal ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); return; } + finally + { + // Stream response handles cleanup in StreamResultsAsync + // And normal invocations handle cleanup below in the finally + if (isStreamCall) + { + hubActivator?.Release(hub); + scope.Dispose(); + foreach (var stream in hubMethodInvocationMessage.StreamIds) + { + try + { + connection.StreamTracker.Complete(CompletionMessage.Empty(stream)); + } + // ignore failures, it means the client already completed the streams + catch { } + } + } + } await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs index 95421780e8..dec2e67aaf 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs @@ -9,7 +9,6 @@ using System.Reflection; using System.Threading; using System.Threading.Channels; using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.SignalR.Protocol; using Microsoft.Extensions.Internal; namespace Microsoft.AspNetCore.SignalR.Internal @@ -43,8 +42,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal HasSyntheticArguments = true; return false; } + else if (ReflectionHelper.IsStreamingType(p.ParameterType, mustBeDirectType: true)) + { + if (StreamingParameters == null) + { + StreamingParameters = new List(); + } + + StreamingParameters.Add(p.ParameterType.GetGenericArguments()[0]); + HasSyntheticArguments = true; + return false; + } return true; - }).Select(GetParameterType).ToArray(); + }).Select(p => p.ParameterType).ToArray(); if (HasSyntheticArguments) { @@ -54,7 +64,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal Policies = policies.ToArray(); } - public bool HasStreamingParameters { get; private set; } + public List StreamingParameters { get; private set; } private Func> _convertToEnumerator; @@ -76,17 +86,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal public bool HasSyntheticArguments { get; private set; } - private Type GetParameterType(ParameterInfo p) - { - var type = p.ParameterType; - if (ReflectionHelper.IsStreamingType(type, mustBeDirectType: true)) - { - HasStreamingParameters = true; - return typeof(StreamPlaceholder); - } - return type; - } - private static bool IsChannelType(Type type, out Type payloadType) { var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>)); diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/StreamTracker.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/StreamTracker.cs index 3d36e38c0b..2d7432ebd6 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/StreamTracker.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/StreamTracker.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -39,24 +39,24 @@ namespace Microsoft.AspNetCore.SignalR } } - public Task ProcessItem(StreamDataMessage message) + public Task ProcessItem(StreamItemMessage message) { - return TryGetConverter(message.StreamId).WriteToStream(message.Item); + return TryGetConverter(message.InvocationId).WriteToStream(message.Item); } - + public Type GetStreamItemType(string streamId) { return TryGetConverter(streamId).GetItemType(); } - public void Complete(StreamCompleteMessage message) + public void Complete(CompletionMessage message) { - _lookup.TryRemove(message.StreamId, out var converter); + _lookup.TryRemove(message.InvocationId, out var converter); if (converter == null) { - throw new KeyNotFoundException($"No stream with id '{message.StreamId}' could be found."); + throw new KeyNotFoundException($"No stream with id '{message.InvocationId}' could be found."); } - converter.TryComplete(message.HasError ? new Exception(message.Error) : null); + converter.TryComplete(message.HasResult || message.Error == null ? null : new Exception(message.Error)); } private static IStreamConverter BuildStream() diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs index adeac3278c..7eb2fcd651 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs @@ -30,7 +30,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol private static readonly string ProtocolName = "messagepack"; private static readonly int ProtocolVersion = 1; private static readonly int ProtocolMinorVersion = 0; - + /// public string Name => ProtocolName; @@ -121,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver) { - MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize); + var itemCount = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize); startOffset += readSize; var messageType = ReadInt32(input, ref startOffset, "messageType"); @@ -129,11 +129,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol switch (messageType) { case HubProtocolConstants.InvocationMessageType: - return CreateInvocationMessage(input, ref startOffset, binder, resolver); + return CreateInvocationMessage(input, ref startOffset, binder, resolver, itemCount); case HubProtocolConstants.StreamInvocationMessageType: - return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver); - case HubProtocolConstants.StreamDataMessageType: - return CreateStreamDataMessage(input, ref startOffset, binder, resolver); + return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver, itemCount); case HubProtocolConstants.StreamItemMessageType: return CreateStreamItemMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.CompletionMessageType: @@ -144,15 +142,13 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: return CreateCloseMessage(input, ref startOffset); - case HubProtocolConstants.StreamCompleteMessageType: - return CreateStreamCompleteMessage(input, ref startOffset); default: // Future protocol changes can add message types, old clients can ignore them return null; } } - private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); @@ -166,42 +162,52 @@ namespace Microsoft.AspNetCore.SignalR.Protocol var target = ReadString(input, ref offset, "target"); + object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - var arguments = BindArguments(input, ref offset, parameterTypes, resolver); - return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments)); + arguments = BindArguments(input, ref offset, parameterTypes, resolver); } catch (Exception ex) { return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); } + + string[] streams = null; + // Previous clients will send 5 items, so we check if they sent a stream array or not + if (itemCount > 5) + { + streams = ReadStreamIds(input, ref offset); + } + + return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams)); } - private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) { var headers = ReadHeaders(input, ref offset); var invocationId = ReadInvocationId(input, ref offset); var target = ReadString(input, ref offset, "target"); + object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - var arguments = BindArguments(input, ref offset, parameterTypes, resolver); - return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments)); + arguments = BindArguments(input, ref offset, parameterTypes, resolver); } catch (Exception ex) { return new InvocationBindingFailureMessage(invocationId, target, ExceptionDispatchInfo.Capture(ex)); } - } - private static StreamDataMessage CreateStreamDataMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) - { - var streamId = ReadString(input, ref offset, "streamId"); - var itemType = binder.GetStreamItemType(streamId); - var value = DeserializeObject(input, ref offset, itemType, "item", resolver); - return new StreamDataMessage(streamId, value); + string[] streams = null; + // Previous clients will send 5 items, so we check if they sent a stream array or not + if (itemCount > 5) + { + streams = ReadStreamIds(input, ref offset); + } + + return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams)); } private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) @@ -256,17 +262,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return new CloseMessage(error); } - private static StreamCompleteMessage CreateStreamCompleteMessage(byte[] input, ref int offset) - { - var streamId = ReadString(input, ref offset, "streamId"); - var error = ReadString(input, ref offset, "error"); - if (string.IsNullOrEmpty(error)) - { - error = null; - } - return new StreamCompleteMessage(streamId, error); - } - private static Dictionary ReadHeaders(byte[] input, ref int offset) { var headerCount = ReadMapLength(input, ref offset, "headers"); @@ -289,6 +284,24 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } + private static string[] ReadStreamIds(byte[] input, ref int offset) + { + var streamIdCount = ReadArrayLength(input, ref offset, "streamIds"); + List streams = null; + + if (streamIdCount > 0) + { + streams = new List(); + for (var i = 0; i < streamIdCount; i++) + { + streams.Add(MessagePackBinary.ReadString(input, offset, out var read)); + offset += read; + } + } + + return streams?.ToArray(); + } + private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList parameterTypes, IFormatterResolver resolver) { var argumentCount = ReadArrayLength(input, ref offset, "arguments"); @@ -384,9 +397,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case StreamInvocationMessage streamInvocationMessage: WriteStreamInvocationMessage(streamInvocationMessage, packer); break; - case StreamDataMessage streamDataMessage: - WriteStreamDataMessage(streamDataMessage, packer); - break; case StreamItemMessage streamItemMessage: WriteStreamingItemMessage(streamItemMessage, packer); break; @@ -402,9 +412,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case CloseMessage closeMessage: WriteCloseMessage(closeMessage, packer); break; - case StreamCompleteMessage m: - WriteStreamCompleteMessage(m, packer); - break; default: throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); } @@ -412,7 +419,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol private void WriteInvocationMessage(InvocationMessage message, Stream packer) { - MessagePackBinary.WriteArrayHeader(packer, 5); + MessagePackBinary.WriteArrayHeader(packer, 6); + MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType); PackHeaders(packer, message.Headers); if (string.IsNullOrEmpty(message.InvocationId)) @@ -429,11 +437,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { WriteArgument(arg, packer); } + + WriteStreamIds(message.StreamIds, packer); } private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer) { - MessagePackBinary.WriteArrayHeader(packer, 5); + MessagePackBinary.WriteArrayHeader(packer, 6); + MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType); PackHeaders(packer, message.Headers); MessagePackBinary.WriteString(packer, message.InvocationId); @@ -444,14 +455,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { WriteArgument(arg, packer); } - } - private void WriteStreamDataMessage(StreamDataMessage message, Stream packer) - { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamDataMessageType); - MessagePackBinary.WriteString(packer, message.StreamId); - WriteArgument(message.Item, packer); + WriteStreamIds(message.StreamIds, packer); } private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer) @@ -475,6 +480,22 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } + private void WriteStreamIds(string[] streamIds, Stream packer) + { + if (streamIds != null) + { + MessagePackBinary.WriteArrayHeader(packer, streamIds.Length); + foreach (var streamId in streamIds) + { + MessagePackBinary.WriteString(packer, streamId); + } + } + else + { + MessagePackBinary.WriteArrayHeader(packer, 0); + } + } + private void WriteCompletionMessage(CompletionMessage message, Stream packer) { var resultKind = @@ -506,21 +527,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol MessagePackBinary.WriteString(packer, message.InvocationId); } - private void WriteStreamCompleteMessage(StreamCompleteMessage message, Stream packer) - { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamCompleteMessageType); - MessagePackBinary.WriteString(packer, message.StreamId); - if (message.HasError) - { - MessagePackBinary.WriteString(packer, message.Error); - } - else - { - MessagePackBinary.WriteNil(packer); - } - } - private void WriteCloseMessage(CloseMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 2); @@ -600,23 +606,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException); } - private static bool ReadBoolean(byte[] input, ref int offset, string field) - { - Exception msgPackException = null; - try - { - var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize); - offset += readSize; - return readBool; - } - catch (Exception e) - { - msgPackException = e; - } - - throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException); - } - private static long ReadMapLength(byte[] input, ref int offset, string field) { Exception msgPackException = null; diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs index c7d96e1d39..09a9491587 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs @@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol private const string ResultPropertyName = "result"; private const string ItemPropertyName = "item"; private const string InvocationIdPropertyName = "invocationId"; - private const string StreamIdPropertyName = "streamId"; + private const string StreamIdsPropertyName = "streamIds"; private const string TypePropertyName = "type"; private const string ErrorPropertyName = "error"; private const string TargetPropertyName = "target"; @@ -120,7 +120,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol int? type = null; string invocationId = null; - string streamId = null; string target = null; string error = null; var hasItem = false; @@ -131,6 +130,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol JToken resultToken = null; var hasArguments = false; object[] arguments = null; + string[] streamIds = null; JArray argumentsToken = null; ExceptionDispatchInfo argumentBindingException = null; Dictionary headers = null; @@ -167,8 +167,23 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case InvocationIdPropertyName: invocationId = JsonUtils.ReadAsString(reader, InvocationIdPropertyName); break; - case StreamIdPropertyName: - streamId = JsonUtils.ReadAsString(reader, StreamIdPropertyName); + case StreamIdsPropertyName: + JsonUtils.CheckRead(reader); + + if (reader.TokenType != JsonToken.StartArray) + { + throw new InvalidDataException($"Expected '{ArgumentsPropertyName}' to be of type {JTokenType.Array}."); + } + + var newStreamIds = new List(); + reader.Read(); + while (reader.TokenType != JsonToken.EndArray) + { + newStreamIds.Add(reader.Value?.ToString()); + reader.Read(); + } + + streamIds = newStreamIds.ToArray(); break; case TargetPropertyName: target = JsonUtils.ReadAsString(reader, TargetPropertyName); @@ -210,13 +225,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { id = invocationId; } - else if (!string.IsNullOrEmpty(streamId)) - { - id = streamId; - } else { - // If we don't have an id yetmthen we need to store it as a JToken to parse later + // If we don't have an id yet then we need to store it as a JToken to parse later itemToken = JToken.Load(reader); break; } @@ -310,7 +321,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindInvocationMessage(invocationId, target, arguments, hasArguments, binder); + : BindInvocationMessage(invocationId, target, arguments, hasArguments, streamIds, binder); } break; case HubProtocolConstants.StreamInvocationMessageType: @@ -331,25 +342,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol message = argumentBindingException != null ? new InvocationBindingFailureMessage(invocationId, target, argumentBindingException) - : BindStreamInvocationMessage(invocationId, target, arguments, hasArguments, binder); + : BindStreamInvocationMessage(invocationId, target, arguments, hasArguments, streamIds, binder); } break; - case HubProtocolConstants.StreamDataMessageType: - if (itemToken != null) - { - try - { - var itemType = binder.GetStreamItemType(streamId); - item = itemToken.ToObject(itemType, PayloadSerializer); - } - catch (Exception ex) - { - message = new StreamBindingFailureMessage(streamId, ExceptionDispatchInfo.Capture(ex)); - break; - } - } - message = BindParamStreamMessage(streamId, item, hasItem, binder); - break; case HubProtocolConstants.StreamItemMessageType: if (itemToken != null) { @@ -383,9 +378,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: return BindCloseMessage(error); - case HubProtocolConstants.StreamCompleteMessageType: - message = BindStreamCompleteMessage(streamId, error); - break; case null: throw new InvalidDataException($"Missing required property '{TypePropertyName}'."); default: @@ -456,10 +448,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol WriteHeaders(writer, m); WriteStreamInvocationMessage(m, writer); break; - case StreamDataMessage m: - WriteMessageType(writer, HubProtocolConstants.StreamDataMessageType); - WriteStreamDataMessage(m, writer); - break; case StreamItemMessage m: WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType); WriteHeaders(writer, m); @@ -482,10 +470,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol WriteMessageType(writer, HubProtocolConstants.CloseMessageType); WriteCloseMessage(m, writer); break; - case StreamCompleteMessage m: - WriteMessageType(writer, HubProtocolConstants.StreamCompleteMessageType); - WriteStreamCompleteMessage(m, writer); - break; default: throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}"); } @@ -534,18 +518,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol WriteInvocationId(message, writer); } - private void WriteStreamCompleteMessage(StreamCompleteMessage message, JsonTextWriter writer) - { - writer.WritePropertyName(StreamIdPropertyName); - writer.WriteValue(message.StreamId); - - if (message.Error != null) - { - writer.WritePropertyName(ErrorPropertyName); - writer.WriteValue(message.Error); - } - } - private void WriteStreamItemMessage(StreamItemMessage message, JsonTextWriter writer) { WriteInvocationId(message, writer); @@ -553,14 +525,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol PayloadSerializer.Serialize(writer, message.Item); } - private void WriteStreamDataMessage(StreamDataMessage message, JsonTextWriter writer) - { - writer.WritePropertyName(StreamIdPropertyName); - writer.WriteValue(message.StreamId); - writer.WritePropertyName(ItemPropertyName); - PayloadSerializer.Serialize(writer, message.Item); - } - private void WriteInvocationMessage(InvocationMessage message, JsonTextWriter writer) { WriteInvocationId(message, writer); @@ -568,6 +532,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol writer.WriteValue(message.Target); WriteArguments(message.Arguments, writer); + + WriteStreamIds(message.StreamIds, writer); } private void WriteStreamInvocationMessage(StreamInvocationMessage message, JsonTextWriter writer) @@ -577,6 +543,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol writer.WriteValue(message.Target); WriteArguments(message.Arguments, writer); + + WriteStreamIds(message.StreamIds, writer); } private void WriteCloseMessage(CloseMessage message, JsonTextWriter writer) @@ -599,6 +567,22 @@ namespace Microsoft.AspNetCore.SignalR.Protocol writer.WriteEndArray(); } + private void WriteStreamIds(string[] streamIds, JsonTextWriter writer) + { + if (streamIds == null) + { + return; + } + + writer.WritePropertyName(StreamIdsPropertyName); + writer.WriteStartArray(); + foreach (var streamId in streamIds) + { + writer.WriteValue(streamId); + } + writer.WriteEndArray(); + } + private static void WriteInvocationId(HubInvocationMessage message, JsonTextWriter writer) { if (!string.IsNullOrEmpty(message.InvocationId)) @@ -624,17 +608,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return new CancelInvocationMessage(invocationId); } - private HubMessage BindStreamCompleteMessage(string streamId, string error) - { - if (string.IsNullOrEmpty(streamId)) - { - throw new InvalidDataException($"Missing required property '{StreamIdPropertyName}'."); - } - - // note : if the stream completes normally, the error should be `null` - return new StreamCompleteMessage(streamId, error); - } - private HubMessage BindCompletionMessage(string invocationId, string error, object result, bool hasResult, IInvocationBinder binder) { if (string.IsNullOrEmpty(invocationId)) @@ -655,20 +628,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return new CompletionMessage(invocationId, error, result: null, hasResult: false); } - private HubMessage BindParamStreamMessage(string streamId, object item, bool hasItem, IInvocationBinder binder) - { - if (string.IsNullOrEmpty(streamId)) - { - throw new InvalidDataException($"Missing required property '{StreamIdPropertyName}"); - } - if (!hasItem) - { - throw new InvalidDataException($"Missing required property '{ItemPropertyName}"); - } - - return new StreamDataMessage(streamId, item); - } - private HubMessage BindStreamItemMessage(string invocationId, object item, bool hasItem, IInvocationBinder binder) { if (string.IsNullOrEmpty(invocationId)) @@ -684,7 +643,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return new StreamItemMessage(invocationId, item); } - private HubMessage BindStreamInvocationMessage(string invocationId, string target, object[] arguments, bool hasArguments, IInvocationBinder binder) + private HubMessage BindStreamInvocationMessage(string invocationId, string target, object[] arguments, bool hasArguments, string[] streamIds, IInvocationBinder binder) { if (string.IsNullOrEmpty(invocationId)) { @@ -701,10 +660,10 @@ namespace Microsoft.AspNetCore.SignalR.Protocol throw new InvalidDataException($"Missing required property '{TargetPropertyName}'."); } - return new StreamInvocationMessage(invocationId, target, arguments); + return new StreamInvocationMessage(invocationId, target, arguments, streamIds); } - private HubMessage BindInvocationMessage(string invocationId, string target, object[] arguments, bool hasArguments, IInvocationBinder binder) + private HubMessage BindInvocationMessage(string invocationId, string target, object[] arguments, bool hasArguments, string[] streamIds, IInvocationBinder binder) { if (string.IsNullOrEmpty(target)) { @@ -716,7 +675,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol throw new InvalidDataException($"Missing required property '{ArgumentsPropertyName}'."); } - return new InvocationMessage(invocationId, target, arguments); + return new InvocationMessage(invocationId, target, arguments, streamIds); } private bool ReadArgumentAsType(JsonTextReader reader, IReadOnlyList paramTypes, int paramIndex) diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index d41acbe350..108c3ffe4f 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -855,7 +855,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests // List will be looked at to replace with a StreamPlaceholder and should be skipped, so an error will be thrown from the // protocol on the server when it tries to match List with a StreamPlaceholder var hubException = await Assert.ThrowsAsync(() => connection.InvokeAsync("StreamEcho", new List { "1", "2" }).OrTimeout()); - Assert.Equal("Failed to invoke 'StreamEcho' due to an error on the server. InvalidDataException: Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", + Assert.Equal("Failed to invoke 'StreamEcho' due to an error on the server. InvalidDataException: Invocation provides 1 argument(s) but target expects 0.", hubException.Message); await connection.DisposeAsync().OrTimeout(); } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 6a28816246..7df3419539 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -179,14 +179,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { var output = Channel.CreateUnbounded(); _ = Task.Run(async () => { - while (await source.WaitToReadAsync()) + try { - while (source.TryRead(out var item)) + while (await source.WaitToReadAsync()) { - await output.Writer.WriteAsync(item); + while (source.TryRead(out var item)) + { + await output.Writer.WriteAsync(item); + } } } - output.Writer.TryComplete(); + finally + { + output.Writer.TryComplete(); + } }); return output.Reader; diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 66d711b2da..a5379cb25e 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -218,21 +218,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var invocation = await connection.ReadSentJsonAsync().OrTimeout(); Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]); Assert.Equal("SomeMethod", invocation["target"]); - var streamId = invocation["arguments"][0]["streamId"]; + var streamId = invocation["streamIds"][0]; foreach (var number in new[] { 42, 43, 322, 3145, -1234 }) { await channel.Writer.WriteAsync(number).AsTask().OrTimeout(); var item = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamDataMessageType, item["type"]); + Assert.Equal(HubProtocolConstants.StreamItemMessageType, item["type"]); Assert.Equal(number, item["item"]); - Assert.Equal(streamId, item["streamId"]); + Assert.Equal(streamId, item["invocationId"]); } channel.Writer.TryComplete(); var completion = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, completion["type"]); + Assert.Equal(HubProtocolConstants.CompletionMessageType, completion["type"]); await connection.ReceiveJsonMessage( new { type = HubProtocolConstants.CompletionMessageType, invocationId = invocation["invocationId"], result = 42 } @@ -259,16 +259,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]); Assert.Equal("SomeMethod", invocation["target"]); Assert.Null(invocation["invocationId"]); - var streamId = invocation["arguments"][0]["streamId"]; + var streamId = invocation["streamIds"][0]; foreach (var item in new[] { 2, 3, 10, 5 }) { await channel.Writer.WriteAsync(item); var received = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamDataMessageType, received["type"]); + Assert.Equal(HubProtocolConstants.StreamItemMessageType, received["type"]); Assert.Equal(item, received["item"]); - Assert.Equal(streamId, received["streamId"]); + Assert.Equal(streamId, received["invocationId"]); } } } @@ -297,14 +297,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await channel.Writer.WriteAsync(item); var received = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamDataMessageType, received["type"]); + Assert.Equal(HubProtocolConstants.StreamItemMessageType, received["type"]); Assert.Equal(item.Foo, received["item"]["foo"]); Assert.Equal(item.Bar, received["item"]["bar"]); } channel.Writer.TryComplete(); var completion = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, completion["type"]); + Assert.Equal(HubProtocolConstants.CompletionMessageType, completion["type"]); var expected = new SampleObject("oof", 14); await connection.ReceiveJsonMessage( @@ -345,7 +345,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // the next sent message should be a completion message var complete = await connection.ReadSentJsonAsync().OrTimeout(); - Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, complete["type"]); + Assert.Equal(HubProtocolConstants.CompletionMessageType, complete["type"]); Assert.EndsWith("canceled by client.", ((string)complete["error"])); } } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs index 1da37780ea..5942a4834f 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/JsonHubProtocolTests.cs @@ -40,7 +40,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueIgnore", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"byteArrProp\":\"AQID\"}]}"), new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueIgnoreAndNoCamelCase", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), false, NullValueHandling.Include, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00Z\",\"NullProp\":null,\"ByteArrProp\":\"AQID\"}]}"), new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueInclude", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), true, NullValueHandling.Include, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}]}"), - new JsonProtocolTestData("InvocationMessage_HasStreamPlaceholder", new InvocationMessage(null, "Target", new object[] { new StreamPlaceholder("__test_id__")}), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"streamId\":\"__test_id__\"}]}"), + new JsonProtocolTestData("InvocationMessage_HasStreamArgument", new InvocationMessage(null, "Target", Array.Empty(), new string[] { "__test_id__" }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[],\"streamIds\":[\"__test_id__\"]}"), + new JsonProtocolTestData("InvocationMessage_HasStreamAndNormalArgument", new InvocationMessage(null, "Target", new object[] { 42 }, new string[] { "__test_id__" }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[42],\"streamIds\":[\"__test_id__\"]}"), + new JsonProtocolTestData("InvocationMessage_HasMultipleStreams", new InvocationMessage(null, "Target", Array.Empty(), new string[] { "__test_id__", "__test_id2__" }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[],\"streamIds\":[\"__test_id__\",\"__test_id2__\"]}"), new JsonProtocolTestData("InvocationMessage_HasHeaders", AddHeaders(TestHeaders, new InvocationMessage("123", "Target", new object[] { 1, "Foo", 2.0f })), true, NullValueHandling.Ignore, "{\"type\":1," + SerializedHeaders + ",\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}"), new JsonProtocolTestData("InvocationMessage_StringIsoDateArgument", new InvocationMessage("Method", new object[] { "2016-05-10T13:51:20+12:34" }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Method\",\"arguments\":[\"2016-05-10T13:51:20+12:34\"]}"), new JsonProtocolTestData("InvocationMessage_DateTimeOffsetArgument", new InvocationMessage("Method", new object[] { DateTimeOffset.Parse("2016-05-10T13:51:20+12:34") }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Method\",\"arguments\":[\"2016-05-10T13:51:20+12:34\"]}"), @@ -75,6 +77,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new JsonProtocolTestData("StreamInvocationMessage_HasFloatArgument", new StreamInvocationMessage("123", "Target", new object[] { 1, "Foo", 2.0f }), true, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}"), new JsonProtocolTestData("StreamInvocationMessage_HasBoolArgument", new StreamInvocationMessage("123", "Target", new object[] { true }), true, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[true]}"), new JsonProtocolTestData("StreamInvocationMessage_HasNullArgument", new StreamInvocationMessage("123", "Target", new object[] { null }), true, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[null]}"), + new JsonProtocolTestData("StreamInvocationMessage_HasStreamArgument", new StreamInvocationMessage("123", "Target", Array.Empty(), new string[] { "__test_id__" }), true, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[],\"streamIds\":[\"__test_id__\"]}"), new JsonProtocolTestData("StreamInvocationMessage_HasCustomArgumentWithNoCamelCase", new StreamInvocationMessage("123", "Target", new object[] { new CustomObject() }), false, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00Z\",\"ByteArrProp\":\"AQID\"}]}"), new JsonProtocolTestData("StreamInvocationMessage_HasCustomArgumentWithNullValueIgnore", new StreamInvocationMessage("123", "Target", new object[] { new CustomObject() }), true, NullValueHandling.Ignore, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"byteArrProp\":\"AQID\"}]}"), new JsonProtocolTestData("StreamInvocationMessage_HasCustomArgumentWithNullValueIgnoreAndNoCamelCase", new StreamInvocationMessage("123", "Target", new object[] { new CustomObject() }), false, NullValueHandling.Include, "{\"type\":4,\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00Z\",\"NullProp\":null,\"ByteArrProp\":\"AQID\"}]}"), @@ -91,9 +94,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new JsonProtocolTestData("CloseMessage_HasErrorWithCamelCase", new CloseMessage("Error!"), true, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"Error!\"}"), new JsonProtocolTestData("CloseMessage_HasErrorEmptyString", new CloseMessage(""), false, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"\"}"), - new JsonProtocolTestData("StreamCompleteMessage", new StreamCompleteMessage("123"), true, NullValueHandling.Ignore, "{\"type\":8,\"streamId\":\"123\"}"), - new JsonProtocolTestData("StreamCompleteMessageWithError", new StreamCompleteMessage("123", "zoinks"), true, NullValueHandling.Ignore, "{\"type\":8,\"streamId\":\"123\",\"error\":\"zoinks\"}"), - }.ToDictionary(t => t.Name); public static IEnumerable ProtocolTestDataNames => ProtocolTestData.Keys.Select(name => new object[] { name }); diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index 20ff61ae33..35fe3c04fc 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -58,44 +58,51 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new ProtocolTestData( name: "InvocationWithNoHeadersAndNoArgs", message: new InvocationMessage("xyz", "method", Array.Empty()), - binary: "lQGAo3h5eqZtZXRob2SQ"), + binary: "lgGAo3h5eqZtZXRob2SQkA=="), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdAndNoArgs", message: new InvocationMessage("method", Array.Empty()), - binary: "lQGAwKZtZXRob2SQ"), + binary: "lgGAwKZtZXRob2SQkA=="), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdAndSingleNullArg", message: new InvocationMessage("method", new object[] { null }), - binary: "lQGAwKZtZXRob2SRwA=="), + binary: "lgGAwKZtZXRob2SRwJA="), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdAndSingleIntArg", message: new InvocationMessage("method", new object[] { 42 }), - binary: "lQGAwKZtZXRob2SRKg=="), + binary: "lgGAwKZtZXRob2SRKpA="), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdIntAndStringArgs", message: new InvocationMessage("method", new object[] { 42, "string" }), - binary: "lQGAwKZtZXRob2SSKqZzdHJpbmc="), + binary: "lgGAwKZtZXRob2SSKqZzdHJpbmeQ"), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdIntAndEnumArgs", message: new InvocationMessage("method", new object[] { 42, TestEnum.One }), - binary: "lQGAwKZtZXRob2SSKqNPbmU="), + binary: "lgGAwKZtZXRob2SSKqNPbmWQ"), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdAndCustomObjectArg", message: new InvocationMessage("method", new object[] { 42, "string", new CustomObject() }), - binary: "lQGAwKZtZXRob2STKqZzdHJpbmeGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM="), + binary: "lgGAwKZtZXRob2STKqZzdHJpbmeGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOQ"), new ProtocolTestData( name: "InvocationWithNoHeadersNoIdAndArrayOfCustomObjectArgs", message: new InvocationMessage("method", new object[] { new CustomObject(), new CustomObject() }), - binary: "lQGAwKZtZXRob2SShqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDhqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQID"), + binary: "lgGAwKZtZXRob2SShqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDhqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDkA=="), new ProtocolTestData( name: "InvocationWithHeadersNoIdAndArrayOfCustomObjectArgs", message: AddHeaders(TestHeaders, new InvocationMessage("method", new object[] { new CustomObject(), new CustomObject() })), - binary: "lQGDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmXApm1ldGhvZJKGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM="), + binary: "lgGDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmXApm1ldGhvZJKGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOQ"), new ProtocolTestData( - name: "InvocationWithStreamPlaceholderObject", - message: new InvocationMessage(null, "Target", new object[] { new StreamPlaceholder("__test_id__")}), - binary: "lQGAwKZUYXJnZXSRgahTdHJlYW1JZKtfX3Rlc3RfaWRfXw==" - ), + name: "InvocationWithStreamArgument", + message: new InvocationMessage(null, "Target", Array.Empty(), new string[] { "__test_id__" }), + binary: "lgGAwKZUYXJnZXSQkatfX3Rlc3RfaWRfXw=="), + new ProtocolTestData( + name: "InvocationWithStreamAndNormalArgument", + message: new InvocationMessage(null, "Target", new object[] { 42 }, new string[] { "__test_id__" }), + binary: "lgGAwKZUYXJnZXSRKpGrX190ZXN0X2lkX18="), + new ProtocolTestData( + name: "InvocationWithMulitpleStreams", + message: new InvocationMessage(null, "Target", Array.Empty(), new string[] { "__test_id__", "__test_id2__" }), + binary: "lgGAwKZUYXJnZXSQkqtfX3Rlc3RfaWRfX6xfX3Rlc3RfaWQyX18="), // StreamItem Messages new ProtocolTestData( @@ -193,35 +200,43 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndNoArgs", message: new StreamInvocationMessage("xyz", "method", Array.Empty()), - binary: "lQSAo3h5eqZtZXRob2SQ"), + binary: "lgSAo3h5eqZtZXRob2SQkA=="), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndNullArg", message: new StreamInvocationMessage("xyz", "method", new object[] { null }), - binary: "lQSAo3h5eqZtZXRob2SRwA=="), + binary: "lgSAo3h5eqZtZXRob2SRwJA="), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndIntArg", message: new StreamInvocationMessage("xyz", "method", new object[] { 42 }), - binary: "lQSAo3h5eqZtZXRob2SRKg=="), + binary: "lgSAo3h5eqZtZXRob2SRKpA="), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndEnumArg", message: new StreamInvocationMessage("xyz", "method", new object[] { TestEnum.One }), - binary: "lQSAo3h5eqZtZXRob2SRo09uZQ=="), + binary: "lgSAo3h5eqZtZXRob2SRo09uZZA="), + new ProtocolTestData( + name: "StreamInvocationWithStreamArgument", + message: new StreamInvocationMessage("xyz", "method", Array.Empty(), new string[] { "__test_id__" }), + binary: "lgSAo3h5eqZtZXRob2SQkatfX3Rlc3RfaWRfXw=="), + new ProtocolTestData( + name: "StreamInvocationWithStreamAndNormalArgument", + message: new StreamInvocationMessage("xyz", "method", new object[] { 42 }, new string[] { "__test_id__" }), + binary: "lgSAo3h5eqZtZXRob2SRKpGrX190ZXN0X2lkX18="), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndIntAndStringArgs", message: new StreamInvocationMessage("xyz", "method", new object[] { 42, "string" }), - binary: "lQSAo3h5eqZtZXRob2SSKqZzdHJpbmc="), + binary: "lgSAo3h5eqZtZXRob2SSKqZzdHJpbmeQ"), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndIntStringAndCustomObjectArgs", message: new StreamInvocationMessage("xyz", "method", new object[] { 42, "string", new CustomObject() }), - binary: "lQSAo3h5eqZtZXRob2STKqZzdHJpbmeGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM="), + binary: "lgSAo3h5eqZtZXRob2STKqZzdHJpbmeGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOQ"), new ProtocolTestData( name: "StreamInvocationWithNoHeadersAndCustomObjectArrayArg", message: new StreamInvocationMessage("xyz", "method", new object[] { new CustomObject(), new CustomObject() }), - binary: "lQSAo3h5eqZtZXRob2SShqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDhqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQID"), + binary: "lgSAo3h5eqZtZXRob2SShqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDhqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQIDkA=="), new ProtocolTestData( name: "StreamInvocationWithHeadersAndCustomObjectArrayArg", message: AddHeaders(TestHeaders, new StreamInvocationMessage("xyz", "method", new object[] { new CustomObject(), new CustomObject() })), - binary: "lQSDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmWjeHl6pm1ldGhvZJKGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM="), + binary: "lgSDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmWjeHl6pm1ldGhvZJKGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOQ"), // CancelInvocation Messages new ProtocolTestData( @@ -233,27 +248,11 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol message: AddHeaders(TestHeaders, new CancelInvocationMessage("xyz")), binary: "kwWDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmWjeHl6"), - // StreamComplete Messages - new ProtocolTestData( - name: "StreamComplete", - message: new StreamCompleteMessage("xyz"), - binary: "kwijeHl6wA=="), - new ProtocolTestData( - name: "StreamCompleteWithError", - message: new StreamCompleteMessage("xyz", "zoinks"), - binary: "kwijeHl6pnpvaW5rcw=="), - // Ping Messages new ProtocolTestData( name: "Ping", message: PingMessage.Instance, binary: "kQY="), - - // StreamData Messages - new ProtocolTestData( - name: "StreamData", - message: new StreamDataMessage("xyz", new CustomObject()), - binary: "kwmjeHl6hqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQID"), }.ToDictionary(t => t.Name); [Theory] @@ -285,7 +284,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol 0x80, StringBytes(3), (byte)'x', (byte)'y', (byte)'z', StringBytes(6), (byte)'m', (byte)'e', (byte)'t', (byte)'h', (byte)'o', (byte)'d', - ArrayBytes(0), + ArrayBytes(0), // Arguments + ArrayBytes(0), // Streams 0xc3, StringBytes(2), (byte)'e', (byte)'x' }; diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs index c12028411a..d7282c42a1 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs @@ -29,9 +29,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol case CompletionMessage c: _returnType = c.Result?.GetType() ?? typeof(object); break; - case StreamDataMessage sd: - _returnType = sd.Item.GetType() ?? typeof(object); - break; } } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs index c67b208b23..a51d6b0e39 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs @@ -39,10 +39,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return true; case CloseMessage closeMessage: return string.Equals(closeMessage.Error, ((CloseMessage) y).Error); - case StreamCompleteMessage streamCompleteMessage: - return StreamCompleteMessagesEqual(streamCompleteMessage, (StreamCompleteMessage)y); - case StreamDataMessage streamDataMessage: - return StreamDataMessagesEqual(streamDataMessage, (StreamDataMessage)y); default: throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}"); } @@ -69,7 +65,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return SequenceEqual(x.Headers, y.Headers) && string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) && string.Equals(x.Target, y.Target, StringComparison.Ordinal) - && ArgumentListsEqual(x.Arguments, y.Arguments); + && ArgumentListsEqual(x.Arguments, y.Arguments) + && StringArrayEqual(x.StreamIds, y.StreamIds); } private bool StreamInvocationMessagesEqual(StreamInvocationMessage x, StreamInvocationMessage y) @@ -77,19 +74,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return SequenceEqual(x.Headers, y.Headers) && string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) && string.Equals(x.Target, y.Target, StringComparison.Ordinal) - && ArgumentListsEqual(x.Arguments, y.Arguments); - } - - private bool StreamCompleteMessagesEqual(StreamCompleteMessage x, StreamCompleteMessage y) - { - return x.StreamId == y.StreamId - && x.Error == y.Error; - } - - private bool StreamDataMessagesEqual(StreamDataMessage x, StreamDataMessage y) - { - return x.StreamId == y.StreamId - && (Equals(x.Item, y.Item) || SequenceEqual(x.Item, y.Item)); + && ArgumentListsEqual(x.Arguments, y.Arguments) + && StringArrayEqual(x.StreamIds, y.StreamIds); } private bool ArgumentListsEqual(object[] left, object[] right) @@ -106,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol for (var i = 0; i < left.Length; i++) { - if (!(Equals(left[i], right[i]) || SequenceEqual(left[i], right[i]) || PlaceholdersEqual(left[i], right[i]))) + if (!(Equals(left[i], right[i]) || SequenceEqual(left[i], right[i]))) { return false; } @@ -114,21 +100,6 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return true; } - private bool PlaceholdersEqual(object left, object right) - { - if (left.GetType() != right.GetType()) - { - return false; - } - switch(left) - { - case StreamPlaceholder leftPlaceholder: - return leftPlaceholder.StreamId == (right as StreamPlaceholder).StreamId; - default: - return false; - } - } - private bool SequenceEqual(object left, object right) { if (left == null && right == null) @@ -158,6 +129,34 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return !leftMoved && !rightMoved; } + private bool StringArrayEqual(string[] left, string[] right) + { + if (left == null && right == null) + { + return true; + } + + if (left == null || right == null) + { + return false; + } + + if (left.Length != right.Length) + { + return false; + } + + for (var i = 0; i < left.Length; i++) + { + if (!string.Equals(left[i], right[i])) + { + return false; + } + } + + return true; + } + public int GetHashCode(HubMessage obj) { // We never use these in a hash-table diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs index 5ac4eab33c..6183691ad4 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests.Utils/TestClient.cs @@ -95,9 +95,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests return connection; } - public async Task> StreamAsync(string methodName, params object[] args) + public Task> StreamAsync(string methodName, params object[] args) { - var invocationId = await SendStreamInvocationAsync(methodName, args); + return StreamAsync(methodName, streamIds: null, args); + } + + public async Task> StreamAsync(string methodName, string[] streamIds, params object[] args) + { + var invocationId = await SendStreamInvocationAsync(methodName, streamIds, args); var messages = new List(); while (true) @@ -174,13 +179,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests public Task SendStreamInvocationAsync(string methodName, params object[] args) { - var invocationId = GetInvocationId(); - return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, args)); + return SendStreamInvocationAsync(methodName, streamIds: null, args); } - public Task BeginUploadStreamAsync(string invocationId, string methodName, params object[] args) + public Task SendStreamInvocationAsync(string methodName, string[] streamIds, params object[] args) { - var message = new InvocationMessage(invocationId, methodName, args); + var invocationId = GetInvocationId(); + return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, args, streamIds)); + } + + public Task BeginUploadStreamAsync(string invocationId, string methodName, string[] streamIds, params object[] args) + { + var message = new InvocationMessage(invocationId, methodName, args, streamIds); return SendHubMessageAsync(message); } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs index 1d05e22e01..fb0c3ee00c 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTestUtils/Hubs.cs @@ -252,6 +252,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests return false; } + + public Task UploadIgnoreItems(ChannelReader source) + { + // Wait for an item to appear first then return from the hub method to end the invocation + return source.WaitToReadAsync().AsTask(); + } } public abstract class TestHub : Hub diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index ca0c1e8c3c..35741b02ca 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -2609,7 +2609,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); // Long running hub invocation to test that other invocations will not run until it is completed - var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream)).OrTimeout(); + var streamInvocationId = await client.SendStreamInvocationAsync(nameof(LongRunningHub.LongRunningStream), null).OrTimeout(); // Wait for the long running method to start await tcsService.StartedMethod.Task.OrTimeout(); @@ -2698,14 +2698,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id")); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new[] { "id" }, Array.Empty()); foreach (var letter in new[] { "B", "E", "A", "N", "E", "D" }) { - await client.SendHubMessageAsync(new StreamDataMessage("id", letter)).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", letter)).OrTimeout(); } - await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); var result = (CompletionMessage)await client.ReadAsync().OrTimeout(); Assert.Equal("BEANED", result.Result); @@ -2721,15 +2721,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadArray), new StreamPlaceholder("id")); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadArray), new[] { "id" }, Array.Empty()); var objects = new[] { new SampleObject("solo", 322), new SampleObject("ggez", 3145) }; foreach (var thing in objects) { - await client.SendHubMessageAsync(new StreamDataMessage("id", thing)).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", thing)).OrTimeout(); } - await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); var result = ((JArray)response.Result).ToArray(); @@ -2753,7 +2753,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests foreach (string id in ids) { - await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new StreamPlaceholder(id)); + await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new[] { id }, Array.Empty()); } var words = new[] { "zygapophyses", "qwerty", "abcd" }; @@ -2762,13 +2762,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests foreach (var spot in order) { - await client.SendHubMessageAsync(new StreamDataMessage(spot.ToString(), words[spot][pos[spot]])).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage(spot.ToString(), words[spot][pos[spot]])).OrTimeout(); pos[spot] += 1; } foreach (string id in new[] { "0", "2", "1" }) { - await client.SendHubMessageAsync(new StreamCompleteMessage(id)).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty(id)).OrTimeout(); var response = await client.ReadAsync().OrTimeout(); Debug.Write(response); Assert.Equal(words[int.Parse(id)], ((CompletionMessage)response).Result); @@ -2819,13 +2819,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id")).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), streamIds: new[] { "id" }, Array.Empty()).OrTimeout(); // send integers that are then cast to strings - await client.SendHubMessageAsync(new StreamDataMessage("id", 5)).OrTimeout(); - await client.SendHubMessageAsync(new StreamDataMessage("id", 10)).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", 5)).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", 10)).OrTimeout(); - await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); Assert.Equal("510", response.Result); @@ -2871,7 +2871,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocationId", nameof(MethodHub.TestTypeCastingErrors), new StreamPlaceholder("channelId")).OrTimeout(); + await client.BeginUploadStreamAsync("invocationId", nameof(MethodHub.TestTypeCastingErrors), new[] { "channelId" }, Array.Empty()).OrTimeout(); // client is running wild, sending strings not ints. // this error should be propogated to the user's HubMethod code @@ -2934,7 +2934,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.SendHubMessageAsync(new StreamCompleteMessage("fake_id")).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("fake_id")).OrTimeout(); // Client is breaking protocol by sending an invalid id, and should be closed. var message = client.TryRead(); @@ -2957,8 +2957,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var client = new TestClient()) { await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.TestCustomErrorPassing), new StreamPlaceholder("id")).OrTimeout(); - await client.SendHubMessageAsync(new StreamCompleteMessage("id", CustomErrorMessage)).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.TestCustomErrorPassing), streamIds: new[] { "id" }, args: Array.Empty()).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.WithError("id", CustomErrorMessage)).OrTimeout(); var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); Assert.True((bool)response.Result); @@ -2966,6 +2966,141 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task UploadStreamWithTooManyStreamsFails() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), streamIds: new[] { "id", "id2" }, args: Array.Empty()).OrTimeout(); + + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + Assert.Equal("An unexpected error occurred invoking 'StreamingConcat' on the server. HubException: Client sent 2 stream(s), Hub method expects 1.", response.Error); + } + } + } + + [Fact] + public async Task UploadStreamWithTooFewStreamsFails() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" && + writeContext.EventId.Name == "FailedInvokingHubMethod"; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), streamIds: Array.Empty(), args: Array.Empty()).OrTimeout(); + + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + Assert.Equal("An unexpected error occurred invoking 'StreamingConcat' on the server. HubException: Client sent 0 stream(s), Hub method expects 1.", response.Error); + } + } + } + + [Fact] + public async Task UploadStreamReleasesHubActivatorOnceComplete() + { + using (StartVerifiableLog()) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(builder => + { + builder.AddSingleton(typeof(IHubActivator<>), typeof(CustomHubActivator<>)); + }, LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), streamIds: new[] { "id" }, args: Array.Empty()).OrTimeout(); + + await client.SendHubMessageAsync(new StreamItemMessage("id", "hello")).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("id", " world")).OrTimeout(); + await client.SendHubMessageAsync(CompletionMessage.Empty("id")).OrTimeout(); + var result = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(result); + Assert.Equal("hello world", simpleCompletion.Result); + + var hubActivator = serviceProvider.GetService>() as CustomHubActivator; + + // OnConnectedAsync and StreamingConcat hubs have been disposed + Assert.Equal(2, hubActivator.ReleaseCount); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + } + + [Fact] + public async Task UploadStreamClosesStreamsOnServerWhenMethodCompletes() + { + bool errorLogged = false; + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && + writeContext.EventId.Name == "ErrorProcessingRequest") + { + errorLogged = true; + return true; + } + + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadIgnoreItems), streamIds: new[] { "id" }, args: Array.Empty()).OrTimeout(); + + await client.SendHubMessageAsync(new StreamItemMessage("id", "ignored")).OrTimeout(); + var result = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(result); + Assert.Null(simpleCompletion.Result); + + // This will log an error on the server as the hub method has completed and will complete all associated streams + await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout(); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + // Check that the stream has been completed by noting the existance of an error + Assert.True(errorLogged); + } + [Theory] [InlineData(nameof(LongRunningHub.CancelableStream))] [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)] @@ -3088,14 +3223,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).OrTimeout(); var streamId = "sample_id"; - var messagePromise = client.StreamAsync(nameof(StreamingHub.StreamEcho), new StreamPlaceholder(streamId)).OrTimeout(); + var messagePromise = client.StreamAsync(nameof(StreamingHub.StreamEcho), new[] { streamId }, Array.Empty()).OrTimeout(); var phrases = new[] { "asdf", "qwer", "zxcv" }; foreach (var phrase in phrases) { - await client.SendHubMessageAsync(new StreamDataMessage(streamId, phrase)); + await client.SendHubMessageAsync(new StreamItemMessage(streamId, phrase)); } - await client.SendHubMessageAsync(new StreamCompleteMessage(streamId)); + await client.SendHubMessageAsync(CompletionMessage.Empty(streamId)); var messages = await messagePromise;