Added support for negotiate response to redirect the client to another SignalR endpoint (#2070)
- Add a SkipNegotiation flag to the .NET and ts client to allow skipping the negotiation phase. Don't infer it based on the transport type. - Updated the negotiate protocol to support returning a redirect url - Added support to .NET client to handle redirect negotiations - Handle poorly written endpoints that sends infinite redirects - Added access token support and an infinite redirect guard - Add delete handler for stopping the transport
This commit is contained in:
parent
d9272032d0
commit
903fe1e902
|
|
@ -267,10 +267,11 @@ describe("HttpConnection", () => {
|
|||
}
|
||||
});
|
||||
|
||||
it("does not send negotiate request if WebSockets transport requested explicitly", async (done) => {
|
||||
it("does not send negotiate request if WebSockets transport requested explicitly and skipNegotiation is true", async (done) => {
|
||||
const options: IHttpConnectionOptions = {
|
||||
...commonOptions,
|
||||
httpClient: new TestHttpClient(),
|
||||
skipNegotiation: true,
|
||||
transport: HttpTransportType.WebSockets,
|
||||
} as IHttpConnectionOptions;
|
||||
|
||||
|
|
@ -287,6 +288,150 @@ describe("HttpConnection", () => {
|
|||
}
|
||||
});
|
||||
|
||||
it("does not start non WebSockets transport requested explicitly and skipNegotiation is true", async (done) => {
|
||||
const options: IHttpConnectionOptions = {
|
||||
...commonOptions,
|
||||
httpClient: new TestHttpClient(),
|
||||
skipNegotiation: true,
|
||||
transport: HttpTransportType.LongPolling,
|
||||
} as IHttpConnectionOptions;
|
||||
|
||||
const connection = new HttpConnection("http://tempuri.org", options);
|
||||
try {
|
||||
await connection.start(TransferFormat.Text);
|
||||
fail();
|
||||
done();
|
||||
} catch (e) {
|
||||
// WebSocket is created when the transport is connecting which happens after
|
||||
// negotiate request would be sent. No better/easier way to test this.
|
||||
expect(e.message).toBe("Negotiation can only be skipped when using the WebSocket transport directly.");
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it("redirects to url when negotiate returns it", async (done) => {
|
||||
let firstNegotiate = true;
|
||||
let firstPoll = true;
|
||||
const httpClient = new TestHttpClient()
|
||||
.on("POST", /negotiate$/, (r) => {
|
||||
if (firstNegotiate) {
|
||||
firstNegotiate = false;
|
||||
return { url: "https://another.domain.url/chat" };
|
||||
}
|
||||
return {
|
||||
availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }],
|
||||
connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00",
|
||||
};
|
||||
})
|
||||
.on("GET", (r) => {
|
||||
if (firstPoll) {
|
||||
firstPoll = false;
|
||||
return "";
|
||||
}
|
||||
return new HttpResponse(204, "No Content", "");
|
||||
});
|
||||
|
||||
const options: IHttpConnectionOptions = {
|
||||
...commonOptions,
|
||||
httpClient,
|
||||
transport: HttpTransportType.LongPolling,
|
||||
} as IHttpConnectionOptions;
|
||||
|
||||
try {
|
||||
const connection = new HttpConnection("http://tempuri.org", options);
|
||||
await connection.start(TransferFormat.Text);
|
||||
} catch (e) {
|
||||
fail(e);
|
||||
done();
|
||||
}
|
||||
|
||||
expect(httpClient.sentRequests.length).toBe(4);
|
||||
expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate");
|
||||
expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate");
|
||||
expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
|
||||
expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
|
||||
done();
|
||||
});
|
||||
|
||||
it("fails to start if negotiate redirects more than 100 times", async (done) => {
|
||||
const httpClient = new TestHttpClient()
|
||||
.on("POST", /negotiate$/, (r) => ({ url: "https://another.domain.url/chat" }));
|
||||
|
||||
const options: IHttpConnectionOptions = {
|
||||
...commonOptions,
|
||||
httpClient,
|
||||
transport: HttpTransportType.LongPolling,
|
||||
} as IHttpConnectionOptions;
|
||||
|
||||
try {
|
||||
const connection = new HttpConnection("http://tempuri.org", options);
|
||||
await connection.start(TransferFormat.Text);
|
||||
fail();
|
||||
} catch (e) {
|
||||
expect(e.message).toBe("Negotiate redirection limit exceeded.");
|
||||
done();
|
||||
}
|
||||
});
|
||||
|
||||
it("redirects to url when negotiate returns it with access token", async (done) => {
|
||||
let firstNegotiate = true;
|
||||
let firstPoll = true;
|
||||
const httpClient = new TestHttpClient()
|
||||
.on("POST", /negotiate$/, (r) => {
|
||||
if (firstNegotiate) {
|
||||
firstNegotiate = false;
|
||||
|
||||
if (r.headers && r.headers.Authorization !== "Bearer firstSecret") {
|
||||
return new HttpResponse(401, "Unauthorized", "");
|
||||
}
|
||||
|
||||
return { url: "https://another.domain.url/chat", accessToken: "secondSecret" };
|
||||
}
|
||||
|
||||
if (r.headers && r.headers.Authorization !== "Bearer secondSecret") {
|
||||
return new HttpResponse(401, "Unauthorized", "");
|
||||
}
|
||||
|
||||
return {
|
||||
availableTransports: [{ transport: "LongPolling", transferFormats: ["Text"] }],
|
||||
connectionId: "0rge0d00-0040-0030-0r00-000q00r00e00",
|
||||
};
|
||||
})
|
||||
.on("GET", (r) => {
|
||||
if (r.headers && r.headers.Authorization !== "Bearer secondSecret") {
|
||||
return new HttpResponse(401, "Unauthorized", "");
|
||||
}
|
||||
|
||||
if (firstPoll) {
|
||||
firstPoll = false;
|
||||
return "";
|
||||
}
|
||||
return new HttpResponse(204, "No Content", "");
|
||||
});
|
||||
|
||||
const options: IHttpConnectionOptions = {
|
||||
...commonOptions,
|
||||
accessTokenFactory: () => "firstSecret",
|
||||
httpClient,
|
||||
transport: HttpTransportType.LongPolling,
|
||||
} as IHttpConnectionOptions;
|
||||
|
||||
try {
|
||||
const connection = new HttpConnection("http://tempuri.org", options);
|
||||
await connection.start(TransferFormat.Text);
|
||||
} catch (e) {
|
||||
fail(e);
|
||||
done();
|
||||
}
|
||||
|
||||
expect(httpClient.sentRequests.length).toBe(4);
|
||||
expect(httpClient.sentRequests[0].url).toBe("http://tempuri.org/negotiate");
|
||||
expect(httpClient.sentRequests[1].url).toBe("https://another.domain.url/chat/negotiate");
|
||||
expect(httpClient.sentRequests[2].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
|
||||
expect(httpClient.sentRequests[3].url).toMatch(/^https:\/\/another\.domain\.url\/chat\?id=0rge0d00-0040-0030-0r00-000q00r00e00/i);
|
||||
done();
|
||||
});
|
||||
|
||||
it("authorization header removed when token factory returns null and using LongPolling", async (done) => {
|
||||
const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] };
|
||||
|
||||
|
|
|
|||
|
|
@ -8,15 +8,18 @@ export type TestHttpHandler = (request: HttpRequest, next?: (request: HttpReques
|
|||
|
||||
export class TestHttpClient extends HttpClient {
|
||||
private handler: (request: HttpRequest) => Promise<HttpResponse>;
|
||||
public sentRequests: HttpRequest[];
|
||||
|
||||
constructor() {
|
||||
super();
|
||||
this.sentRequests = [];
|
||||
this.handler = (request: HttpRequest) =>
|
||||
Promise.reject(`Request has no handler: ${request.method} ${request.url}`);
|
||||
|
||||
}
|
||||
|
||||
public send(request: HttpRequest): Promise<HttpResponse> {
|
||||
this.sentRequests.push(request);
|
||||
return this.handler(request);
|
||||
}
|
||||
|
||||
|
|
@ -59,7 +62,7 @@ export class TestHttpClient extends HttpClient {
|
|||
if (typeof val === "string") {
|
||||
// string payload
|
||||
return new HttpResponse(200, "OK", val);
|
||||
} else if(typeof val === "object" && val.statusCode) {
|
||||
} else if (typeof val === "object" && val.statusCode) {
|
||||
// HttpResponse payload
|
||||
return val as HttpResponse;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ export interface IHttpConnectionOptions {
|
|||
logger?: ILogger | LogLevel;
|
||||
accessTokenFactory?: () => string | Promise<string>;
|
||||
logMessageContent?: boolean;
|
||||
skipNegotiation?: boolean;
|
||||
}
|
||||
|
||||
const enum ConnectionState {
|
||||
|
|
@ -27,6 +28,8 @@ const enum ConnectionState {
|
|||
interface INegotiateResponse {
|
||||
connectionId: string;
|
||||
availableTransports: IAvailableTransport[];
|
||||
url: string;
|
||||
accessToken: string;
|
||||
}
|
||||
|
||||
interface IAvailableTransport {
|
||||
|
|
@ -34,17 +37,18 @@ interface IAvailableTransport {
|
|||
transferFormats: Array<keyof typeof TransferFormat>;
|
||||
}
|
||||
|
||||
const MAX_REDIRECTS = 100;
|
||||
|
||||
export class HttpConnection implements IConnection {
|
||||
private connectionState: ConnectionState;
|
||||
private baseUrl: string;
|
||||
private url: string;
|
||||
private readonly httpClient: HttpClient;
|
||||
private readonly logger: ILogger;
|
||||
private readonly options: IHttpConnectionOptions;
|
||||
private transport: ITransport;
|
||||
private connectionId: string;
|
||||
private startPromise: Promise<void>;
|
||||
private stopError?: Error;
|
||||
private accessTokenFactory?: () => string | Promise<string>;
|
||||
|
||||
public readonly features: any = {};
|
||||
public onreceive: (data: string | ArrayBuffer) => void;
|
||||
|
|
@ -110,29 +114,53 @@ export class HttpConnection implements IConnection {
|
|||
}
|
||||
|
||||
private async startInternal(transferFormat: TransferFormat): Promise<void> {
|
||||
// Store the original base url and the access token factory since they may change
|
||||
// as part of negotiating
|
||||
let url = this.baseUrl;
|
||||
this.accessTokenFactory = this.options.accessTokenFactory;
|
||||
|
||||
try {
|
||||
if (this.options.transport === HttpTransportType.WebSockets) {
|
||||
// No need to add a connection ID in this case
|
||||
this.url = this.baseUrl;
|
||||
this.transport = this.constructTransport(HttpTransportType.WebSockets);
|
||||
// We should just call connect directly in this case.
|
||||
// No fallback or negotiate in this case.
|
||||
await this.transport.connect(this.url, transferFormat);
|
||||
if (this.options.skipNegotiation) {
|
||||
if (this.options.transport === HttpTransportType.WebSockets) {
|
||||
// No need to add a connection ID in this case
|
||||
this.transport = this.constructTransport(HttpTransportType.WebSockets);
|
||||
// We should just call connect directly in this case.
|
||||
// No fallback or negotiate in this case.
|
||||
await this.transport.connect(url, transferFormat);
|
||||
} else {
|
||||
throw Error("Negotiation can only be skipped when using the WebSocket transport directly.");
|
||||
}
|
||||
} else {
|
||||
const token = await this.options.accessTokenFactory();
|
||||
let headers;
|
||||
if (token) {
|
||||
headers = {
|
||||
["Authorization"]: `Bearer ${token}`,
|
||||
};
|
||||
let negotiateResponse: INegotiateResponse = null;
|
||||
let redirects = 0;
|
||||
|
||||
do {
|
||||
negotiateResponse = await this.getNegotiationResponse(url);
|
||||
// the user tries to stop the connection when it is being started
|
||||
if (this.connectionState === ConnectionState.Disconnected) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (negotiateResponse.url) {
|
||||
url = negotiateResponse.url;
|
||||
}
|
||||
|
||||
if (negotiateResponse.accessToken) {
|
||||
// Replace the current access token factory with one that uses
|
||||
// the returned access token
|
||||
const accessToken = negotiateResponse.accessToken;
|
||||
this.accessTokenFactory = () => accessToken;
|
||||
}
|
||||
|
||||
redirects++;
|
||||
}
|
||||
while (negotiateResponse.url && redirects < MAX_REDIRECTS);
|
||||
|
||||
if (redirects === MAX_REDIRECTS && negotiateResponse.url) {
|
||||
throw Error("Negotiate redirection limit exceeded.");
|
||||
}
|
||||
|
||||
const negotiateResponse = await this.getNegotiationResponse(headers);
|
||||
// the user tries to stop the the connection when it is being started
|
||||
if (this.connectionState === ConnectionState.Disconnected) {
|
||||
return;
|
||||
}
|
||||
await this.createTransport(this.options.transport, negotiateResponse, transferFormat, headers);
|
||||
await this.createTransport(url, this.options.transport, negotiateResponse, transferFormat);
|
||||
}
|
||||
|
||||
if (this.transport instanceof LongPollingTransport) {
|
||||
|
|
@ -153,32 +181,44 @@ export class HttpConnection implements IConnection {
|
|||
}
|
||||
}
|
||||
|
||||
private async getNegotiationResponse(headers: any): Promise<INegotiateResponse> {
|
||||
const negotiateUrl = this.resolveNegotiateUrl(this.baseUrl);
|
||||
private async getNegotiationResponse(url: string): Promise<INegotiateResponse> {
|
||||
const token = await this.accessTokenFactory();
|
||||
let headers;
|
||||
if (token) {
|
||||
headers = {
|
||||
["Authorization"]: `Bearer ${token}`,
|
||||
};
|
||||
}
|
||||
|
||||
const negotiateUrl = this.resolveNegotiateUrl(url);
|
||||
this.logger.log(LogLevel.Debug, `Sending negotiation request: ${negotiateUrl}`);
|
||||
try {
|
||||
const response = await this.httpClient.post(negotiateUrl, {
|
||||
content: "",
|
||||
headers,
|
||||
});
|
||||
return JSON.parse(response.content as string);
|
||||
|
||||
if (response.statusCode !== 200) {
|
||||
throw Error(`Unexpected status code returned from negotiate ${response.statusCode}`);
|
||||
}
|
||||
|
||||
return JSON.parse(response.content as string) as INegotiateResponse;
|
||||
} catch (e) {
|
||||
this.logger.log(LogLevel.Error, "Failed to complete negotiation with the server: " + e);
|
||||
throw e;
|
||||
}
|
||||
}
|
||||
|
||||
private updateConnectionId(negotiateResponse: INegotiateResponse) {
|
||||
this.connectionId = negotiateResponse.connectionId;
|
||||
this.url = this.baseUrl + (this.baseUrl.indexOf("?") === -1 ? "?" : "&") + `id=${this.connectionId}`;
|
||||
private createConnectUrl(url: string, connectionId: string) {
|
||||
return url + (url.indexOf("?") === -1 ? "?" : "&") + `id=${connectionId}`;
|
||||
}
|
||||
|
||||
private async createTransport(requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat, headers: any): Promise<void> {
|
||||
this.updateConnectionId(negotiateResponse);
|
||||
private async createTransport(url: string, requestedTransport: HttpTransportType | ITransport, negotiateResponse: INegotiateResponse, requestedTransferFormat: TransferFormat): Promise<void> {
|
||||
let connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId);
|
||||
if (this.isITransport(requestedTransport)) {
|
||||
this.logger.log(LogLevel.Debug, "Connection was provided an instance of ITransport, using that directly.");
|
||||
this.transport = requestedTransport;
|
||||
await this.transport.connect(this.url, requestedTransferFormat);
|
||||
await this.transport.connect(connectUrl, requestedTransferFormat);
|
||||
|
||||
// only change the state if we were connecting to not overwrite
|
||||
// the state if the connection is already marked as Disconnected
|
||||
|
|
@ -193,11 +233,11 @@ export class HttpConnection implements IConnection {
|
|||
if (typeof transport === "number") {
|
||||
this.transport = this.constructTransport(transport);
|
||||
if (negotiateResponse.connectionId === null) {
|
||||
negotiateResponse = await this.getNegotiationResponse(headers);
|
||||
this.updateConnectionId(negotiateResponse);
|
||||
negotiateResponse = await this.getNegotiationResponse(url);
|
||||
connectUrl = this.createConnectUrl(url, negotiateResponse.connectionId);
|
||||
}
|
||||
try {
|
||||
await this.transport.connect(this.url, requestedTransferFormat);
|
||||
await this.transport.connect(connectUrl, requestedTransferFormat);
|
||||
this.changeState(ConnectionState.Connecting, ConnectionState.Connected);
|
||||
return;
|
||||
} catch (ex) {
|
||||
|
|
@ -214,11 +254,11 @@ export class HttpConnection implements IConnection {
|
|||
private constructTransport(transport: HttpTransportType) {
|
||||
switch (transport) {
|
||||
case HttpTransportType.WebSockets:
|
||||
return new WebSocketTransport(this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
return new WebSocketTransport(this.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
case HttpTransportType.ServerSentEvents:
|
||||
return new ServerSentEventsTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
return new ServerSentEventsTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
case HttpTransportType.LongPolling:
|
||||
return new LongPollingTransport(this.httpClient, this.options.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
return new LongPollingTransport(this.httpClient, this.accessTokenFactory, this.logger, this.options.logMessageContent);
|
||||
default:
|
||||
throw new Error(`Unknown transport: ${transport}.`);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,32 +18,49 @@ Throughout this document, the term `[endpoint-base]` is used to refer to the rou
|
|||
|
||||
## `POST [endpoint-base]/negotiate` request
|
||||
|
||||
The `POST [endpoint-base]/negotiate` request is used to establish connection between the client and the server. The response to the `POST [endpoint-base]/negotiate` request contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server. The content type of the response is `application/json`. The following is a sample response to the `POST [endpoint-base]/negotiate` request
|
||||
The `POST [endpoint-base]/negotiate` request is used to establish a connection between the client and the server. The content type of the response is `application/json`. The response to the `POST [endpoint-base]/negotiate` request contains one of two types of responses:
|
||||
|
||||
```
|
||||
{
|
||||
"connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d",
|
||||
"availableTransports":[
|
||||
{
|
||||
"transport": "WebSockets",
|
||||
"transferFormats": [ "Text", "Binary" ]
|
||||
},
|
||||
{
|
||||
"transport": "ServerSentEvents",
|
||||
"transferFormats": [ "Text" ]
|
||||
},
|
||||
{
|
||||
"transport": "LongPolling",
|
||||
"transferFormats": [ "Text", "Binary" ]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
1. A response that contains the `connectionId` which will be used to identify the connection on the server and the list of the transports supported by the server.
|
||||
|
||||
The payload returned from this endpoint provides the following data:
|
||||
```
|
||||
{
|
||||
"connectionId":"807809a5-31bf-470d-9e23-afaee35d8a0d",
|
||||
"availableTransports":[
|
||||
{
|
||||
"transport": "WebSockets",
|
||||
"transferFormats": [ "Text", "Binary" ]
|
||||
},
|
||||
{
|
||||
"transport": "ServerSentEvents",
|
||||
"transferFormats": [ "Text" ]
|
||||
},
|
||||
{
|
||||
"transport": "LongPolling",
|
||||
"transferFormats": [ "Text", "Binary" ]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
* The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives).
|
||||
* The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`)
|
||||
The payload returned from this endpoint provides the following data:
|
||||
|
||||
* The `connectionId` which is **required** by the Long Polling and Server-Sent Events transports (in order to correlate sends and receives).
|
||||
* The `availableTransports` list which describes the transports the server supports. For each transport, the name of the transport (`transport`) is listed, as is a list of "transfer formats" supported by the transport (`transferFormats`)
|
||||
|
||||
|
||||
2. A redirect response which tells the client which URL and optionally access token to use as a result.
|
||||
|
||||
```
|
||||
{
|
||||
"url": "https://myapp.com/chat",
|
||||
"accessToken": "accessToken"
|
||||
}
|
||||
```
|
||||
|
||||
The payload returned from this endpoint provides the following data:
|
||||
|
||||
* The `url` which is the URL the client should connect to.
|
||||
* The `accessToken` which is an optional bearer token for accessing the specified url.
|
||||
|
||||
## Transfer Formats
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,11 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
{
|
||||
public partial class HttpConnection : ConnectionContext, IConnectionInherentKeepAliveFeature
|
||||
{
|
||||
// Not configurable on purpose, high enough that if we reach here, it's likely
|
||||
// a buggy server
|
||||
private static readonly int _maxRedirects = 100;
|
||||
private static readonly Task<string> _noAccessToken = Task.FromResult<string>(null);
|
||||
|
||||
private static readonly TimeSpan HttpClientTimeout = TimeSpan.FromSeconds(120);
|
||||
#if !NETCOREAPP2_1
|
||||
private static readonly Version Windows8Version = new Version(6, 2);
|
||||
|
|
@ -40,6 +45,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
private readonly ConnectionLogScope _logScope;
|
||||
private readonly IDisposable _scopeDisposable;
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private Func<Task<string>> _accessTokenProvider;
|
||||
|
||||
public override IDuplexPipe Transport
|
||||
{
|
||||
|
|
@ -96,7 +102,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
_logger = _loggerFactory.CreateLogger<HttpConnection>();
|
||||
_httpConnectionOptions = httpConnectionOptions;
|
||||
|
||||
if (httpConnectionOptions.Transports != HttpTransportType.WebSockets)
|
||||
if (!httpConnectionOptions.SkipNegotiation || httpConnectionOptions.Transports != HttpTransportType.WebSockets)
|
||||
{
|
||||
_httpClient = CreateHttpClient();
|
||||
}
|
||||
|
|
@ -210,17 +216,54 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
|
||||
private async Task SelectAndStartTransport(TransferFormat transferFormat)
|
||||
{
|
||||
if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets)
|
||||
var uri = _httpConnectionOptions.Url;
|
||||
// Set the initial access token provider back to the original one from options
|
||||
_accessTokenProvider = _httpConnectionOptions.AccessTokenProvider;
|
||||
|
||||
if (_httpConnectionOptions.SkipNegotiation)
|
||||
{
|
||||
Log.StartingTransport(_logger, _httpConnectionOptions.Transports, _httpConnectionOptions.Url);
|
||||
await StartTransport(_httpConnectionOptions.Url, _httpConnectionOptions.Transports, transferFormat);
|
||||
if (_httpConnectionOptions.Transports == HttpTransportType.WebSockets)
|
||||
{
|
||||
Log.StartingTransport(_logger, _httpConnectionOptions.Transports, uri);
|
||||
await StartTransport(uri, _httpConnectionOptions.Transports, transferFormat);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new InvalidOperationException("Negotiation can only be skipped when using the WebSocket transport directly.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
var negotiationResponse = await GetNegotiationResponse();
|
||||
NegotiationResponse negotiationResponse;
|
||||
var redirects = 0;
|
||||
|
||||
do
|
||||
{
|
||||
negotiationResponse = await GetNegotiationResponseAsync(uri);
|
||||
|
||||
if (negotiationResponse.Url != null)
|
||||
{
|
||||
uri = new Uri(negotiationResponse.Url);
|
||||
}
|
||||
|
||||
if (negotiationResponse.AccessToken != null)
|
||||
{
|
||||
string accessToken = negotiationResponse.AccessToken;
|
||||
// Set the current access token factory so that future requests use this access token
|
||||
_accessTokenProvider = () => Task.FromResult(accessToken);
|
||||
}
|
||||
|
||||
redirects++;
|
||||
}
|
||||
while (negotiationResponse.Url != null && redirects < _maxRedirects);
|
||||
|
||||
if (redirects == _maxRedirects && negotiationResponse.Url != null)
|
||||
{
|
||||
throw new InvalidOperationException("Negotiate redirection limit exceeded.");
|
||||
}
|
||||
|
||||
// This should only need to happen once
|
||||
var connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId);
|
||||
var connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId);
|
||||
|
||||
// We're going to search for the transfer format as a string because we don't want to parse
|
||||
// all the transfer formats in the negotiation response, and we want to allow transfer formats
|
||||
|
|
@ -256,8 +299,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
// The negotiation response gets cleared in the fallback scenario.
|
||||
if (negotiationResponse == null)
|
||||
{
|
||||
negotiationResponse = await GetNegotiationResponse();
|
||||
connectUrl = CreateConnectUrl(_httpConnectionOptions.Url, negotiationResponse.ConnectionId);
|
||||
negotiationResponse = await GetNegotiationResponseAsync(uri);
|
||||
connectUrl = CreateConnectUrl(uri, negotiationResponse.ConnectionId);
|
||||
}
|
||||
|
||||
Log.StartingTransport(_logger, transportType, connectUrl);
|
||||
|
|
@ -281,7 +324,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
}
|
||||
}
|
||||
|
||||
private async Task<NegotiationResponse> Negotiate(Uri url, HttpClient httpClient, ILogger logger)
|
||||
private async Task<NegotiationResponse> NegotiateAsync(Uri url, HttpClient httpClient, ILogger logger)
|
||||
{
|
||||
try
|
||||
{
|
||||
|
|
@ -399,10 +442,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
}
|
||||
|
||||
// Apply the authorization header in a handler instead of a default header because it can change with each request
|
||||
if (_httpConnectionOptions.AccessTokenProvider != null)
|
||||
{
|
||||
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, _httpConnectionOptions.AccessTokenProvider);
|
||||
}
|
||||
httpMessageHandler = new AccessTokenHttpMessageHandler(httpMessageHandler, this);
|
||||
}
|
||||
|
||||
// Wrap message handler after HttpMessageHandlerFactory to ensure not overriden
|
||||
|
|
@ -430,6 +470,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
return httpClient;
|
||||
}
|
||||
|
||||
internal Task<string> GetAccessTokenAsync()
|
||||
{
|
||||
if (_accessTokenProvider == null)
|
||||
{
|
||||
return _noAccessToken;
|
||||
}
|
||||
return _accessTokenProvider();
|
||||
}
|
||||
|
||||
private void CheckDisposed()
|
||||
{
|
||||
if (_disposed)
|
||||
|
|
@ -458,9 +507,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
#endif
|
||||
}
|
||||
|
||||
private async Task<NegotiationResponse> GetNegotiationResponse()
|
||||
private async Task<NegotiationResponse> GetNegotiationResponseAsync(Uri uri)
|
||||
{
|
||||
var negotiationResponse = await Negotiate(_httpConnectionOptions.Url, _httpClient, _logger);
|
||||
var negotiationResponse = await NegotiateAsync(uri, _httpClient, _logger);
|
||||
_connectionId = negotiationResponse.ConnectionId;
|
||||
_logScope.ConnectionId = _connectionId;
|
||||
return negotiationResponse;
|
||||
|
|
|
|||
|
|
@ -52,6 +52,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client
|
|||
|
||||
public Uri Url { get; set; }
|
||||
public HttpTransportType Transports { get; set; }
|
||||
|
||||
public bool SkipNegotiation { get; set; }
|
||||
|
||||
public Func<Task<string>> AccessTokenProvider { get; set; }
|
||||
public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5);
|
||||
public ICredentials Credentials { get; set; }
|
||||
|
|
|
|||
|
|
@ -11,17 +11,21 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
|
|||
{
|
||||
internal class AccessTokenHttpMessageHandler : DelegatingHandler
|
||||
{
|
||||
private readonly Func<Task<string>> _accessTokenProvider;
|
||||
private readonly HttpConnection _httpConnection;
|
||||
|
||||
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, Func<Task<string>> accessTokenProvider) : base(inner)
|
||||
public AccessTokenHttpMessageHandler(HttpMessageHandler inner, HttpConnection httpConnection) : base(inner)
|
||||
{
|
||||
_accessTokenProvider = accessTokenProvider ?? throw new ArgumentNullException(nameof(accessTokenProvider));
|
||||
_httpConnection = httpConnection;
|
||||
}
|
||||
|
||||
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
|
||||
{
|
||||
var accessToken = await _accessTokenProvider();
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
|
||||
var accessToken = await _httpConnection.GetAccessTokenAsync();
|
||||
|
||||
if (!string.IsNullOrEmpty(accessToken))
|
||||
{
|
||||
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken);
|
||||
}
|
||||
|
||||
return await base.SendAsync(request, cancellationToken);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,8 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
public static class NegotiateProtocol
|
||||
{
|
||||
private const string ConnectionIdPropertyName = "connectionId";
|
||||
private const string UrlPropertyName = "url";
|
||||
private const string AccessTokenPropertyName = "accessToken";
|
||||
private const string AvailableTransportsPropertyName = "availableTransports";
|
||||
private const string TransportPropertyName = "transport";
|
||||
private const string TransferFormatsPropertyName = "transferFormats";
|
||||
|
|
@ -25,8 +27,25 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
using (var jsonWriter = JsonUtils.CreateJsonTextWriter(textWriter))
|
||||
{
|
||||
jsonWriter.WriteStartObject();
|
||||
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
|
||||
jsonWriter.WriteValue(response.ConnectionId);
|
||||
|
||||
if (!string.IsNullOrEmpty(response.Url))
|
||||
{
|
||||
jsonWriter.WritePropertyName(UrlPropertyName);
|
||||
jsonWriter.WriteValue(response.Url);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(response.AccessToken))
|
||||
{
|
||||
jsonWriter.WritePropertyName(AccessTokenPropertyName);
|
||||
jsonWriter.WriteValue(response.AccessToken);
|
||||
}
|
||||
|
||||
if (!string.IsNullOrEmpty(response.ConnectionId))
|
||||
{
|
||||
jsonWriter.WritePropertyName(ConnectionIdPropertyName);
|
||||
jsonWriter.WriteValue(response.ConnectionId);
|
||||
}
|
||||
|
||||
jsonWriter.WritePropertyName(AvailableTransportsPropertyName);
|
||||
jsonWriter.WriteStartArray();
|
||||
|
||||
|
|
@ -69,6 +88,8 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
JsonUtils.EnsureObjectStart(reader);
|
||||
|
||||
string connectionId = null;
|
||||
string url = null;
|
||||
string accessToken = null;
|
||||
List<AvailableTransport> availableTransports = null;
|
||||
|
||||
var completed = false;
|
||||
|
|
@ -81,6 +102,12 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
|
||||
switch (memberName)
|
||||
{
|
||||
case UrlPropertyName:
|
||||
url = JsonUtils.ReadAsString(reader, UrlPropertyName);
|
||||
break;
|
||||
case AccessTokenPropertyName:
|
||||
accessToken = JsonUtils.ReadAsString(reader, AccessTokenPropertyName);
|
||||
break;
|
||||
case ConnectionIdPropertyName:
|
||||
connectionId = JsonUtils.ReadAsString(reader, ConnectionIdPropertyName);
|
||||
break;
|
||||
|
|
@ -114,19 +141,25 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
}
|
||||
}
|
||||
|
||||
if (connectionId == null)
|
||||
if (url == null)
|
||||
{
|
||||
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
|
||||
}
|
||||
// if url isn't specified, connectionId and available transports are required
|
||||
if (connectionId == null)
|
||||
{
|
||||
throw new InvalidDataException($"Missing required property '{ConnectionIdPropertyName}'.");
|
||||
}
|
||||
|
||||
if (availableTransports == null)
|
||||
{
|
||||
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
|
||||
if (availableTransports == null)
|
||||
{
|
||||
throw new InvalidDataException($"Missing required property '{AvailableTransportsPropertyName}'.");
|
||||
}
|
||||
}
|
||||
|
||||
return new NegotiationResponse
|
||||
{
|
||||
ConnectionId = connectionId,
|
||||
Url = url,
|
||||
AccessToken = accessToken,
|
||||
AvailableTransports = availableTransports
|
||||
};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ namespace Microsoft.AspNetCore.Http.Connections
|
|||
{
|
||||
public class NegotiationResponse
|
||||
{
|
||||
public string Url { get; set; }
|
||||
public string AccessToken { get; set; }
|
||||
public string ConnectionId { get; set; }
|
||||
public IList<AvailableTransport> AvailableTransports { get; set; }
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,21 +9,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
|
|||
public class NegotiateProtocolTests
|
||||
{
|
||||
[Theory]
|
||||
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0])]
|
||||
[InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0])]
|
||||
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new [] { "test"})]
|
||||
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports)
|
||||
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[]}", "123", new string[0], null, null)]
|
||||
[InlineData("{\"connectionId\":\"\",\"availableTransports\":[]}", "", new string[0], null, null)]
|
||||
[InlineData("{\"url\": \"http://foo.com/chat\"}", null, null, "http://foo.com/chat", null)]
|
||||
[InlineData("{\"url\": \"http://foo.com/chat\", \"accessToken\": \"token\"}", null, null, "http://foo.com/chat", "token")]
|
||||
[InlineData("{\"connectionId\":\"123\",\"availableTransports\":[{\"transport\":\"test\",\"transferFormats\":[]}]}", "123", new[] { "test" }, null, null)]
|
||||
public void ParsingNegotiateResponseMessageSuccessForValid(string json, string connectionId, string[] availableTransports, string url, string accessToken)
|
||||
{
|
||||
var responseData = Encoding.UTF8.GetBytes(json);
|
||||
var ms = new MemoryStream(responseData);
|
||||
var response = NegotiateProtocol.ParseResponse(ms);
|
||||
|
||||
Assert.Equal(connectionId, response.ConnectionId);
|
||||
Assert.Equal(availableTransports.Length, response.AvailableTransports.Count);
|
||||
Assert.Equal(availableTransports?.Length, response.AvailableTransports?.Count);
|
||||
Assert.Equal(url, response.Url);
|
||||
Assert.Equal(accessToken, response.AccessToken);
|
||||
|
||||
var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList();
|
||||
if (response.AvailableTransports != null)
|
||||
{
|
||||
var responseTransports = response.AvailableTransports.Select(t => t.Transport).ToList();
|
||||
|
||||
Assert.Equal(availableTransports, responseTransports);
|
||||
Assert.Equal(availableTransports, responseTransports);
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
using System;
|
||||
using System.IO;
|
||||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Connections;
|
||||
using Microsoft.AspNetCore.Http.Connections;
|
||||
|
|
@ -70,7 +71,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
});
|
||||
|
||||
using (var noErrorScope = new VerifyNoErrorsScope())
|
||||
{
|
||||
{
|
||||
await WithConnectionAsync(
|
||||
CreateConnection(testHttpHandler, url: requestedUrl, loggerFactory: noErrorScope.LoggerFactory),
|
||||
async (connection) =>
|
||||
|
|
@ -82,6 +83,173 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
Assert.Equal(expectedNegotiate, await negotiateUrlTcs.Task.OrTimeout());
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NegotiateThatReturnsUrlGetFollowed()
|
||||
{
|
||||
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
|
||||
var firstNegotiate = true;
|
||||
testHttpHandler.OnNegotiate((request, cancellationToken) =>
|
||||
{
|
||||
if (firstNegotiate)
|
||||
{
|
||||
firstNegotiate = false;
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
|
||||
JsonConvert.SerializeObject(new
|
||||
{
|
||||
url = "https://another.domain.url/chat"
|
||||
}));
|
||||
}
|
||||
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
|
||||
JsonConvert.SerializeObject(new
|
||||
{
|
||||
connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00",
|
||||
availableTransports = new object[]
|
||||
{
|
||||
new
|
||||
{
|
||||
transport = "LongPolling",
|
||||
transferFormats = new[] { "Text" }
|
||||
},
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
testHttpHandler.OnLongPoll((token) =>
|
||||
{
|
||||
var tcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)));
|
||||
|
||||
return tcs.Task;
|
||||
});
|
||||
|
||||
testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted));
|
||||
|
||||
using (var noErrorScope = new VerifyNoErrorsScope())
|
||||
{
|
||||
await WithConnectionAsync(
|
||||
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory),
|
||||
async (connection) =>
|
||||
{
|
||||
await connection.StartAsync(TransferFormat.Text).OrTimeout();
|
||||
});
|
||||
}
|
||||
|
||||
Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString());
|
||||
Assert.Equal(5, testHttpHandler.ReceivedRequests.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NegotiateThatReturnsRedirectUrlForeverThrowsAfter100Tries()
|
||||
{
|
||||
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
|
||||
testHttpHandler.OnNegotiate((request, cancellationToken) =>
|
||||
{
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
|
||||
JsonConvert.SerializeObject(new
|
||||
{
|
||||
url = "https://another.domain.url/chat"
|
||||
}));
|
||||
});
|
||||
|
||||
using (var noErrorScope = new VerifyNoErrorsScope())
|
||||
{
|
||||
await WithConnectionAsync(
|
||||
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory),
|
||||
async (connection) =>
|
||||
{
|
||||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => connection.StartAsync(TransferFormat.Text).OrTimeout());
|
||||
Assert.Equal("Negotiate redirection limit exceeded.", exception.Message);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NegotiateThatReturnsUrlGetFollowedWithAccessToken()
|
||||
{
|
||||
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
|
||||
var firstNegotiate = true;
|
||||
testHttpHandler.OnNegotiate((request, cancellationToken) =>
|
||||
{
|
||||
if (firstNegotiate)
|
||||
{
|
||||
firstNegotiate = false;
|
||||
|
||||
// The first negotiate requires an access token
|
||||
if (request.Headers.Authorization?.Parameter != "firstSecret")
|
||||
{
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized);
|
||||
}
|
||||
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
|
||||
JsonConvert.SerializeObject(new
|
||||
{
|
||||
url = "https://another.domain.url/chat",
|
||||
accessToken = "secondSecret"
|
||||
}));
|
||||
}
|
||||
|
||||
// All other requests require an access token
|
||||
if (request.Headers.Authorization?.Parameter != "secondSecret")
|
||||
{
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized);
|
||||
}
|
||||
|
||||
return ResponseUtils.CreateResponse(HttpStatusCode.OK,
|
||||
JsonConvert.SerializeObject(new
|
||||
{
|
||||
connectionId = "0rge0d00-0040-0030-0r00-000q00r00e00",
|
||||
availableTransports = new object[]
|
||||
{
|
||||
new
|
||||
{
|
||||
transport = "LongPolling",
|
||||
transferFormats = new[] { "Text" }
|
||||
},
|
||||
}
|
||||
}));
|
||||
});
|
||||
|
||||
testHttpHandler.OnLongPoll((request, token) =>
|
||||
{
|
||||
// All other requests require an access token
|
||||
if (request.Headers.Authorization?.Parameter != "secondSecret")
|
||||
{
|
||||
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.Unauthorized));
|
||||
}
|
||||
var tcs = new TaskCompletionSource<HttpResponseMessage>(TaskCreationOptions.RunContinuationsAsynchronously);
|
||||
|
||||
token.Register(() => tcs.TrySetResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)));
|
||||
|
||||
return tcs.Task;
|
||||
});
|
||||
|
||||
testHttpHandler.OnLongPollDelete((token) => ResponseUtils.CreateResponse(HttpStatusCode.Accepted));
|
||||
|
||||
Task<string> AccessTokenProvider() => Task.FromResult<string>("firstSecret");
|
||||
|
||||
using (var noErrorScope = new VerifyNoErrorsScope())
|
||||
{
|
||||
await WithConnectionAsync(
|
||||
CreateConnection(testHttpHandler, loggerFactory: noErrorScope.LoggerFactory, accessTokenProvider: AccessTokenProvider),
|
||||
async (connection) =>
|
||||
{
|
||||
await connection.StartAsync(TransferFormat.Text).OrTimeout();
|
||||
});
|
||||
}
|
||||
|
||||
Assert.Equal("http://fakeuri.org/negotiate", testHttpHandler.ReceivedRequests[0].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat/negotiate", testHttpHandler.ReceivedRequests[1].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[2].RequestUri.ToString());
|
||||
Assert.Equal("https://another.domain.url/chat?id=0rge0d00-0040-0030-0r00-000q00r00e00", testHttpHandler.ReceivedRequests[3].RequestUri.ToString());
|
||||
// Delete request
|
||||
Assert.Equal(5, testHttpHandler.ReceivedRequests.Count);
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task StartSkipsOverTransportsThatTheClientDoesNotUnderstand()
|
||||
{
|
||||
|
|
|
|||
|
|
@ -185,12 +185,22 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
|
|||
public void OnLongPoll(Func<CancellationToken, HttpResponseMessage> handler) => OnLongPoll(cancellationToken => Task.FromResult(handler(cancellationToken)));
|
||||
|
||||
public void OnLongPoll(Func<CancellationToken, Task<HttpResponseMessage>> handler)
|
||||
{
|
||||
OnLongPoll((request, token) => handler(token));
|
||||
}
|
||||
|
||||
public void OnLongPoll(Func<HttpRequestMessage, CancellationToken, HttpResponseMessage> handler)
|
||||
{
|
||||
OnLongPoll((request, token) => Task.FromResult(handler(request, token)));
|
||||
}
|
||||
|
||||
public void OnLongPoll(Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler)
|
||||
{
|
||||
OnRequest((request, next, cancellationToken) =>
|
||||
{
|
||||
if (ResponseUtils.IsLongPollRequest(request))
|
||||
{
|
||||
return handler(cancellationToken);
|
||||
return handler(request, cancellationToken);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
[ConditionalFact]
|
||||
[WebSocketsSupportedCondition]
|
||||
public async Task HttpRequestsNotSentWhenWebSocketsTransportRequested()
|
||||
public async Task HttpRequestsNotSentWhenWebSocketsTransportRequestedAndSkipNegotiationSet()
|
||||
{
|
||||
using (StartVerifableLog(out var loggerFactory))
|
||||
{
|
||||
|
|
@ -184,6 +184,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
{
|
||||
Url = new Uri(url),
|
||||
Transports = HttpTransportType.WebSockets,
|
||||
SkipNegotiation = true,
|
||||
HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object
|
||||
};
|
||||
|
||||
|
|
@ -213,6 +214,49 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[InlineData(HttpTransportType.LongPolling)]
|
||||
[InlineData(HttpTransportType.ServerSentEvents)]
|
||||
public async Task HttpConnectionThrowsIfSkipNegotiationSetAndTransportIsNotWebSockets(HttpTransportType transportType)
|
||||
{
|
||||
using (StartVerifableLog(out var loggerFactory))
|
||||
{
|
||||
var logger = loggerFactory.CreateLogger<EndToEndTests>();
|
||||
var url = ServerFixture.Url + "/echo";
|
||||
|
||||
var mockHttpHandler = new Mock<HttpMessageHandler>();
|
||||
mockHttpHandler.Protected()
|
||||
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
|
||||
.Returns<HttpRequestMessage, CancellationToken>(
|
||||
(request, cancellationToken) => Task.FromException<HttpResponseMessage>(new InvalidOperationException("HTTP requests should not be sent.")));
|
||||
|
||||
var httpOptions = new HttpConnectionOptions
|
||||
{
|
||||
Url = new Uri(url),
|
||||
Transports = transportType,
|
||||
SkipNegotiation = true,
|
||||
HttpMessageHandlerFactory = (httpMessageHandler) => mockHttpHandler.Object
|
||||
};
|
||||
|
||||
var connection = new HttpConnection(httpOptions, loggerFactory);
|
||||
|
||||
try
|
||||
{
|
||||
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
|
||||
Assert.Equal("Negotiation can only be skipped when using the WebSocket transport directly.", exception.Message);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
logger.LogInformation(ex, "Test threw exception");
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(TransportTypesAndTransferFormats))]
|
||||
public async Task ConnectionCanSendAndReceiveMessages(HttpTransportType transportType, TransferFormat requestedTransferFormat)
|
||||
|
|
@ -326,7 +370,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
bool ExpectedErrors(WriteContext writeContext)
|
||||
{
|
||||
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
|
||||
writeContext.EventId.Name == "ErrorStartingTransport";
|
||||
writeContext.EventId.Name == "ErrorWithNegotiation";
|
||||
}
|
||||
|
||||
using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors))
|
||||
|
|
@ -336,24 +380,37 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
var url = ServerFixture.Url + "/auth";
|
||||
var connection = new HttpConnection(new Uri(url), HttpTransportType.WebSockets, loggerFactory);
|
||||
|
||||
try
|
||||
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
|
||||
|
||||
Assert.Contains("401", exception.Message);
|
||||
}
|
||||
}
|
||||
|
||||
[ConditionalFact]
|
||||
[WebSocketsSupportedCondition]
|
||||
public async Task UnauthorizedDirectWebSocketsConnectionDoesNotConnect()
|
||||
{
|
||||
bool ExpectedErrors(WriteContext writeContext)
|
||||
{
|
||||
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
|
||||
writeContext.EventId.Name == "ErrorStartingTransport";
|
||||
}
|
||||
|
||||
using (StartVerifableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors))
|
||||
{
|
||||
var logger = loggerFactory.CreateLogger<EndToEndTests>();
|
||||
|
||||
var url = ServerFixture.Url + "/auth";
|
||||
var options = new HttpConnectionOptions
|
||||
{
|
||||
logger.LogInformation("Starting connection to {url}", url);
|
||||
await connection.StartAsync(TransferFormat.Binary).OrTimeout();
|
||||
Assert.True(false);
|
||||
}
|
||||
catch (WebSocketException) { }
|
||||
catch (Exception ex)
|
||||
{
|
||||
logger.LogInformation(ex, "Test threw exception");
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
{
|
||||
logger.LogInformation("Disposing Connection");
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
logger.LogInformation("Disposed Connection");
|
||||
}
|
||||
Url = new Uri(url),
|
||||
Transports = HttpTransportType.WebSockets,
|
||||
SkipNegotiation = true
|
||||
};
|
||||
|
||||
var connection = new HttpConnection(options, loggerFactory);
|
||||
|
||||
await Assert.ThrowsAsync<WebSocketException>(() => connection.StartAsync(TransferFormat.Binary).OrTimeout());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue