From f52882b6aa3b3e0b96ab060a4354fd3654fe62dc Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Wed, 21 Nov 2018 09:59:31 -0800 Subject: [PATCH] Changing streaming things --- clients/ts/signalr/src/UploadStream.ts | 27 -- .../clients/ts/FunctionalTests/TestHub.cs | 16 ++ .../ts/FunctionalTests/package-lock.json | 43 ++-- .../clients/ts/FunctionalTests/package.json | 1 + .../FunctionalTests/ts/HubConnectionTests.ts | 42 +++ .../src/MessagePackHubProtocol.ts | 22 +- .../clients/ts/signalr/src/HubConnection.ts | 133 ++++++++-- .../clients/ts/signalr/src/IHubProtocol.ts | 26 +- src/SignalR/clients/ts/signalr/src/Subject.ts | 45 ++++ src/SignalR/clients/ts/signalr/src/Utils.ts | 43 +--- src/SignalR/clients/ts/signalr/src/index.ts | 4 +- .../ts/signalr/tests/HubConnection.test.ts | 99 +++++++- .../samples/ClientSample/UploadSample.cs | 51 +--- .../samples/SignalRSamples/Hubs/UploadHub.cs | 24 +- .../wwwroot/channelParameters.html | 29 +-- src/SignalR/src/Common/ReflectionHelper.cs | 24 +- .../Internal/DefaultHubDispatcher.cs | 18 +- .../Internal/HubMethodDescriptor.cs | 6 +- .../Protocol/MessagePackHubProtocol.cs | 21 ++ .../Protocol/NewtonsoftJsonHubProtocol.cs | 17 +- .../HubConnectionTests.cs | 62 +++++ .../Hubs.cs | 39 +-- .../Protocol/MessagePackHubProtocolTests.cs | 6 + .../Internal/Protocol/TestBinder.cs | 3 + .../TestHubMessageEqualityComparer.cs | 10 +- .../HubConnectionHandlerTests.cs | 239 ++++++++++-------- 26 files changed, 674 insertions(+), 376 deletions(-) delete mode 100644 clients/ts/signalr/src/UploadStream.ts create mode 100644 src/SignalR/clients/ts/signalr/src/Subject.ts rename {samples => src/SignalR/samples}/SignalRSamples/wwwroot/channelParameters.html (83%) diff --git a/clients/ts/signalr/src/UploadStream.ts b/clients/ts/signalr/src/UploadStream.ts deleted file mode 100644 index eb1deafc36..0000000000 --- a/clients/ts/signalr/src/UploadStream.ts +++ /dev/null @@ -1,27 +0,0 @@ -import { HubConnection } from "./HubConnection"; -import { MessageType } from "./IHubProtocol"; - -export class UploadStream { - private connection: HubConnection; - - public readonly streamId: string; - public readonly placeholder: object; - - constructor(connection: HubConnection) { - this.connection = connection; - this.streamId = connection.nextStreamId(); - this.placeholder = {streamId: this.streamId}; - } - - public write(item: any): Promise { - return this.connection.sendWithProtocol(this.connection.createStreamDataMessage(this.streamId, item)); - } - - public complete(error?: string): Promise { - if (error) { - return this.connection.sendWithProtocol({ type: MessageType.StreamComplete, streamId: this.streamId, error }); - } else { - return this.connection.sendWithProtocol({ type: MessageType.StreamComplete, streamId: this.streamId }); - } - } -} diff --git a/src/SignalR/clients/ts/FunctionalTests/TestHub.cs b/src/SignalR/clients/ts/FunctionalTests/TestHub.cs index 9c39d97d6e..e6bea7e82b 100644 --- a/src/SignalR/clients/ts/FunctionalTests/TestHub.cs +++ b/src/SignalR/clients/ts/FunctionalTests/TestHub.cs @@ -3,6 +3,7 @@ using System; using System.Reactive.Linq; +using System.Text; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Connections; @@ -50,6 +51,21 @@ namespace FunctionalTests return channel.Reader; } + public async Task StreamingConcat(ChannelReader stream) + { + var sb = new StringBuilder(); + + while (await stream.WaitToReadAsync()) + { + while (stream.TryRead(out var item)) + { + sb.Append(item); + } + } + + return sb.ToString(); + } + public ChannelReader EmptyStream() { var channel = Channel.CreateUnbounded(); diff --git a/src/SignalR/clients/ts/FunctionalTests/package-lock.json b/src/SignalR/clients/ts/FunctionalTests/package-lock.json index e2b24fb2bb..e3a4aa8f71 100644 --- a/src/SignalR/clients/ts/FunctionalTests/package-lock.json +++ b/src/SignalR/clients/ts/FunctionalTests/package-lock.json @@ -1827,14 +1827,12 @@ "balanced-match": { "version": "1.0.0", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "brace-expansion": { "version": "1.1.11", "bundled": true, "dev": true, - "optional": true, "requires": { "balanced-match": "^1.0.0", "concat-map": "0.0.1" @@ -1849,20 +1847,17 @@ "code-point-at": { "version": "1.1.0", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "concat-map": { "version": "0.0.1", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "console-control-strings": { "version": "1.1.0", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "core-util-is": { "version": "1.0.2", @@ -1979,8 +1974,7 @@ "inherits": { "version": "2.0.3", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "ini": { "version": "1.3.5", @@ -1992,7 +1986,6 @@ "version": "1.0.0", "bundled": true, "dev": true, - "optional": true, "requires": { "number-is-nan": "^1.0.0" } @@ -2007,7 +2000,6 @@ "version": "3.0.4", "bundled": true, "dev": true, - "optional": true, "requires": { "brace-expansion": "^1.1.7" } @@ -2015,14 +2007,12 @@ "minimist": { "version": "0.0.8", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "minipass": { "version": "2.2.4", "bundled": true, "dev": true, - "optional": true, "requires": { "safe-buffer": "^5.1.1", "yallist": "^3.0.0" @@ -2041,7 +2031,6 @@ "version": "0.5.1", "bundled": true, "dev": true, - "optional": true, "requires": { "minimist": "0.0.8" } @@ -2122,8 +2111,7 @@ "number-is-nan": { "version": "1.0.1", "bundled": true, - "dev": true, - "optional": true + "dev": true }, "object-assign": { "version": "4.1.1", @@ -2135,7 +2123,6 @@ "version": "1.4.0", "bundled": true, "dev": true, - "optional": true, "requires": { "wrappy": "1" } @@ -2257,7 +2244,6 @@ "version": "1.0.2", "bundled": true, "dev": true, - "optional": true, "requires": { "code-point-at": "^1.0.0", "is-fullwidth-code-point": "^1.0.0", @@ -3443,6 +3429,15 @@ "glob": "^7.0.5" } }, + "rxjs": { + "version": "6.3.3", + "resolved": "https://registry.npmjs.org/rxjs/-/rxjs-6.3.3.tgz", + "integrity": "sha512-JTWmoY9tWCs7zvIk/CvRjhjGaOd+OVBM987mxFo+OW66cGpdKjZcpmc74ES1sB//7Kl/PAe8+wEakuhG4pcgOw==", + "dev": true, + "requires": { + "tslib": "^1.9.0" + } + }, "safe-buffer": { "version": "5.1.1", "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.1.1.tgz", @@ -3989,6 +3984,12 @@ "strip-json-comments": "^2.0.0" } }, + "tslib": { + "version": "1.9.3", + "resolved": "https://registry.npmjs.org/tslib/-/tslib-1.9.3.tgz", + "integrity": "sha512-4krF8scpejhaOgqzBEcGM7yDIEfi0/8+8zDRZhNZZ2kjmHJ4hv3zCbQWxoJGz1iw5U0Jl0nma13xzHXcncMavQ==", + "dev": true + }, "tunnel-agent": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/tunnel-agent/-/tunnel-agent-0.6.0.tgz", diff --git a/src/SignalR/clients/ts/FunctionalTests/package.json b/src/SignalR/clients/ts/FunctionalTests/package.json index 0d8e28af8a..6b3128210e 100644 --- a/src/SignalR/clients/ts/FunctionalTests/package.json +++ b/src/SignalR/clients/ts/FunctionalTests/package.json @@ -30,6 +30,7 @@ "karma-sauce-launcher": "^1.2.0", "karma-sourcemap-loader": "^0.3.7", "karma-summary-reporter": "^1.5.0", + "rxjs": "^6.3.3", "ts-node": "^4.1.0", "typescript": "^3.0.1", "ws": " ^6.0.0" diff --git a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 1fda5080cd..99944c1da2 100644 --- a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -11,6 +11,8 @@ import { eachTransport, eachTransportAndProtocol, ENDPOINT_BASE_HTTPS_URL, ENDPO import "./LogBannerReporter"; import { TestLogger } from "./TestLogger"; +import * as RX from "rxjs"; + const TESTHUBENDPOINT_URL = ENDPOINT_BASE_URL + "/testhub"; const TESTHUBENDPOINT_HTTPS_URL = ENDPOINT_BASE_HTTPS_URL ? (ENDPOINT_BASE_HTTPS_URL + "/testhub") : undefined; @@ -531,6 +533,46 @@ describe("hubConnection", () => { done(); }); }); + + it("can stream from client to server with rxjs", async (done) => { + const hubConnection = getConnectionBuilder(transportType) + .withHubProtocol(protocol) + .build(); + + await hubConnection.start(); + const subject = new RX.Subject(); + const resultPromise = hubConnection.invoke("StreamingConcat", subject.asObservable()); + subject.next("Hello "); + subject.next("world"); + subject.next("!"); + subject.complete(); + expect(await resultPromise).toBe("Hello world!"); + await hubConnection.stop(); + done(); + }); + + it("can stream from client to server and close with error with rxjs", async (done) => { + const hubConnection = getConnectionBuilder(transportType) + .withHubProtocol(protocol) + .build(); + + await hubConnection.start(); + const subject = new RX.Subject(); + const resultPromise = hubConnection.invoke("StreamingConcat", subject.asObservable()); + subject.next("Hello "); + subject.next("world"); + subject.next("!"); + subject.error(new Error("Something bad")); + try { + await resultPromise; + expect(false).toBe(true); + } catch (err) { + expect(err.message).toEqual("An unexpected error occurred invoking 'StreamingConcat' on the server. Exception: Something bad"); + } finally { + await hubConnection.stop(); + } + done(); + }); }); }); diff --git a/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts b/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts index 38564d14fb..78609f89b7 100644 --- a/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts +++ b/src/SignalR/clients/ts/signalr-protocol-msgpack/src/MessagePackHubProtocol.ts @@ -4,7 +4,7 @@ import { Buffer } from "buffer"; import * as msgpack5 from "msgpack5"; -import { CompletionMessage, HubMessage, IHubProtocol, ILogger, InvocationMessage, LogLevel, MessageHeaders, MessageType, NullLogger, StreamInvocationMessage, StreamItemMessage, TransferFormat } from "@aspnet/signalr"; +import { CompletionMessage, HubMessage, IHubProtocol, ILogger, InvocationMessage, LogLevel, MessageHeaders, MessageType, NullLogger, StreamCompleteMessage, StreamDataMessage, StreamInvocationMessage, StreamItemMessage, TransferFormat } from "@aspnet/signalr"; import { BinaryMessageFormat } from "./BinaryMessageFormat"; import { isArrayBuffer } from "./Utils"; @@ -65,11 +65,15 @@ export class MessagePackHubProtocol implements IHubProtocol { return this.writeInvocation(message as InvocationMessage); case MessageType.StreamInvocation: return this.writeStreamInvocation(message as StreamInvocationMessage); + case MessageType.StreamData: + return this.writeStreamData(message as StreamDataMessage); case MessageType.StreamItem: case MessageType.Completion: throw new Error(`Writing messages of type '${message.type}' is not supported.`); case MessageType.Ping: return BinaryMessageFormat.write(SERIALIZED_PING_MESSAGE); + case MessageType.StreamComplete: + return this.writeStreamComplete(message as StreamCompleteMessage); default: throw new Error("Invalid message type."); } @@ -226,6 +230,22 @@ export class MessagePackHubProtocol implements IHubProtocol { return BinaryMessageFormat.write(payload.slice()); } + private writeStreamData(streamDataMessage: StreamDataMessage): ArrayBuffer { + const msgpack = msgpack5(); + const payload = msgpack.encode([MessageType.StreamData, streamDataMessage.streamId, + streamDataMessage.item]); + + return BinaryMessageFormat.write(payload.slice()); + } + + private writeStreamComplete(streamCompleteMessage: StreamCompleteMessage): ArrayBuffer { + const msgpack = msgpack5(); + const payload = msgpack.encode([MessageType.StreamComplete, streamCompleteMessage.streamId, + streamCompleteMessage.error || null]); + + return BinaryMessageFormat.write(payload.slice()); + } + private readHeaders(properties: any): MessageHeaders { const headers: MessageHeaders = properties[1] as MessageHeaders; if (typeof headers !== "object") { diff --git a/src/SignalR/clients/ts/signalr/src/HubConnection.ts b/src/SignalR/clients/ts/signalr/src/HubConnection.ts index e0b73b5fe5..05d7731beb 100644 --- a/src/SignalR/clients/ts/signalr/src/HubConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HubConnection.ts @@ -1,14 +1,13 @@ -import { UploadStream } from "./UploadStream"; - // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { HandshakeProtocol, HandshakeRequestMessage, HandshakeResponseMessage } from "./HandshakeProtocol"; import { IConnection } from "./IConnection"; -import { CancelInvocationMessage, CompletionMessage, IHubProtocol, InvocationMessage, MessageType, StreamDataMessage, StreamInvocationMessage, StreamItemMessage } from "./IHubProtocol"; +import { CancelInvocationMessage, CompletionMessage, IHubProtocol, InvocationMessage, MessageType, StreamCompleteMessage, StreamDataMessage, StreamInvocationMessage, StreamItemMessage } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { IStreamResult } from "./Stream"; -import { Arg, Subject } from "./Utils"; +import { Subject } from "./Subject"; +import { Arg } from "./Utils"; const DEFAULT_TIMEOUT_IN_MS: number = 30 * 1000; const DEFAULT_PING_INTERVAL_IN_MS: number = 15 * 1000; @@ -30,7 +29,7 @@ export class HubConnection { private handshakeProtocol: HandshakeProtocol; private callbacks: { [invocationId: string]: (invocationEvent: StreamItemMessage | CompletionMessage | null, error?: Error) => void }; private methods: { [name: string]: Array<(...args: any[]) => void> }; - private id: number; + private invocationId: number; private streamId: number; private closedCallbacks: Array<(error?: Error) => void>; private receivedHandshakeResponse: boolean; @@ -86,7 +85,7 @@ export class HubConnection { this.callbacks = {}; this.methods = {}; this.closedCallbacks = []; - this.id = 0; + this.invocationId = 0; this.streamId = 0; this.receivedHandshakeResponse = false; this.connectionState = HubConnectionState.Disconnected; @@ -126,7 +125,7 @@ export class HubConnection { this.logger.log(LogLevel.Information, `Using HubProtocol '${this.protocol.name}'.`); - // defensively cleanup timeout in case we receive a message export from the server before we finish start + // defensively cleanup timeout in case we receive a message from the server before we finish start this.cleanupTimeout(); this.resetTimeoutPeriod(); this.resetKeepAliveInterval(); @@ -156,15 +155,17 @@ export class HubConnection { * @returns {IStreamResult} An object that yields results from the server as they are received. */ public stream(methodName: string, ...args: any[]): IStreamResult { + const streams = this.replaceStreamingParams(args); const invocationDescriptor = this.createStreamInvocation(methodName, args); - const subject = new Subject(() => { + const subject = new Subject(); + subject.cancelCallback = () => { const cancelInvocation: CancelInvocationMessage = this.createCancelInvocation(invocationDescriptor.invocationId); delete this.callbacks[invocationDescriptor.invocationId]; return this.sendWithProtocol(cancelInvocation); - }); + }; this.callbacks[invocationDescriptor.invocationId] = (invocationEvent: CompletionMessage | StreamItemMessage | null, error?: Error) => { if (error) { @@ -184,12 +185,14 @@ export class HubConnection { } }; - this.sendWithProtocol(invocationDescriptor) + const promiseQueue = this.sendWithProtocol(invocationDescriptor) .catch((e) => { subject.error(e); delete this.callbacks[invocationDescriptor.invocationId]; }); + this.launchStreams(streams, promiseQueue); + return subject; } @@ -202,7 +205,7 @@ export class HubConnection { * Sends a js object to the server. * @param message The js object to serialize and send. */ - public sendWithProtocol(message: any) { + private sendWithProtocol(message: any) { return this.sendMessage(this.protocol.writeMessage(message)); } @@ -216,18 +219,19 @@ export class HubConnection { * @returns {Promise} A Promise that resolves when the invocation has been successfully sent, or rejects with an error. */ public send(methodName: string, ...args: any[]): Promise { - return this.sendWithProtocol(this.createInvocation(methodName, args, true)); + const streams = this.replaceStreamingParams(args); + const sendPromise = this.sendWithProtocol(this.createInvocation(methodName, args, true)); + + this.launchStreams(streams, sendPromise); + + return sendPromise; } - public nextStreamId(): string { + private nextStreamId(): string { this.streamId += 1; return this.streamId.toString(); } - public newUploadStream(): UploadStream { - return new UploadStream(this); - } - /** Invokes a hub method on the server using the specified name and arguments. * * The Promise returned by this method resolves when the server indicates it has finished invoking the method. When the promise @@ -240,10 +244,11 @@ export class HubConnection { * @returns {Promise} A Promise that resolves with the result of the server method (if any), or rejects with an error. */ public invoke(methodName: string, ...args: any[]): Promise { + const streams = this.replaceStreamingParams(args); const invocationDescriptor = this.createInvocation(methodName, args, false); const p = new Promise((resolve, reject) => { - // invocationId will always have a value for a non-blocking inexport vocation + // invocationId will always have a value for a non-blocking invocation this.callbacks[invocationDescriptor.invocationId!] = (invocationEvent: StreamItemMessage | CompletionMessage | null, error?: Error) => { if (error) { reject(error); @@ -262,12 +267,14 @@ export class HubConnection { } }; - this.sendWithProtocol(invocationDescriptor) + const promiseQueue = this.sendWithProtocol(invocationDescriptor) .catch((e) => { reject(e); // invocationId will always have a value for a non-blocking invocation delete this.callbacks[invocationDescriptor.invocationId!]; }); + + this.launchStreams(streams, promiseQueue); }); return p; @@ -520,25 +527,84 @@ export class HubConnection { type: MessageType.Invocation, }; } else { - const id = this.id; - this.id++; + const invocationId = this.invocationId; + this.invocationId++; return { arguments: args, - invocationId: id.toString(), + invocationId: invocationId.toString(), target: methodName, type: MessageType.Invocation, }; } } + private launchStreams(streams: Array>, promiseQueue: Promise): void { + if (streams.length === 0) { + return; + } + + // Synchronize stream data so they arrive in-order on the server + if (!promiseQueue) { + promiseQueue = Promise.resolve(); + } + + // We want to iterate over the keys, since the keys are the stream ids + // tslint:disable-next-line:forin + for (const streamId in streams) { + streams[streamId].subscribe({ + complete: () => { + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamCompleteMessage(streamId))); + }, + error: (err) => { + let message: string; + if (err instanceof Error) { + message = err.message; + } else if (err && err.toString) { + message = err.toString(); + } else { + message = "Unknown error"; + } + + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamCompleteMessage(streamId, message))); + }, + next: (item) => { + promiseQueue = promiseQueue.then(() => this.sendWithProtocol(this.createStreamDataMessage(streamId, item))); + }, + }); + } + } + + private replaceStreamingParams(args: any[]): Array> { + const streams: Array> = []; + for (let i = 0; i < args.length; i++) { + const argument = args[i]; + if (this.isObservable(argument)) { + const streamId = this.nextStreamId(); + // Store the stream for later use + streams[streamId] = argument; + // Replace the stream with a placeholder + // Use capitalized StreamId because the MessagePack-CSharp library expects exact case for arguments + // Json allows case-insensitive arguments by default + args[i] = { StreamId: streamId }; + } + } + + return streams; + } + + private isObservable(arg: any): arg is IStreamResult { + // This allows other stream implementations to just work (like rxjs) + return arg.subscribe && typeof arg.subscribe === "function"; + } + private createStreamInvocation(methodName: string, args: any[]): StreamInvocationMessage { - const id = this.id; - this.id++; + const invocationId = this.invocationId; + this.invocationId++; return { arguments: args, - invocationId: id.toString(), + invocationId: invocationId.toString(), target: methodName, type: MessageType.StreamInvocation, }; @@ -551,11 +617,26 @@ export class HubConnection { }; } - public createStreamDataMessage(id: string, item: any): StreamDataMessage { + private createStreamDataMessage(id: string, item: any): StreamDataMessage { return { item, streamId: id, type: MessageType.StreamData, }; } + + private createStreamCompleteMessage(id: string, error?: string): StreamCompleteMessage { + if (error) { + return { + error, + streamId: id, + type: MessageType.StreamComplete, + }; + } + + return { + streamId: id, + type: MessageType.StreamComplete, + }; + } } diff --git a/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts b/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts index d7c45eec19..d5c39e5e90 100644 --- a/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts +++ b/src/SignalR/clients/ts/signalr/src/IHubProtocol.ts @@ -33,14 +33,16 @@ export interface MessageHeaders { } /** Union type of all known Hub messages. */ -export type HubMessage = InvocationMessage - | StreamInvocationMessage - | StreamItemMessage - | CompletionMessage - | CancelInvocationMessage - | PingMessage - | CloseMessage - | StreamDataMessage; +export type HubMessage = + InvocationMessage | + StreamInvocationMessage | + StreamItemMessage | + CompletionMessage | + CancelInvocationMessage | + PingMessage | + CloseMessage | + StreamCompleteMessage | + StreamDataMessage; /** Defines properties common to all Hub messages. */ export interface HubMessageBase { @@ -100,10 +102,10 @@ export interface StreamDataMessage extends HubMessageBase { /** @inheritDoc */ readonly type: MessageType.StreamData; - /** The streamId */ + /** The streamId. */ readonly streamId: string; - /** The item produced by the client */ + /** The item produced by the client. */ readonly item?: any; } @@ -153,13 +155,13 @@ export interface CancelInvocationMessage extends HubInvocationMessage { readonly invocationId: string; } -/** A hub message send to indicate the end of stream items for a streaming parameter. */ +/** A hub message sent to indicate the end of stream items for a streaming parameter. */ export interface StreamCompleteMessage extends HubMessageBase { /** @inheritDoc */ readonly type: MessageType.StreamComplete; /** The stream ID of the stream to be completed. */ readonly streamId: string; - /** The error that trigger completion, if any. */ + /** The error that triggered completion, if any. */ readonly error?: string; } diff --git a/src/SignalR/clients/ts/signalr/src/Subject.ts b/src/SignalR/clients/ts/signalr/src/Subject.ts new file mode 100644 index 0000000000..5a6b6bd6fa --- /dev/null +++ b/src/SignalR/clients/ts/signalr/src/Subject.ts @@ -0,0 +1,45 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +import { IStreamResult, IStreamSubscriber, ISubscription } from "./Stream"; +import { SubjectSubscription } from "./Utils"; + +/** Stream implementation to stream items to the server. */ +export class Subject implements IStreamResult { + /** @internal */ + public observers: Array>; + + /** @internal */ + public cancelCallback?: () => Promise; + + constructor() { + this.observers = []; + } + + public next(item: T): void { + for (const observer of this.observers) { + observer.next(item); + } + } + + public error(err: any): void { + for (const observer of this.observers) { + if (observer.error) { + observer.error(err); + } + } + } + + public complete(): void { + for (const observer of this.observers) { + if (observer.complete) { + observer.complete(); + } + } + } + + public subscribe(observer: IStreamSubscriber): ISubscription { + this.observers.push(observer); + return new SubjectSubscription(this, observer); + } +} diff --git a/src/SignalR/clients/ts/signalr/src/Utils.ts b/src/SignalR/clients/ts/signalr/src/Utils.ts index 6e0ae91862..23780eafa8 100644 --- a/src/SignalR/clients/ts/signalr/src/Utils.ts +++ b/src/SignalR/clients/ts/signalr/src/Utils.ts @@ -4,7 +4,8 @@ import { HttpClient } from "./HttpClient"; import { ILogger, LogLevel } from "./ILogger"; import { NullLogger } from "./Loggers"; -import { IStreamResult, IStreamSubscriber, ISubscription } from "./Stream"; +import { IStreamSubscriber, ISubscription } from "./Stream"; +import { Subject } from "./Subject"; /** @private */ export class Arg { @@ -104,44 +105,6 @@ export function createLogger(logger?: ILogger | LogLevel) { return new ConsoleLogger(logger as LogLevel); } -/** @private */ -export class Subject implements IStreamResult { - public observers: Array>; - public cancelCallback: () => Promise; - - constructor(cancelCallback: () => Promise) { - this.observers = []; - this.cancelCallback = cancelCallback; - } - - public next(item: T): void { - for (const observer of this.observers) { - observer.next(item); - } - } - - public error(err: any): void { - for (const observer of this.observers) { - if (observer.error) { - observer.error(err); - } - } - } - - public complete(): void { - for (const observer of this.observers) { - if (observer.complete) { - observer.complete(); - } - } - } - - public subscribe(observer: IStreamSubscriber): ISubscription { - this.observers.push(observer); - return new SubjectSubscription(this, observer); - } -} - /** @private */ export class SubjectSubscription implements ISubscription { private subject: Subject; @@ -158,7 +121,7 @@ export class SubjectSubscription implements ISubscription { this.subject.observers.splice(index, 1); } - if (this.subject.observers.length === 0) { + if (this.subject.observers.length === 0 && this.subject.cancelCallback) { this.subject.cancelCallback().catch((_) => { }); } } diff --git a/src/SignalR/clients/ts/signalr/src/index.ts b/src/SignalR/clients/ts/signalr/src/index.ts index 5ede8dab75..f770e9049e 100644 --- a/src/SignalR/clients/ts/signalr/src/index.ts +++ b/src/SignalR/clients/ts/signalr/src/index.ts @@ -13,9 +13,11 @@ export { DefaultHttpClient } from "./DefaultHttpClient"; export { IHttpConnectionOptions } from "./IHttpConnectionOptions"; export { HubConnection, HubConnectionState } from "./HubConnection"; export { HubConnectionBuilder } from "./HubConnectionBuilder"; -export { MessageType, MessageHeaders, HubMessage, HubMessageBase, HubInvocationMessage, InvocationMessage, StreamInvocationMessage, StreamItemMessage, CompletionMessage, PingMessage, CloseMessage, CancelInvocationMessage, IHubProtocol } from "./IHubProtocol"; +export { MessageType, MessageHeaders, HubMessage, HubMessageBase, HubInvocationMessage, InvocationMessage, StreamInvocationMessage, StreamItemMessage, CompletionMessage, + PingMessage, CloseMessage, CancelInvocationMessage, IHubProtocol, StreamDataMessage, StreamCompleteMessage } from "./IHubProtocol"; export { ILogger, LogLevel } from "./ILogger"; export { HttpTransportType, TransferFormat, ITransport } from "./ITransport"; export { IStreamSubscriber, IStreamResult, ISubscription } from "./Stream"; export { NullLogger } from "./Loggers"; export { JsonHubProtocol } from "./JsonHubProtocol"; +export { Subject } from "./Subject"; diff --git a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts index 85cd6e107d..6f08420c10 100644 --- a/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HubConnection.test.ts @@ -9,6 +9,7 @@ import { TransferFormat } from "../src/ITransport"; import { JsonHubProtocol } from "../src/JsonHubProtocol"; import { NullLogger } from "../src/Loggers"; import { IStreamSubscriber } from "../src/Stream"; +import { Subject } from "../src/Subject"; import { TextMessageFormat } from "../src/TextMessageFormat"; import { VerifyLogger } from "./Common"; @@ -330,25 +331,28 @@ describe("HubConnection", () => { }); }); - it("is able to send stream items to server", async () => { + it("is able to send stream items to server with invoke", async () => { await VerifyLogger.run(async (logger) => { const connection = new TestConnection(); const hubConnection = createHubConnection(connection, logger); try { - connection.receiveHandshakeResponse(); + await hubConnection.start(); - const stream = hubConnection.newUploadStream(); - const invokePromise = hubConnection.invoke("testMethod", "arg", stream.placeholder); + const subject = new Subject(); + const invokePromise = hubConnection.invoke("testMethod", "arg", subject); - expect(JSON.parse(connection.sentData[0])).toEqual({ - arguments: ["arg", {streamId: "1"}], + expect(JSON.parse(connection.sentData[1])).toEqual({ + arguments: ["arg", {StreamId: "1"}], invocationId: "0", target: "testMethod", type: MessageType.Invocation, }); - await stream.write("item numero uno"); - expect(JSON.parse(connection.sentData[1])).toEqual({ + subject.next("item numero uno"); + await new Promise((resolve) => { + setTimeout(resolve, 50); + }); + expect(JSON.parse(connection.sentData[2])).toEqual({ item: "item numero uno", streamId: "1", type: MessageType.StreamData, @@ -363,6 +367,85 @@ describe("HubConnection", () => { }); }); + it("is able to send stream items to server with send", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + const subject = new Subject(); + await hubConnection.send("testMethod", "arg", subject); + + expect(JSON.parse(connection.sentData[1])).toEqual({ + arguments: ["arg", { StreamId: "1" }], + target: "testMethod", + type: MessageType.Invocation, + }); + + subject.next("item numero uno"); + await new Promise((resolve) => { + setTimeout(resolve, 50); + }); + expect(JSON.parse(connection.sentData[2])).toEqual({ + item: "item numero uno", + streamId: "1", + type: MessageType.StreamData, + }); + } finally { + await hubConnection.stop(); + } + }); + }); + + it("is able to send stream items to server with stream", async () => { + await VerifyLogger.run(async (logger) => { + const connection = new TestConnection(); + const hubConnection = createHubConnection(connection, logger); + try { + await hubConnection.start(); + + let streamItem = ""; + let streamError: any = null; + const subject = new Subject(); + hubConnection.stream("testMethod", "arg", subject).subscribe({ + complete: () => { + }, + error: (e) => { + streamError = e; + }, + next: (item) => { + streamItem = item; + }, + }); + + expect(JSON.parse(connection.sentData[1])).toEqual({ + arguments: ["arg", { StreamId: "1" }], + invocationId: "0", + target: "testMethod", + type: MessageType.StreamInvocation, + }); + + subject.next("item numero uno"); + await new Promise((resolve) => { + setTimeout(resolve, 50); + }); + expect(JSON.parse(connection.sentData[2])).toEqual({ + item: "item numero uno", + streamId: "1", + type: MessageType.StreamData, + }); + + connection.receive({ type: MessageType.StreamItem, invocationId: connection.lastInvocationId, item: "foo" }); + expect(streamItem).toEqual("foo"); + + expect(streamError).toBe(null); + } finally { + await hubConnection.stop(); + } + }); + }); + it("completes pending invocations when stopped", async () => { await VerifyLogger.run(async (logger) => { const connection = new TestConnection(); diff --git a/src/SignalR/samples/ClientSample/UploadSample.cs b/src/SignalR/samples/ClientSample/UploadSample.cs index a6ce78f7ca..efdddb6d1c 100644 --- a/src/SignalR/samples/ClientSample/UploadSample.cs +++ b/src/SignalR/samples/ClientSample/UploadSample.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System.Collections.Generic; @@ -19,7 +19,7 @@ namespace ClientSample { cmd.Description = "Tests a streaming invocation from client to hub"; - CommandArgument baseUrlArgument = cmd.Argument("", "The URL to the Chat Hub to test"); + var baseUrlArgument = cmd.Argument("", "The URL to the Chat Hub to test"); cmd.OnExecute(() => ExecuteAsync(baseUrlArgument.Value)); }); @@ -34,7 +34,6 @@ namespace ClientSample //await BasicInvoke(connection); //await ScoreTrackerExample(connection); - //await FileUploadExample(connection); await StreamingEcho(connection); return 0; @@ -58,8 +57,6 @@ namespace ClientSample public static async Task ScoreTrackerExample(HubConnection connection) { - // Andrew please add the updated code from your laptop here - var channel_one = Channel.CreateBounded(2); var channel_two = Channel.CreateBounded(2); _ = WriteItemsAsync(channel_one.Writer, new[] { 2, 2, 3 }); @@ -68,7 +65,6 @@ namespace ClientSample var result = await connection.InvokeAsync("ScoreTracker", channel_one.Reader, channel_two.Reader); Debug.WriteLine(result); - async Task WriteItemsAsync(ChannelWriter source, IEnumerable scores) { await Task.Delay(1000); @@ -78,53 +74,12 @@ namespace ClientSample await Task.Delay(250); } - // tryComplete triggers the end of this upload's relayLoop + // TryComplete triggers the end of this upload's relayLoop // which sends a StreamComplete to the server source.TryComplete(); } } - public static async Task FileUploadExample(HubConnection connection) - { - var fileNameSource = @"C:\Users\t-dygray\Pictures\weeg.jpg"; - var fileNameDest = @"C:\Users\t-dygray\Pictures\TargetFolder\weeg.jpg"; - - var channel = Channel.CreateUnbounded(); - var invocation = connection.InvokeAsync("UploadFile", fileNameDest, channel.Reader); - - using (var file = new FileStream(fileNameSource, FileMode.Open, FileAccess.Read)) - { - foreach (var chunk in GetChunks(file, kilobytesPerChunk: 5)) - { - await channel.Writer.WriteAsync(chunk); - } - } - channel.Writer.TryComplete(); - - Debug.WriteLine(await invocation); - } - - public static IEnumerable GetChunks(FileStream fileStream, double kilobytesPerChunk) - { - var chunkSize = (int)kilobytesPerChunk * 1024; - - var position = 0; - while (true) - { - if (position + chunkSize > fileStream.Length) - { - var lastChunk = new byte[fileStream.Length - position]; - fileStream.Read(lastChunk, 0, lastChunk.Length); - yield return lastChunk; - break; - } - - var chunk = new byte[chunkSize]; - position += fileStream.Read(chunk, 0, chunk.Length); - yield return chunk; - } - } - public static async Task StreamingEcho(HubConnection connection) { var channel = Channel.CreateUnbounded(); diff --git a/src/SignalR/samples/SignalRSamples/Hubs/UploadHub.cs b/src/SignalR/samples/SignalRSamples/Hubs/UploadHub.cs index a0d5a912ea..eee36eb2cc 100644 --- a/src/SignalR/samples/SignalRSamples/Hubs/UploadHub.cs +++ b/src/SignalR/samples/SignalRSamples/Hubs/UploadHub.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -29,7 +29,7 @@ namespace SignalRSamples.Hubs // receiving a StreamCompleteMessage should cause this WaitToRead to return false while (await source.WaitToReadAsync()) { - while (source.TryRead(out string item)) + while (source.TryRead(out var item)) { Debug.WriteLine($"received: {item}"); Console.WriteLine($"received: {item}"); @@ -55,7 +55,7 @@ namespace SignalRSamples.Hubs while (await reader.WaitToReadAsync()) { - while (reader.TryRead(out int item)) + while (reader.TryRead(out var item)) { Debug.WriteLine($"got score {item}"); score += item; @@ -66,24 +66,6 @@ namespace SignalRSamples.Hubs } } - public async Task UploadFile(string filepath, ChannelReader source) - { - var result = Enumerable.Empty(); - var chunk = 1; - - while (await source.WaitToReadAsync()) - { - while (source.TryRead(out var item)) - { - Debug.WriteLine($"received chunk #{chunk++}"); - result = result.Concat(item); // atrocious - await Task.Delay(50); - } - } - - File.WriteAllBytes(filepath, result.ToArray()); - } - public ChannelReader StreamEcho(ChannelReader source) { var output = Channel.CreateUnbounded(); diff --git a/samples/SignalRSamples/wwwroot/channelParameters.html b/src/SignalR/samples/SignalRSamples/wwwroot/channelParameters.html similarity index 83% rename from samples/SignalRSamples/wwwroot/channelParameters.html rename to src/SignalR/samples/SignalRSamples/wwwroot/channelParameters.html index 464186a670..159cc48bda 100644 --- a/samples/SignalRSamples/wwwroot/channelParameters.html +++ b/src/SignalR/samples/SignalRSamples/wwwroot/channelParameters.html @@ -1,4 +1,4 @@ - + @@ -87,32 +87,25 @@ }); async function run(method) { - - //let id = invocationCounter; - //invocationCounter += 1; - - //alert("invoking " + method); - if (method == "Echo") { var promise = connection.invoke(method, "hello?"); promise.then(function (result) { - alert("received response -- " + result); + addLine('resultsList', "received " + result); }); } else if (method == "Sum") { - // var data = new Blob(['D', 'R', 'E', 'A', 'M'], { type: 'plain/text', endings: 'native' }); - var stream = connection.newUploadStream(); - var promise = connection.invoke("UploadWord", stream); + var subject = new signalR.Subject(); + var promise = connection.invoke("UploadWord", subject); - await stream.write("Z"); - await stream.write("O"); - await stream.write("O"); - await stream.write("P"); - await stream.write("!"); - await stream.complete(); + subject.next("Z"); + subject.next("o"); + subject.next("o"); + subject.next("p"); + subject.next("!"); + subject.complete(); promise.then(function (result) { - alert("received response -- " + result); + addLine('resultsList', "received " + result); }); } else { diff --git a/src/SignalR/src/Common/ReflectionHelper.cs b/src/SignalR/src/Common/ReflectionHelper.cs index 654371b97a..0692da5713 100644 --- a/src/SignalR/src/Common/ReflectionHelper.cs +++ b/src/SignalR/src/Common/ReflectionHelper.cs @@ -1,30 +1,19 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Text; using System.Threading.Channels; namespace Microsoft.AspNetCore.SignalR { internal static class ReflectionHelper { - public static bool IsStreamingType(Type type) + // mustBeDirectType - Hub methods must use the base 'stream' type and not be a derived class that just implements the 'stream' type + // and 'stream' types from the client are allowed to inherit from accepted 'stream' types + public static bool IsStreamingType(Type type, bool mustBeDirectType = false) { - // IMPORTANT !! - // All valid types must be generic - // because HubConnectionContext gets the generic argument and uses it to determine the expected item type of the stream - // The long-term solution is making a (streaming type => expected item type) method. - - if (!type.IsGenericType) - { - return false; - } - - // walk up inheritance chain, until parent is either null or a ChannelReader // TODO #2594 - add Streams here, to make sending files easy - while (type != null) + do { if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>)) { @@ -32,7 +21,8 @@ namespace Microsoft.AspNetCore.SignalR } type = type.BaseType; - } + } while (mustBeDirectType == false && type != null); + return false; } } diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index af77e90e72..d65b33ad07 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -146,7 +146,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage) { var errorString = ErrorMessageHelper.BuildErrorMessage( - $"Failed to bind Stream Item arguments to proper type.", + "Failed to bind Stream message.", bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors); var message = new StreamCompleteMessage(bindingFailureMessage.Id, errorString); @@ -160,7 +160,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { Log.ReceivedStreamItem(_logger, message); return connection.StreamTracker.ProcessItem(message); - } private Task ProcessInvocation(HubConnectionContext connection, @@ -291,8 +290,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal // Invoke Async, one reponse expected async Task ExecuteInvocation() { - var result = await ExecuteHubMethod(methodExecutor, hub, arguments); - Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); + object result; + try + { + result = await ExecuteHubMethod(methodExecutor, hub, arguments); + Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); + } + catch (Exception ex) + { + await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection, + ErrorMessageHelper.BuildErrorMessage($"An unexpected error occurred invoking '{hubMethodInvocationMessage.Target}' on the server.", ex, _enableDetailedErrors)); + return; + } + await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result)); } invocation = ExecuteInvocation(); diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs index fe22be662a..95421780e8 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -79,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal private Type GetParameterType(ParameterInfo p) { var type = p.ParameterType; - if (ReflectionHelper.IsStreamingType(type)) + if (ReflectionHelper.IsStreamingType(type, mustBeDirectType: true)) { HasStreamingParameters = true; return typeof(StreamPlaceholder); @@ -141,4 +141,4 @@ namespace Microsoft.AspNetCore.SignalR.Internal return lambda.Compile(); } } -} \ No newline at end of file +} diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs index f74cb7d1ff..adeac3278c 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.MessagePack/Protocol/MessagePackHubProtocol.cs @@ -132,6 +132,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return CreateInvocationMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.StreamInvocationMessageType: return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver); + case HubProtocolConstants.StreamDataMessageType: + return CreateStreamDataMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.StreamItemMessageType: return CreateStreamItemMessage(input, ref startOffset, binder, resolver); case HubProtocolConstants.CompletionMessageType: @@ -194,6 +196,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } + private static StreamDataMessage CreateStreamDataMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + { + var streamId = ReadString(input, ref offset, "streamId"); + var itemType = binder.GetStreamItemType(streamId); + var value = DeserializeObject(input, ref offset, itemType, "item", resolver); + return new StreamDataMessage(streamId, value); + } + private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) { var headers = ReadHeaders(input, ref offset); @@ -374,6 +384,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case StreamInvocationMessage streamInvocationMessage: WriteStreamInvocationMessage(streamInvocationMessage, packer); break; + case StreamDataMessage streamDataMessage: + WriteStreamDataMessage(streamDataMessage, packer); + break; case StreamItemMessage streamItemMessage: WriteStreamingItemMessage(streamItemMessage, packer); break; @@ -433,6 +446,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } + private void WriteStreamDataMessage(StreamDataMessage message, Stream packer) + { + MessagePackBinary.WriteArrayHeader(packer, 3); + MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamDataMessageType); + MessagePackBinary.WriteString(packer, message.StreamId); + WriteArgument(message.Item, packer); + } + private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer) { MessagePackBinary.WriteArrayHeader(packer, 4); diff --git a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs index 22bca86813..c7d96e1d39 100644 --- a/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs +++ b/src/SignalR/src/Microsoft.AspNetCore.SignalR.Protocols.NewtonsoftJson/Protocol/NewtonsoftJsonHubProtocol.cs @@ -221,13 +221,12 @@ namespace Microsoft.AspNetCore.SignalR.Protocol break; } - Type itemType = binder.GetStreamItemType(id); - try { + var itemType = binder.GetStreamItemType(id); item = PayloadSerializer.Deserialize(reader, itemType); } - catch (JsonSerializationException ex) + catch (Exception ex) { return new StreamBindingFailureMessage(id, ExceptionDispatchInfo.Capture(ex)); } @@ -338,14 +337,15 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case HubProtocolConstants.StreamDataMessageType: if (itemToken != null) { - var itemType = binder.GetStreamItemType(streamId); try { + var itemType = binder.GetStreamItemType(streamId); item = itemToken.ToObject(itemType, PayloadSerializer); } - catch (JsonSerializationException ex) + catch (Exception ex) { - return new StreamBindingFailureMessage(streamId, ExceptionDispatchInfo.Capture(ex)); + message = new StreamBindingFailureMessage(streamId, ExceptionDispatchInfo.Capture(ex)); + break; } } message = BindParamStreamMessage(streamId, item, hasItem, binder); @@ -353,14 +353,15 @@ namespace Microsoft.AspNetCore.SignalR.Protocol case HubProtocolConstants.StreamItemMessageType: if (itemToken != null) { - var returnType = binder.GetStreamItemType(invocationId); + var returnType = binder.GetReturnType(invocationId); try { item = itemToken.ToObject(returnType, PayloadSerializer); } catch (JsonSerializationException ex) { - return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex)); + message = new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex)); + break; }; } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index efcc5f1a13..d41acbe350 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -353,6 +353,43 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [LogLevel(LogLevel.Trace)] + public async Task CanStreamToAndFromClientInSameInvocation(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + using (StartServer(out var server)) + { + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + + var channelWriter = Channel.CreateBounded(5); + var channel = await connection.StreamAsChannelAsync("StreamEcho", channelWriter.Reader).OrTimeout(); + + await channelWriter.Writer.WriteAsync("1").AsTask().OrTimeout(); + Assert.Equal("1", await channel.ReadAsync().AsTask().OrTimeout()); + await channelWriter.Writer.WriteAsync("2").AsTask().OrTimeout(); + Assert.Equal("2", await channel.ReadAsync().AsTask().OrTimeout()); + channelWriter.Writer.Complete(); + + var results = await channel.ReadAllAsync().OrTimeout(); + Assert.Empty(results); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [LogLevel(LogLevel.Trace)] @@ -799,6 +836,31 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task RandomGenericIsNotTreatedAsStream() + { + bool ExpectedErrors(WriteContext writeContext) + { + return "Microsoft.AspNetCore.SignalR.Internal.DefaultHubDispatcher" == writeContext.LoggerName && + "FailedInvokingHubMethod" == writeContext.EventId.Name; + } + var hubPath = HubPaths[0]; + var hubProtocol = HubProtocols.First().Value; + var transportType = TransportTypes().First().Cast().First(); + + using (StartServer(out var server, ExpectedErrors)) + { + var connection = CreateHubConnection(server.Url, hubPath, transportType, hubProtocol, LoggerFactory); + await connection.StartAsync().OrTimeout(); + // List will be looked at to replace with a StreamPlaceholder and should be skipped, so an error will be thrown from the + // protocol on the server when it tries to match List with a StreamPlaceholder + var hubException = await Assert.ThrowsAsync(() => connection.InvokeAsync("StreamEcho", new List { "1", "2" }).OrTimeout()); + Assert.Equal("Failed to invoke 'StreamEcho' due to an error on the server. InvalidDataException: Error binding arguments. Make sure that the types of the provided values match the types of the hub method being invoked.", + hubException.Message); + await connection.DisposeAsync().OrTimeout(); + } + } + [Theory] [MemberData(nameof(TransportTypes))] public async Task ClientCanUseJwtBearerTokenForAuthentication(HttpTransportType transportType) diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 16e43f36e0..6a28816246 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -9,7 +9,6 @@ using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Authentication.JwtBearer; using Microsoft.AspNetCore.Authorization; -using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.AspNetCore.Http.Features; @@ -37,6 +36,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).SendAsync("NoClientHandler"); } + public ChannelReader StreamEcho(ChannelReader source) => TestHubMethodsImpl.StreamEcho(source); + public string GetUserIdentifier() { return Context.UserIdentifier; @@ -108,6 +109,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { await Clients.Client(Context.ConnectionId).NoClientHandler(); } + + public ChannelReader StreamEcho(ChannelReader source) => TestHubMethodsImpl.StreamEcho(source); } public class TestHubT : Hub @@ -132,22 +135,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).NoClientHandler(); } - public ChannelReader IncrementEach(ChannelReader source) - { - var output = Channel.CreateUnbounded(); - _ = Task.Run(async () => { - while (await source.WaitToReadAsync()) - { - while (source.TryRead(out var item)) - { - await output.Writer.WriteAsync(item + 1); - } - } - output.Writer.TryComplete(); - }); - - return output.Reader; - } + public ChannelReader StreamEcho(ChannelReader source) => TestHubMethodsImpl.StreamEcho(source); } internal static class TestHubMethodsImpl @@ -186,6 +174,23 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } public static ChannelReader StreamBroken() => null; + + public static ChannelReader StreamEcho(ChannelReader source) + { + var output = Channel.CreateUnbounded(); + _ = Task.Run(async () => { + while (await source.WaitToReadAsync()) + { + while (source.TryRead(out var item)) + { + await output.Writer.WriteAsync(item); + } + } + output.Writer.TryComplete(); + }); + + return output.Reader; + } } public interface ITestHub diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs index 1f8bf8c009..20ff61ae33 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -248,6 +248,12 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol name: "Ping", message: PingMessage.Instance, binary: "kQY="), + + // StreamData Messages + new ProtocolTestData( + name: "StreamData", + message: new StreamDataMessage("xyz", new CustomObject()), + binary: "kwmjeHl6hqpTdHJpbmdQcm9wqFNpZ25hbFIhqkRvdWJsZVByb3DLQBkh+1RCzxKnSW50UHJvcCqsRGF0ZVRpbWVQcm9w1v9Y7ByAqE51bGxQcm9wwKtCeXRlQXJyUHJvcMQDAQID"), }.ToDictionary(t => t.Name); [Theory] diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs index d7282c42a1..c12028411a 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestBinder.cs @@ -29,6 +29,9 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol case CompletionMessage c: _returnType = c.Result?.GetType() ?? typeof(object); break; + case StreamDataMessage sd: + _returnType = sd.Item.GetType() ?? typeof(object); + break; } } diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs index 0682c75640..c67b208b23 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/TestHubMessageEqualityComparer.cs @@ -41,6 +41,8 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol return string.Equals(closeMessage.Error, ((CloseMessage) y).Error); case StreamCompleteMessage streamCompleteMessage: return StreamCompleteMessagesEqual(streamCompleteMessage, (StreamCompleteMessage)y); + case StreamDataMessage streamDataMessage: + return StreamDataMessagesEqual(streamDataMessage, (StreamDataMessage)y); default: throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}"); } @@ -81,7 +83,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol private bool StreamCompleteMessagesEqual(StreamCompleteMessage x, StreamCompleteMessage y) { return x.StreamId == y.StreamId - && y.Error == y.Error; + && x.Error == y.Error; + } + + private bool StreamDataMessagesEqual(StreamDataMessage x, StreamDataMessage y) + { + return x.StreamId == y.StreamId + && (Equals(x.Item, y.Item) || SequenceEqual(x.Item, y.Item)); } private bool ArgumentListsEqual(object[] left, object[] right) diff --git a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 679aa4f84c..ca0c1e8c3c 100644 --- a/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/src/SignalR/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -2809,111 +2809,138 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public async Task UploadStreamItemInvalidTypeAutoCasts() { - // NOTE -- json.net is flexible here, and casts for us - - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) + using (StartVerifiableLog()) { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id")).OrTimeout(); + // NOTE -- json.net is flexible here, and casts for us - // send integers that are then cast to strings - await client.SendHubMessageAsync(new StreamDataMessage("id", 5)).OrTimeout(); - await client.SendHubMessageAsync(new StreamDataMessage("id", 10)).OrTimeout(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); - await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout(); - var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); - - Assert.Equal("510", response.Result); + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id")).OrTimeout(); + + // send integers that are then cast to strings + await client.SendHubMessageAsync(new StreamDataMessage("id", 5)).OrTimeout(); + await client.SendHubMessageAsync(new StreamDataMessage("id", 10)).OrTimeout(); + + await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout(); + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + + Assert.Equal("510", response.Result); + } } } [Fact] public async Task ServerReportsProtocolMinorVersion() { - var testProtocol = new Mock(); - testProtocol.Setup(m => m.Name).Returns("CustomProtocol"); - testProtocol.Setup(m => m.MinorVersion).Returns(112); - testProtocol.Setup(m => m.IsVersionSupported(It.IsAny())).Returns(true); - testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary); - - var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), - (services) => services.AddSingleton(testProtocol.Object)); - - using (var client = new TestClient(protocol: testProtocol.Object)) + using (StartVerifiableLog()) { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + var testProtocol = new Mock(); + testProtocol.Setup(m => m.Name).Returns("CustomProtocol"); + testProtocol.Setup(m => m.MinorVersion).Returns(112); + testProtocol.Setup(m => m.IsVersionSupported(It.IsAny())).Returns(true); + testProtocol.Setup(m => m.TransferFormat).Returns(TransferFormat.Binary); - Assert.NotNull(client.HandshakeResponseMessage); - Assert.Equal(112, client.HandshakeResponseMessage.MinorVersion); + var connectionHandler = HubConnectionHandlerTestUtils.GetHubConnectionHandler(typeof(HubT), + (services) => services.AddSingleton(testProtocol.Object), LoggerFactory); - client.Dispose(); - await connectionHandlerTask.OrTimeout(); + using (var client = new TestClient(protocol: testProtocol.Object)) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + Assert.NotNull(client.HandshakeResponseMessage); + Assert.Equal(112, client.HandshakeResponseMessage.MinorVersion); + + client.Dispose(); + await connectionHandlerTask.OrTimeout(); + } } } [Fact] public async Task UploadStreamItemInvalidType() { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) + using (StartVerifiableLog()) { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocationId", nameof(MethodHub.TestTypeCastingErrors), new StreamPlaceholder("channelId")).OrTimeout(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); - // client is running wild, sending strings not ints. - // this error should be propogated to the user's HubMethod code - await client.SendHubMessageAsync(new StreamItemMessage("channelId", "not a number")).OrTimeout(); - var response = await client.ReadAsync().OrTimeout(); + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocationId", nameof(MethodHub.TestTypeCastingErrors), new StreamPlaceholder("channelId")).OrTimeout(); - Assert.Equal(typeof(CompletionMessage), response.GetType()); - Assert.Equal("error identified and caught", (string)((CompletionMessage)response).Result); + // client is running wild, sending strings not ints. + // this error should be propogated to the user's HubMethod code + await client.SendHubMessageAsync(new StreamItemMessage("channelId", "not a number")).OrTimeout(); + var response = await client.ReadAsync().OrTimeout(); + + Assert.Equal(typeof(CompletionMessage), response.GetType()); + Assert.Equal("error identified and caught", (string)((CompletionMessage)response).Result); + } } } [Fact] public async Task UploadStreamItemInvalidId() { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + bool ExpectedErrors(WriteContext writeContext) { - services.AddSignalR(options => options.EnableDetailedErrors = true); - }); - var connectionHandler = serviceProvider.GetService>(); + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && + writeContext.EventId.Name == "ErrorProcessingRequest"; + } - using (var client = new TestClient()) + using (StartVerifiableLog(ExpectedErrors)) { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.SendHubMessageAsync(new StreamItemMessage("fake_id", "not a number")).OrTimeout(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => options.EnableDetailedErrors = true); + }, loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); - // Client is breaking protocol by sending an invalid id, and should be closed. - var message = client.TryRead(); - Assert.IsType(message); - Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error); + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.SendHubMessageAsync(new StreamItemMessage("fake_id", "not a number")).OrTimeout(); + + // Client is breaking protocol by sending an invalid id, and should be closed. + var message = client.TryRead(); + Assert.IsType(message); + Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error); + } } } [Fact] public async Task UploadStreamCompleteInvalidId() { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + bool ExpectedErrors(WriteContext writeContext) { - services.AddSignalR(options => options.EnableDetailedErrors = true); - }); - var connectionHandler = serviceProvider.GetService>(); + return writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && + writeContext.EventId.Name == "ErrorProcessingRequest"; + } - using (var client = new TestClient()) + using (StartVerifiableLog(ExpectedErrors)) { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.SendHubMessageAsync(new StreamCompleteMessage("fake_id")).OrTimeout(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services => + { + services.AddSignalR(options => options.EnableDetailedErrors = true); + }, loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); - // Client is breaking protocol by sending an invalid id, and should be closed. - var message = client.TryRead(); - Assert.IsType(message); - Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error); + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.SendHubMessageAsync(new StreamCompleteMessage("fake_id")).OrTimeout(); + + // Client is breaking protocol by sending an invalid id, and should be closed. + var message = client.TryRead(); + Assert.IsType(message); + Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error); + } } } @@ -2922,17 +2949,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public async Task UploadStreamCompleteWithError() { - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); - var connectionHandler = serviceProvider.GetService>(); - - using (var client = new TestClient()) + using (StartVerifiableLog()) { - await client.ConnectAsync(connectionHandler).OrTimeout(); - await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.TestCustomErrorPassing), new StreamPlaceholder("id")).OrTimeout(); - await client.SendHubMessageAsync(new StreamCompleteMessage("id", CustomErrorMessage)).OrTimeout(); + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); - var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); - Assert.True((bool)response.Result); + using (var client = new TestClient()) + { + await client.ConnectAsync(connectionHandler).OrTimeout(); + await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.TestCustomErrorPassing), new StreamPlaceholder("id")).OrTimeout(); + await client.SendHubMessageAsync(new StreamCompleteMessage("id", CustomErrorMessage)).OrTimeout(); + + var response = (CompletionMessage)await client.ReadAsync().OrTimeout(); + Assert.True((bool)response.Result); + } } } @@ -3043,40 +3073,43 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public async Task CanPassStreamingParameterToStreamHubMethod() { - IServiceProvider serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(); - HubConnectionHandler connectionHandler = serviceProvider.GetService>(); - Mock invocationBinder = new Mock(); - invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); - - using (TestClient client = new TestClient(invocationBinder: invocationBinder.Object)) + using (StartVerifiableLog()) { - Task connectionHandlerTask = await client.ConnectAsync(connectionHandler); + IServiceProvider serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + HubConnectionHandler connectionHandler = serviceProvider.GetService>(); + Mock invocationBinder = new Mock(); + invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny())).Returns(typeof(string)); - // Wait for a connection, or for the endpoint to fail. - await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).OrTimeout(); - - var streamId = "sample_id"; - var messagePromise = client.StreamAsync("StreamEcho", new StreamPlaceholder(streamId)).OrTimeout(); - - var phrases = new[] { "asdf", "qwer", "zxcv" }; - foreach (var phrase in phrases) + using (TestClient client = new TestClient(invocationBinder: invocationBinder.Object)) { - await client.SendHubMessageAsync(new StreamDataMessage(streamId, phrase)); + Task connectionHandlerTask = await client.ConnectAsync(connectionHandler); + + // Wait for a connection, or for the endpoint to fail. + await client.Connected.OrThrowIfOtherFails(connectionHandlerTask).OrTimeout(); + + var streamId = "sample_id"; + var messagePromise = client.StreamAsync(nameof(StreamingHub.StreamEcho), new StreamPlaceholder(streamId)).OrTimeout(); + + var phrases = new[] { "asdf", "qwer", "zxcv" }; + foreach (var phrase in phrases) + { + await client.SendHubMessageAsync(new StreamDataMessage(streamId, phrase)); + } + await client.SendHubMessageAsync(new StreamCompleteMessage(streamId)); + + var messages = await messagePromise; + + // add one because this includes the completion + Assert.Equal(phrases.Count() + 1, messages.Count); + for (var i = 0; i < phrases.Count(); i++) + { + Assert.Equal("echo:" + phrases[i], ((StreamItemMessage)messages[i]).Item); + } + + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); } - await client.SendHubMessageAsync(new StreamCompleteMessage(streamId)); - - var messages = await messagePromise; - - // add one because this includes the completion - Assert.Equal(phrases.Count() + 1, messages.Count); - for (var i = 0; i < phrases.Count(); i++) - { - Assert.Equal("echo:" + phrases[i], ((StreamItemMessage)messages[i]).Item); - } - - client.Dispose(); - - await connectionHandlerTask.OrTimeout(); } }