diff --git a/clients/ts/FunctionalTests/ts/Common.ts b/clients/ts/FunctionalTests/ts/Common.ts index 7c46dfcc29..376c3733bd 100644 --- a/clients/ts/FunctionalTests/ts/Common.ts +++ b/clients/ts/FunctionalTests/ts/Common.ts @@ -1,31 +1,31 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -import { IHubProtocol, JsonHubProtocol, TransportType } from "@aspnet/signalr"; +import { HttpTransportType, IHubProtocol, JsonHubProtocol } from "@aspnet/signalr"; import { MessagePackHubProtocol } from "@aspnet/signalr-protocol-msgpack"; export const ECHOENDPOINT_URL = "http://" + document.location.host + "/echo"; -export function getTransportTypes(): TransportType[] { +export function getHttpTransportTypes(): HttpTransportType[] { const transportTypes = []; if (typeof WebSocket !== "undefined") { - transportTypes.push(TransportType.WebSockets); + transportTypes.push(HttpTransportType.WebSockets); } if (typeof EventSource !== "undefined") { - transportTypes.push(TransportType.ServerSentEvents); + transportTypes.push(HttpTransportType.ServerSentEvents); } - transportTypes.push(TransportType.LongPolling); + transportTypes.push(HttpTransportType.LongPolling); return transportTypes; } -export function eachTransport(action: (transport: TransportType) => void) { - getTransportTypes().forEach((t) => { +export function eachTransport(action: (transport: HttpTransportType) => void) { + getHttpTransportTypes().forEach((t) => { return action(t); }); } -export function eachTransportAndProtocol(action: (transport: TransportType, protocol: IHubProtocol) => void) { +export function eachTransportAndProtocol(action: (transport: HttpTransportType, protocol: IHubProtocol) => void) { const protocols: IHubProtocol[] = [new JsonHubProtocol()]; // IE9 does not support XmlHttpRequest advanced features so disable for now // This can be enabled if we fix: https://github.com/aspnet/SignalR/issues/742 @@ -35,9 +35,9 @@ export function eachTransportAndProtocol(action: (transport: TransportType, prot // Everything works fine in the module protocols.push(new MessagePackHubProtocol()); } - getTransportTypes().forEach((t) => { + getHttpTransportTypes().forEach((t) => { return protocols.forEach((p) => { - if (t !== TransportType.ServerSentEvents || !(p instanceof MessagePackHubProtocol)) { + if (t !== HttpTransportType.ServerSentEvents || !(p instanceof MessagePackHubProtocol)) { return action(t, p); } }); diff --git a/clients/ts/FunctionalTests/ts/ConnectionTests.ts b/clients/ts/FunctionalTests/ts/ConnectionTests.ts index bfcb3d3252..05143c0938 100644 --- a/clients/ts/FunctionalTests/ts/ConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/ConnectionTests.ts @@ -1,7 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -import { HttpConnection, IHttpConnectionOptions, LogLevel, TransferFormat, TransportType } from "@aspnet/signalr"; +import { HttpConnection, HttpTransportType, IHttpConnectionOptions, LogLevel, TransferFormat } from "@aspnet/signalr"; import { eachTransport, ECHOENDPOINT_URL } from "./Common"; import { TestLogger } from "./TestLogger"; @@ -42,7 +42,7 @@ describe("connection", () => { }); eachTransport((transportType) => { - describe(`over ${TransportType[transportType]}`, () => { + describe(`over ${HttpTransportType[transportType]}`, () => { it("can send and receive messages", (done) => { const message = "Hello World!"; // the url should be resolved relative to the document.location.host diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index f99a05aedd..d00ada4438 100644 --- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -1,7 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -import { HubConnection, IHubConnectionOptions, JsonHubProtocol, LogLevel, TransportType } from "@aspnet/signalr"; +import { DefaultHttpClient, HttpClient, HttpRequest, HttpResponse, HttpTransportType, HubConnection, IHubConnectionOptions, JsonHubProtocol, LogLevel } from "@aspnet/signalr"; import { MessagePackHubProtocol } from "@aspnet/signalr-protocol-msgpack"; import { eachTransport, eachTransportAndProtocol } from "./Common"; @@ -20,7 +20,7 @@ jasmine.DEFAULT_TIMEOUT_INTERVAL = 10 * 1000; describe("hubConnection", () => { eachTransportAndProtocol((transportType, protocol) => { - describe("using " + protocol.name + " over " + TransportType[transportType] + " transport", () => { + describe("using " + protocol.name + " over " + HttpTransportType[transportType] + " transport", async () => { it("can invoke server method and receive result", (done) => { const message = "你好,世界!"; @@ -505,7 +505,7 @@ describe("hubConnection", () => { }); eachTransport((transportType) => { - describe("over " + TransportType[transportType] + " transport", () => { + describe("over " + HttpTransportType[transportType] + " transport", () => { it("can connect to hub with authorization", async (done) => { const message = "你好,世界!"; @@ -562,7 +562,34 @@ describe("hubConnection", () => { } }); - if (transportType !== TransportType.LongPolling) { + it("can connect to hub with authorization using async token factory", async (done) => { + const message = "你好,世界!"; + + try { + const hubConnection = new HubConnection("/authorizedhub", { + accessTokenFactory: () => getJwtToken("http://" + document.location.host + "/generateJwtToken"), + ...commonOptions, + transport: transportType, + }); + hubConnection.onclose((error) => { + expect(error).toBe(undefined); + done(); + }); + await hubConnection.start(); + const response = await hubConnection.invoke("Echo", message); + + expect(response).toEqual(message); + + await hubConnection.stop(); + + done(); + } catch (err) { + fail(err); + done(); + } + }); + + if (transportType !== HttpTransportType.LongPolling) { it("terminates if no messages received within timeout interval", (done) => { const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, { ...commonOptions, @@ -631,7 +658,7 @@ describe("hubConnection", () => { }; const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, { - logger: LogLevel.Trace, + ...commonOptions, protocol: new JsonHubProtocol(), }); @@ -649,6 +676,44 @@ describe("hubConnection", () => { } }); + it("over LongPolling it sends DELETE request and waits for poll to terminate", async (done) => { + // Create an HTTP client to capture the poll + const defaultClient = new DefaultHttpClient(TestLogger.instance); + + class TestClient extends HttpClient { + public pollPromise: Promise; + + public send(request: HttpRequest): Promise { + if (request.method === "GET") { + this.pollPromise = defaultClient.send(request); + return this.pollPromise; + } + return defaultClient.send(request); + } + } + + const testClient = new TestClient(); + const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, { + ...commonOptions, + httpClient: testClient, + }); + try { + await hubConnection.start(); + + expect(testClient.pollPromise).not.toBeNull(); + + // Stop the connection and await the poll terminating + const stopPromise = hubConnection.stop(); + + await testClient.pollPromise; + await stopPromise; + } catch (e) { + fail(e); + } finally { + done(); + } + }); + function getJwtToken(url): Promise { return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); diff --git a/clients/ts/signalr/spec/Common.ts b/clients/ts/signalr/spec/Common.ts index faf74045b8..fa5c397bd8 100644 --- a/clients/ts/signalr/spec/Common.ts +++ b/clients/ts/signalr/spec/Common.ts @@ -1,13 +1,13 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -import { ITransport, TransportType } from "../src/Transports"; +import { ITransport, HttpTransportType } from "../src/ITransport"; -export function eachTransport(action: (transport: TransportType) => void) { +export function eachTransport(action: (transport: HttpTransportType) => void) { const transportTypes = [ - TransportType.WebSockets, - TransportType.ServerSentEvents, - TransportType.LongPolling ]; + HttpTransportType.WebSockets, + HttpTransportType.ServerSentEvents, + HttpTransportType.LongPolling ]; transportTypes.forEach((t) => action(t)); } diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts index f6d2fb665c..15184bd686 100644 --- a/clients/ts/signalr/spec/HttpConnection.spec.ts +++ b/clients/ts/signalr/spec/HttpConnection.spec.ts @@ -5,7 +5,7 @@ import { DataReceived, TransportClosed } from "../src/Common"; import { HttpConnection } from "../src/HttpConnection"; import { IHttpConnectionOptions } from "../src/HttpConnection"; import { HttpResponse } from "../src/index"; -import { ITransport, TransferFormat, TransportType } from "../src/Transports"; +import { ITransport, TransferFormat, HttpTransportType } from "../src/ITransport"; import { eachEndpointUrl, eachTransport } from "./Common"; import { TestHttpClient } from "./TestHttpClient"; @@ -223,12 +223,12 @@ describe("HttpConnection", () => { }); }); - eachTransport((requestedTransport: TransportType) => { + eachTransport((requestedTransport: HttpTransportType) => { // OPTIONS is not sent when WebSockets transport is explicitly requested - if (requestedTransport === TransportType.WebSockets) { + if (requestedTransport === HttpTransportType.WebSockets) { return; } - it(`cannot be started if requested ${TransportType[requestedTransport]} transport not available on server`, async (done) => { + it(`cannot be started if requested ${HttpTransportType[requestedTransport]} transport not available on server`, async (done) => { const options: IHttpConnectionOptions = { ...commonOptions, httpClient: new TestHttpClient() @@ -272,7 +272,7 @@ describe("HttpConnection", () => { const options: IHttpConnectionOptions = { ...commonOptions, httpClient: new TestHttpClient(), - transport: TransportType.WebSockets, + transport: HttpTransportType.WebSockets, } as IHttpConnectionOptions; const connection = new HttpConnection("http://tempuri.org", options); @@ -288,8 +288,29 @@ describe("HttpConnection", () => { } }); + it("sets inherentKeepAlive feature when using LongPolling", async (done) => { + const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; + + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })), + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + + try { + await connection.start(TransferFormat.Text); + expect(connection.features.inherentKeepAlive).toBe(true); + done(); + } catch (e) { + fail(e); + done(); + } + }); + it("does not select ServerSentEvents transport when not available in environment", async (done) => { - const serverSentEventsTransport = { transport: "ServerSentEvents", transferFormats: [ "Text" ] }; + const serverSentEventsTransport = { transport: "ServerSentEvents", transferFormats: ["Text"] }; const options: IHttpConnectionOptions = { ...commonOptions, @@ -312,7 +333,7 @@ describe("HttpConnection", () => { }); it("does not select WebSockets transport when not available in environment", async (done) => { - const webSocketsTransport = { transport: "WebSockets", transferFormats: [ "Text" ] }; + const webSocketsTransport = { transport: "WebSockets", transferFormats: ["Text"] }; const options: IHttpConnectionOptions = { ...commonOptions, diff --git a/clients/ts/signalr/spec/HubConnection.spec.ts b/clients/ts/signalr/spec/HubConnection.spec.ts index 630cd00f55..b1169daf30 100644 --- a/clients/ts/signalr/spec/HubConnection.spec.ts +++ b/clients/ts/signalr/spec/HubConnection.spec.ts @@ -8,7 +8,7 @@ import { HubMessage, IHubProtocol, MessageType } from "../src/IHubProtocol"; import { ILogger, LogLevel } from "../src/ILogger"; import { Observer } from "../src/Observable"; import { TextMessageFormat } from "../src/TextMessageFormat"; -import { ITransport, TransferFormat, TransportType } from "../src/Transports"; +import { ITransport, TransferFormat, HttpTransportType } from "../src/ITransport"; import { IHubConnectionOptions } from "../src/HubConnection"; import { asyncit as it, captureException, delay, PromiseSource } from "./Utils"; diff --git a/clients/ts/signalr/src/HttpClient.ts b/clients/ts/signalr/src/HttpClient.ts index 16808e99ee..347a904cb4 100644 --- a/clients/ts/signalr/src/HttpClient.ts +++ b/clients/ts/signalr/src/HttpClient.ts @@ -46,6 +46,16 @@ export abstract class HttpClient { }); } + public delete(url: string): Promise; + public delete(url: string, options: HttpRequest): Promise; + public delete(url: string, options?: HttpRequest): Promise { + return this.send({ + ...options, + method: "DELETE", + url, + }); + } + public abstract send(request: HttpRequest): Promise; } diff --git a/clients/ts/signalr/src/HttpConnection.ts b/clients/ts/signalr/src/HttpConnection.ts index 97e9078722..1abc478682 100644 --- a/clients/ts/signalr/src/HttpConnection.ts +++ b/clients/ts/signalr/src/HttpConnection.ts @@ -5,13 +5,16 @@ import { ConnectionClosed, DataReceived } from "./Common"; import { DefaultHttpClient, HttpClient } from "./HttpClient"; import { IConnection } from "./IConnection"; import { ILogger, LogLevel } from "./ILogger"; +import { HttpTransportType, ITransport, TransferFormat } from "./ITransport"; import { LoggerFactory } from "./Loggers"; -import { ITransport, LongPollingTransport, ServerSentEventsTransport, TransferFormat, TransportType, WebSocketTransport } from "./Transports"; +import { LongPollingTransport } from "./LongPollingTransport"; +import { ServerSentEventsTransport } from "./ServerSentEventsTransport"; import { Arg } from "./Utils"; +import { WebSocketTransport } from "./WebSocketTransport"; export interface IHttpConnectionOptions { httpClient?: HttpClient; - transport?: TransportType | ITransport; + transport?: HttpTransportType | ITransport; logger?: ILogger | LogLevel; accessTokenFactory?: () => string | Promise; logMessageContent?: boolean; @@ -29,7 +32,7 @@ interface INegotiateResponse { } interface IAvailableTransport { - transport: keyof typeof TransportType; + transport: keyof typeof HttpTransportType; transferFormats: Array; } @@ -43,6 +46,7 @@ export class HttpConnection implements IConnection { private transport: ITransport; private connectionId: string; private startPromise: Promise; + private stopError?: Error; public readonly features: any = {}; @@ -79,10 +83,10 @@ export class HttpConnection implements IConnection { private async startInternal(transferFormat: TransferFormat): Promise { try { - if (this.options.transport === TransportType.WebSockets) { + if (this.options.transport === HttpTransportType.WebSockets) { // No need to add a connection ID in this case this.url = this.baseUrl; - this.transport = this.constructTransport(TransportType.WebSockets); + this.transport = this.constructTransport(HttpTransportType.WebSockets); // We should just call connect directly in this case. // No fallback or negotiate in this case. await this.transport.connect(this.url, transferFormat); @@ -103,12 +107,12 @@ export class HttpConnection implements IConnection { await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers); } - if (typeof this.transport === typeof LongPollingTransport) { + if (this.transport instanceof LongPollingTransport) { this.features.inherentKeepAlive = true; } this.transport.onreceive = this.onreceive; - this.transport.onclose = (e) => this.stopConnection(true, e); + this.transport.onclose = (e) => this.stopConnection(e); // only change the state if we were connecting to not overwrite // the state if the connection is already marked as Disconnected @@ -141,7 +145,7 @@ export class HttpConnection implements IConnection { this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`; } - private async createTransport(requestedTransport: TransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise { + private async createTransport(requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise { this.updateConnectionId(negotiateResponse); if (this.isITransport(requestedTransport)) { this.logger.log(LogLevel.Trace, "Connection was provided an instance of ITransport, using that directly."); @@ -169,7 +173,7 @@ export class HttpConnection implements IConnection { this.changeState(ConnectionState.Connecting, ConnectionState.Connected); return; } catch (ex) { - this.logger.log(LogLevel.Error, `Failed to start the transport '${TransportType[transport]}': ${ex}`); + this.logger.log(LogLevel.Error, `Failed to start the transport '${HttpTransportType[transport]}': ${ex}`); this.connectionState = ConnectionState.Disconnected; negotiateResponse.connectionId = null; } @@ -179,39 +183,39 @@ export class HttpConnection implements IConnection { throw new Error("Unable to initialize any of the available transports."); } - private constructTransport(transport: TransportType) { + private constructTransport(transport: HttpTransportType) { switch (transport) { - case TransportType.WebSockets: + case HttpTransportType.WebSockets: return new WebSocketTransport(this.options.accessTokenFactory, this.logger, this.options.logMessageContent); - case TransportType.ServerSentEvents: + case HttpTransportType.ServerSentEvents: return new ServerSentEventsTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent); - case TransportType.LongPolling: + case HttpTransportType.LongPolling: return new LongPollingTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent); default: throw new Error(`Unknown transport: ${transport}.`); } } - private resolveTransport(endpoint: IAvailableTransport, requestedTransport: TransportType, requestedTransferFormat: TransferFormat): TransportType | null { - const transport = TransportType[endpoint.transport]; + private resolveTransport(endpoint: IAvailableTransport, requestedTransport: HttpTransportType, requestedTransferFormat: TransferFormat): HttpTransportType | null { + const transport = HttpTransportType[endpoint.transport]; if (transport === null || transport === undefined) { this.logger.log(LogLevel.Trace, `Skipping transport '${endpoint.transport}' because it is not supported by this client.`); } else { const transferFormats = endpoint.transferFormats.map((s) => TransferFormat[s]); if (!requestedTransport || transport === requestedTransport) { if (transferFormats.indexOf(requestedTransferFormat) >= 0) { - if ((transport === TransportType.WebSockets && typeof WebSocket === "undefined") || - (transport === TransportType.ServerSentEvents && typeof EventSource === "undefined")) { - this.logger.log(LogLevel.Trace, `Skipping transport '${TransportType[transport]}' because it is not supported in your environment.'`); + if ((transport === HttpTransportType.WebSockets && typeof WebSocket === "undefined") || + (transport === HttpTransportType.ServerSentEvents && typeof EventSource === "undefined")) { + this.logger.log(LogLevel.Trace, `Skipping transport '${HttpTransportType[transport]}' because it is not supported in your environment.'`); } else { - this.logger.log(LogLevel.Trace, `Selecting transport '${TransportType[transport]}'`); + this.logger.log(LogLevel.Trace, `Selecting transport '${HttpTransportType[transport]}'`); return transport; } } else { - this.logger.log(LogLevel.Trace, `Skipping transport '${TransportType[transport]}' because it does not support the requested transfer format '${TransferFormat[requestedTransferFormat]}'.`); + this.logger.log(LogLevel.Trace, `Skipping transport '${HttpTransportType[transport]}' because it does not support the requested transfer format '${TransferFormat[requestedTransferFormat]}'.`); } } else { - this.logger.log(LogLevel.Trace, `Skipping transport '${TransportType[transport]}' because it was disabled by the client.`); + this.logger.log(LogLevel.Trace, `Skipping transport '${HttpTransportType[transport]}' because it was disabled by the client.`); } } return null; @@ -238,7 +242,6 @@ export class HttpConnection implements IConnection { } public async stop(error?: Error): Promise { - const previousState = this.connectionState; this.connectionState = ConnectionState.Disconnected; try { @@ -246,14 +249,20 @@ export class HttpConnection implements IConnection { } catch (e) { // this exception is returned to the user as a rejected Promise from the start method } - this.stopConnection(/*raiseClosed*/ previousState === ConnectionState.Connected, error); - } - private stopConnection(raiseClosed: boolean, error?: Error) { + // The transport's onclose will trigger stopConnection which will run our onclose event. if (this.transport) { - this.transport.stop(); + this.stopError = error; + await this.transport.stop(); this.transport = null; } + } + + private async stopConnection(error?: Error): Promise { + this.transport = null; + + // If we have a stopError, it takes precedence over the error from the transport + error = this.stopError || error; if (error) { this.logger.log(LogLevel.Error, `Connection disconnected with error '${error}'.`); @@ -263,7 +272,7 @@ export class HttpConnection implements IConnection { this.connectionState = ConnectionState.Disconnected; - if (raiseClosed && this.onclose) { + if (this.onclose) { this.onclose(error); } } diff --git a/clients/ts/signalr/src/HubConnection.ts b/clients/ts/signalr/src/HubConnection.ts index 5d9be6173b..01c6421833 100644 --- a/clients/ts/signalr/src/HubConnection.ts +++ b/clients/ts/signalr/src/HubConnection.ts @@ -11,7 +11,6 @@ import { JsonHubProtocol } from "./JsonHubProtocol"; import { ConsoleLogger, LoggerFactory, NullLogger } from "./Loggers"; import { Observable, Subject } from "./Observable"; import { TextMessageFormat } from "./TextMessageFormat"; -import { TransferFormat, TransportType } from "./Transports"; export { JsonHubProtocol }; diff --git a/clients/ts/signalr/src/IConnection.ts b/clients/ts/signalr/src/IConnection.ts index d7e68c9f03..f462d0ed6e 100644 --- a/clients/ts/signalr/src/IConnection.ts +++ b/clients/ts/signalr/src/IConnection.ts @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { ConnectionClosed, DataReceived } from "./Common"; -import { ITransport, TransferFormat, TransportType } from "./Transports"; +import { TransferFormat } from "./ITransport"; export interface IConnection { readonly features: any; diff --git a/clients/ts/signalr/src/IHubProtocol.ts b/clients/ts/signalr/src/IHubProtocol.ts index 6a380c62c3..f98609f976 100644 --- a/clients/ts/signalr/src/IHubProtocol.ts +++ b/clients/ts/signalr/src/IHubProtocol.ts @@ -1,5 +1,5 @@ import { ILogger } from "./ILogger"; -import { TransferFormat } from "./Transports"; +import { TransferFormat } from "./ITransport"; // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. diff --git a/clients/ts/signalr/src/ITransport.ts b/clients/ts/signalr/src/ITransport.ts new file mode 100644 index 0000000000..157f28bbee --- /dev/null +++ b/clients/ts/signalr/src/ITransport.ts @@ -0,0 +1,24 @@ +import { DataReceived, TransportClosed } from "./Common"; +import { IConnection } from "./IConnection"; + +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +export enum HttpTransportType { + WebSockets, + ServerSentEvents, + LongPolling, +} + +export enum TransferFormat { + Text = 1, + Binary, +} + +export interface ITransport { + connect(url: string, transferFormat: TransferFormat): Promise; + send(data: any): Promise; + stop(): Promise; + onreceive: DataReceived; + onclose: TransportClosed; +} diff --git a/clients/ts/signalr/src/JsonHubProtocol.ts b/clients/ts/signalr/src/JsonHubProtocol.ts index b021b3a4b2..967964fdee 100644 --- a/clients/ts/signalr/src/JsonHubProtocol.ts +++ b/clients/ts/signalr/src/JsonHubProtocol.ts @@ -3,9 +3,9 @@ import { CloseMessage, CompletionMessage, HubMessage, IHubProtocol, InvocationMessage, MessageType, PingMessage, StreamItemMessage } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; +import { TransferFormat } from "./ITransport"; import { NullLogger } from "./Loggers"; import { TextMessageFormat } from "./TextMessageFormat"; -import { TransferFormat } from "./Transports"; export const JSON_HUB_PROTOCOL_NAME: string = "json"; diff --git a/clients/ts/signalr/src/LongPollingTransport.ts b/clients/ts/signalr/src/LongPollingTransport.ts new file mode 100644 index 0000000000..5f3e453277 --- /dev/null +++ b/clients/ts/signalr/src/LongPollingTransport.ts @@ -0,0 +1,169 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +import { AbortController } from "./AbortController"; +import { DataReceived, TransportClosed } from "./Common"; +import { HttpError, TimeoutError } from "./Errors"; +import { HttpClient, HttpRequest } from "./HttpClient"; +import { ILogger, LogLevel } from "./ILogger"; +import { ITransport, TransferFormat } from "./ITransport"; +import { Arg, getDataDetail, sendMessage } from "./Utils"; + +const SHUTDOWN_TIMEOUT = 5 * 1000; + +export class LongPollingTransport implements ITransport { + private readonly httpClient: HttpClient; + private readonly accessTokenFactory: () => string | Promise; + private readonly logger: ILogger; + private readonly logMessageContent: boolean; + + private url: string; + private pollXhr: XMLHttpRequest; + private pollAbort: AbortController; + private shutdownTimeout: number; + private running: boolean; + + constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { + this.httpClient = httpClient; + this.accessTokenFactory = accessTokenFactory || (() => null); + this.logger = logger; + this.pollAbort = new AbortController(); + this.logMessageContent = logMessageContent; + } + + public connect(url: string, transferFormat: TransferFormat): Promise { + Arg.isRequired(url, "url"); + Arg.isRequired(transferFormat, "transferFormat"); + Arg.isIn(transferFormat, TransferFormat, "transferFormat"); + + this.url = url; + + this.logger.log(LogLevel.Trace, "(LongPolling transport) Connecting"); + + if (transferFormat === TransferFormat.Binary && (typeof new XMLHttpRequest().responseType !== "string")) { + // This will work if we fix: https://github.com/aspnet/SignalR/issues/742 + throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported."); + } + + this.poll(this.url, transferFormat); + return Promise.resolve(); + } + + private async poll(url: string, transferFormat: TransferFormat): Promise { + this.running = true; + + const pollOptions: HttpRequest = { + abortSignal: this.pollAbort.signal, + headers: {}, + timeout: 90000, + }; + + if (transferFormat === TransferFormat.Binary) { + pollOptions.responseType = "arraybuffer"; + } + + let closeError: Error; + try { + while (this.running) { + // We have to get the access token on each poll, in case it changes + const token = await this.accessTokenFactory(); + if (token) { + // tslint:disable-next-line:no-string-literal + pollOptions.headers["Authorization"] = `Bearer ${token}`; + } + + try { + const pollUrl = `${url}&_=${Date.now()}`; + this.logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}`); + const response = await this.httpClient.get(pollUrl, pollOptions); + + if (response.statusCode === 204) { + this.logger.log(LogLevel.Information, "(LongPolling transport) Poll terminated by server"); + + // If we were on a timeout waiting for shutdown, unregister it. + clearTimeout(this.shutdownTimeout); + + this.running = false; + } else if (response.statusCode !== 200) { + this.logger.log(LogLevel.Error, `(LongPolling transport) Unexpected response code: ${response.statusCode}`); + + // Unexpected status code + closeError = new HttpError(response.statusText, response.statusCode); + this.running = false; + } else { + // Process the response + if (response.content) { + this.logger.log(LogLevel.Trace, `(LongPolling transport) data received. ${getDataDetail(response.content, this.logMessageContent)}`); + if (this.onreceive) { + this.onreceive(response.content); + } + } else { + // This is another way timeout manifest. + this.logger.log(LogLevel.Trace, "(LongPolling transport) Poll timed out, reissuing."); + } + } + } catch (e) { + if (!this.running) { + // Log but disregard errors that occur after we were stopped by DELETE + this.logger.log(LogLevel.Trace, `(LongPolling transport) Poll errored after shutdown: ${e.message}`); + } else { + if (e instanceof TimeoutError) { + // Ignore timeouts and reissue the poll. + this.logger.log(LogLevel.Trace, "(LongPolling transport) Poll timed out, reissuing."); + } else { + // Close the connection with the error as the result. + closeError = e; + this.running = false; + } + } + } + } + } finally { + // Fire our onclosed event + if (this.onclose) { + this.logger.log(LogLevel.Trace, `(LongPolling transport) Firing onclose event. Error: ${closeError || ""}`); + this.onclose(closeError); + } + + this.logger.log(LogLevel.Trace, "(LongPolling transport) Transport finished."); + } + } + + public async send(data: any): Promise { + if (!this.running) { + return Promise.reject(new Error("Cannot send until the transport is connected")); + } + return sendMessage(this.logger, "LongPolling", this.httpClient, this.url, this.accessTokenFactory, data, this.logMessageContent); + } + + public async stop(): Promise { + // Send a DELETE request to stop the poll + try { + this.running = false; + this.logger.log(LogLevel.Trace, `(LongPolling transport) sending DELETE request to ${this.url}.`); + + const deleteOptions: HttpRequest = {}; + const token = await this.accessTokenFactory(); + if (token) { + // tslint:disable-next-line:no-string-literal + deleteOptions.headers = { + ["Authorization"]: `Bearer ${token}`, + }; + } + const response = await this.httpClient.delete(this.url, deleteOptions); + + this.logger.log(LogLevel.Trace, "(LongPolling transport) DELETE request accepted."); + } finally { + // Abort the poll after 5 seconds if the server doesn't stop it. + if (!this.pollAbort.aborted) { + this.shutdownTimeout = setTimeout(SHUTDOWN_TIMEOUT, () => { + this.logger.log(LogLevel.Warning, "(LongPolling transport) server did not terminate within 5 seconds after DELETE request, cancelling poll."); + this.pollAbort.abort(); + }); + } + } + } + + public onreceive: DataReceived; + public onclose: TransportClosed; +} diff --git a/clients/ts/signalr/src/ServerSentEventsTransport.ts b/clients/ts/signalr/src/ServerSentEventsTransport.ts new file mode 100644 index 0000000000..c089c50729 --- /dev/null +++ b/clients/ts/signalr/src/ServerSentEventsTransport.ts @@ -0,0 +1,111 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +import { DataReceived, TransportClosed } from "./Common"; +import { HttpClient } from "./HttpClient"; +import { ILogger, LogLevel } from "./ILogger"; +import { ITransport, TransferFormat } from "./ITransport"; +import { Arg, getDataDetail, sendMessage } from "./Utils"; + +export class ServerSentEventsTransport implements ITransport { + private readonly httpClient: HttpClient; + private readonly accessTokenFactory: () => string | Promise; + private readonly logger: ILogger; + private readonly logMessageContent: boolean; + private eventSource: EventSource; + private url: string; + + constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { + this.httpClient = httpClient; + this.accessTokenFactory = accessTokenFactory || (() => null); + this.logger = logger; + this.logMessageContent = logMessageContent; + } + + public async connect(url: string, transferFormat: TransferFormat): Promise { + Arg.isRequired(url, "url"); + Arg.isRequired(transferFormat, "transferFormat"); + Arg.isIn(transferFormat, TransferFormat, "transferFormat"); + + if (typeof (EventSource) === "undefined") { + throw new Error("'EventSource' is not supported in your environment."); + } + + this.logger.log(LogLevel.Trace, "(SSE transport) Connecting"); + + const token = await this.accessTokenFactory(); + if (token) { + url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; + } + + this.url = url; + return new Promise((resolve, reject) => { + let opened = false; + if (transferFormat !== TransferFormat.Text) { + reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format")); + } + + const eventSource = new EventSource(url, { withCredentials: true }); + + try { + eventSource.onmessage = (e: MessageEvent) => { + if (this.onreceive) { + try { + this.logger.log(LogLevel.Trace, `(SSE transport) data received. ${getDataDetail(e.data, this.logMessageContent)}.`); + this.onreceive(e.data); + } catch (error) { + if (this.onclose) { + this.onclose(error); + } + return; + } + } + }; + + eventSource.onerror = (e: any) => { + const error = new Error(e.message || "Error occurred"); + if (opened) { + this.close(error); + } else { + reject(error); + } + }; + + eventSource.onopen = () => { + this.logger.log(LogLevel.Information, `SSE connected to ${this.url}`); + this.eventSource = eventSource; + opened = true; + resolve(); + }; + } catch (e) { + return Promise.reject(e); + } + }); + } + + public async send(data: any): Promise { + if (!this.eventSource) { + return Promise.reject(new Error("Cannot send until the transport is connected")); + } + return sendMessage(this.logger, "SSE", this.httpClient, this.url, this.accessTokenFactory, data, this.logMessageContent); + } + + public stop(): Promise { + this.close(); + return Promise.resolve(); + } + + private close(e?: Error) { + if (this.eventSource) { + this.eventSource.close(); + this.eventSource = null; + + if (this.onclose) { + this.onclose(e); + } + } + } + + public onreceive: DataReceived; + public onclose: TransportClosed; +} diff --git a/clients/ts/signalr/src/Transports.ts b/clients/ts/signalr/src/Transports.ts deleted file mode 100644 index 5823ad9369..0000000000 --- a/clients/ts/signalr/src/Transports.ts +++ /dev/null @@ -1,369 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -import { AbortController } from "./AbortController"; -import { DataReceived, TransportClosed } from "./Common"; -import { HttpError, TimeoutError } from "./Errors"; -import { HttpClient, HttpRequest } from "./HttpClient"; -import { ILogger, LogLevel } from "./ILogger"; -import { Arg } from "./Utils"; - -export enum TransportType { - WebSockets, - ServerSentEvents, - LongPolling, -} - -export enum TransferFormat { - Text = 1, - Binary, -} - -export interface ITransport { - connect(url: string, transferFormat: TransferFormat): Promise; - send(data: any): Promise; - stop(): Promise; - onreceive: DataReceived; - onclose: TransportClosed; -} - -export class WebSocketTransport implements ITransport { - private readonly logger: ILogger; - private readonly accessTokenFactory: () => string | Promise; - private readonly logMessageContent: boolean; - private webSocket: WebSocket; - - constructor(accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { - this.logger = logger; - this.accessTokenFactory = accessTokenFactory || (() => null); - this.logMessageContent = logMessageContent; - } - - public async connect(url: string, transferFormat: TransferFormat): Promise { - Arg.isRequired(url, "url"); - Arg.isRequired(transferFormat, "transferFormat"); - Arg.isIn(transferFormat, TransferFormat, "transferFormat"); - - if (typeof (WebSocket) === "undefined") { - throw new Error("'WebSocket' is not supported in your environment."); - } - - this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting"); - - const token = await this.accessTokenFactory(); - return new Promise((resolve, reject) => { - url = url.replace(/^http/, "ws"); - if (token) { - url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; - } - - const webSocket = new WebSocket(url); - if (transferFormat === TransferFormat.Binary) { - webSocket.binaryType = "arraybuffer"; - } - - webSocket.onopen = (event: Event) => { - this.logger.log(LogLevel.Information, `WebSocket connected to ${url}`); - this.webSocket = webSocket; - resolve(); - }; - - webSocket.onerror = (event: ErrorEvent) => { - reject(event.error); - }; - - webSocket.onmessage = (message: MessageEvent) => { - this.logger.log(LogLevel.Trace, `(WebSockets transport) data received. ${getDataDetail(message.data, this.logMessageContent)}.`); - if (this.onreceive) { - this.onreceive(message.data); - } - }; - - webSocket.onclose = (event: CloseEvent) => { - // webSocket will be null if the transport did not start successfully - if (this.onclose && this.webSocket) { - if (event.wasClean === false || event.code !== 1000) { - this.onclose(new Error(`Websocket closed with status code: ${event.code} (${event.reason})`)); - } else { - this.onclose(); - } - } - }; - }); - } - - public send(data: any): Promise { - if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) { - this.logger.log(LogLevel.Trace, `(WebSockets transport) sending data. ${getDataDetail(data, this.logMessageContent)}.`); - this.webSocket.send(data); - return Promise.resolve(); - } - - return Promise.reject("WebSocket is not in the OPEN state"); - } - - public stop(): Promise { - if (this.webSocket) { - this.webSocket.close(); - this.webSocket = null; - } - return Promise.resolve(); - } - - public onreceive: DataReceived; - public onclose: TransportClosed; -} - -export class ServerSentEventsTransport implements ITransport { - private readonly httpClient: HttpClient; - private readonly accessTokenFactory: () => string | Promise; - private readonly logger: ILogger; - private readonly logMessageContent: boolean; - private eventSource: EventSource; - private url: string; - - constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { - this.httpClient = httpClient; - this.accessTokenFactory = accessTokenFactory || (() => null); - this.logger = logger; - this.logMessageContent = logMessageContent; - } - - public async connect(url: string, transferFormat: TransferFormat): Promise { - Arg.isRequired(url, "url"); - Arg.isRequired(transferFormat, "transferFormat"); - Arg.isIn(transferFormat, TransferFormat, "transferFormat"); - - if (typeof (EventSource) === "undefined") { - throw new Error("'EventSource' is not supported in your environment."); - } - - this.logger.log(LogLevel.Trace, "(SSE transport) Connecting"); - - this.url = url; - const token = await this.accessTokenFactory(); - return new Promise((resolve, reject) => { - if (transferFormat !== TransferFormat.Text) { - reject(new Error("The Server-Sent Events transport only supports the 'Text' transfer format")); - } - - if (token) { - url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; - } - - const eventSource = new EventSource(url, { withCredentials: true }); - - try { - eventSource.onmessage = (e: MessageEvent) => { - if (this.onreceive) { - try { - this.logger.log(LogLevel.Trace, `(SSE transport) data received. ${getDataDetail(e.data, this.logMessageContent)}.`); - this.onreceive(e.data); - } catch (error) { - if (this.onclose) { - this.onclose(error); - } - return; - } - } - }; - - eventSource.onerror = (e: any) => { - reject(new Error(e.message || "Error occurred")); - - // don't report an error if the transport did not start successfully - if (this.eventSource && this.onclose) { - this.onclose(new Error(e.message || "Error occurred")); - } - }; - - eventSource.onopen = () => { - this.logger.log(LogLevel.Information, `SSE connected to ${this.url}`); - this.eventSource = eventSource; - // SSE is a text protocol - resolve(); - }; - } catch (e) { - return Promise.reject(e); - } - }); - } - - public async send(data: any): Promise { - return send(this.logger, "SSE", this.httpClient, this.url, this.accessTokenFactory, data, this.logMessageContent); - } - - public stop(): Promise { - if (this.eventSource) { - this.eventSource.close(); - this.eventSource = null; - } - return Promise.resolve(); - } - - public onreceive: DataReceived; - public onclose: TransportClosed; -} - -export class LongPollingTransport implements ITransport { - private readonly httpClient: HttpClient; - private readonly accessTokenFactory: () => string | Promise; - private readonly logger: ILogger; - private readonly logMessageContent: boolean; - - private url: string; - private pollXhr: XMLHttpRequest; - private pollAbort: AbortController; - - constructor(httpClient: HttpClient, accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { - this.httpClient = httpClient; - this.accessTokenFactory = accessTokenFactory || (() => null); - this.logger = logger; - this.pollAbort = new AbortController(); - this.logMessageContent = logMessageContent; - } - - public connect(url: string, transferFormat: TransferFormat): Promise { - Arg.isRequired(url, "url"); - Arg.isRequired(transferFormat, "transferFormat"); - Arg.isIn(transferFormat, TransferFormat, "transferFormat"); - - this.url = url; - - this.logger.log(LogLevel.Trace, "(LongPolling transport) Connecting"); - - if (transferFormat === TransferFormat.Binary && (typeof new XMLHttpRequest().responseType !== "string")) { - // This will work if we fix: https://github.com/aspnet/SignalR/issues/742 - throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported."); - } - - this.poll(this.url, transferFormat); - return Promise.resolve(); - } - - private async poll(url: string, transferFormat: TransferFormat): Promise { - const pollOptions: HttpRequest = { - abortSignal: this.pollAbort.signal, - headers: {}, - timeout: 90000, - }; - - if (transferFormat === TransferFormat.Binary) { - pollOptions.responseType = "arraybuffer"; - } - - const token = await this.accessTokenFactory(); - if (token) { - // tslint:disable-next-line:no-string-literal - pollOptions.headers["Authorization"] = `Bearer ${token}`; - } - - while (!this.pollAbort.signal.aborted) { - try { - const pollUrl = `${url}&_=${Date.now()}`; - this.logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}`); - const response = await this.httpClient.get(pollUrl, pollOptions); - if (response.statusCode === 204) { - this.logger.log(LogLevel.Information, "(LongPolling transport) Poll terminated by server"); - - // Poll terminated by server - if (this.onclose) { - this.onclose(); - } - this.pollAbort.abort(); - } else if (response.statusCode !== 200) { - this.logger.log(LogLevel.Error, `(LongPolling transport) Unexpected response code: ${response.statusCode}`); - - // Unexpected status code - if (this.onclose) { - this.onclose(new HttpError(response.statusText, response.statusCode)); - } - this.pollAbort.abort(); - } else { - // Process the response - if (response.content) { - this.logger.log(LogLevel.Trace, `(LongPolling transport) data received. ${getDataDetail(response.content, this.logMessageContent)}`); - if (this.onreceive) { - this.onreceive(response.content); - } - } else { - // This is another way timeout manifest. - this.logger.log(LogLevel.Trace, "(LongPolling transport) Poll timed out, reissuing."); - } - } - } catch (e) { - if (e instanceof TimeoutError) { - // Ignore timeouts and reissue the poll. - this.logger.log(LogLevel.Trace, "(LongPolling transport) Poll timed out, reissuing."); - } else { - // Close the connection with the error as the result. - if (this.onclose) { - this.onclose(e); - } - this.pollAbort.abort(); - } - } - } - } - - public async send(data: any): Promise { - return send(this.logger, "LongPolling", this.httpClient, this.url, this.accessTokenFactory, data, this.logMessageContent); - } - - public stop(): Promise { - this.pollAbort.abort(); - return Promise.resolve(); - } - - public onreceive: DataReceived; - public onclose: TransportClosed; -} - -function getDataDetail(data: any, includeContent: boolean): string { - let length: string = null; - if (data instanceof ArrayBuffer) { - length = `Binary data of length ${data.byteLength}.`; - if (includeContent) { - length += ` Content: '${formatArrayBuffer(data)}'.`; - } - } else if (typeof data === "string") { - length = `String data of length ${data.length}.`; - if (includeContent) { - length += ` Content: '${data}'.`; - } - } - return length; -} - -function formatArrayBuffer(data: ArrayBuffer): string { - const view = new Uint8Array(data); - - // Uint8Array.map only supports returning another Uint8Array? - let str = ""; - view.forEach((num) => { - const pad = num < 16 ? "0" : ""; - str += `0x${pad}${num.toString(16)} `; - }); - - // Trim of trailing space. - return str.substr(0, str.length - 1); -} - -async function send(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise, content: string | ArrayBuffer, logMessageContent: boolean): Promise { - let headers; - const token = await accessTokenFactory(); - if (token) { - headers = { - ["Authorization"]: `Bearer ${token}`, - }; - } - - logger.log(LogLevel.Trace, `(${transportName} transport) sending data. ${getDataDetail(content, logMessageContent)}.`); - - const response = await httpClient.post(url, { - content, - headers, - }); - - logger.log(LogLevel.Trace, `(${transportName} transport) request complete. Response status: ${response.statusCode}.`); -} diff --git a/clients/ts/signalr/src/Utils.ts b/clients/ts/signalr/src/Utils.ts index 72d0ad875a..a02dc05190 100644 --- a/clients/ts/signalr/src/Utils.ts +++ b/clients/ts/signalr/src/Utils.ts @@ -1,6 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +import { HttpClient } from "./HttpClient"; +import { ILogger, LogLevel } from "./ILogger"; + export class Arg { public static isRequired(val: any, name: string): void { if (val === null || val === undefined) { @@ -15,3 +18,52 @@ export class Arg { } } } + +export function getDataDetail(data: any, includeContent: boolean): string { + let length: string = null; + if (data instanceof ArrayBuffer) { + length = `Binary data of length ${data.byteLength}`; + if (includeContent) { + length += `. Content: '${formatArrayBuffer(data)}'`; + } + } else if (typeof data === "string") { + length = `String data of length ${data.length}`; + if (includeContent) { + length += `. Content: '${data}'.`; + } + } + return length; +} + +export function formatArrayBuffer(data: ArrayBuffer): string { + const view = new Uint8Array(data); + + // Uint8Array.map only supports returning another Uint8Array? + let str = ""; + view.forEach((num) => { + const pad = num < 16 ? "0" : ""; + str += `0x${pad}${num.toString(16)} `; + }); + + // Trim of trailing space. + return str.substr(0, str.length - 1); +} + +export async function sendMessage(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: () => string | Promise, content: string | ArrayBuffer, logMessageContent: boolean): Promise { + let headers; + const token = await accessTokenFactory(); + if (token) { + headers = { + ["Authorization"]: `Bearer ${token}`, + }; + } + + logger.log(LogLevel.Trace, `(${transportName} transport) sending data. ${getDataDetail(content, logMessageContent)}.`); + + const response = await httpClient.post(url, { + content, + headers, + }); + + logger.log(LogLevel.Trace, `(${transportName} transport) request complete. Response status: ${response.statusCode}.`); +} diff --git a/clients/ts/signalr/src/WebSocketTransport.ts b/clients/ts/signalr/src/WebSocketTransport.ts new file mode 100644 index 0000000000..6e549cfff3 --- /dev/null +++ b/clients/ts/signalr/src/WebSocketTransport.ts @@ -0,0 +1,95 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +import { DataReceived, TransportClosed } from "./Common"; +import { ILogger, LogLevel } from "./ILogger"; +import { ITransport, TransferFormat } from "./ITransport"; +import { Arg, getDataDetail } from "./Utils"; + +export class WebSocketTransport implements ITransport { + private readonly logger: ILogger; + private readonly accessTokenFactory: () => string | Promise; + private readonly logMessageContent: boolean; + private webSocket: WebSocket; + + constructor(accessTokenFactory: () => string | Promise, logger: ILogger, logMessageContent: boolean) { + this.logger = logger; + this.accessTokenFactory = accessTokenFactory || (() => null); + this.logMessageContent = logMessageContent; + } + + public async connect(url: string, transferFormat: TransferFormat): Promise { + Arg.isRequired(url, "url"); + Arg.isRequired(transferFormat, "transferFormat"); + Arg.isIn(transferFormat, TransferFormat, "transferFormat"); + + if (typeof (WebSocket) === "undefined") { + throw new Error("'WebSocket' is not supported in your environment."); + } + + this.logger.log(LogLevel.Trace, "(WebSockets transport) Connecting"); + + const token = await this.accessTokenFactory(); + if (token) { + url += (url.indexOf("?") < 0 ? "?" : "&") + `access_token=${encodeURIComponent(token)}`; + } + + return new Promise((resolve, reject) => { + url = url.replace(/^http/, "ws"); + const webSocket = new WebSocket(url); + if (transferFormat === TransferFormat.Binary) { + webSocket.binaryType = "arraybuffer"; + } + + webSocket.onopen = (event: Event) => { + this.logger.log(LogLevel.Information, `WebSocket connected to ${url}`); + this.webSocket = webSocket; + resolve(); + }; + + webSocket.onerror = (event: ErrorEvent) => { + reject(event.error); + }; + + webSocket.onmessage = (message: MessageEvent) => { + this.logger.log(LogLevel.Trace, `(WebSockets transport) data received. ${getDataDetail(message.data, this.logMessageContent)}.`); + if (this.onreceive) { + this.onreceive(message.data); + } + }; + + webSocket.onclose = (event: CloseEvent) => { + // webSocket will be null if the transport did not start successfully + this.logger.log(LogLevel.Trace, "(WebSockets transport) socket closed."); + if (this.onclose) { + if (event.wasClean === false || event.code !== 1000) { + this.onclose(new Error(`Websocket closed with status code: ${event.code} (${event.reason})`)); + } else { + this.onclose(); + } + } + }; + }); + } + + public send(data: any): Promise { + if (this.webSocket && this.webSocket.readyState === WebSocket.OPEN) { + this.logger.log(LogLevel.Trace, `(WebSockets transport) sending data. ${getDataDetail(data, this.logMessageContent)}.`); + this.webSocket.send(data); + return Promise.resolve(); + } + + return Promise.reject("WebSocket is not in the OPEN state"); + } + + public stop(): Promise { + if (this.webSocket) { + this.webSocket.close(); + this.webSocket = null; + } + return Promise.resolve(); + } + + public onreceive: DataReceived; + public onclose: TransportClosed; +} diff --git a/clients/ts/signalr/src/index.ts b/clients/ts/signalr/src/index.ts index 392dd08a62..724815fa4d 100644 --- a/clients/ts/signalr/src/index.ts +++ b/clients/ts/signalr/src/index.ts @@ -11,5 +11,5 @@ export * from "./IConnection"; export * from "./IHubProtocol"; export * from "./ILogger"; export * from "./Loggers"; -export * from "./Transports"; +export * from "./ITransport"; export * from "./Observable"; diff --git a/specs/TransportProtocols.md b/specs/TransportProtocols.md index db7ffbec2b..91e1f81a2a 100644 --- a/specs/TransportProtocols.md +++ b/specs/TransportProtocols.md @@ -96,6 +96,8 @@ In this transport, the client establishes an SSE connection to `[endpoint-base]` The Server-Sent Events transport only supports text data, because it is a text-based protocol. As a result, it is reported by the server as supporting only the `Text` transfer format. If a client wishes to send arbitrary binary data, it should skip the Server-Sent Events transport when selecting an appropriate transport. +When the client has finished with the connection, it can terminate the event stream connection (send a TCP reset). The server will clean up the necessary resources. + ## Long Polling (Server-to-Client only) Long Polling is a server-to-client half-transport, so it is always paired with HTTP Post. It requires a connection already be established using the `POST [endpoint-base]/negotiate` request. @@ -109,3 +111,5 @@ A Poll is established by sending an HTTP GET request to `[endpoint-base]` with t When data is available, the server responds with a body in one of the two formats below (depending upon the value of the `Accept` header). The response may be chunked, as per the chunked encoding part of the HTTP spec. If the `connectionId` parameter is missing, a `400 Bad Request` response is returned. If there is no connection with the ID specified in `connectionId`, a `404 Not Found` response is returned. + +When the client has finished with the connection, it can issue a `DELETE` request to `[endpoint-base]` (with the `connectionId` in the querystring) to gracefully terminate the connection. The server will complete the latest poll with `204` to indicate that it has shut down. diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.Log.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.Log.cs index 3d71c82c2d..ec407297e5 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.Log.cs @@ -44,6 +44,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal LoggerMessage.Define(LogLevel.Trace, new EventId(10, "PollResponseReceived"), "Poll response with status code {StatusCode} received from server. Content length: {ContentLength}."); + private static readonly Action _sendingDeleteRequest = + LoggerMessage.Define(LogLevel.Debug, new EventId(11, "SendingDeleteRequest"), "Sending DELETE request to '{PollUrl}'."); + + private static readonly Action _deleteRequestAccepted = + LoggerMessage.Define(LogLevel.Debug, new EventId(12, "DeleteRequestAccepted"), "DELETE request to '{PollUrl}' accepted."); + + private static readonly Action _errorSendingDeleteRequest = + LoggerMessage.Define(LogLevel.Error, new EventId(13, "ErrorSendingDeleteRequest"), "Error sending DELETE request to '{PollUrl}'."); + // EventIds 100 - 106 used in SendUtils public static void StartTransport(ILogger logger, TransferFormat transferFormat) @@ -99,6 +108,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal response.Content.Headers.ContentLength ?? -1, null); } } + + public static void SendingDeleteRequest(ILogger logger, Uri pollUrl) + { + _sendingDeleteRequest(logger, pollUrl, null); + } + + public static void DeleteRequestAccepted(ILogger logger, Uri pollUrl) + { + _deleteRequestAccepted(logger, pollUrl, null); + } + + public static void ErrorSendingDeleteRequest(ILogger logger, Uri pollUrl, Exception ex) + { + _errorSendingDeleteRequest(logger, pollUrl, ex); + } } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs index cd68844808..3678910d04 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs @@ -8,8 +8,6 @@ using System.Net.Http; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; -using Microsoft.AspNetCore.Connections.Features; -using Microsoft.AspNetCore.Http.Connections.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -17,6 +15,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { public partial class LongPollingTransport : ITransport { + private static readonly TimeSpan DefaultShutdownTimeout = TimeSpan.FromSeconds(5); + private readonly HttpClient _httpClient; private readonly ILogger _logger; private IDuplexPipe _application; @@ -32,6 +32,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal public PipeWriter Output => _transport.Output; + internal TimeSpan ShutdownTimeout { get; set; } + public LongPollingTransport(HttpClient httpClient) : this(httpClient, null) { } @@ -40,6 +42,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { _httpClient = httpClient; _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); + ShutdownTimeout = DefaultShutdownTimeout; } public Task StartAsync(Uri url, TransferFormat transferFormat) @@ -74,6 +77,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal if (trigger == receiving) { + // We don't need to DELETE here because the poll completed, which means the server shut down already. + // We're waiting for the application to finish and there are 2 things it could be doing // 1. Waiting for application data // 2. Waiting for an outgoing send (this should be instantaneous) @@ -86,7 +91,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal // Set the sending error so we communicate that to the application _error = sending.IsFaulted ? sending.Exception.InnerException : null; - _transportCts.Cancel(); + // Send the DELETE request to clean-up the connection on the server. + // This will also cause the poll to return. + await SendDeleteRequest(url); + + // This timeout is only to ensure the poll is cleaned up despite a misbehaving server. + // It doesn't need to be configurable. + _transportCts.CancelAfter(ShutdownTimeout); // Cancel any pending flush so that we can quit _application.Output.CancelPendingFlush(); @@ -97,9 +108,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); - _transport.Output.Complete(); - _transport.Input.Complete(); - _application.Input.CancelPendingRead(); try @@ -112,6 +120,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal throw; } + _transport.Output.Complete(); + _transport.Input.Complete(); + Log.TransportStopped(_logger, null); } @@ -187,5 +198,20 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.ReceiveStopped(_logger); } } + + private async Task SendDeleteRequest(Uri pollUrl) + { + try + { + Log.SendingDeleteRequest(_logger, pollUrl); + var response = await _httpClient.DeleteAsync(pollUrl); + response.EnsureSuccessStatusCode(); + Log.DeleteRequestAccepted(_logger, pollUrl); + } + catch (Exception ex) + { + Log.ErrorSendingDeleteRequest(_logger, pollUrl, ex); + } + } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs index 1f5c7ba61b..00d78bee4e 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.Log.cs @@ -40,6 +40,12 @@ namespace Microsoft.AspNetCore.Http.Connections private static readonly Action _negotiationRequest = LoggerMessage.Define(LogLevel.Debug, new EventId(10, "NegotiationRequest"), "Sending negotiation response."); + private static readonly Action _receivedDeleteRequestForUnsupportedTransport = + LoggerMessage.Define(LogLevel.Trace, new EventId(11, "ReceivedDeleteRequestForUnsupportedTransport"), "Received DELETE request for unsupported transport: {TransportType}."); + + private static readonly Action _terminatingConnection = + LoggerMessage.Define(LogLevel.Trace, new EventId(12, "TerminatingConection"), "Terminating Long Polling connection due to a DELETE request."); + public static void ConnectionDisposed(ILogger logger, string connectionId) { _connectionDisposed(logger, connectionId, null); @@ -89,6 +95,16 @@ namespace Microsoft.AspNetCore.Http.Connections { _negotiationRequest(logger, null); } + + public static void ReceivedDeleteRequestForUnsupportedTransport(ILogger logger, HttpTransportType transportType) + { + _receivedDeleteRequestForUnsupportedTransport(logger, transportType, null); + } + + public static void TerminatingConection(ILogger logger) + { + _terminatingConnection(logger, null); + } } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs index 35e2152706..1b0d1d501d 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/HttpConnectionDispatcher.cs @@ -78,6 +78,11 @@ namespace Microsoft.AspNetCore.Http.Connections // GET /{path} await ExecuteAsync(context, connectionDelegate, options, logScope); } + else if (HttpMethods.IsDelete(context.Request.Method)) + { + // DELETE /{path} + await ProcessDeleteAsync(context); + } else { context.Response.ContentType = "text/plain"; @@ -121,7 +126,7 @@ namespace Microsoft.AspNetCore.Http.Connections if (headers.Accept?.Contains(new Net.Http.Headers.MediaTypeHeaderValue("text/event-stream")) == true) { // Connection must already exist - var connection = await GetConnectionAsync(context, options); + var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -171,7 +176,7 @@ namespace Microsoft.AspNetCore.Http.Connections // GET /{path} maps to long polling // Connection must already exist - var connection = await GetConnectionAsync(context, options); + var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -240,7 +245,7 @@ namespace Microsoft.AspNetCore.Http.Connections context.Response.RegisterForDispose(timeoutSource); context.Response.RegisterForDispose(tokenSource); - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, connection.ConnectionId, _loggerFactory); + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); // Start the transport connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); @@ -439,7 +444,7 @@ namespace Microsoft.AspNetCore.Http.Connections private async Task ProcessSend(HttpContext context, HttpConnectionOptions options) { - var connection = await GetConnectionAsync(context, options); + var connection = await GetConnectionAsync(context); if (connection == null) { // No such connection, GetConnection already set the response status code @@ -487,6 +492,36 @@ namespace Microsoft.AspNetCore.Http.Connections } } + private async Task ProcessDeleteAsync(HttpContext context) + { + var connection = await GetConnectionAsync(context); + if (connection == null) + { + // No such connection, GetConnection already set the response status code + return; + } + + // This end point only works for long polling + if (connection.TransportType != HttpTransportType.LongPolling) + { + Log.ReceivedDeleteRequestForUnsupportedTransport(_logger, connection.TransportType); + context.Response.StatusCode = StatusCodes.Status400BadRequest; + context.Response.ContentType = "text/plain"; + await context.Response.WriteAsync("Cannot terminate this connection using the DELETE endpoint."); + return; + } + + Log.TerminatingConection(_logger); + + // Complete the receiving end of the pipe + connection.Application.Output.Complete(); + + // Dispose the connection gracefully, but don't wait for it + _ = _manager.DisposeAndRemoveAsync(connection, closeGracefully: true); + context.Response.StatusCode = StatusCodes.Status202Accepted; + context.Response.ContentType = "text/plain"; + } + private async Task EnsureConnectionStateAsync(HttpConnectionContext connection, HttpContext context, HttpTransportType transportType, HttpTransportType supportedTransports, ConnectionLogScope logScope, HttpConnectionOptions options) { if ((supportedTransports & transportType) == 0) @@ -610,7 +645,7 @@ namespace Microsoft.AspNetCore.Http.Connections return newHttpContext; } - private async Task GetConnectionAsync(HttpContext context, HttpConnectionOptions options) + private async Task GetConnectionAsync(HttpContext context) { var connectionId = GetConnectionId(context); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/LongPollingTransport.cs index 9ee5be2deb..0597c692e4 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/LongPollingTransport.cs @@ -16,13 +16,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports private readonly PipeReader _application; private readonly ILogger _logger; private readonly CancellationToken _timeoutToken; - private readonly string _connectionId; - public LongPollingTransport(CancellationToken timeoutToken, PipeReader application, string connectionId, ILoggerFactory loggerFactory) + public LongPollingTransport(CancellationToken timeoutToken, PipeReader application, ILoggerFactory loggerFactory) { _timeoutToken = timeoutToken; _application = application; - _connectionId = connectionId; _logger = loggerFactory.CreateLogger(); } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index ed06e7c19c..907ab5d225 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -1588,6 +1588,140 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } + [Theory] + [InlineData(HttpTransportType.ServerSentEvents)] + [InlineData(HttpTransportType.WebSockets)] + public async Task DeleteEndpointRejectsRequestToTerminateNonLongPollingTransport(HttpTransportType transportType) + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = transportType; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + SetTransport(context, transportType); + + var serviceCollection = new ServiceCollection(); + serviceCollection.AddSingleton(); + var services = serviceCollection.BuildServiceProvider(); + var builder = new ConnectionBuilder(services); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionOptions(); + + _ = dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + // Issue the delete request + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + var ms = new MemoryStream(); + deleteContext.Response.Body = ms; + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status400BadRequest, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + Assert.Equal("Cannot terminate this connection using the DELETE endpoint.", Encoding.UTF8.GetString(ms.ToArray())); + } + } + + [Fact] + public async Task DeleteEndpointGracefullyTerminatesLongPolling() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionOptions(); + + var pollTask = dispatcher.ExecuteAsync(context, options, app); + + // Issue the delete request and make sure the poll completes + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + + Assert.False(pollTask.IsCompleted); + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + await pollTask.OrTimeout(); + + // Verify that everything shuts down + await connection.ApplicationTask.OrTimeout(); + await connection.TransportTask.OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + + // Verify the connection was removed from the manager + Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + } + } + + [Fact] + public async Task DeleteEndpointGracefullyTerminatesLongPollingEvenWhenBetweenPolls() + { + using (StartLog(out var loggerFactory, LogLevel.Debug)) + { + var manager = CreateConnectionManager(loggerFactory); + var connection = manager.CreateConnection(); + connection.TransportType = HttpTransportType.LongPolling; + + var dispatcher = new HttpConnectionDispatcher(manager, loggerFactory); + + var context = MakeRequest("/foo", connection); + + var services = new ServiceCollection(); + services.AddSingleton(); + var builder = new ConnectionBuilder(services.BuildServiceProvider()); + builder.UseConnectionHandler(); + var app = builder.Build(); + var options = new HttpConnectionOptions(); + options.LongPolling.PollTimeout = TimeSpan.FromMilliseconds(1); + + await dispatcher.ExecuteAsync(context, options, app).OrTimeout(); + + // Issue the delete request and make sure the poll completes + var deleteContext = new DefaultHttpContext(); + deleteContext.Request.Path = "/foo"; + deleteContext.Request.QueryString = new QueryString($"?id={connection.ConnectionId}"); + deleteContext.Request.Method = "DELETE"; + + await dispatcher.ExecuteAsync(deleteContext, options, app).OrTimeout(); + + // Verify that everything shuts down + await connection.ApplicationTask.OrTimeout(); + await connection.TransportTask.OrTimeout(); + + // Verify the response from the DELETE request + Assert.Equal(StatusCodes.Status202Accepted, deleteContext.Response.StatusCode); + Assert.Equal("text/plain", deleteContext.Response.ContentType); + + // Verify the connection was removed from the manager + Assert.False(manager.TryGetConnection(connection.ConnectionId, out _)); + } + } + [Fact] public async Task NegotiateDoesNotReturnWebSocketsWhenNotAvailable() { @@ -1747,7 +1881,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } - public class NerverEndingConnectionHandler : ConnectionHandler + public class NeverEndingConnectionHandler : ConnectionHandler { public override Task OnConnectedAsync(ConnectionContext connection) { @@ -1817,8 +1951,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class TestConnectionHandler : ConnectionHandler { + private TaskCompletionSource _startedTcs = new TaskCompletionSource(); + + public Task Started => _startedTcs.Task; + public override async Task OnConnectedAsync(ConnectionContext connection) { + _startedTcs.TrySetResult(null); + while (true) { var result = await connection.Transport.Input.ReadAsync(); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs index 3712313a73..86c74a4752 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/LongPollingTests.cs @@ -24,7 +24,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var context = new DefaultHttpContext(); - var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, loggerFactory: new LoggerFactory()); connection.Transport.Output.Complete(); @@ -41,7 +41,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var context = new DefaultHttpContext(); var timeoutToken = new CancellationToken(true); - var poll = new LongPollingTransport(timeoutToken, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(timeoutToken, connection.Application.Input, loggerFactory: new LoggerFactory()); using (var cts = CancellationTokenSource.CreateLinkedTokenSource(timeoutToken, context.RequestAborted)) { @@ -59,7 +59,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; @@ -79,7 +79,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); var context = new DefaultHttpContext(); - var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + var poll = new LongPollingTransport(CancellationToken.None, connection.Application.Input, loggerFactory: new LoggerFactory()); var ms = new MemoryStream(); context.Response.Body = ms; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 901e3bc1f8..5090acdf76 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -871,6 +871,59 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Fact] + public async Task StopCausesPollToReturnImmediately() + { + using (StartLog(out var loggerFactory)) + { + PollTrackingMessageHandler pollTracker = null; + var hubConnection = new HubConnectionBuilder() + .WithLoggerFactory(loggerFactory) + .WithUrl(_serverFixture.Url + "/default", options => + { + options.Transports = HttpTransportType.LongPolling; + options.HttpMessageHandlerFactory = handler => + { + pollTracker = new PollTrackingMessageHandler(handler); + return pollTracker; + }; + }) + .Build(); + + await hubConnection.StartAsync(); + + Assert.NotNull(pollTracker); + Assert.NotNull(pollTracker.ActivePoll); + + var stopTask = hubConnection.StopAsync(); + + // Stop async and wait for the poll to shut down. It should do so very quickly because the DELETE will stop the poll! + await pollTracker.ActivePoll.OrTimeout(TimeSpan.FromMilliseconds(100)); + + await stopTask; + } + } + + private class PollTrackingMessageHandler : DelegatingHandler + { + public Task ActivePoll { get; private set; } + + public PollTrackingMessageHandler(HttpMessageHandler innerHandler) : base(innerHandler) + { + } + + protected override Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) + { + if (request.Method == HttpMethod.Get) + { + ActivePoll = base.SendAsync(request, cancellationToken); + return ActivePoll; + } + + return base.SendAsync(request, cancellationToken); + } + } + public static IEnumerable HubProtocolsAndTransportsAndHubPaths { get diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index ff1d32893f..5d7128392b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -5,10 +5,12 @@ using System; using System.Buffers; using System.Collections.Generic; using System.IO.Pipelines; +using System.Linq; using System.Net; using System.Net.Http; using System.Net.Http.Headers; using System.Reflection; +using System.Runtime.InteropServices.ComTypes; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -23,6 +25,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public class LongPollingTransportTests { + private static readonly Uri TestUri = new Uri("http://example.com/?id=1234"); + [Fact] public async Task LongPollingTransportStopsPollAndSendLoopsWhenTransportStopped() { @@ -43,7 +47,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); transportActiveTask = longPollingTransport.Running; @@ -76,7 +80,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); await longPollingTransport.Running.OrTimeout(); @@ -129,7 +133,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); var data = await longPollingTransport.Input.ReadAllAsync().OrTimeout(); await longPollingTransport.Running.OrTimeout(); @@ -159,7 +163,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); var exception = await Assert.ThrowsAsync(async () => @@ -183,16 +187,27 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [Fact] public async Task LongPollingTransportStopsWhenSendRequestFails() { + var stopped = false; var mockHttpHandler = new Mock(); mockHttpHandler.Protected() .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(async (request, cancellationToken) => { await Task.Yield(); - var statusCode = request.Method == HttpMethod.Post - ? HttpStatusCode.InternalServerError - : HttpStatusCode.OK; - return ResponseUtils.CreateResponse(statusCode); + switch (request.Method.Method) + { + case "DELETE": + stopped = true; + return ResponseUtils.CreateResponse(HttpStatusCode.Accepted); + case "GET" when stopped: + return ResponseUtils.CreateResponse(HttpStatusCode.NoContent); + case "GET": + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + case "POST": + return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); + default: + throw new InvalidOperationException("Unexpected request"); + } }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -200,7 +215,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var longPollingTransport = new LongPollingTransport(httpClient); try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); await longPollingTransport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello World")); @@ -208,6 +223,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var exception = await Assert.ThrowsAsync(async () => await longPollingTransport.Input.ReadAllAsync().OrTimeout()); Assert.Contains(" 500 ", exception.Message); + + Assert.True(stopped); } finally { @@ -218,6 +235,49 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [Fact] public async Task LongPollingTransportShutsDownWhenChannelIsClosed() + { + var mockHttpHandler = new Mock(); + var stopped = false; + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (request.Method == HttpMethod.Delete) + { + stopped = true; + return ResponseUtils.CreateResponse(HttpStatusCode.Accepted); + } + else + { + return stopped + ? ResponseUtils.CreateResponse(HttpStatusCode.NoContent) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + try + { + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); + + longPollingTransport.Output.Complete(); + + await longPollingTransport.Running.OrTimeout(); + + await longPollingTransport.Input.ReadAllAsync().OrTimeout(); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } + + [Fact] + public async Task LongPollingTransportShutsDownAfterTimeoutEvenIfServerDoesntCompletePoll() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -231,9 +291,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var longPollingTransport = new LongPollingTransport(httpClient); + longPollingTransport.ShutdownTimeout = TimeSpan.FromMilliseconds(1); + try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); longPollingTransport.Output.Complete(); @@ -279,7 +341,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); // Wait for the transport to finish await longPollingTransport.Running.OrTimeout(); @@ -325,7 +387,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); longPollingTransport.Output.Write(Encoding.UTF8.GetBytes("Hello")); longPollingTransport.Output.Write(Encoding.UTF8.GetBytes("World")); @@ -367,7 +429,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), transferFormat); + await longPollingTransport.StartAsync(TestUri, transferFormat); } finally { @@ -394,7 +456,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var longPollingTransport = new LongPollingTransport(httpClient); var exception = await Assert.ThrowsAsync(() => - longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), transferFormat)); + longPollingTransport.StartAsync(TestUri, transferFormat)); Assert.Contains($"The '{transferFormat}' transfer format is not supported by this transport.", exception.Message); Assert.Equal("transferFormat", exception.ParamName); @@ -429,7 +491,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests try { - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), TransferFormat.Binary); + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); var completedTask = await Task.WhenAny(completionTcs.Task, longPollingTransport.Running).OrTimeout(); Assert.Equal(completionTcs.Task, completedTask); @@ -440,5 +502,23 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } } + + [Fact] + public async Task SendsDeleteRequestWhenTransportCompleted() + { + var handler = TestHttpMessageHandler.CreateDefault(); + + using (var httpClient = new HttpClient(handler)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + + await longPollingTransport.StartAsync(TestUri, TransferFormat.Binary); + await longPollingTransport.StopAsync(); + + var deleteRequest = handler.ReceivedRequests.SingleOrDefault(r => r.Method == HttpMethod.Delete); + Assert.NotNull(deleteRequest); + Assert.Equal(TestUri, deleteRequest.RequestUri); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs index c0b6d67bae..44be6be900 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; using System.Net; using System.Net.Http; using System.Threading; @@ -8,11 +9,23 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { public class TestHttpMessageHandler : HttpMessageHandler { + private List _receivedRequests = new List(); private Func> _handler; + public IReadOnlyList ReceivedRequests + { + get + { + lock (_receivedRequests) + { + return _receivedRequests.ToArray(); + } + } + } + public TestHttpMessageHandler(bool autoNegotiate = true) { - _handler = (request, cancellationToken) => BaseHandler(request, cancellationToken); + _handler = BaseHandler; if (autoNegotiate) { @@ -24,6 +37,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await Task.Yield(); + lock (_receivedRequests) + { + _receivedRequests.Add(request); + } + return await _handler(request, cancellationToken); } @@ -31,17 +49,31 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var testHttpMessageHandler = new TestHttpMessageHandler(); + var deleteCts = new CancellationTokenSource(); + testHttpMessageHandler.OnSocketSend((_, __) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); testHttpMessageHandler.OnLongPoll(async cancellationToken => { + var cts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, deleteCts.Token); + // Just block until canceled var tcs = new TaskCompletionSource(); - using (cancellationToken.Register(() => tcs.TrySetResult(null))) + using (cts.Token.Register(() => tcs.TrySetResult(null))) { await tcs.Task; } return ResponseUtils.CreateResponse(HttpStatusCode.NoContent); }); + testHttpMessageHandler.OnRequest((request, next, cancellationToken) => + { + if (request.Method.Equals(HttpMethod.Delete) && request.RequestUri.PathAndQuery.StartsWith("/?id=")) + { + deleteCts.Cancel(); + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + } + + return next(); + }); return testHttpMessageHandler; }