diff --git a/build/repo.targets b/build/repo.targets index 51077aa751..150d34afcc 100644 --- a/build/repo.targets +++ b/build/repo.targets @@ -6,6 +6,6 @@ - + diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Common.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Common.ts new file mode 100644 index 0000000000..e40b74eda6 --- /dev/null +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Common.ts @@ -0,0 +1,9 @@ +import { ITransport, TransportType } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports" + +export function eachTransport(action: (transport: TransportType) => void) { + let transportTypes = [ + TransportType.WebSockets, + TransportType.ServerSentEvents, + TransportType.LongPolling ]; + transportTypes.forEach(t => action(t)); +}; \ No newline at end of file diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Connection.spec.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Connection.spec.ts index e8348aa0cb..c2c26d7292 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Connection.spec.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/Connection.spec.ts @@ -2,7 +2,8 @@ import { IHttpClient } from "../Microsoft.AspNetCore.SignalR.Client.TS/HttpClien import { Connection } from "../Microsoft.AspNetCore.SignalR.Client.TS/Connection" import { ISignalROptions } from "../Microsoft.AspNetCore.SignalR.Client.TS/ISignalROptions" import { DataReceived, TransportClosed } from "../Microsoft.AspNetCore.SignalR.Client.TS/Common" -import { ITransport } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports" +import { ITransport, TransportType } from "../Microsoft.AspNetCore.SignalR.Client.TS/Transports" +import { eachTransport } from "./Common"; describe("Connection", () => { @@ -102,7 +103,7 @@ describe("Connection", () => { httpClient: { options(url: string): Promise { connection.stop(); - return Promise.resolve(""); + return Promise.resolve("{}"); }, get(url: string): Promise { connection.stop(); @@ -133,7 +134,7 @@ describe("Connection", () => { let options: ISignalROptions = { httpClient: { options(url: string): Promise { - return Promise.resolve("42"); + return Promise.resolve("{ \"connectionId\": \"42\" }"); }, get(url: string): Promise { return Promise.resolve(""); @@ -168,4 +169,54 @@ describe("Connection", () => { expect(connectUrl).toBe("http://tempuri.org?q=myData&id=42"); done(); }); + + eachTransport((requestedTransport: TransportType) => { + it(`Connection cannot be started if requested ${TransportType[requestedTransport]} transport not available on server`, async done => { + let options: ISignalROptions = { + httpClient: { + options(url: string): Promise { + return Promise.resolve("{ \"connectionId\": \"42\", \"availableTransports\": [] }"); + }, + get(url: string): Promise { + return Promise.resolve(""); + } + } + } as ISignalROptions; + + var connection = new Connection("http://tempuri.org", options); + try { + await connection.start(requestedTransport); + fail(); + done(); + } + catch (e) { + expect(e.message).toBe("No available transports found."); + done(); + } + }); + }); + + it(`Connection cannot be started if no transport available on server and no transport requested`, async done => { + let options: ISignalROptions = { + httpClient: { + options(url: string): Promise { + return Promise.resolve("{ \"connectionId\": \"42\", \"availableTransports\": [] }"); + }, + get(url: string): Promise { + return Promise.resolve(""); + } + } + } as ISignalROptions; + + var connection = new Connection("http://tempuri.org", options); + try { + await connection.start(); + fail(); + done(); + } + catch (e) { + expect(e.message).toBe("No available transports found."); + done(); + } + }); }); diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Connection.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Connection.ts index e4d228131a..b502aca087 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Connection.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/Connection.ts @@ -11,6 +11,11 @@ enum ConnectionState { Disconnected } +interface INegotiateResponse { + connectionId: string + availableTransports: string[] +} + export class Connection implements IConnection { private connectionState: ConnectionState; private url: string; @@ -25,7 +30,7 @@ export class Connection implements IConnection { this.connectionState = ConnectionState.Initial; } - async start(transport: TransportType | ITransport = TransportType.WebSockets): Promise { + async start(transport?: TransportType | ITransport): Promise { if (this.connectionState != ConnectionState.Initial) { return Promise.reject(new Error("Cannot start a connection that is not in the 'Initial' state.")); } @@ -38,7 +43,9 @@ export class Connection implements IConnection { private async startInternal(transportType: TransportType | ITransport): Promise { try { - this.connectionId = await this.httpClient.options(this.url); + let negotiatePayload = await this.httpClient.options(this.url); + let negotiateResponse: INegotiateResponse = JSON.parse(negotiatePayload); + this.connectionId = negotiateResponse.connectionId; // the user tries to stop the the connection when it is being started if (this.connectionState == ConnectionState.Disconnected) { @@ -47,7 +54,7 @@ export class Connection implements IConnection { this.url += (this.url.indexOf("?") == -1 ? "?" : "&") + `id=${this.connectionId}`; - this.transport = this.createTransport(transportType); + this.transport = this.createTransport(transportType, negotiateResponse.availableTransports); this.transport.onDataReceived = this.onDataReceived; this.transport.onClosed = e => this.stopConnection(true, e); await this.transport.connect(this.url); @@ -56,21 +63,24 @@ export class Connection implements IConnection { this.changeState(ConnectionState.Connecting, ConnectionState.Connected); } catch (e) { - console.log("Failed to start the connection. " + e) + console.log("Failed to start the connection. " + e); this.connectionState = ConnectionState.Disconnected; this.transport = null; throw e; }; } - private createTransport(transport: TransportType | ITransport): ITransport { - if (transport === TransportType.WebSockets) { + private createTransport(transport: TransportType | ITransport, availableTransports: string[]): ITransport { + if (!transport && availableTransports.length > 0) { + transport = TransportType[availableTransports[0]]; + } + if (transport === TransportType.WebSockets && availableTransports.indexOf(TransportType[transport]) >= 0) { return new WebSocketTransport(); } - if (transport === TransportType.ServerSentEvents) { + if (transport === TransportType.ServerSentEvents && availableTransports.indexOf(TransportType[transport]) >= 0) { return new ServerSentEventsTransport(this.httpClient); } - if (transport === TransportType.LongPolling) { + if (transport === TransportType.LongPolling && availableTransports.indexOf(TransportType[transport]) >= 0) { return new LongPollingTransport(this.httpClient); } @@ -78,11 +88,11 @@ export class Connection implements IConnection { return transport; } - throw new Error("No valid transports requested."); + throw new Error("No available transports found."); } private isITransport(transport: any): transport is ITransport { - return "connect" in transport; + return typeof(transport) === "object" && "connect" in transport; } private changeState(from: ConnectionState, to: ConnectionState): Boolean { diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/connectionTests.js b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/connectionTests.js index 88d0f419bf..13a1a21011 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/connectionTests.js +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/connectionTests.js @@ -1,4 +1,31 @@ describe('connection', () => { + it(`can connect to the server without specifying transport explicitly`, done => { + const message = "Hello World!"; + let connection = new signalR.Connection(ECHOENDPOINT_URL); + + let received = ""; + connection.onDataReceived = data => { + received += data; + if (data == message) { + connection.stop(); + } + } + + connection.onClosed = error => { + expect(error).toBeUndefined(); + done(); + } + + connection.start() + .then(() => { + connection.send(message); + }) + .catch(e => { + fail(); + done(); + }); + }); + eachTransport(transportType => { it(`over ${signalR.TransportType[transportType]} can send and receive messages`, done => { const message = "Hello World!"; diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js index eea1907676..6e8df4cbf2 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js +++ b/client-ts/Microsoft.AspNetCore.SignalR.Test.Server/wwwroot/js/hubConnectionTests.js @@ -47,11 +47,11 @@ describe('hubConnection', () => { }, error: (err) => { fail(err); - done(); + hubConnection.stop(); }, complete: () => { expect(received).toEqual(["a", "b", "c"]); - done(); + hubConnection.stop(); } }); }) diff --git a/specs/TransportProtocols.md b/specs/TransportProtocols.md index fdebf8f54d..12bb934513 100644 --- a/specs/TransportProtocols.md +++ b/specs/TransportProtocols.md @@ -16,6 +16,17 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou **NOTE on errors:** In all error cases, by default, the detailed exception message is **never** provided; a short description string may be provided. However, an application developer may elect to allow detailed exception messages to be emitted, which should only be used in the `Development` environment. Unexpected errors are communicated by HTTP `500 Server Error` status codes or WebSockets `1008 Policy Violation` close frames; in these cases the connection should be considered to be terminated. +## `OPTIONS [endpoint-base]` request + +The `OPTIONS [endpoint-base]` request is used to establish connection between the client and the server. The response to the `OPTIONS [endpoint-base]` request contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server. The content type of the response is `application/json`. The following is a sample response to the `OPTIONS [endpoint-base]` request + +``` +{ + "connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d", + "availableTransports":["WebSockets","ServerSentEvents","LongPolling"] +} +``` + ## WebSockets (Full Duplex) The WebSockets transport is unique in that it is full duplex, and a persistent connection that can be established in a single operation. As a result, the client is not required to use the `OPTIONS [endpoint-base]` request to establish a connection in advance. It also includes all the necessary metadata in it's own frame metadata. diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 3aeb9f1313..cf46cc503d 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -4,9 +4,6 @@ using System; using System.Buffers; using System.Collections.Generic; -using System.IO; -using System.IO.Pipelines; -using System.IO.Pipelines.Text.Primitives; using System.Linq; using System.Reflection; using System.Text; diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs index 52e2aea69b..7fb903393b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/Connection.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; using System.Net.Http; using System.Threading; using System.Threading.Tasks; @@ -10,6 +11,7 @@ using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.Sockets.Client { @@ -103,7 +105,7 @@ namespace Microsoft.AspNetCore.Sockets.Client try { - var connectUrl = await GetConnectUrl(Url, httpClient, _logger); + var negotiationResponse = await Negotiate(Url, httpClient, _logger); // Connection is being stopped while start was in progress if (_connectionState == ConnectionState.Disconnected) @@ -112,9 +114,9 @@ namespace Microsoft.AspNetCore.Sockets.Client return; } - // TODO: Available server transports should be sent by the server in the negotiation response - _transport = transportFactory.CreateTransport(TransportType.All); + _transport = transportFactory.CreateTransport(GetAvailableServerTransports(negotiationResponse)); + var connectUrl = CreateConnectUrl(Url, negotiationResponse); _logger.LogDebug("Starting transport '{0}' with Url: {1}", _transport.GetType().Name, connectUrl); await StartTransport(connectUrl); } @@ -179,32 +181,75 @@ namespace Microsoft.AspNetCore.Sockets.Client } } - private static async Task GetConnectUrl(Uri url, HttpClient httpClient, ILogger logger) - { - var connectionId = await GetConnectionId(url, httpClient, logger); - return Utils.AppendQueryString(url, "id=" + connectionId); - } - - private static async Task GetConnectionId(Uri url, HttpClient httpClient, ILogger logger) + private async static Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) { try { // Get a connection ID from the server logger.LogDebug("Establishing Connection at: {0}", url); - var request = new HttpRequestMessage(HttpMethod.Options, url); - var response = await httpClient.SendAsync(request); - response.EnsureSuccessStatusCode(); - var connectionId = await response.Content.ReadAsStringAsync(); - logger.LogDebug("Connection Id: {0}", connectionId); - return connectionId; + using (var request = new HttpRequestMessage(HttpMethod.Options, url)) + using (var response = await httpClient.SendAsync(request)) + { + response.EnsureSuccessStatusCode(); + return await ParseNegotiateResponse(response, logger); + } } catch (Exception ex) { - logger.LogError("Failed to start connection. Error getting connection id from '{0}': {1}", url, ex); + logger.LogError("Failed to start connection. Error getting negotiation response from '{0}': {1}", url, ex); throw; } } + private static async Task ParseNegotiateResponse(HttpResponseMessage response, ILogger logger) + { + NegotiationResponse negotiationResponse; + using (var reader = new JsonTextReader(new StreamReader(await response.Content.ReadAsStreamAsync()))) + { + try + { + negotiationResponse = new JsonSerializer().Deserialize(reader); + } + catch (Exception ex) + { + throw new FormatException("Invalid negotiation response received.", ex); + } + } + + if (negotiationResponse == null) + { + throw new FormatException("Invalid negotiation response received."); + } + + return negotiationResponse; + } + + private TransportType GetAvailableServerTransports(NegotiationResponse negotiationResponse) + { + if (negotiationResponse.AvailableTransports == null) + { + throw new FormatException("No transports returned in negotiation response."); + } + + var availableServerTransports = (TransportType)0; + foreach (var t in negotiationResponse.AvailableTransports) + { + availableServerTransports |= t; + } + + return availableServerTransports; + } + + private static Uri CreateConnectUrl(Uri url, NegotiationResponse negotiationResponse) + { + if (string.IsNullOrWhiteSpace(negotiationResponse.ConnectionId)) + { + throw new FormatException("Invalid connection id returned in negotiation response."); + } + + return Utils.AppendQueryString(url, "id=" + negotiationResponse.ConnectionId); + } + private async Task StartTransport(Uri connectUrl) { var applicationToTransport = Channel.CreateUnbounded(); @@ -352,5 +397,11 @@ namespace Microsoft.AspNetCore.Sockets.Client public const int Connected = 2; public const int Disconnected = 3; } + + private class NegotiationResponse + { + public string ConnectionId { get; set; } + public TransportType[] AvailableTransports { get; set; } + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj index f0812bd108..b58e94a281 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj @@ -22,6 +22,7 @@ + diff --git a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs index 01a8b166f0..a1e3381371 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/HttpConnectionDispatcher.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.Sockets { @@ -314,17 +315,46 @@ namespace Microsoft.AspNetCore.Sockets // Set the allowed headers for this resource context.Response.Headers.AppendCommaSeparatedValues("Allow", "GET", "POST", "OPTIONS"); - context.Response.ContentType = "text/plain"; + context.Response.ContentType = "application/json"; // Establish the connection var connection = _manager.CreateConnection(); // Get the bytes for the connection id - var connectionIdBuffer = Encoding.UTF8.GetBytes(connection.ConnectionId); + var negotiateResponseBuffer = Encoding.UTF8.GetBytes(GetNegotiatePayload(connection.ConnectionId, options)); // Write it out to the response with the right content length - context.Response.ContentLength = connectionIdBuffer.Length; - return context.Response.Body.WriteAsync(connectionIdBuffer, 0, connectionIdBuffer.Length); + context.Response.ContentLength = negotiateResponseBuffer.Length; + return context.Response.Body.WriteAsync(negotiateResponseBuffer, 0, negotiateResponseBuffer.Length); + } + + private static string GetNegotiatePayload(string connectionId, HttpSocketOptions options) + { + var sb = new StringBuilder(); + using (var jsonWriter = new JsonTextWriter(new StringWriter(sb))) + { + jsonWriter.WriteStartObject(); + jsonWriter.WritePropertyName("connectionId"); + jsonWriter.WriteValue(connectionId); + jsonWriter.WritePropertyName("availableTransports"); + jsonWriter.WriteStartArray(); + if ((options.Transports & TransportType.WebSockets) != 0) + { + jsonWriter.WriteValue(nameof(TransportType.WebSockets)); + } + if ((options.Transports & TransportType.ServerSentEvents) != 0) + { + jsonWriter.WriteValue(nameof(TransportType.ServerSentEvents)); + } + if ((options.Transports & TransportType.LongPolling) != 0) + { + jsonWriter.WriteValue(nameof(TransportType.LongPolling)); + } + jsonWriter.WriteEndArray(); + jsonWriter.WriteEndObject(); + } + + return sb.ToString(); } private async Task ProcessSend(HttpContext context) diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj index e60b0807fd..f3ad26b1df 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Http/Microsoft.AspNetCore.Sockets.Http.csproj @@ -25,6 +25,7 @@ + diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs index 0577fb320e..149504ba0e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs @@ -51,7 +51,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -81,7 +84,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -132,7 +137,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests // allow DisposeAsync to continue once we know we are past the connection state check allowDisposeTcs.SetResult(null); await releaseNegotiateTcs.Task; - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -174,7 +181,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -199,7 +208,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -230,7 +241,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); var mockTransport = new Mock(); @@ -267,7 +280,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(HttpStatusCode.OK); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -294,11 +309,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - if (request.Method == HttpMethod.Get) - { - return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) }; - } - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + + return request.Method == HttpMethod.Get + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -328,7 +344,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); var mockTransport = new Mock(); @@ -370,7 +388,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); var mockTransport = new Mock(); @@ -438,7 +458,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(HttpStatusCode.OK); + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -477,7 +500,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests { sendTcs.SetResult(await request.Content.ReadAsByteArrayAsync()); } - return ResponseUtils.CreateResponse(HttpStatusCode.OK); + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -523,7 +549,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests { content = "T2:T:42;"; } - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(content) }; + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -549,11 +578,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - if (request.Method == HttpMethod.Post) - { - return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); - } - return ResponseUtils.CreateResponse(HttpStatusCode.OK); + + return request.Method == HttpMethod.Post + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -585,7 +615,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests { content = "42"; } - return ResponseUtils.CreateResponse(HttpStatusCode.OK, content); + + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK, content); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -627,11 +660,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - if (request.Method == HttpMethod.Get) - { - return new HttpResponseMessage(HttpStatusCode.InternalServerError) { Content = new StringContent(string.Empty) }; - } - return new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(string.Empty) }; + + return request.Method == HttpMethod.Get + ? ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError) + : request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -658,5 +692,100 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } } } + + [Theory] + [InlineData("")] + [InlineData("Not Json")] + public async Task StartThrowsFormatExceptionIfNegotiationResponseIsInvalid(string negotiatePayload) + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, negotiatePayload); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + var exception = await Assert.ThrowsAsync( + () => connection.StartAsync(TransportType.LongPolling, httpClient)); + + Assert.Equal("Invalid negotiation response received.", exception.Message); + } + } + + [Fact] + public async Task StartThrowsFormatExceptionIfNegotiationResponseHasNoConnectionId() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + ResponseUtils.CreateNegotiationResponse(connectionId: null)); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + var exception = await Assert.ThrowsAsync( + () => connection.StartAsync(TransportType.LongPolling, httpClient)); + + Assert.Equal("Invalid connection id returned in negotiation response.", exception.Message); + } + } + + [Fact] + public async Task StartThrowsFormatExceptionIfNegotiationResponseHasNoTransports() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + ResponseUtils.CreateNegotiationResponse(transportTypes: null)); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + var exception = await Assert.ThrowsAsync( + () => connection.StartAsync(TransportType.LongPolling, httpClient)); + + Assert.Equal("No transports returned in negotiation response.", exception.Message); + } + } + + [Theory] + [InlineData((TransportType)0)] + [InlineData(TransportType.ServerSentEvents)] + public async Task ConnectionCannotBeStartedIfNoCommonTransportsBetweenClientAndServer(TransportType serverTransports) + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + ResponseUtils.CreateNegotiationResponse(transportTypes: serverTransports)); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var connection = new Connection(new Uri("http://fakeuri.org/")); + var exception = await Assert.ThrowsAsync( + () => connection.StartAsync(TransportType.LongPolling, httpClient)); + + Assert.Equal("No requested transports available on the server.", exception.Message); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 02889ebf03..5b7d718872 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -47,7 +47,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -78,7 +80,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -139,7 +143,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) @@ -170,7 +176,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Returns(async (request, cancellationToken) => { await Task.Yield(); - return ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(System.Net.HttpStatusCode.OK); }); using (var httpClient = new HttpClient(mockHttpHandler.Object)) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs index 06b3f1f87c..6651cf24bc 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs @@ -3,7 +3,9 @@ using System.Net; using System.Net.Http; -using System.Net.Http.Headers; +using System.Text; + +using SocketsTransportType = Microsoft.AspNetCore.Sockets.TransportType; namespace Microsoft.AspNetCore.Client.Tests { @@ -25,5 +27,37 @@ namespace Microsoft.AspNetCore.Client.Tests Content = payload }; } + + public static string CreateNegotiationResponse(string connectionId = "00000000-0000-0000-0000-000000000000", + SocketsTransportType? transportTypes = SocketsTransportType.All) + { + var sb = new StringBuilder("{ "); + if (connectionId != null) + { + sb.Append($"\"connectionId\": \"{connectionId}\","); + } + if (transportTypes != null) + { + sb.Append($"\"availableTransports\": [ "); + if ((transportTypes & SocketsTransportType.WebSockets) == SocketsTransportType.WebSockets) + { + sb.Append($"\"{nameof(SocketsTransportType.WebSockets)}\","); + } + if ((transportTypes & SocketsTransportType.ServerSentEvents) == SocketsTransportType.ServerSentEvents) + { + sb.Append($"\"{nameof(SocketsTransportType.ServerSentEvents)}\","); + } + if ((transportTypes & SocketsTransportType.LongPolling) == SocketsTransportType.LongPolling) + { + sb.Append($"\"{nameof(SocketsTransportType.LongPolling)}\","); + } + sb.Length--; + sb.Append("],"); + } + sb.Length--; + sb.Append("}"); + + return sb.ToString(); + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/Formatters/ServerSentEventsParserTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/Formatters/ServerSentEventsParserTests.cs index 7dcc88cade..17c5dc9181 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/Formatters/ServerSentEventsParserTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/Formatters/ServerSentEventsParserTests.cs @@ -15,7 +15,6 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters { [Theory] [InlineData("\r\n", "")] - [InlineData("\r\n", "")] [InlineData("\r\n:\r\n", "")] [InlineData("\r\n:comment\r\n", "")] [InlineData("data: \r\r\n\r\n", "\r")] @@ -99,7 +98,6 @@ namespace Microsoft.AspNetCore.Sockets.Common.Tests.Internal.Formatters [InlineData(new[] { "dat", "a: Hello, World\r\n\r\n" }, "Hello, World")] [InlineData(new[] { "data", ": Hello, World\r\n\r\n" }, "Hello, World")] [InlineData(new[] { "data:", " Hello, World\r\n\r\n" }, "Hello, World")] - [InlineData(new[] { "data: ", "Hello, World\r\n\r\n" }, "Hello, World")] [InlineData(new[] { "data: Hello, World", "\r\n\r\n" }, "Hello, World")] [InlineData(new[] { "data: Hello, World\r\n", "\r\n" }, "Hello, World")] [InlineData(new[] { "data: ", "Hello, World\r\n\r\n" }, "Hello, World")] diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index 056db76b9c..a039ede684 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -18,6 +18,8 @@ using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Primitives; +using Newtonsoft.Json; +using Newtonsoft.Json.Linq; using Xunit; namespace Microsoft.AspNetCore.Sockets.Tests @@ -45,11 +47,41 @@ namespace Microsoft.AspNetCore.Sockets.Tests builder.UseEndPoint(); var app = builder.Build(); await dispatcher.ExecuteAsync(context, new HttpSocketOptions(), app); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + var connectionId = negotiateResponse.Value("connectionId"); + Assert.True(manager.TryGetConnection(connectionId, out var connectionContext)); + Assert.Equal(connectionId, connectionContext.ConnectionId); + } - var id = Encoding.UTF8.GetString(ms.ToArray()); + [Theory] + [InlineData(TransportType.All)] + [InlineData((TransportType)0)] + [InlineData(TransportType.LongPolling | TransportType.WebSockets)] + public async Task NegotiateReturnsAvailableTransports(TransportType transports) + { + var manager = CreateConnectionManager(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddEndPoint(); + services.AddOptions(); + var ms = new MemoryStream(); + context.Request.Path = "/foo"; + context.Request.Method = "OPTIONS"; + context.Response.Body = ms; + var builder = new SocketBuilder(services.BuildServiceProvider()); + builder.UseEndPoint(); + var app = builder.Build(); + await dispatcher.ExecuteAsync(context, new HttpSocketOptions { Transports = transports }, app); - Assert.True(manager.TryGetConnection(id, out var connection)); - Assert.Equal(id, connection.ConnectionId); + var negotiateResponse = JsonConvert.DeserializeObject(Encoding.UTF8.GetString(ms.ToArray())); + var availableTransports = (TransportType)0; + foreach (var transport in negotiateResponse["availableTransports"]) + { + availableTransports |= (TransportType)Enum.Parse(typeof(TransportType), transport.Value()); + } + + Assert.Equal(transports, availableTransports); } [Theory] diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj index 41d1342cb5..e17e197247 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj @@ -20,6 +20,7 @@ +