diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts index a2ec536e1f..aa7d05b871 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts @@ -387,6 +387,31 @@ describe("HubConnection", () => { // Expectation is connection.receive will not to throw connection.receive({ type: MessageType.Completion, invocationId: connection.lastInvocationId }); }); + + it("can be canceled", () => { + let connection = new TestConnection(); + + let hubConnection = new HubConnection(connection, { logger: null }); + let observer = new TestObserver(); + let subscription = hubConnection.stream("testMethod") + .subscribe(observer); + + connection.receive({ type: MessageType.StreamItem, invocationId: connection.lastInvocationId, item: 1 }); + expect(observer.itemsReceived).toEqual([1]); + + subscription.dispose(); + + connection.receive({ type: MessageType.StreamItem, invocationId: connection.lastInvocationId, item: 2 }); + // Observer should no longer receive messages + expect(observer.itemsReceived).toEqual([1]); + + // Verify the cancel is sent + expect(connection.sentData.length).toBe(2); + expect(JSON.parse(connection.sentData[1])).toEqual({ + type: MessageType.CancelInvocation, + invocationId: connection.lastInvocationId + }); + }); }); describe("onClose", () => { diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts index a4a8451c90..7c4d819b05 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts @@ -6,7 +6,7 @@ import { IConnection } from "./IConnection" import { HttpConnection } from "./HttpConnection" import { TransportType, TransferMode } from "./Transports" import { Subject, Observable } from "./Observable" -import { IHubProtocol, ProtocolType, MessageType, HubMessage, CompletionMessage, ResultMessage, InvocationMessage, StreamInvocationMessage, NegotiationMessage } from "./IHubProtocol"; +import { IHubProtocol, ProtocolType, MessageType, HubMessage, CompletionMessage, ResultMessage, InvocationMessage, StreamInvocationMessage, NegotiationMessage, CancelInvocation } from "./IHubProtocol"; import { JsonHubProtocol } from "./JsonHubProtocol"; import { TextMessageFormat } from "./Formatters" import { Base64EncodedHubProtocol } from "./Base64EncodedHubProtocol" @@ -167,7 +167,14 @@ export class HubConnection { stream(methodName: string, ...args: any[]): Observable { let invocationDescriptor = this.createStreamInvocation(methodName, args); - let subject = new Subject(); + let subject = new Subject(() => { + let cancelInvocation: CancelInvocation = this.createCancelInvocation(invocationDescriptor.invocationId); + let message: any = this.protocol.writeMessage(cancelInvocation); + + this.callbacks.delete(invocationDescriptor.invocationId); + + return this.connection.send(message); + }); this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: HubMessage, error?: Error) => { if (error) { @@ -280,7 +287,7 @@ export class HubConnection { private createInvocation(methodName: string, args: any[], nonblocking: boolean): InvocationMessage { if (nonblocking) { - return { + return { type: MessageType.Invocation, target: methodName, arguments: args, @@ -290,7 +297,7 @@ export class HubConnection { let id = this.id; this.id++; - return { + return { type: MessageType.Invocation, invocationId: id.toString(), target: methodName, @@ -303,11 +310,18 @@ export class HubConnection { let id = this.id; this.id++; - return { + return { type: MessageType.StreamInvocation, invocationId: id.toString(), target: methodName, arguments: args, }; } + + private createCancelInvocation(id: string): CancelInvocation { + return { + type: MessageType.CancelInvocation, + invocationId: id, + }; + } } diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/IHubProtocol.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/IHubProtocol.ts index 01fc2831dd..03859bbaaa 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/IHubProtocol.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/IHubProtocol.ts @@ -41,6 +41,10 @@ export interface NegotiationMessage { readonly protocol: string; } +export interface CancelInvocation extends HubMessage { + readonly invocationId: string; +} + export const enum ProtocolType { Text = 1, Binary diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Observable.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Observable.ts index e2ea7cb11e..eeaaf0c626 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Observable.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Observable.ts @@ -10,16 +10,38 @@ export interface Observer { complete?: () => void; } +export class Subscription { + subject: Subject; + observer: Observer; + + constructor(subject: Subject, observer: Observer) { + this.subject = subject; + this.observer = observer; + } + + public dispose(): void { + let index: number = this.subject.observers.indexOf(this.observer); + if (index > -1) { + this.subject.observers.splice(index, 1); + } + + if (this.subject.observers.length === 0) { + this.subject.cancelCallback().catch((_) => { }); + } + } +} + export interface Observable { - // TODO: Return a Subscription so the caller can unsubscribe? IDisposable in System.IObservable - subscribe(observer: Observer): void; + subscribe(observer: Observer): Subscription; } export class Subject implements Observable { observers: Observer[]; + cancelCallback: () => Promise; - constructor() { + constructor(cancelCallback: () => Promise) { this.observers = []; + this.cancelCallback = cancelCallback; } public next(item: T): void { @@ -44,7 +66,8 @@ export class Subject implements Observable { } } - public subscribe(observer: Observer): void { + public subscribe(observer: Observer): Subscription { this.observers.push(observer); + return new Subscription(this, observer); } }