diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj
index 69db719b11..646477a011 100644
--- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj
+++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks.csproj
@@ -2,7 +2,7 @@
Exe
- netcoreapp2.0
+ netcoreapp2.1
diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts
index 0ae87d7de0..faff3b1fab 100644
--- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts
+++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts
@@ -544,6 +544,33 @@ describe("hubConnection", () => {
}
});
+ it("transport falls back from WebSockets to SSE or LongPolling", async (done) => {
+ // Replace Websockets with a function that just
+ // throws to force fallback.
+ const oldWebSocket = (window as any).WebSocket;
+ (window as any).WebSocket = () => {
+ throw new Error("Kick rocks");
+ };
+
+ const hubConnection = new HubConnection(TESTHUBENDPOINT_URL, {
+ logger: LogLevel.Trace,
+ protocol: new JsonHubProtocol(),
+ });
+
+ try {
+ await hubConnection.start();
+
+ // Make sure that we connect with SSE or LongPolling after Websockets fail
+ const transportName = await hubConnection.invoke("GetActiveTransportName");
+ expect(transportName === "ServerSentEvents" || transportName === "LongPolling").toBe(true);
+ } catch (e) {
+ fail(e);
+ } finally {
+ (window as any).WebSocket = oldWebSocket;
+ done();
+ }
+ });
+
function getJwtToken(url): Promise {
return new Promise((resolve, reject) => {
const xhr = new XMLHttpRequest();
diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts
index 1cd521798f..3d06eaf8de 100644
--- a/clients/ts/signalr/spec/HttpConnection.spec.ts
+++ b/clients/ts/signalr/spec/HttpConnection.spec.ts
@@ -134,6 +134,24 @@ describe("HttpConnection", () => {
done();
});
+ it("start throws after all transports fail", async (done) => {
+ const options: IHttpConnectionOptions = {
+ httpClient: new TestHttpClient()
+ .on("POST", (r) => ({ connectionId: "42", availableTransports: [] }))
+ .on("GET", (r) => { throw new Error("fail"); }),
+ } as IHttpConnectionOptions;
+
+ const connection = new HttpConnection("http://tempuri.org?q=myData", options);
+ try {
+ await connection.start(TransferFormat.Text);
+ fail();
+ done();
+ } catch (e) {
+ expect(e.message).toBe("Unable to initialize any of the available transports.");
+ }
+ done();
+ });
+
it("preserves user's query string", async (done) => {
let connectUrl: string;
const fakeTransport: ITransport = {
diff --git a/clients/ts/signalr/src/HttpConnection.ts b/clients/ts/signalr/src/HttpConnection.ts
index 614e51b416..febd61c21c 100644
--- a/clients/ts/signalr/src/HttpConnection.ts
+++ b/clients/ts/signalr/src/HttpConnection.ts
@@ -79,39 +79,29 @@ export class HttpConnection implements IConnection {
// No need to add a connection ID in this case
this.url = this.baseUrl;
this.transport = this.constructTransport(TransportType.WebSockets);
+ // We should just call connect directly in this case.
+ // No fallback or negotiate in this case.
+ await this.transport.connect(this.url, transferFormat, this);
} else {
- let headers;
const token = this.options.accessTokenFactory();
+ let headers;
if (token) {
headers = {
["Authorization"]: `Bearer ${token}`,
};
}
- const negotiatePayload = await this.httpClient.post(this.resolveNegotiateUrl(this.baseUrl), {
- content: "",
- headers,
- });
-
- const negotiateResponse: INegotiateResponse = JSON.parse(negotiatePayload.content as string);
- this.connectionId = negotiateResponse.connectionId;
-
+ const negotiateResponse = await this.getNegotiationResponse(headers);
// the user tries to stop the the connection when it is being started
if (this.connectionState === ConnectionState.Disconnected) {
return;
}
-
- if (this.connectionId) {
- this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`;
- this.transport = this.createTransport(this.options.transport, negotiateResponse.availableTransports, transferFormat);
- }
+ await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers);
}
this.transport.onreceive = this.onreceive;
this.transport.onclose = (e) => this.stopConnection(true, e);
- await this.transport.connect(this.url, transferFormat, this);
-
// only change the state if we were connecting to not overwrite
// the state if the connection is already marked as Disconnected
this.changeState(ConnectionState.Connecting, ConnectionState.Connected);
@@ -123,16 +113,51 @@ export class HttpConnection implements IConnection {
}
}
- private createTransport(requestedTransport: TransportType | ITransport, availableTransports: IAvailableTransport[], requestedTransferFormat: TransferFormat): ITransport {
+ private async getNegotiationResponse(headers: any): Promise {
+ const response = await this.httpClient.post(this.resolveNegotiateUrl(this.baseUrl), {
+ content: "",
+ headers,
+ });
+ return JSON.parse(response.content as string);
+ }
+
+ private updateConnectionId(negotiateResponse: INegotiateResponse) {
+ this.connectionId = negotiateResponse.connectionId;
+ this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`;
+ }
+
+ private async createTransport(requestedTransport: TransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise {
+ this.updateConnectionId(negotiateResponse);
if (this.isITransport(requestedTransport)) {
this.logger.log(LogLevel.Trace, "Connection was provided an instance of ITransport, using that directly.");
- return requestedTransport;
+ this.transport = requestedTransport;
+ await this.transport.connect(this.url, requestedTransferFormat, this);
+
+ // only change the state if we were connecting to not overwrite
+ // the state if the connection is already marked as Disconnected
+ this.changeState(ConnectionState.Connecting, ConnectionState.Connected);
+ return;
}
- for (const endpoint of availableTransports) {
+ const transports = negotiateResponse.availableTransports;
+ for (const endpoint of transports) {
+ this.connectionState = ConnectionState.Connecting;
const transport = this.resolveTransport(endpoint, requestedTransport, requestedTransferFormat);
if (typeof transport === "number") {
- return this.constructTransport(transport);
+ this.transport = this.constructTransport(transport);
+ if (negotiateResponse.connectionId === null) {
+ negotiateResponse = await this.getNegotiationResponse(headers);
+ this.updateConnectionId(negotiateResponse);
+ }
+ try {
+ await this.transport.connect(this.url, requestedTransferFormat, this);
+ this.changeState(ConnectionState.Connecting, ConnectionState.Connected);
+ return;
+ } catch (ex) {
+ this.logger.log(LogLevel.Error, `Failed to start the transport' ${TransportType[transport]}:' transport'${ex}'`);
+ this.connectionState = ConnectionState.Disconnected;
+ negotiateResponse.connectionId = null;
+ }
}
}
diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs
index bc6168be43..8efbee6a26 100644
--- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HandshakeProtocol.cs
@@ -58,7 +58,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private static JsonTextReader CreateJsonTextReader(ReadOnlyMemory payload)
{
- var textReader = new Utf8BufferTextReader(payload);
+ var textReader = new Utf8BufferTextReader();
+ textReader.SetBuffer(payload);
var reader = new JsonTextReader(textReader);
reader.ArrayPool = JsonArrayPool.Shared;
diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs
index 14ba3ea78d..c1ebe09cf8 100644
--- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/HubMessage.cs
@@ -12,9 +12,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
}
- // Initialize with capacity 2 for the 2 built in protocols
private object _lock = new object();
- private readonly List _serializedMessages = new List(2);
+ private List _serializedMessages;
public byte[] WriteMessage(IHubProtocol protocol)
{
@@ -25,7 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
lock (_lock)
{
- for (var i = 0; i < _serializedMessages.Count; i++)
+ for (var i = 0; i < _serializedMessages?.Count; i++)
{
if (_serializedMessages[i].Protocol.Equals(protocol))
{
@@ -35,6 +34,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
var bytes = protocol.WriteToArray(this);
+ if (_serializedMessages == null)
+ {
+ // Initialize with capacity 2 for the 2 built in protocols
+ _serializedMessages = new List(2);
+ }
+
// We don't want to balloon memory if someone writes a poor IHubProtocolResolver
// So we cap how many caches we store and worst case just serialize the message for every connection
if (_serializedMessages.Count < 10)
diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs
index df6a664f2f..80d7e56440 100644
--- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs
+++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs
@@ -58,11 +58,19 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
while (TextMessageParser.TryParseMessage(ref input, out var payload))
{
- var textReader = new Utf8BufferTextReader(payload);
- var message = ParseMessage(textReader, binder);
- if (message != null)
+ var textReader = Utf8BufferTextReader.Get(payload);
+
+ try
{
- messages.Add(message);
+ var message = ParseMessage(textReader, binder);
+ if (message != null)
+ {
+ messages.Add(message);
+ }
+ }
+ finally
+ {
+ Utf8BufferTextReader.Return(textReader);
}
}
@@ -103,6 +111,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
using (var reader = new JsonTextReader(textReader))
{
reader.ArrayPool = JsonArrayPool.Shared;
+ reader.CloseInput = false;
JsonUtils.CheckRead(reader);
@@ -559,7 +568,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
private object[] BindArguments(JsonTextReader reader, IReadOnlyList paramTypes)
{
- var arguments = new object[paramTypes.Count];
+ object[] arguments = null;
var paramIndex = 0;
var argumentsCount = 0;
@@ -572,7 +581,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
throw new InvalidDataException($"Invocation provides {argumentsCount} argument(s) but target expects {paramTypes.Count}.");
}
- return arguments;
+ return arguments ?? Array.Empty