diff --git a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 93275f5336..41c54074a0 100644 --- a/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/src/SignalR/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -1123,6 +1123,34 @@ describe("hubConnection", () => { fail(e); } }); + + it("overwrites library headers with user headers", async (done) => { + const [name] = getUserAgentHeader(); + const headers = { [name]: "Custom Agent", "X-HEADER": "VALUE" }; + const hubConnection = getConnectionBuilder(t, TESTHUBENDPOINT_URL, { headers }) + .withHubProtocol(new JsonHubProtocol()) + .build(); + + try { + await hubConnection.start(); + + const customUserHeader = await hubConnection.invoke("GetHeader", "X-HEADER"); + const headerValue = await hubConnection.invoke("GetHeader", name); + + if ((t === HttpTransportType.ServerSentEvents || t === HttpTransportType.WebSockets) && !Platform.isNode) { + expect(headerValue).toBeNull(); + expect(customUserHeader).toBeNull(); + } else { + expect(headerValue).toEqual("Custom Agent"); + expect(customUserHeader).toEqual("VALUE"); + } + + await hubConnection.stop(); + done(); + } catch (e) { + fail(e); + } + }); }); function getJwtToken(url: string): Promise { diff --git a/src/SignalR/clients/ts/signalr/README.md b/src/SignalR/clients/ts/signalr/README.md index 921556949e..4339b739c3 100644 --- a/src/SignalR/clients/ts/signalr/README.md +++ b/src/SignalR/clients/ts/signalr/README.md @@ -36,7 +36,7 @@ To use the client in a NodeJS application, install the package to your `node_mod ### Example (Browser) -```JavaScript +```javascript let connection = new signalR.HubConnectionBuilder() .withUrl("/chat") .build(); @@ -51,8 +51,7 @@ connection.start() ### Example (WebWorker) - -```JavaScript +```javascript importScripts('signalr.js'); let connection = new signalR.HubConnectionBuilder() @@ -70,7 +69,7 @@ connection.start() ### Example (NodeJS) -```JavaScript +```javascript const signalR = require("@microsoft/signalr"); let connection = new signalR.HubConnectionBuilder() diff --git a/src/SignalR/clients/ts/signalr/src/HttpClient.ts b/src/SignalR/clients/ts/signalr/src/HttpClient.ts index 57614bb86b..8ae85722e9 100644 --- a/src/SignalR/clients/ts/signalr/src/HttpClient.ts +++ b/src/SignalR/clients/ts/signalr/src/HttpClient.ts @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { AbortSignal } from "./AbortController"; +import { MessageHeaders } from "./IHubProtocol"; /** Represents an HTTP request. */ export interface HttpRequest { @@ -15,7 +16,7 @@ export interface HttpRequest { content?: string | ArrayBuffer; /** An object describing headers to apply to the request. */ - headers?: { [key: string]: string }; + headers?: MessageHeaders; /** The XMLHttpRequestResponseType to apply to the request. */ responseType?: XMLHttpRequestResponseType; diff --git a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts index e439ba159b..a53fa15b69 100644 --- a/src/SignalR/clients/ts/signalr/src/HttpConnection.ts +++ b/src/SignalR/clients/ts/signalr/src/HttpConnection.ts @@ -315,7 +315,7 @@ export class HttpConnection implements IConnection { try { const response = await this.httpClient.post(negotiateUrl, { content: "", - headers, + headers: { ...headers, ...this.options.headers }, withCredentials: this.options.withCredentials, }); @@ -403,14 +403,14 @@ export class HttpConnection implements IConnection { if (!this.options.WebSocket) { throw new Error("'WebSocket' is not supported in your environment."); } - return new WebSocketTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.WebSocket); + return new WebSocketTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.WebSocket, this.options.headers || {}); case HttpTransportType.ServerSentEvents: if (!this.options.EventSource) { throw new Error("'EventSource' is not supported in your environment."); } - return new ServerSentEventsTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.EventSource, this.options.withCredentials!); + return new ServerSentEventsTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.EventSource, this.options.withCredentials!, this.options.headers || {}); case HttpTransportType.LongPolling: - return new LongPollingTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.withCredentials!); + return new LongPollingTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent || false, this.options.withCredentials!, this.options.headers || {}); default: throw new Error(`Unknown transport: ${transport}.`); } diff --git a/src/SignalR/clients/ts/signalr/src/IHttpConnectionOptions.ts b/src/SignalR/clients/ts/signalr/src/IHttpConnectionOptions.ts index 128872722b..181097ee39 100644 --- a/src/SignalR/clients/ts/signalr/src/IHttpConnectionOptions.ts +++ b/src/SignalR/clients/ts/signalr/src/IHttpConnectionOptions.ts @@ -2,12 +2,16 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { HttpClient } from "./HttpClient"; +import { MessageHeaders } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { HttpTransportType, ITransport } from "./ITransport"; import { EventSourceConstructor, WebSocketConstructor } from "./Polyfills"; /** Options provided to the 'withUrl' method on {@link @microsoft/signalr.HubConnectionBuilder} to configure options for the HTTP-based transports. */ export interface IHttpConnectionOptions { + /** {@link @microsoft/signalr.MessageHeaders} containing custom headers to be sent with every HTTP request. Note, setting headers in the browser will not work for WebSockets or the ServerSentEvents stream. */ + headers?: MessageHeaders; + /** An {@link @microsoft/signalr.HttpClient} that will be used to make HTTP requests. */ httpClient?: HttpClient; diff --git a/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts b/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts index f7a7da1784..615c1f3a54 100644 --- a/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts +++ b/src/SignalR/clients/ts/signalr/src/LongPollingTransport.ts @@ -4,6 +4,7 @@ import { AbortController } from "./AbortController"; import { HttpError, TimeoutError } from "./Errors"; import { HttpClient, HttpRequest } from "./HttpClient"; +import { MessageHeaders } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { ITransport, TransferFormat } from "./ITransport"; import { Arg, getDataDetail, getUserAgentHeader, sendMessage } from "./Utils"; @@ -17,6 +18,7 @@ export class LongPollingTransport implements ITransport { private readonly logMessageContent: boolean; private readonly withCredentials: boolean; private readonly pollAbort: AbortController; + private readonly headers: MessageHeaders; private url?: string; private running: boolean; @@ -31,13 +33,14 @@ export class LongPollingTransport implements ITransport { return this.pollAbort.aborted; } - constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, logMessageContent: boolean, withCredentials: boolean) { + constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, logMessageContent: boolean, withCredentials: boolean, headers: MessageHeaders) { this.httpClient = httpClient; this.accessTokenFactory = accessTokenFactory; this.logger = logger; this.pollAbort = new AbortController(); this.logMessageContent = logMessageContent; this.withCredentials = withCredentials; + this.headers = headers; this.running = false; @@ -60,9 +63,8 @@ export class LongPollingTransport implements ITransport { throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported."); } - const headers = {}; const [name, value] = getUserAgentHeader(); - headers[name] = value; + const headers = { [name]: value, ...this.headers }; const pollOptions: HttpRequest = { abortSignal: this.pollAbort.signal, @@ -185,7 +187,7 @@ export class LongPollingTransport implements ITransport { if (!this.running) { return Promise.reject(new Error("Cannot send until the transport is connected")); } - return sendMessage(this.logger, "LongPolling", this.httpClient, this.url!, this.accessTokenFactory, data, this.logMessageContent, this.withCredentials); + return sendMessage(this.logger, "LongPolling", this.httpClient, this.url!, this.accessTokenFactory, data, this.logMessageContent, this.withCredentials, this.headers); } public async stop(): Promise { @@ -206,7 +208,7 @@ export class LongPollingTransport implements ITransport { headers[name] = value; const deleteOptions: HttpRequest = { - headers, + headers: { ...headers, ...this.headers }, withCredentials: this.withCredentials, }; const token = await this.getAccessToken(); diff --git a/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts b/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts index de4bf3e2b7..2bed7f9b29 100644 --- a/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts +++ b/src/SignalR/clients/ts/signalr/src/ServerSentEventsTransport.ts @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { HttpClient } from "./HttpClient"; +import { MessageHeaders } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { ITransport, TransferFormat } from "./ITransport"; import { EventSourceConstructor } from "./Polyfills"; @@ -17,18 +18,20 @@ export class ServerSentEventsTransport implements ITransport { private readonly eventSourceConstructor: EventSourceConstructor; private eventSource?: EventSource; private url?: string; + private headers: MessageHeaders; public onreceive: ((data: string | ArrayBuffer) => void) | null; public onclose: ((error?: Error) => void) | null; constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, - logMessageContent: boolean, eventSourceConstructor: EventSourceConstructor, withCredentials: boolean) { + logMessageContent: boolean, eventSourceConstructor: EventSourceConstructor, withCredentials: boolean, headers: MessageHeaders) { this.httpClient = httpClient; this.accessTokenFactory = accessTokenFactory; this.logger = logger; this.logMessageContent = logMessageContent; this.withCredentials = withCredentials; this.eventSourceConstructor = eventSourceConstructor; + this.headers = headers; this.onreceive = null; this.onclose = null; @@ -64,13 +67,12 @@ export class ServerSentEventsTransport implements ITransport { } else { // Non-browser passes cookies via the dictionary const cookies = this.httpClient.getCookieString(url); - const headers = { - Cookie: cookies, - }; + const headers: MessageHeaders = {}; + headers.Cookie = cookies; const [name, value] = getUserAgentHeader(); headers[name] = value; - eventSource = new this.eventSourceConstructor(url, { withCredentials: this.withCredentials, headers } as EventSourceInit); + eventSource = new this.eventSourceConstructor(url, { withCredentials: this.withCredentials, headers: { ...headers, ...this.headers} } as EventSourceInit); } try { @@ -112,7 +114,7 @@ export class ServerSentEventsTransport implements ITransport { if (!this.eventSource) { return Promise.reject(new Error("Cannot send until the transport is connected")); } - return sendMessage(this.logger, "SSE", this.httpClient, this.url!, this.accessTokenFactory, data, this.logMessageContent, this.withCredentials); + return sendMessage(this.logger, "SSE", this.httpClient, this.url!, this.accessTokenFactory, data, this.logMessageContent, this.withCredentials, this.headers); } public stop(): Promise { diff --git a/src/SignalR/clients/ts/signalr/src/Utils.ts b/src/SignalR/clients/ts/signalr/src/Utils.ts index 3f7318cd68..dae2404807 100644 --- a/src/SignalR/clients/ts/signalr/src/Utils.ts +++ b/src/SignalR/clients/ts/signalr/src/Utils.ts @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { HttpClient } from "./HttpClient"; +import { MessageHeaders } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { NullLogger } from "./Loggers"; import { IStreamSubscriber, ISubscription } from "./Stream"; @@ -85,7 +86,7 @@ export function isArrayBuffer(val: any): val is ArrayBuffer { /** @private */ export async function sendMessage(logger: ILogger, transportName: string, httpClient: HttpClient, url: string, accessTokenFactory: (() => string | Promise) | undefined, - content: string | ArrayBuffer, logMessageContent: boolean, withCredentials: boolean): Promise { + content: string | ArrayBuffer, logMessageContent: boolean, withCredentials: boolean, defaultHeaders: MessageHeaders): Promise { let headers = {}; if (accessTokenFactory) { const token = await accessTokenFactory(); @@ -104,7 +105,7 @@ export async function sendMessage(logger: ILogger, transportName: string, httpCl const responseType = isArrayBuffer(content) ? "arraybuffer" : "text"; const response = await httpClient.post(url, { content, - headers, + headers: { ...headers, ...defaultHeaders}, responseType, withCredentials, }); diff --git a/src/SignalR/clients/ts/signalr/src/WebSocketTransport.ts b/src/SignalR/clients/ts/signalr/src/WebSocketTransport.ts index 4eb723f76e..5e3f2a6d79 100644 --- a/src/SignalR/clients/ts/signalr/src/WebSocketTransport.ts +++ b/src/SignalR/clients/ts/signalr/src/WebSocketTransport.ts @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. import { HttpClient } from "./HttpClient"; +import { MessageHeaders } from "./IHubProtocol"; import { ILogger, LogLevel } from "./ILogger"; import { ITransport, TransferFormat } from "./ITransport"; import { WebSocketConstructor } from "./Polyfills"; @@ -15,12 +16,13 @@ export class WebSocketTransport implements ITransport { private readonly webSocketConstructor: WebSocketConstructor; private readonly httpClient: HttpClient; private webSocket?: WebSocket; + private headers: MessageHeaders; public onreceive: ((data: string | ArrayBuffer) => void) | null; public onclose: ((error?: Error) => void) | null; constructor(httpClient: HttpClient, accessTokenFactory: (() => string | Promise) | undefined, logger: ILogger, - logMessageContent: boolean, webSocketConstructor: WebSocketConstructor) { + logMessageContent: boolean, webSocketConstructor: WebSocketConstructor, headers: MessageHeaders) { this.logger = logger; this.accessTokenFactory = accessTokenFactory; this.logMessageContent = logMessageContent; @@ -29,6 +31,7 @@ export class WebSocketTransport implements ITransport { this.onreceive = null; this.onclose = null; + this.headers = headers; } public async connect(url: string, transferFormat: TransferFormat): Promise { @@ -59,9 +62,9 @@ export class WebSocketTransport implements ITransport { headers[`Cookie`] = `${cookies}`; } - // Only pass cookies when in non-browser environments + // Only pass headers when in non-browser environments webSocket = new this.webSocketConstructor(url, undefined, { - headers, + headers: { ...headers, ...this.headers }, }); } diff --git a/src/SignalR/clients/ts/signalr/tests/HttpClient.test.ts b/src/SignalR/clients/ts/signalr/tests/HttpClient.test.ts index f7a3b71034..4b72ec59c4 100644 --- a/src/SignalR/clients/ts/signalr/tests/HttpClient.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HttpClient.test.ts @@ -44,9 +44,11 @@ describe("HttpClient", () => { }); await testClient.get("http://localhost", { + headers: { "X-HEADER": "VALUE"}, timeout: 42, }); expect(request.timeout).toEqual(42); + expect(request.headers).toEqual({ "X-HEADER": "VALUE"}); }); }); @@ -86,9 +88,11 @@ describe("HttpClient", () => { }); await testClient.post("http://localhost", { + headers: { "X-HEADER": "VALUE"}, timeout: 42, }); expect(request.timeout).toEqual(42); + expect(request.headers).toEqual({ "X-HEADER": "VALUE"}); }); }); }); diff --git a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts index 67d889974e..94c0ba1818 100644 --- a/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/HttpConnection.test.ts @@ -1152,6 +1152,30 @@ describe("HttpConnection", () => { }, "Failed to start the connection: Error: nope"); }); + it("overwrites library headers with user headers on negotiate", async () => { + await VerifyLogger.run(async (logger) => { + const headers = { "User-Agent": "Custom Agent", "X-HEADER": "VALUE" }; + const options: IHttpConnectionOptions = { + ...commonOptions, + headers, + httpClient: new TestHttpClient() + .on("POST", (r) => { + expect(r.headers).toEqual(headers); + return new HttpResponse(200, "", "{\"error\":\"nope\"}"); + }), + logger, + }; + + const connection = new HttpConnection("http://tempuri.org", options); + try { + await connection.start(TransferFormat.Text); + } catch { + } finally { + await connection.stop(); + } + }, "Failed to start the connection: Error: nope"); + }); + it("logMessageContent displays correctly with binary data", async () => { await VerifyLogger.run(async (logger) => { const availableTransport = { transport: "LongPolling", transferFormats: ["Text", "Binary"] }; diff --git a/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts b/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts index 853b3527ec..f75debc5d2 100644 --- a/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/LongPollingTransport.test.ts @@ -24,7 +24,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(200); } else { // Turn 'onabort' into a promise. - const abort = new Promise((resolve, reject) => { + const abort = new Promise((resolve) => { if (r.abortSignal!.aborted) { resolve(); } else { @@ -40,7 +40,7 @@ describe("LongPollingTransport", () => { } }) .on("DELETE", () => new HttpResponse(202)); - const transport = new LongPollingTransport(client, undefined, logger, false, true); + const transport = new LongPollingTransport(client, undefined, logger, false, true, {}); await transport.connect("http://example.com", TransferFormat.Text); const stopPromise = transport.stop(); @@ -64,7 +64,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(204); } }); - const transport = new LongPollingTransport(client, undefined, logger, false, true); + const transport = new LongPollingTransport(client, undefined, logger, false, true, {}); const stopPromise = makeClosedPromise(transport); @@ -82,7 +82,7 @@ describe("LongPollingTransport", () => { const pollingPromiseSource = new PromiseSource(); const deleteSyncPoint = new SyncPoint(); const httpClient = new TestHttpClient() - .on("GET", async (r) => { + .on("GET", async () => { if (firstPoll) { firstPoll = false; return new HttpResponse(200); @@ -91,13 +91,13 @@ describe("LongPollingTransport", () => { return new HttpResponse(204); } }) - .on("DELETE", async (r) => { + .on("DELETE", async () => { deleteSent = true; await deleteSyncPoint.waitToContinue(); return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, false, true); + const transport = new LongPollingTransport(httpClient, undefined, logger, false, true, {}); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -146,7 +146,7 @@ describe("LongPollingTransport", () => { return new HttpResponse(202); }); - const transport = new LongPollingTransport(httpClient, undefined, logger, false, true); + const transport = new LongPollingTransport(httpClient, undefined, logger, false, true, {}); await transport.connect("http://tempuri.org", TransferFormat.Text); @@ -165,6 +165,67 @@ describe("LongPollingTransport", () => { expect(secondPollUserAgent).toEqual(value); }); }); + + it("overwrites library headers with user headers", async () => { + await VerifyLogger.run(async (logger) => { + const headers = { "User-Agent": "Custom Agent", "X-HEADER": "VALUE" }; + let firstPoll = true; + let firstPollUserAgent = ""; + let firstUserHeader = ""; + let secondPollUserAgent = ""; + let secondUserHeader = ""; + let deleteUserAgent = ""; + let deleteUserHeader = ""; + const pollingPromiseSource = new PromiseSource(); + const httpClient = new TestHttpClient() + .on("POST", async (r) => { + expect(r.content).toEqual({ message: "hello" }); + expect(r.headers).toEqual(headers); + expect(r.method).toEqual("POST"); + expect(r.url).toEqual("http://tempuri.org"); + }) + .on("GET", async (r) => { + if (firstPoll) { + firstPoll = false; + firstPollUserAgent = r.headers!["User-Agent"]; + firstUserHeader = r.headers!["X-HEADER"]; + return new HttpResponse(200); + } else { + secondPollUserAgent = r.headers!["User-Agent"]; + secondUserHeader = r.headers!["X-HEADER"]; + await pollingPromiseSource.promise; + return new HttpResponse(204); + } + }) + .on("DELETE", async (r) => { + deleteUserAgent = r.headers!["User-Agent"]; + deleteUserHeader = r.headers!["X-HEADER"]; + return new HttpResponse(202); + }); + + const transport = new LongPollingTransport(httpClient, undefined, logger, false, true, headers); + + await transport.connect("http://tempuri.org", TransferFormat.Text); + + await transport.send({ message: "hello" }); + + // Begin stopping transport + const stopPromise = transport.stop(); + + // Allow polling to complete + pollingPromiseSource.resolve(); + + // Wait for stop to complete + await stopPromise; + + expect(firstPollUserAgent).toEqual("Custom Agent"); + expect(deleteUserAgent).toEqual("Custom Agent"); + expect(secondPollUserAgent).toEqual("Custom Agent"); + expect(firstUserHeader).toEqual("VALUE"); + expect(secondUserHeader).toEqual("VALUE"); + expect(deleteUserHeader).toEqual("VALUE"); + }); + }); }); function makeClosedPromise(transport: LongPollingTransport): Promise { diff --git a/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts b/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts index 9fb038fa0f..5cad66a621 100644 --- a/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/ServerSentEventsTransport.test.ts @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +import { MessageHeaders } from "../src/IHubProtocol"; import { TransferFormat } from "../src/ITransport"; import { HttpClient, HttpRequest } from "../src/HttpClient"; @@ -17,7 +18,7 @@ registerUnhandledRejectionHandler(); describe("ServerSentEventsTransport", () => { it("does not allow non-text formats", async () => { await VerifyLogger.run(async (logger) => { - const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true); + const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true, {}); await expect(sse.connect("", TransferFormat.Binary)) .rejects @@ -27,7 +28,7 @@ describe("ServerSentEventsTransport", () => { it("connect waits for EventSource to be connected", async () => { await VerifyLogger.run(async (logger) => { - const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true); + const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true, {}); let connectComplete: boolean = false; const connectPromise = (async () => { @@ -48,7 +49,7 @@ describe("ServerSentEventsTransport", () => { it("connect failure does not call onclose handler", async () => { await VerifyLogger.run(async (logger) => { - const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true); + const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true, {}); let closeCalled = false; sse.onclose = () => closeCalled = true; @@ -169,7 +170,7 @@ describe("ServerSentEventsTransport", () => { it("send throws if not connected", async () => { await VerifyLogger.run(async (logger) => { - const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true); + const sse = new ServerSentEventsTransport(new TestHttpClient(), undefined, logger, true, TestEventSource, true, {}); await expect(sse.send("")) .rejects @@ -221,10 +222,31 @@ describe("ServerSentEventsTransport", () => { expect(request!.url).toBe("http://example.com"); }); }); + + it("overwrites library headers with user headers", async () => { + await VerifyLogger.run(async (logger) => { + let request: HttpRequest; + const httpClient = new TestHttpClient().on((r) => { + request = r; + return ""; + }); + + const headers = { "User-Agent": "Custom Agent", "X-HEADER": "VALUE" }; + const sse = await createAndStartSSE(logger, "http://example.com", undefined, httpClient, headers); + + expect((TestEventSource.eventSource.eventSourceInitDict as any).headers["User-Agent"]).toEqual("Custom Agent"); + expect((TestEventSource.eventSource.eventSourceInitDict as any).headers["X-HEADER"]).toEqual("VALUE"); + await sse.send(""); + + expect(request!.headers!["User-Agent"]).toEqual("Custom Agent"); + expect(request!.headers!["X-HEADER"]).toEqual("VALUE"); + expect(request!.url).toBe("http://example.com"); + }); + }); }); -async function createAndStartSSE(logger: ILogger, url?: string, accessTokenFactory?: (() => string | Promise), httpClient?: HttpClient): Promise { - const sse = new ServerSentEventsTransport(httpClient || new TestHttpClient(), accessTokenFactory, logger, true, TestEventSource, true); +async function createAndStartSSE(logger: ILogger, url?: string, accessTokenFactory?: (() => string | Promise), httpClient?: HttpClient, headers?: MessageHeaders): Promise { + const sse = new ServerSentEventsTransport(httpClient || new TestHttpClient(), accessTokenFactory, logger, true, TestEventSource, true, headers || {}); const connectPromise = sse.connect(url || "http://example.com", TransferFormat.Text); await TestEventSource.eventSource.openSet; diff --git a/src/SignalR/clients/ts/signalr/tests/WebSocketTransport.test.ts b/src/SignalR/clients/ts/signalr/tests/WebSocketTransport.test.ts index cb4455a517..d3984cd673 100644 --- a/src/SignalR/clients/ts/signalr/tests/WebSocketTransport.test.ts +++ b/src/SignalR/clients/ts/signalr/tests/WebSocketTransport.test.ts @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +import { MessageHeaders } from "../src/IHubProtocol"; import { ILogger } from "../src/ILogger"; import { TransferFormat } from "../src/ITransport"; import { getUserAgentHeader } from "../src/Utils"; @@ -24,7 +25,7 @@ describe("WebSocketTransport", () => { it("connect waits for WebSocket to be connected", async () => { await VerifyLogger.run(async (logger) => { - const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket); + const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket, {}); let connectComplete: boolean = false; const connectPromise = (async () => { @@ -46,7 +47,7 @@ describe("WebSocketTransport", () => { it("connect fails if there is error during connect", async () => { await VerifyLogger.run(async (logger) => { (global as any).ErrorEvent = TestErrorEvent; - const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket); + const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket, {}); let connectComplete: boolean = false; const connectPromise = (async () => { @@ -70,7 +71,7 @@ describe("WebSocketTransport", () => { it("connect failure does not call onclose handler", async () => { await VerifyLogger.run(async (logger) => { (global as any).ErrorEvent = TestErrorEvent; - const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket); + const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket, {}); let closeCalled = false; webSocket.onclose = () => closeCalled = true; @@ -257,6 +258,33 @@ describe("WebSocketTransport", () => { }); }); + it("overwrites library headers with user headers", async () => { + await VerifyLogger.run(async (logger) => { + (global as any).ErrorEvent = TestEvent; + const headers = { "User-Agent": "Custom Agent", "X-HEADER": "VALUE" }; + const webSocket = await createAndStartWebSocket(logger, undefined, undefined, undefined, headers); + + let closeCalled: boolean = false; + let error: Error; + webSocket.onclose = (e) => { + closeCalled = true; + error = e!; + }; + + expect(TestWebSocket.webSocket.options!.headers[`User-Agent`]).toEqual("Custom Agent"); + expect(TestWebSocket.webSocket.options!.headers[`X-HEADER`]).toEqual("VALUE"); + + await webSocket.stop(); + + expect(closeCalled).toBe(true); + expect(error!).toBeUndefined(); + + await expect(webSocket.send("")) + .rejects + .toBe("WebSocket is not in the OPEN state"); + }); + }); + it("is closed from 'onreceive' callback throwing", async () => { await VerifyLogger.run(async (logger) => { (global as any).ErrorEvent = TestEvent; @@ -270,7 +298,7 @@ describe("WebSocketTransport", () => { }; const receiveError = new Error("callback error"); - webSocket.onreceive = (data) => { + webSocket.onreceive = () => { throw receiveError; }; @@ -290,7 +318,7 @@ describe("WebSocketTransport", () => { it("does not run onclose callback if Transport does not fully connect and exits", async () => { await VerifyLogger.run(async (logger) => { (global as any).ErrorEvent = TestErrorEvent; - const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket); + const webSocket = new WebSocketTransport(new TestHttpClient(), undefined, logger, true, TestWebSocket, {}); const connectPromise = webSocket.connect("http://example.com", TransferFormat.Text); @@ -318,8 +346,8 @@ describe("WebSocketTransport", () => { }); }); -async function createAndStartWebSocket(logger: ILogger, url?: string, accessTokenFactory?: (() => string | Promise), format?: TransferFormat): Promise { - const webSocket = new WebSocketTransport(new TestHttpClient(), accessTokenFactory, logger, true, TestWebSocket); +async function createAndStartWebSocket(logger: ILogger, url?: string, accessTokenFactory?: (() => string | Promise), format?: TransferFormat, headers?: MessageHeaders): Promise { + const webSocket = new WebSocketTransport(new TestHttpClient(), accessTokenFactory, logger, true, TestWebSocket, headers || {}); const connectPromise = webSocket.connect(url || "http://example.com", format || TransferFormat.Text);