Added support for negotiate response to redirect the client to another SignalR endpoint (#2070)

- Add a SkipNegotiation flag to the .NET and ts client
to allow skipping the negotiation phase. Don't infer it based on the transport type.
-  Updated the negotiate protocol to support returning a redirect url
- Added support to .NET client to handle redirect negotiations
- Handle poorly written endpoints that sends infinite redirects
- Added access token support and an infinite redirect guard
- Add delete handler for stopping the transport
This commit is contained in:
David Fowler 2018-04-18 14:22:45 -07:00 committed by GitHub
parent d9272032d0
commit 903fe1e902
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 655 additions and 117 deletions

View File

@ -267,10 +267,11 @@ describe("HttpConnection", () => {
}
});
it("does not send negotiate request if WebSockets transport requested explicitly", async (done) => {
it("does not send negotiate request if WebSockets transport requested explicitly and skipNegotiation is true", async (done) => {
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient: new TestHttpClient(),
skipNegotiation: true,
transport: HttpTransportType.WebSockets,
} as IHttpConnectionOptions;
@ -287,6 +288,150 @@ describe("HttpConnection", () => {
}
});
it("does not start non WebSockets transport requested explicitly and skipNegotiation is true", async (done) => {
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient: new TestHttpClient(),
skipNegotiation: true,
transport: HttpTransportType.LongPolling,
} as IHttpConnectionOptions;
const connection = new HttpConnection("http://tempuri.org", options);
try {
await connection.start(TransferFormat.Text);
fail();
done();
} catch (e) {
// WebSocket is created when the transport is connecting which happens after
// negotiate request would be sent. No better/easier way to test this.
expect(e.message).toBe("Negotiation can only be skipped when using the WebSocket transport directly.");
done();
}
});
it("redirects to url when negotiate returns it", async (done) => {
let firstNegotiate = true;
let firstPoll = true;
const httpClient = new TestHttpClient()
.on("POST", /negotiate$/, (r) => {
if (firstNegotiate) {
firstNegotiate = false;
return { url: "https://another.domain.url/chat" };
}
return {
availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }],
connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00",
};
})
.on("GET", (r) => {
if (firstPoll) {
firstPoll = false;
return "";
}
return new HttpResponse(204, "No Content", "");
});
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient,
transport: HttpTransportType.LongPolling,
} as IHttpConnectionOptions;
try {
const connection = new HttpConnection("http://tempuri.org", options);
await connection.start(TransferFormat.Text);
} catch (e) {
fail(e);
done();
}
expect(httpClient.sentRequests.length).toBe(4);
expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate");
expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate");
expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
done();
});
it("fails to start if negotiate redirects more than 100 times", async (done) => {
const httpClient = new TestHttpClient()
.on("POST", /negotiate$/, (r) => ({ url: "https://another.domain.url/chat" }));
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient,
transport: HttpTransportType.LongPolling,
} as IHttpConnectionOptions;
try {
const connection = new HttpConnection("http://tempuri.org", options);
await connection.start(TransferFormat.Text);
fail();
} catch (e) {
expect(e.message).toBe("Negotiate redirection limit exceeded.");
done();
}
});
it("redirects to url when negotiate returns it with access token", async (done) => {
let firstNegotiate = true;
let firstPoll = true;
const httpClient = new TestHttpClient()
.on("POST", /negotiate$/, (r) => {
if (firstNegotiate) {
firstNegotiate = false;
if (r.headers && r.headers.Authorization !== "Bearer firstSecret") {
return new HttpResponse(401, "Unauthorized", "");
}
return { url: "https://another.domain.url/chat", accessToken: "secondSecret" };
}
if (r.headers && r.headers.Authorization !== "Bearer secondSecret") {
return new HttpResponse(401, "Unauthorized", "");
}
return {
availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }],
connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00",
};
})
.on("GET", (r) => {
if (r.headers && r.headers.Authorization !== "Bearer secondSecret") {
return new HttpResponse(401, "Unauthorized", "");
}
if (firstPoll) {
firstPoll = false;
return "";
}
return new HttpResponse(204, "No Content", "");
});
const options: IHttpConnectionOptions = {
...commonOptions,
accessTokenFactory: () => "firstSecret",
httpClient,
transport: HttpTransportType.LongPolling,
} as IHttpConnectionOptions;
try {
const connection = new HttpConnection("http://tempuri.org", options);
await connection.start(TransferFormat.Text);
} catch (e) {
fail(e);
done();
}
expect(httpClient.sentRequests.length).toBe(4);
expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate");
expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate");
expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
done();
});
it("authorization header removed when token factory returns null and using LongPolling", async (done) => {
const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] };

View File

@ -8,15 +8,18 @@ export type TestHttpHandler = (request: HttpRequest, next?: (request: HttpReques
export class TestHttpClient extends HttpClient {
private handler: (request: HttpRequest) => Promise<HttpResponse>;
public sentRequests: HttpRequest[];
constructor() {
super();
this.sentRequests = [];
this.handler = (request: HttpRequest) =>
Promise.reject(`Request has no handler: ${request.method} ${request.url}`);
}
public send(request: HttpRequest): Promise<HttpResponse> {
this.sentRequests.push(request);
return this.handler(request);
}
@ -59,7 +62,7 @@ export class TestHttpClient extends HttpClient {
if (typeof val === "string") {
// string payload
return new HttpResponse(200, "OK", val);
} else if(typeof val === "object" && val.statusCode) {
} else if (typeof val === "object" && val.statusCode) {
// HttpResponse payload
return val as HttpResponse;
} else {

View File

@ -16,6 +16,7 @@ export interface IHttpConnectionOptions {
logger?: ILogger | LogLevel;
accessTokenFactory?: () => string | Promise<string>;
logMessageContent?: boolean;
skipNegotiation?: boolean;
}
const enum ConnectionState {
@ -27,6 +28,8 @@ const enum ConnectionState {
interface INegotiateResponse {
connectionId: string;
availableTransports: IAvailableTransport[];
url: string;
accessToken: string;
}
interface IAvailableTransport {
@ -34,17 +37,18 @@ interface IAvailableTransport {
transferFormats: Array<keyof typeof TransferFormat>;
}
const MAX_REDIRECTS = 100;
export class HttpConnection implements IConnection {
private connectionState: ConnectionState;
private baseUrl: string;
private url: string;
private readonly httpClient: HttpClient;
private readonly logger: ILogger;
private readonly options: IHttpConnectionOptions;
private transport: ITransport;
private connectionId: string;
private startPromise: Promise<void>;
private stopError?: Error;
private accessTokenFactory?: () => string | Promise<string>;
public readonly features: any = {};
public onreceive: (data: string | ArrayBuffer) => void;
@ -110,29 +114,53 @@ export class HttpConnection implements IConnection {
}
private async startInternal(transferFormat: TransferFormat): Promise<void> {
// Store the original base url and the access token factory since they may change
// as part of negotiating
let url = this.baseUrl;
this.accessTokenFactory = this.options.accessTokenFactory;
try {
if (this.options.transport === HttpTransportType.WebSockets) {
// No need to add a connection ID in this case
this.url = this.baseUrl;
this.transport = this.constructTransport(HttpTransportType.WebSockets);
// We should just call connect directly in this case.
// No fallback or negotiate in this case.
await this.transport.connect(this.url, transferFormat);
if (this.options.skipNegotiation) {
if (this.options.transport === HttpTransportType.WebSockets) {
// No need to add a connection ID in this case
this.transport = this.constructTransport(HttpTransportType.WebSockets);
// We should just call connect directly in this case.
// No fallback or negotiate in this case.
await this.transport.connect(url, transferFormat);
} else {
throw Error("Negotiation can only be skipped when using the WebSocket transport directly.");
}
} else {
const token = await this.options.accessTokenFactory();
let headers;
if (token) {
headers = {
["Authorization"]: `Bearer ${token}`,
};
let negotiateResponse: INegotiateResponse = null;
let redirects = 0;
do {
negotiateResponse = await this.getNegotiationResponse(url);
// the user tries to stop the connection when it is being started
if (this.connectionState === ConnectionState.Disconnected) {
return;
}
if (negotiateResponse.url) {
url = negotiateResponse.url;
}
if (negotiateResponse.accessToken) {
// Replace the current access token factory with one that uses
// the returned access token
const accessToken = negotiateResponse.accessToken;
this.accessTokenFactory = () => accessToken;
}
redirects++;
}
while (negotiateResponse.url && redirects < MAX_REDIRECTS);
if (redirects === MAX_REDIRECTS && negotiateResponse.url) {
throw Error("Negotiate redirection limit exceeded.");
}
const negotiateResponse = await this.getNegotiationResponse(headers);
// the user tries to stop the the connection when it is being started
if (this.connectionState === ConnectionState.Disconnected) {
return;
}
await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers);
await this.createTransport(url, this.options.transport, negotiateResponse, transferFormat);
}
if (this.transport instanceof LongPollingTransport) {
@ -153,32 +181,44 @@ export class HttpConnection implements IConnection {
}
}
private async getNegotiationResponse(headers: any): Promise<INegotiateResponse> {
const negotiateUrl = this.resolveNegotiateUrl(this.baseUrl);
private async getNegotiationResponse(url: string): Promise<INegotiateResponse> {
const token = await this.accessTokenFactory();
let headers;
if (token) {
headers = {
["Authorization"]: `Bearer ${token}`,
};
}
const negotiateUrl = this.resolveNegotiateUrl(url);
this.logger.log(LogLevel.Debug, `Sending negotiation request: ${negotiateUrl}`);
try {
const response = await this.httpClient.post(negotiateUrl, {
content: "",
headers,
});
return JSON.parse(response.content as string);
if (response.statusCode !== 200) {
throw Error(`Unexpected status code returned from negotiate ${response.statusCode}`);
}
return JSON.parse(response.content as string) as INegotiateResponse;
} catch (e) {
this.logger.log(LogLevel.Error, "Failed to complete negotiation with the server: " + e);
throw e;
}
}
private updateConnectionId(negotiateResponse: INegotiateResponse) {
this.connectionId = negotiateResponse.connectionId;
this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`;
private createConnectUrl(url: string, connectionId: string) {
return url + (url.indexOf("?") === -1 ? "?" : "&") + `id=${connectionId}`;
}
private async createTransport(requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise<void> {
this.updateConnectionId(negotiateResponse);
private async createTransport(url: string, requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat): Promise<void> {
let connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId);
if (this.isITransport(requestedTransport)) {
this.logger.log(LogLevel.Debug, "Connection was provided an instance of ITransport, using that directly.");
this.transport = requestedTransport;
await this.transport.connect(this.url, requestedTransferFormat);
await this.transport.connect(connectUrl, requestedTransferFormat);
// only change the state if we were connecting to not overwrite
// the state if the connection is already marked as Disconnected
@ -193,11 +233,11 @@ export class HttpConnection implements IConnection {
if (typeof transport === "number") {
this.transport = this.constructTransport(transport);
if (negotiateResponse.connectionId === null) {
negotiateResponse = await this.getNegotiationResponse(headers);
this.updateConnectionId(negotiateResponse);
negotiateResponse = await this.getNegotiationResponse(url);
connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId);
}
try {
await this.transport.connect(this.url, requestedTransferFormat);
await this.transport.connect(connectUrl, requestedTransferFormat);
this.changeState(ConnectionState.Connecting, ConnectionState.Connected);
return;
} catch (ex) {
@ -214,11 +254,11 @@ export class HttpConnection implements IConnection {
private constructTransport(transport: HttpTransportType) {
switch (transport) {
case HttpTransportType.WebSockets:
return new WebSocketTransport(this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
return new WebSocketTransport(this.accessTokenFactory, this.logger, this.options.logMessageContent);
case HttpTransportType.ServerSentEvents:
return new ServerSentEventsTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
return new ServerSentEventsTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent);
case HttpTransportType.LongPolling:
return new LongPollingTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
return new LongPollingTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent);
default:
throw new Error(`Unknown transport: ${transport}.`);
}

View File

@ -18,32 +18,49 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou
## `POST [endpoint-base]/negotiate` request
The `POST [endpoint-base]/negotiate` request is used to establish connection between the client and the server. The response to the `POST [endpoint-base]/negotiate` request contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server. The content type of the response is `application/json`. The following is a sample response to the `POST [endpoint-base]/negotiate` request
The `POST [endpoint-base]/negotiate` request is used to establish a connection between the client and the server. The content type of the response is `application/json`. The response to the `POST [endpoint-base]/negotiate` request contains one of two types of responses:
```
{
"connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d",
"availableTransports":[
{
"transport": "WebSockets",
"transferFormats": [ "Text", "Binary" ]
},
{
"transport": "ServerSentEvents",
"transferFormats": [ "Text" ]
},
{
"transport": "LongPolling",
"transferFormats": [ "Text", "Binary" ]
}
]
}
```
1. A response that contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server.
The payload returned from this endpoint provides the following data:
```
{
"connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d",
"availableTransports":[
{
"transport": "WebSockets",
"transferFormats": [ "Text", "Binary" ]
},
{
"transport": "ServerSentEvents",
"transferFormats": [ "Text" ]
},
{
"transport": "LongPolling",
"transferFormats": [ "Text", "Binary" ]
}
]
}
```
* The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives).
* The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`)
The payload returned from this endpoint provides the following data:
* The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives).
* The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`)
2. A redirect response which tells the client which URL and optionally access token to use as a result.
```
{
"url": "https://myapp.com/chat",
"accessToken": "accessToken"
}
```
The payload returned from this endpoint provides the following data:
* The `url` which is the URL the client should connect to.
* The `accessToken` which is an optional bearer token for accessing the specified url.
## Transfer Formats

View File

@ -20,6 +20,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
{
public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature
{
// Not configurable on purpose, high enough that if we reach here, it's likely
// a buggy server
private static readonly int _maxRedirects = 100;
private static readonly Task<string> _noAccessToken = Task.FromResult<string>(null);
private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120);
#if !NETCOREAPP2_1
private static readonly Version Windows8Version = new Version(6, 2);
@ -40,6 +45,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
private readonly ConnectionLogScope _logScope;
private readonly IDisposable _scopeDisposable;
private readonly ILoggerFactory _loggerFactory;
private Func<Task<string>> _accessTokenProvider;
public override IDuplexPipe Transport
{
@ -96,7 +102,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
_logger = _loggerFactory.CreateLogger<HttpConnection>();
_httpConnectionOptions = httpConnectionOptions;
if (httpConnectionOptions.Transports != HttpTransportType.WebSockets)
if (!httpConnectionOptions.SkipNegotiation || httpConnectionOptions.Transports != HttpTransportType.WebSockets)
{
_httpClient = CreateHttpClient();
}
@ -210,17 +216,54 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
private async Task SelectAndStartTransport(TransferFormat transferFormat)
{
if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets)
var uri = _httpConnectionOptions.Url;
// Set the initial access token provider back to the original one from options
_accessTokenProvider = _httpConnectionOptions.AccessTokenProvider;
if (_httpConnectionOptions.SkipNegotiation)
{
Log.StartingTransport(_logger, _httpConnectionOptions.Transports, _httpConnectionOptions.Url);
await StartTransport(_httpConnectionOptions.Url, _httpConnectionOptions.Transports, transferFormat);
if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets)
{
Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri);
await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat);
}
else
{
throw new InvalidOperationException("Negotiation can only be skipped when using the WebSocket transport directly.");
}
}
else
{
var negotiationResponse = await GetNegotiationResponse();
NegotiationResponse negotiationResponse;
var redirects = 0;
do
{
negotiationResponse = await GetNegotiationResponseAsync(uri);
if (negotiationResponse.Url != null)
{
uri = new Uri(negotiationResponse.Url);
}
if (negotiationResponse.AccessToken != null)
{
string accessToken = negotiationResponse.AccessToken;
// Set the current access token factory so that future requests use this access token
_accessTokenProvider = () => Task.FromResult(accessToken);
}
redirects++;
}
while (negotiationResponse.Url != null && redirects < _maxRedirects);
if (redirects == _maxRedirects && negotiationResponse.Url != null)
{
throw new InvalidOperationException("Negotiate redirection limit exceeded.");
}
// This should only need to happen once
var connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId);
var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId);
// We're going to search for the transfer format as a string because we don't want to parse
// all the transfer formats in the negotiation response, and we want to allow transfer formats
@ -256,8 +299,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
// The negotiation response gets cleared in the fallback scenario.
if (negotiationResponse == null)
{
negotiationResponse = await GetNegotiationResponse();
connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId);
negotiationResponse = await GetNegotiationResponseAsync(uri);
connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId);
}
Log.StartingTransport(_logger, transportType, connectUrl);
@ -281,7 +324,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
}
}
private async Task<NegotiationResponse> Negotiate(Uri url, HttpClient httpClient, ILogger logger)
private async Task<NegotiationResponse> NegotiateAsync(Uri url, HttpClient httpClient, ILogger logger)
{
try
{
@ -399,10 +442,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
}
// Apply the authorization header in a handler instead of a default header because it can change with each request
if (_httpConnectionOptions.AccessTokenProvider != null)
{
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpConnectionOptions.AccessTokenProvider);
}
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, this);
}
// Wrap message handler after HttpMessageHandlerFactory to ensure not overriden
@ -430,6 +470,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
return httpClient;
}
internal Task<string> GetAccessTokenAsync()
{
if (_accessTokenProvider == null)
{
return _noAccessToken;
}
return _accessTokenProvider();
}
private void CheckDisposed()
{
if (_disposed)
@ -458,9 +507,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
#endif
}
private async Task<NegotiationResponse> GetNegotiationResponse()
private async Task<NegotiationResponse> GetNegotiationResponseAsync(Uri uri)
{
var negotiationResponse = await Negotiate(_httpConnectionOptions.Url, _httpClient, _logger);
var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger);
_connectionId = negotiationResponse.ConnectionId;
_logScope.ConnectionId = _connectionId;
return negotiationResponse;

View File

@ -52,6 +52,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
public Uri Url { get; set; }
public HttpTransportType Transports { get; set; }
public bool SkipNegotiation { get; set; }
public Func<Task<string>> AccessTokenProvider { get; set; }
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
public ICredentials Credentials { get; set; }

View File

@ -11,17 +11,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
internal class AccessTokenHttpMessageHandler : DelegatingHandler
{
private readonly Func<Task<string>> _accessTokenProvider;
private readonly HttpConnection _httpConnection;
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<Task<string>> accessTokenProvider) : base(inner)
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, HttpConnection httpConnection) : base(inner)
{
_accessTokenProvider = accessTokenProvider ?? throw new ArgumentNullException(nameof(accessTokenProvider));
_httpConnection = httpConnection;
}
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
var accessToken = await _accessTokenProvider();
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
var accessToken = await _httpConnection.GetAccessTokenAsync();
if (!string.IsNullOrEmpty(accessToken))
{
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
}
return await base.SendAsync(request, cancellationToken);
}

View File

@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Http.Connections
public static class NegotiateProtocol
{
private const string ConnectionIdPropertyName = "connectionId";
private const string UrlPropertyName = "url";
private const string AccessTokenPropertyName = "accessToken";
private const string AvailableTransportsPropertyName = "availableTransports";
private const string TransportPropertyName = "transport";
private const string TransferFormatsPropertyName = "transferFormats";
@ -25,8 +27,25 @@ namespace Microsoft.AspNetCore.Http.Connections
using (var jsonWriter = JsonUtils.CreateJsonTextWriter(textWriter))
{
jsonWriter.WriteStartObject();
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
jsonWriter.WriteValue(response.ConnectionId);
if (!string.IsNullOrEmpty(response.Url))
{
jsonWriter.WritePropertyName(UrlPropertyName);
jsonWriter.WriteValue(response.Url);
}
if (!string.IsNullOrEmpty(response.AccessToken))
{
jsonWriter.WritePropertyName(AccessTokenPropertyName);
jsonWriter.WriteValue(response.AccessToken);
}
if (!string.IsNullOrEmpty(response.ConnectionId))
{
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
jsonWriter.WriteValue(response.ConnectionId);
}
jsonWriter.WritePropertyName(AvailableTransportsPropertyName);
jsonWriter.WriteStartArray();
@ -69,6 +88,8 @@ namespace Microsoft.AspNetCore.Http.Connections
JsonUtils.EnsureObjectStart(reader);
string connectionId = null;
string url = null;
string accessToken = null;
List<AvailableTransport> availableTransports = null;
var completed = false;
@ -81,6 +102,12 @@ namespace Microsoft.AspNetCore.Http.Connections
switch (memberName)
{
case UrlPropertyName:
url = JsonUtils.ReadAsString(reader, UrlPropertyName);
break;
case AccessTokenPropertyName:
accessToken = JsonUtils.ReadAsString(reader, AccessTokenPropertyName);
break;
case ConnectionIdPropertyName:
connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName);
break;
@ -114,19 +141,25 @@ namespace Microsoft.AspNetCore.Http.Connections
}
}
if (connectionId == null)
if (url == null)
{
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
}
// if url isn't specified, connectionId and available transports are required
if (connectionId == null)
{
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
}
if (availableTransports == null)
{
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
if (availableTransports == null)
{
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
}
}
return new NegotiationResponse
{
ConnectionId = connectionId,
Url = url,
AccessToken = accessToken,
AvailableTransports = availableTransports
};
}

View File

@ -7,6 +7,8 @@ namespace Microsoft.AspNetCore.Http.Connections
{
public class NegotiationResponse
{
public string Url { get; set; }
public string AccessToken { get; set; }
public string ConnectionId { get; set; }
public IList<AvailableTransport> AvailableTransports { get; set; }
}

View File

@ -9,21 +9,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public class NegotiateProtocolTests
{
[Theory]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0])]
[InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0])]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new [] { "test"})]
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports)
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null)]
[InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null)]
[InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null)]
[InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token")]
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null)]
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken)
{
var responseData = Encoding.UTF8.GetBytes(json);
var ms = new MemoryStream(responseData);
var response = NegotiateProtocol.ParseResponse(ms);
Assert.Equal(connectionId, response.ConnectionId);
Assert.Equal(availableTransports.Length, response.AvailableTransports.Count);
Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count);
Assert.Equal(url, response.Url);
Assert.Equal(accessToken, response.AccessToken);
var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList();
if (response.AvailableTransports != null)
{
var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList();
Assert.Equal(availableTransports, responseTransports);
Assert.Equal(availableTransports, responseTransports);
}
}
[Theory]

View File

@ -4,6 +4,7 @@
using System;
using System.IO;
using System.Net;
using System.Net.Http;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Http.Connections;
@ -70,7 +71,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
});
using (var noErrorScope = new VerifyNoErrorsScope())
{
{
await WithConnectionAsync(
CreateConnection(testHttpHandler, url: requestedUrl, loggerFactory: noErrorScope.LoggerFactory),
async (connection) =>
@ -82,6 +83,173 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
Assert.Equal(expectedNegotiate, await negotiateUrlTcs.Task.OrTimeout());
}
[Fact]
public async Task NegotiateThatReturnsUrlGetFollowed()
{
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
var firstNegotiate = true;
testHttpHandler.OnNegotiate((request, cancellationToken) =>
{
if (firstNegotiate)
{
firstNegotiate = false;
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
JsonConvert.SerializeObject(new
{
url = "https://another.domain.url/chat"
}));
}
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
JsonConvert.SerializeObject(new
{
connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00",
availableTransports = new object[]
{
new
{
transport = "LongPolling",
transferFormats = new[] { "Text" }
},
}
}));
});
testHttpHandler.OnLongPoll((token) =>
{
var tcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)));
return tcs.Task;
});
testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted));
using (var noErrorScope = new VerifyNoErrorsScope())
{
await WithConnectionAsync(
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory),
async (connection) =>
{
await connection.StartAsync(TransferFormat.Text).OrTimeout();
});
}
Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString());
Assert.Equal(5, testHttpHandler.ReceivedRequests.Count);
}
[Fact]
public async Task NegotiateThatReturnsRedirectUrlForeverThrowsAfter100Tries()
{
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
testHttpHandler.OnNegotiate((request, cancellationToken) =>
{
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
JsonConvert.SerializeObject(new
{
url = "https://another.domain.url/chat"
}));
});
using (var noErrorScope = new VerifyNoErrorsScope())
{
await WithConnectionAsync(
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory),
async (connection) =>
{
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => connection.StartAsync(TransferFormat.Text).OrTimeout());
Assert.Equal("Negotiate redirection limit exceeded.", exception.Message);
});
}
}
[Fact]
public async Task NegotiateThatReturnsUrlGetFollowedWithAccessToken()
{
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
var firstNegotiate = true;
testHttpHandler.OnNegotiate((request, cancellationToken) =>
{
if (firstNegotiate)
{
firstNegotiate = false;
// The first negotiate requires an access token
if (request.Headers.Authorization?.Parameter != "firstSecret")
{
return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized);
}
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
JsonConvert.SerializeObject(new
{
url = "https://another.domain.url/chat",
accessToken = "secondSecret"
}));
}
// All other requests require an access token
if (request.Headers.Authorization?.Parameter != "secondSecret")
{
return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized);
}
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
JsonConvert.SerializeObject(new
{
connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00",
availableTransports = new object[]
{
new
{
transport = "LongPolling",
transferFormats = new[] { "Text" }
},
}
}));
});
testHttpHandler.OnLongPoll((request, token) =>
{
// All other requests require an access token
if (request.Headers.Authorization?.Parameter != "secondSecret")
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized));
}
var tcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)));
return tcs.Task;
});
testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted));
Task<string> AccessTokenProvider() => Task.FromResult<string>("firstSecret");
using (var noErrorScope = new VerifyNoErrorsScope())
{
await WithConnectionAsync(
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory, accessTokenProvider: AccessTokenProvider),
async (connection) =>
{
await connection.StartAsync(TransferFormat.Text).OrTimeout();
});
}
Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString());
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString());
// Delete request
Assert.Equal(5, testHttpHandler.ReceivedRequests.Count);
}
[Fact]
public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand()
{

View File

@ -185,12 +185,22 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public void OnLongPoll(Func<CancellationToken, HttpResponseMessage> handler) => OnLongPoll(cancellationToken => Task.FromResult(handler(cancellationToken)));
public void OnLongPoll(Func<CancellationToken, Task<HttpResponseMessage>> handler)
{
OnLongPoll((request, token) => handler(token));
}
public void OnLongPoll(Func<HttpRequestMessage, CancellationToken, HttpResponseMessage> handler)
{
OnLongPoll((request, token) => Task.FromResult(handler(request, token)));
}
public void OnLongPoll(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler)
{
OnRequest((request, next, cancellationToken) =>
{
if (ResponseUtils.IsLongPollRequest(request))
{
return handler(cancellationToken);
return handler(request, cancellationToken);
}
else
{

View File

@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[ConditionalFact]
[WebSocketsSupportedCondition]
public async Task HttpRequestsNotSentWhenWebSocketsTransportRequested()
public async Task HttpRequestsNotSentWhenWebSocketsTransportRequestedAndSkipNegotiationSet()
{
using (StartVerifableLog(out var loggerFactory))
{
@ -184,6 +184,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
Url = new Uri(url),
Transports = HttpTransportType.WebSockets,
SkipNegotiation = true,
HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object
};
@ -213,6 +214,49 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task HttpConnectionThrowsIfSkipNegotiationSetAndTransportIsNotWebSockets(HttpTransportType transportType)
{
using (StartVerifableLog(out var loggerFactory))
{
var logger = loggerFactory.CreateLogger<EndToEndTests>();
var url = ServerFixture.Url + "/echo";
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(
(request, cancellationToken) => Task.FromException<HttpResponseMessage>(new InvalidOperationException("HTTP requests should not be sent.")));
var httpOptions = new HttpConnectionOptions
{
Url = new Uri(url),
Transports = transportType,
SkipNegotiation = true,
HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object
};
var connection = new HttpConnection(httpOptions, loggerFactory);
try
{
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
Assert.Equal("Negotiation can only be skipped when using the WebSocket transport directly.", exception.Message);
}
catch (Exception ex)
{
logger.LogInformation(ex, "Test threw exception");
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(TransportTypesAndTransferFormats))]
public async Task ConnectionCanSendAndReceiveMessages(HttpTransportType transportType, TransferFormat requestedTransferFormat)
@ -326,7 +370,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorStartingTransport";
writeContext.EventId.Name == "ErrorWithNegotiation";
}
using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors))
@ -336,24 +380,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var url = ServerFixture.Url + "/auth";
var connection = new HttpConnection(new Uri(url), HttpTransportType.WebSockets, loggerFactory);
try
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
Assert.Contains("401", exception.Message);
}
}
[ConditionalFact]
[WebSocketsSupportedCondition]
public async Task UnauthorizedDirectWebSocketsConnectionDoesNotConnect()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorStartingTransport";
}
using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors))
{
var logger = loggerFactory.CreateLogger<EndToEndTests>();
var url = ServerFixture.Url + "/auth";
var options = new HttpConnectionOptions
{
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync(TransferFormat.Binary).OrTimeout();
Assert.True(false);
}
catch (WebSocketException) { }
catch (Exception ex)
{
logger.LogInformation(ex, "Test threw exception");
throw;
}
finally
{
logger.LogInformation("Disposing Connection");
await connection.DisposeAsync().OrTimeout();
logger.LogInformation("Disposed Connection");
}
Url = new Uri(url),
Transports = HttpTransportType.WebSockets,
SkipNegotiation = true
};
var connection = new HttpConnection(options, loggerFactory);
await Assert.ThrowsAsync<WebSocketException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
}
}