diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 0ae87d7de0..faff3b1fab 100644 --- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -544,6 +544,33 @@ describe("hubConnection", () => { } }); + it("transport falls back from WebSockets to SSE or LongPolling", async (done) => { + // Replace Websockets with a function that just + // throws to force fallback. + const oldWebSocket = (window as any).WebSocket; + (window as any).WebSocket = () => { + throw new Error("Kick rocks"); + }; + + const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, { + logger: LogLevel.Trace, + protocol: new JsonHubProtocol(), + }); + + try { + await hubConnection.start(); + + // Make sure that we connect with SSE or LongPolling after Websockets fail + const transportName = await hubConnection.invoke("GetActiveTransportName"); + expect(transportName === "ServerSentEvents" || transportName === "LongPolling").toBe(true); + } catch (e) { + fail(e); + } finally { + (window as any).WebSocket = oldWebSocket; + done(); + } + }); + function getJwtToken(url): Promise { return new Promise((resolve, reject) => { const xhr = new XMLHttpRequest(); diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts index 1cd521798f..3d06eaf8de 100644 --- a/clients/ts/signalr/spec/HttpConnection.spec.ts +++ b/clients/ts/signalr/spec/HttpConnection.spec.ts @@ -134,6 +134,24 @@ describe("HttpConnection", () => { done(); }); + it("start throws after all transports fail", async (done) => { + const options: IHttpConnectionOptions = { + httpClient: new TestHttpClient() + .on("POST", (r) => ({ connectionId: "42", availableTransports: [] })) + .on("GET", (r) => { throw new Error("fail"); }), + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org?q=myData", options); + try { + await connection.start(TransferFormat.Text); + fail(); + done(); + } catch (e) { + expect(e.message).toBe("Unable to initialize any of the available transports."); + } + done(); + }); + it("preserves user's query string", async (done) => { let connectUrl: string; const fakeTransport: ITransport = { diff --git a/clients/ts/signalr/src/HttpConnection.ts b/clients/ts/signalr/src/HttpConnection.ts index 614e51b416..febd61c21c 100644 --- a/clients/ts/signalr/src/HttpConnection.ts +++ b/clients/ts/signalr/src/HttpConnection.ts @@ -79,39 +79,29 @@ export class HttpConnection implements IConnection { // No need to add a connection ID in this case this.url = this.baseUrl; this.transport = this.constructTransport(TransportType.WebSockets); + // We should just call connect directly in this case. + // No fallback or negotiate in this case. + await this.transport.connect(this.url, transferFormat, this); } else { - let headers; const token = this.options.accessTokenFactory(); + let headers; if (token) { headers = { ["Authorization"]: `Bearer ${token}`, }; } - const negotiatePayload = await this.httpClient.post(this.resolveNegotiateUrl(this.baseUrl), { - content: "", - headers, - }); - - const negotiateResponse: INegotiateResponse = JSON.parse(negotiatePayload.content as string); - this.connectionId = negotiateResponse.connectionId; - + const negotiateResponse = await this.getNegotiationResponse(headers); // the user tries to stop the the connection when it is being started if (this.connectionState === ConnectionState.Disconnected) { return; } - - if (this.connectionId) { - this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`; - this.transport = this.createTransport(this.options.transport, negotiateResponse.availableTransports, transferFormat); - } + await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers); } this.transport.onreceive = this.onreceive; this.transport.onclose = (e) => this.stopConnection(true, e); - await this.transport.connect(this.url, transferFormat, this); - // only change the state if we were connecting to not overwrite // the state if the connection is already marked as Disconnected this.changeState(ConnectionState.Connecting, ConnectionState.Connected); @@ -123,16 +113,51 @@ export class HttpConnection implements IConnection { } } - private createTransport(requestedTransport: TransportType | ITransport, availableTransports: IAvailableTransport[], requestedTransferFormat: TransferFormat): ITransport { + private async getNegotiationResponse(headers: any): Promise { + const response = await this.httpClient.post(this.resolveNegotiateUrl(this.baseUrl), { + content: "", + headers, + }); + return JSON.parse(response.content as string); + } + + private updateConnectionId(negotiateResponse: INegotiateResponse) { + this.connectionId = negotiateResponse.connectionId; + this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`; + } + + private async createTransport(requestedTransport: TransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise { + this.updateConnectionId(negotiateResponse); if (this.isITransport(requestedTransport)) { this.logger.log(LogLevel.Trace, "Connection was provided an instance of ITransport, using that directly."); - return requestedTransport; + this.transport = requestedTransport; + await this.transport.connect(this.url, requestedTransferFormat, this); + + // only change the state if we were connecting to not overwrite + // the state if the connection is already marked as Disconnected + this.changeState(ConnectionState.Connecting, ConnectionState.Connected); + return; } - for (const endpoint of availableTransports) { + const transports = negotiateResponse.availableTransports; + for (const endpoint of transports) { + this.connectionState = ConnectionState.Connecting; const transport = this.resolveTransport(endpoint, requestedTransport, requestedTransferFormat); if (typeof transport === "number") { - return this.constructTransport(transport); + this.transport = this.constructTransport(transport); + if (negotiateResponse.connectionId === null) { + negotiateResponse = await this.getNegotiationResponse(headers); + this.updateConnectionId(negotiateResponse); + } + try { + await this.transport.connect(this.url, requestedTransferFormat, this); + this.changeState(ConnectionState.Connecting, ConnectionState.Connected); + return; + } catch (ex) { + this.logger.log(LogLevel.Error, `Failed to start the transport' ${TransportType[transport]}:' transport'${ex}'`); + this.connectionState = ConnectionState.Disconnected; + negotiateResponse.connectionId = null; + } } }