From 903fe1e902829af6b635c90d998fc4540656e06d Mon Sep 17 00:00:00 2001 From: David Fowler Date: Wed, 18 Apr 2018 14:22:45 -0700 Subject: [PATCH] Added support for negotiate response to redirect the client to another SignalR endpoint (#2070) - Add a SkipNegotiation flag to the .NET and ts client to allow skipping the negotiation phase. Don't infer it based on the transport type. - Updated the negotiate protocol to support returning a redirect url - Added support to .NET client to handle redirect negotiations - Handle poorly written endpoints that sends infinite redirects - Added access token support and an infinite redirect guard - Add delete handler for stopping the transport --- .../ts/signalr/spec/HttpConnection.spec.ts | 147 ++++++++++++++- clients/ts/signalr/spec/TestHttpClient.ts | 5 +- clients/ts/signalr/src/HttpConnection.ts | 112 ++++++++---- specs/TransportProtocols.md | 63 ++++--- .../HttpConnection.cs | 79 ++++++-- .../HttpConnectionOptions.cs | 3 + .../Internal/AccessTokenHttpMessageHandler.cs | 14 +- .../NegotiateProtocol.cs | 49 ++++- .../NegotiationResponse.cs | 2 + .../NegotiateProtocolTests.cs | 21 ++- .../HttpConnectionTests.Negotiate.cs | 170 +++++++++++++++++- .../TestHttpMessageHandler.cs | 12 +- .../EndToEndTests.cs | 95 ++++++++-- 13 files changed, 655 insertions(+), 117 deletions(-) diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts index 69347332d5..6c45e76213 100644 --- a/clients/ts/signalr/spec/HttpConnection.spec.ts +++ b/clients/ts/signalr/spec/HttpConnection.spec.ts @@ -267,10 +267,11 @@ describe("HttpConnection", () => { } }); - it("does not send negotiate request if WebSockets transport requested explicitly", async (done) => { + it("does not send negotiate request if WebSockets transport requested explicitly and skipNegotiation is true", async (done) => { const options: IHttpConnectionOptions = { ...commonOptions, httpClient: new TestHttpClient(), + skipNegotiation: true, transport: HttpTransportType.WebSockets, } as IHttpConnectionOptions; @@ -287,6 +288,150 @@ describe("HttpConnection", () => { } }); + it("does not start non WebSockets transport requested explicitly and skipNegotiation is true", async (done) => { + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient(), + skipNegotiation: true, + transport: HttpTransportType.LongPolling, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + try { + await connection.start(TransferFormat.Text); + fail(); + done(); + } catch (e) { + // WebSocket is created when the transport is connecting which happens after + // negotiate request would be sent. No better/easier way to test this. + expect(e.message).toBe("Negotiation can only be skipped when using the WebSocket transport directly."); + done(); + } + }); + + it("redirects to url when negotiate returns it", async (done) => { + let firstNegotiate = true; + let firstPoll = true; + const httpClient = new TestHttpClient() + .on("POST", /negotiate$/, (r) => { + if (firstNegotiate) { + firstNegotiate = false; + return { url: "https://another.domain.url/chat" }; + } + return { + availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }], + connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00", + }; + }) + .on("GET", (r) => { + if (firstPoll) { + firstPoll = false; + return ""; + } + return new HttpResponse(204, "No Content", ""); + }); + + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient, + transport: HttpTransportType.LongPolling, + } as IHttpConnectionOptions; + + try { + const connection = new HttpConnection("http://tempuri.org", options); + await connection.start(TransferFormat.Text); + } catch (e) { + fail(e); + done(); + } + + expect(httpClient.sentRequests.length).toBe(4); + expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate"); + expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate"); + expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); + expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); + done(); + }); + + it("fails to start if negotiate redirects more than 100 times", async (done) => { + const httpClient = new TestHttpClient() + .on("POST", /negotiate$/, (r) => ({ url: "https://another.domain.url/chat" })); + + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient, + transport: HttpTransportType.LongPolling, + } as IHttpConnectionOptions; + + try { + const connection = new HttpConnection("http://tempuri.org", options); + await connection.start(TransferFormat.Text); + fail(); + } catch (e) { + expect(e.message).toBe("Negotiate redirection limit exceeded."); + done(); + } + }); + + it("redirects to url when negotiate returns it with access token", async (done) => { + let firstNegotiate = true; + let firstPoll = true; + const httpClient = new TestHttpClient() + .on("POST", /negotiate$/, (r) => { + if (firstNegotiate) { + firstNegotiate = false; + + if (r.headers && r.headers.Authorization !== "Bearer firstSecret") { + return new HttpResponse(401, "Unauthorized", ""); + } + + return { url: "https://another.domain.url/chat", accessToken: "secondSecret" }; + } + + if (r.headers && r.headers.Authorization !== "Bearer secondSecret") { + return new HttpResponse(401, "Unauthorized", ""); + } + + return { + availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }], + connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00", + }; + }) + .on("GET", (r) => { + if (r.headers && r.headers.Authorization !== "Bearer secondSecret") { + return new HttpResponse(401, "Unauthorized", ""); + } + + if (firstPoll) { + firstPoll = false; + return ""; + } + return new HttpResponse(204, "No Content", ""); + }); + + const options: IHttpConnectionOptions = { + ...commonOptions, + accessTokenFactory: () => "firstSecret", + httpClient, + transport: HttpTransportType.LongPolling, + } as IHttpConnectionOptions; + + try { + const connection = new HttpConnection("http://tempuri.org", options); + await connection.start(TransferFormat.Text); + } catch (e) { + fail(e); + done(); + } + + expect(httpClient.sentRequests.length).toBe(4); + expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate"); + expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate"); + expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); + expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i); + done(); + }); + it("authorization header removed when token factory returns null and using LongPolling", async (done) => { const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; diff --git a/clients/ts/signalr/spec/TestHttpClient.ts b/clients/ts/signalr/spec/TestHttpClient.ts index 90194f2efc..b591c1bebd 100644 --- a/clients/ts/signalr/spec/TestHttpClient.ts +++ b/clients/ts/signalr/spec/TestHttpClient.ts @@ -8,15 +8,18 @@ export type TestHttpHandler = (request: HttpRequest, next?: (request: HttpReques export class TestHttpClient extends HttpClient { private handler: (request: HttpRequest) => Promise; + public sentRequests: HttpRequest[]; constructor() { super(); + this.sentRequests = []; this.handler = (request: HttpRequest) => Promise.reject(`Request has no handler: ${request.method} ${request.url}`); } public send(request: HttpRequest): Promise { + this.sentRequests.push(request); return this.handler(request); } @@ -59,7 +62,7 @@ export class TestHttpClient extends HttpClient { if (typeof val === "string") { // string payload return new HttpResponse(200, "OK", val); - } else if(typeof val === "object" && val.statusCode) { + } else if (typeof val === "object" && val.statusCode) { // HttpResponse payload return val as HttpResponse; } else { diff --git a/clients/ts/signalr/src/HttpConnection.ts b/clients/ts/signalr/src/HttpConnection.ts index 5ef88647cd..3a9e905d4a 100644 --- a/clients/ts/signalr/src/HttpConnection.ts +++ b/clients/ts/signalr/src/HttpConnection.ts @@ -16,6 +16,7 @@ export interface IHttpConnectionOptions { logger?: ILogger | LogLevel; accessTokenFactory?: () => string | Promise; logMessageContent?: boolean; + skipNegotiation?: boolean; } const enum ConnectionState { @@ -27,6 +28,8 @@ const enum ConnectionState { interface INegotiateResponse { connectionId: string; availableTransports: IAvailableTransport[]; + url: string; + accessToken: string; } interface IAvailableTransport { @@ -34,17 +37,18 @@ interface IAvailableTransport { transferFormats: Array; } +const MAX_REDIRECTS = 100; + export class HttpConnection implements IConnection { private connectionState: ConnectionState; private baseUrl: string; - private url: string; private readonly httpClient: HttpClient; private readonly logger: ILogger; private readonly options: IHttpConnectionOptions; private transport: ITransport; - private connectionId: string; private startPromise: Promise; private stopError?: Error; + private accessTokenFactory?: () => string | Promise; public readonly features: any = {}; public onreceive: (data: string | ArrayBuffer) => void; @@ -110,29 +114,53 @@ export class HttpConnection implements IConnection { } private async startInternal(transferFormat: TransferFormat): Promise { + // Store the original base url and the access token factory since they may change + // as part of negotiating + let url = this.baseUrl; + this.accessTokenFactory = this.options.accessTokenFactory; + try { - if (this.options.transport === HttpTransportType.WebSockets) { - // No need to add a connection ID in this case - this.url = this.baseUrl; - this.transport = this.constructTransport(HttpTransportType.WebSockets); - // We should just call connect directly in this case. - // No fallback or negotiate in this case. - await this.transport.connect(this.url, transferFormat); + if (this.options.skipNegotiation) { + if (this.options.transport === HttpTransportType.WebSockets) { + // No need to add a connection ID in this case + this.transport = this.constructTransport(HttpTransportType.WebSockets); + // We should just call connect directly in this case. + // No fallback or negotiate in this case. + await this.transport.connect(url, transferFormat); + } else { + throw Error("Negotiation can only be skipped when using the WebSocket transport directly."); + } } else { - const token = await this.options.accessTokenFactory(); - let headers; - if (token) { - headers = { - ["Authorization"]: `Bearer ${token}`, - }; + let negotiateResponse: INegotiateResponse = null; + let redirects = 0; + + do { + negotiateResponse = await this.getNegotiationResponse(url); + // the user tries to stop the connection when it is being started + if (this.connectionState === ConnectionState.Disconnected) { + return; + } + + if (negotiateResponse.url) { + url = negotiateResponse.url; + } + + if (negotiateResponse.accessToken) { + // Replace the current access token factory with one that uses + // the returned access token + const accessToken = negotiateResponse.accessToken; + this.accessTokenFactory = () => accessToken; + } + + redirects++; + } + while (negotiateResponse.url && redirects < MAX_REDIRECTS); + + if (redirects === MAX_REDIRECTS && negotiateResponse.url) { + throw Error("Negotiate redirection limit exceeded."); } - const negotiateResponse = await this.getNegotiationResponse(headers); - // the user tries to stop the the connection when it is being started - if (this.connectionState === ConnectionState.Disconnected) { - return; - } - await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers); + await this.createTransport(url, this.options.transport, negotiateResponse, transferFormat); } if (this.transport instanceof LongPollingTransport) { @@ -153,32 +181,44 @@ export class HttpConnection implements IConnection { } } - private async getNegotiationResponse(headers: any): Promise { - const negotiateUrl = this.resolveNegotiateUrl(this.baseUrl); + private async getNegotiationResponse(url: string): Promise { + const token = await this.accessTokenFactory(); + let headers; + if (token) { + headers = { + ["Authorization"]: `Bearer ${token}`, + }; + } + + const negotiateUrl = this.resolveNegotiateUrl(url); this.logger.log(LogLevel.Debug, `Sending negotiation request: ${negotiateUrl}`); try { const response = await this.httpClient.post(negotiateUrl, { content: "", headers, }); - return JSON.parse(response.content as string); + + if (response.statusCode !== 200) { + throw Error(`Unexpected status code returned from negotiate ${response.statusCode}`); + } + + return JSON.parse(response.content as string) as INegotiateResponse; } catch (e) { this.logger.log(LogLevel.Error, "Failed to complete negotiation with the server: " + e); throw e; } } - private updateConnectionId(negotiateResponse: INegotiateResponse) { - this.connectionId = negotiateResponse.connectionId; - this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`; + private createConnectUrl(url: string, connectionId: string) { + return url + (url.indexOf("?") === -1 ? "?" : "&") + `id=${connectionId}`; } - private async createTransport(requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise { - this.updateConnectionId(negotiateResponse); + private async createTransport(url: string, requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat): Promise { + let connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId); if (this.isITransport(requestedTransport)) { this.logger.log(LogLevel.Debug, "Connection was provided an instance of ITransport, using that directly."); this.transport = requestedTransport; - await this.transport.connect(this.url, requestedTransferFormat); + await this.transport.connect(connectUrl, requestedTransferFormat); // only change the state if we were connecting to not overwrite // the state if the connection is already marked as Disconnected @@ -193,11 +233,11 @@ export class HttpConnection implements IConnection { if (typeof transport === "number") { this.transport = this.constructTransport(transport); if (negotiateResponse.connectionId === null) { - negotiateResponse = await this.getNegotiationResponse(headers); - this.updateConnectionId(negotiateResponse); + negotiateResponse = await this.getNegotiationResponse(url); + connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId); } try { - await this.transport.connect(this.url, requestedTransferFormat); + await this.transport.connect(connectUrl, requestedTransferFormat); this.changeState(ConnectionState.Connecting, ConnectionState.Connected); return; } catch (ex) { @@ -214,11 +254,11 @@ export class HttpConnection implements IConnection { private constructTransport(transport: HttpTransportType) { switch (transport) { case HttpTransportType.WebSockets: - return new WebSocketTransport(this.options.accessTokenFactory, this.logger, this.options.logMessageContent); + return new WebSocketTransport(this.accessTokenFactory, this.logger, this.options.logMessageContent); case HttpTransportType.ServerSentEvents: - return new ServerSentEventsTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent); + return new ServerSentEventsTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent); case HttpTransportType.LongPolling: - return new LongPollingTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent); + return new LongPollingTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent); default: throw new Error(`Unknown transport: ${transport}.`); } diff --git a/specs/TransportProtocols.md b/specs/TransportProtocols.md index 91e1f81a2a..3d90ceb6cd 100644 --- a/specs/TransportProtocols.md +++ b/specs/TransportProtocols.md @@ -18,32 +18,49 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou ## `POST [endpoint-base]/negotiate` request -The `POST [endpoint-base]/negotiate` request is used to establish connection between the client and the server. The response to the `POST [endpoint-base]/negotiate` request contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server. The content type of the response is `application/json`. The following is a sample response to the `POST [endpoint-base]/negotiate` request +The `POST [endpoint-base]/negotiate` request is used to establish a connection between the client and the server. The content type of the response is `application/json`. The response to the `POST [endpoint-base]/negotiate` request contains one of two types of responses: -``` -{ - "connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d", - "availableTransports":[ - { - "transport": "WebSockets", - "transferFormats": [ "Text", "Binary" ] - }, - { - "transport": "ServerSentEvents", - "transferFormats": [ "Text" ] - }, - { - "transport": "LongPolling", - "transferFormats": [ "Text", "Binary" ] - } - ] -} -``` +1. A response that contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server. -The payload returned from this endpoint provides the following data: + ``` + { + "connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d", + "availableTransports":[ + { + "transport": "WebSockets", + "transferFormats": [ "Text", "Binary" ] + }, + { + "transport": "ServerSentEvents", + "transferFormats": [ "Text" ] + }, + { + "transport": "LongPolling", + "transferFormats": [ "Text", "Binary" ] + } + ] + } + ``` -* The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives). -* The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`) + The payload returned from this endpoint provides the following data: + + * The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives). + * The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`) + + +2. A redirect response which tells the client which URL and optionally access token to use as a result. + + ``` + { + "url": "https://myapp.com/chat", + "accessToken": "accessToken" + } + ``` + + The payload returned from this endpoint provides the following data: + + * The `url` which is the URL the client should connect to. + * The `accessToken` which is an optional bearer token for accessing the specified url. ## Transfer Formats diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs index f2e827979a..7615ce7327 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnection.cs @@ -20,6 +20,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Client { public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature { + // Not configurable on purpose, high enough that if we reach here, it's likely + // a buggy server + private static readonly int _maxRedirects = 100; + private static readonly Task _noAccessToken = Task.FromResult(null); + private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120); #if !NETCOREAPP2_1 private static readonly Version Windows8Version = new Version(6, 2); @@ -40,6 +45,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private readonly ConnectionLogScope _logScope; private readonly IDisposable _scopeDisposable; private readonly ILoggerFactory _loggerFactory; + private Func> _accessTokenProvider; public override IDuplexPipe Transport { @@ -96,7 +102,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client _logger = _loggerFactory.CreateLogger(); _httpConnectionOptions = httpConnectionOptions; - if (httpConnectionOptions.Transports != HttpTransportType.WebSockets) + if (!httpConnectionOptions.SkipNegotiation || httpConnectionOptions.Transports != HttpTransportType.WebSockets) { _httpClient = CreateHttpClient(); } @@ -210,17 +216,54 @@ namespace Microsoft.AspNetCore.Http.Connections.Client private async Task SelectAndStartTransport(TransferFormat transferFormat) { - if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) + var uri = _httpConnectionOptions.Url; + // Set the initial access token provider back to the original one from options + _accessTokenProvider = _httpConnectionOptions.AccessTokenProvider; + + if (_httpConnectionOptions.SkipNegotiation) { - Log.StartingTransport(_logger, _httpConnectionOptions.Transports, _httpConnectionOptions.Url); - await StartTransport(_httpConnectionOptions.Url, _httpConnectionOptions.Transports, transferFormat); + if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets) + { + Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri); + await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat); + } + else + { + throw new InvalidOperationException("Negotiation can only be skipped when using the WebSocket transport directly."); + } } else { - var negotiationResponse = await GetNegotiationResponse(); + NegotiationResponse negotiationResponse; + var redirects = 0; + + do + { + negotiationResponse = await GetNegotiationResponseAsync(uri); + + if (negotiationResponse.Url != null) + { + uri = new Uri(negotiationResponse.Url); + } + + if (negotiationResponse.AccessToken != null) + { + string accessToken = negotiationResponse.AccessToken; + // Set the current access token factory so that future requests use this access token + _accessTokenProvider = () => Task.FromResult(accessToken); + } + + redirects++; + } + while (negotiationResponse.Url != null && redirects < _maxRedirects); + + if (redirects == _maxRedirects && negotiationResponse.Url != null) + { + throw new InvalidOperationException("Negotiate redirection limit exceeded."); + } // This should only need to happen once - var connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId); + var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); // We're going to search for the transfer format as a string because we don't want to parse // all the transfer formats in the negotiation response, and we want to allow transfer formats @@ -256,8 +299,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client // The negotiation response gets cleared in the fallback scenario. if (negotiationResponse == null) { - negotiationResponse = await GetNegotiationResponse(); - connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId); + negotiationResponse = await GetNegotiationResponseAsync(uri); + connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId); } Log.StartingTransport(_logger, transportType, connectUrl); @@ -281,7 +324,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client } } - private async Task Negotiate(Uri url, HttpClient httpClient, ILogger logger) + private async Task NegotiateAsync(Uri url, HttpClient httpClient, ILogger logger) { try { @@ -399,10 +442,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client } // Apply the authorization header in a handler instead of a default header because it can change with each request - if (_httpConnectionOptions.AccessTokenProvider != null) - { - httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpConnectionOptions.AccessTokenProvider); - } + httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, this); } // Wrap message handler after HttpMessageHandlerFactory to ensure not overriden @@ -430,6 +470,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client return httpClient; } + internal Task GetAccessTokenAsync() + { + if (_accessTokenProvider == null) + { + return _noAccessToken; + } + return _accessTokenProvider(); + } + private void CheckDisposed() { if (_disposed) @@ -458,9 +507,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client #endif } - private async Task GetNegotiationResponse() + private async Task GetNegotiationResponseAsync(Uri uri) { - var negotiationResponse = await Negotiate(_httpConnectionOptions.Url, _httpClient, _logger); + var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger); _connectionId = negotiationResponse.ConnectionId; _logScope.ConnectionId = _connectionId; return negotiationResponse; diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnectionOptions.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnectionOptions.cs index 9bf7db2fb6..5e30dd4594 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnectionOptions.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/HttpConnectionOptions.cs @@ -52,6 +52,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client public Uri Url { get; set; } public HttpTransportType Transports { get; set; } + + public bool SkipNegotiation { get; set; } + public Func> AccessTokenProvider { get; set; } public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); public ICredentials Credentials { get; set; } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs index b2c4eea071..327a81a5d7 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/AccessTokenHttpMessageHandler.cs @@ -11,17 +11,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { internal class AccessTokenHttpMessageHandler : DelegatingHandler { - private readonly Func> _accessTokenProvider; + private readonly HttpConnection _httpConnection; - public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func> accessTokenProvider) : base(inner) + public AccessTokenHttpMessageHandler(HttpMessageHandler inner, HttpConnection httpConnection) : base(inner) { - _accessTokenProvider = accessTokenProvider ?? throw new ArgumentNullException(nameof(accessTokenProvider)); + _httpConnection = httpConnection; } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) { - var accessToken = await _accessTokenProvider(); - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); + var accessToken = await _httpConnection.GetAccessTokenAsync(); + + if (!string.IsNullOrEmpty(accessToken)) + { + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken); + } return await base.SendAsync(request, cancellationToken); } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiateProtocol.cs b/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiateProtocol.cs index f1cd03ac62..b23dcb4be6 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiateProtocol.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiateProtocol.cs @@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Http.Connections public static class NegotiateProtocol { private const string ConnectionIdPropertyName = "connectionId"; + private const string UrlPropertyName = "url"; + private const string AccessTokenPropertyName = "accessToken"; private const string AvailableTransportsPropertyName = "availableTransports"; private const string TransportPropertyName = "transport"; private const string TransferFormatsPropertyName = "transferFormats"; @@ -25,8 +27,25 @@ namespace Microsoft.AspNetCore.Http.Connections using (var jsonWriter = JsonUtils.CreateJsonTextWriter(textWriter)) { jsonWriter.WriteStartObject(); - jsonWriter.WritePropertyName(ConnectionIdPropertyName); - jsonWriter.WriteValue(response.ConnectionId); + + if (!string.IsNullOrEmpty(response.Url)) + { + jsonWriter.WritePropertyName(UrlPropertyName); + jsonWriter.WriteValue(response.Url); + } + + if (!string.IsNullOrEmpty(response.AccessToken)) + { + jsonWriter.WritePropertyName(AccessTokenPropertyName); + jsonWriter.WriteValue(response.AccessToken); + } + + if (!string.IsNullOrEmpty(response.ConnectionId)) + { + jsonWriter.WritePropertyName(ConnectionIdPropertyName); + jsonWriter.WriteValue(response.ConnectionId); + } + jsonWriter.WritePropertyName(AvailableTransportsPropertyName); jsonWriter.WriteStartArray(); @@ -69,6 +88,8 @@ namespace Microsoft.AspNetCore.Http.Connections JsonUtils.EnsureObjectStart(reader); string connectionId = null; + string url = null; + string accessToken = null; List availableTransports = null; var completed = false; @@ -81,6 +102,12 @@ namespace Microsoft.AspNetCore.Http.Connections switch (memberName) { + case UrlPropertyName: + url = JsonUtils.ReadAsString(reader, UrlPropertyName); + break; + case AccessTokenPropertyName: + accessToken = JsonUtils.ReadAsString(reader, AccessTokenPropertyName); + break; case ConnectionIdPropertyName: connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName); break; @@ -114,19 +141,25 @@ namespace Microsoft.AspNetCore.Http.Connections } } - if (connectionId == null) + if (url == null) { - throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); - } + // if url isn't specified, connectionId and available transports are required + if (connectionId == null) + { + throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'."); + } - if (availableTransports == null) - { - throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); + if (availableTransports == null) + { + throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'."); + } } return new NegotiationResponse { ConnectionId = connectionId, + Url = url, + AccessToken = accessToken, AvailableTransports = availableTransports }; } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiationResponse.cs b/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiationResponse.cs index 70a930c1a2..1d87b5e19f 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiationResponse.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Common/NegotiationResponse.cs @@ -7,6 +7,8 @@ namespace Microsoft.AspNetCore.Http.Connections { public class NegotiationResponse { + public string Url { get; set; } + public string AccessToken { get; set; } public string ConnectionId { get; set; } public IList AvailableTransports { get; set; } } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/NegotiateProtocolTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/NegotiateProtocolTests.cs index c52783fa65..85bbd0bf47 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/NegotiateProtocolTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/NegotiateProtocolTests.cs @@ -9,21 +9,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class NegotiateProtocolTests { [Theory] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0])] - [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0])] - [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new [] { "test"})] - public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports) + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null)] + [InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null)] + [InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null)] + [InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token")] + [InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null)] + public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken) { var responseData = Encoding.UTF8.GetBytes(json); var ms = new MemoryStream(responseData); var response = NegotiateProtocol.ParseResponse(ms); Assert.Equal(connectionId, response.ConnectionId); - Assert.Equal(availableTransports.Length, response.AvailableTransports.Count); + Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count); + Assert.Equal(url, response.Url); + Assert.Equal(accessToken, response.AccessToken); - var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList(); + if (response.AvailableTransports != null) + { + var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList(); - Assert.Equal(availableTransports, responseTransports); + Assert.Equal(availableTransports, responseTransports); + } } [Theory] diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs index 68753a526d..635aca510f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Negotiate.cs @@ -4,6 +4,7 @@ using System; using System.IO; using System.Net; +using System.Net.Http; using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Connections; @@ -70,7 +71,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); using (var noErrorScope = new VerifyNoErrorsScope()) - { + { await WithConnectionAsync( CreateConnection(testHttpHandler, url: requestedUrl, loggerFactory: noErrorScope.LoggerFactory), async (connection) => @@ -82,6 +83,173 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Equal(expectedNegotiate, await negotiateUrlTcs.Task.OrTimeout()); } + [Fact] + public async Task NegotiateThatReturnsUrlGetFollowed() + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + var firstNegotiate = true; + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + if (firstNegotiate) + { + firstNegotiate = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + url = "https://another.domain.url/chat" + })); + } + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + } + })); + }); + + testHttpHandler.OnLongPoll((token) => + { + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); + + return tcs.Task; + }); + + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + }); + } + + Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); + } + + [Fact] + public async Task NegotiateThatReturnsRedirectUrlForeverThrowsAfter100Tries() + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + url = "https://another.domain.url/chat" + })); + }); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory), + async (connection) => + { + var exception = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Text).OrTimeout()); + Assert.Equal("Negotiate redirection limit exceeded.", exception.Message); + }); + } + } + + [Fact] + public async Task NegotiateThatReturnsUrlGetFollowedWithAccessToken() + { + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); + var firstNegotiate = true; + testHttpHandler.OnNegotiate((request, cancellationToken) => + { + if (firstNegotiate) + { + firstNegotiate = false; + + // The first negotiate requires an access token + if (request.Headers.Authorization?.Parameter != "firstSecret") + { + return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized); + } + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + url = "https://another.domain.url/chat", + accessToken = "secondSecret" + })); + } + + // All other requests require an access token + if (request.Headers.Authorization?.Parameter != "secondSecret") + { + return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized); + } + + return ResponseUtils.CreateResponse(HttpStatusCode.OK, + JsonConvert.SerializeObject(new + { + connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00", + availableTransports = new object[] + { + new + { + transport = "LongPolling", + transferFormats = new[] { "Text" } + }, + } + })); + }); + + testHttpHandler.OnLongPoll((request, token) => + { + // All other requests require an access token + if (request.Headers.Authorization?.Parameter != "secondSecret") + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized)); + } + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + + token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); + + return tcs.Task; + }); + + testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted)); + + Task AccessTokenProvider() => Task.FromResult("firstSecret"); + + using (var noErrorScope = new VerifyNoErrorsScope()) + { + await WithConnectionAsync( + CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory, accessTokenProvider: AccessTokenProvider), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + }); + } + + Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString()); + Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString()); + // Delete request + Assert.Equal(5, testHttpHandler.ReceivedRequests.Count); + } + [Fact] public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand() { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs index 6bde5d0af2..93b234c43b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -185,12 +185,22 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public void OnLongPoll(Func handler) => OnLongPoll(cancellationToken => Task.FromResult(handler(cancellationToken))); public void OnLongPoll(Func> handler) + { + OnLongPoll((request, token) => handler(token)); + } + + public void OnLongPoll(Func handler) + { + OnLongPoll((request, token) => Task.FromResult(handler(request, token))); + } + + public void OnLongPoll(Func> handler) { OnRequest((request, next, cancellationToken) => { if (ResponseUtils.IsLongPollRequest(request)) { - return handler(cancellationToken); + return handler(request, cancellationToken); } else { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 62b5fedddf..e2f5ff150a 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [ConditionalFact] [WebSocketsSupportedCondition] - public async Task HttpRequestsNotSentWhenWebSocketsTransportRequested() + public async Task HttpRequestsNotSentWhenWebSocketsTransportRequestedAndSkipNegotiationSet() { using (StartVerifableLog(out var loggerFactory)) { @@ -184,6 +184,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { Url = new Uri(url), Transports = HttpTransportType.WebSockets, + SkipNegotiation = true, HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object }; @@ -213,6 +214,49 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Theory] + [InlineData(HttpTransportType.LongPolling)] + [InlineData(HttpTransportType.ServerSentEvents)] + public async Task HttpConnectionThrowsIfSkipNegotiationSetAndTransportIsNotWebSockets(HttpTransportType transportType) + { + using (StartVerifableLog(out var loggerFactory)) + { + var logger = loggerFactory.CreateLogger(); + var url = ServerFixture.Url + "/echo"; + + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns( + (request, cancellationToken) => Task.FromException(new InvalidOperationException("HTTP requests should not be sent."))); + + var httpOptions = new HttpConnectionOptions + { + Url = new Uri(url), + Transports = transportType, + SkipNegotiation = true, + HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object + }; + + var connection = new HttpConnection(httpOptions, loggerFactory); + + try + { + var exception = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Binary).OrTimeout()); + Assert.Equal("Negotiation can only be skipped when using the WebSocket transport directly.", exception.Message); + } + catch (Exception ex) + { + logger.LogInformation(ex, "Test threw exception"); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(TransportTypesAndTransferFormats))] public async Task ConnectionCanSendAndReceiveMessages(HttpTransportType transportType, TransferFormat requestedTransferFormat) @@ -326,7 +370,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests bool ExpectedErrors(WriteContext writeContext) { return writeContext.LoggerName == typeof(HttpConnection).FullName && - writeContext.EventId.Name == "ErrorStartingTransport"; + writeContext.EventId.Name == "ErrorWithNegotiation"; } using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors)) @@ -336,24 +380,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests var url = ServerFixture.Url + "/auth"; var connection = new HttpConnection(new Uri(url), HttpTransportType.WebSockets, loggerFactory); - try + var exception = await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Binary).OrTimeout()); + + Assert.Contains("401", exception.Message); + } + } + + [ConditionalFact] + [WebSocketsSupportedCondition] + public async Task UnauthorizedDirectWebSocketsConnectionDoesNotConnect() + { + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorStartingTransport"; + } + + using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors)) + { + var logger = loggerFactory.CreateLogger(); + + var url = ServerFixture.Url + "/auth"; + var options = new HttpConnectionOptions { - logger.LogInformation("Starting connection to {url}", url); - await connection.StartAsync(TransferFormat.Binary).OrTimeout(); - Assert.True(false); - } - catch (WebSocketException) { } - catch (Exception ex) - { - logger.LogInformation(ex, "Test threw exception"); - throw; - } - finally - { - logger.LogInformation("Disposing Connection"); - await connection.DisposeAsync().OrTimeout(); - logger.LogInformation("Disposed Connection"); - } + Url = new Uri(url), + Transports = HttpTransportType.WebSockets, + SkipNegotiation = true + }; + + var connection = new HttpConnection(options, loggerFactory); + + await Assert.ThrowsAsync(() => connection.StartAsync(TransferFormat.Binary).OrTimeout()); } }