Flush first long poll immediately (#2032)

There was a race condition between the first poll and any other http request that was sent. 
In particular, if you called StartAsync then StopAsync it was possible for the delete to happen before the poll started leading to 400 errors. This change fixes that by making the very first poll
return immediately so that the client can use that to determine if there was an error connecting.
This commit is contained in:
David Fowler 2018-04-17 00:49:26 -07:00 committed by GitHub
parent d35bcea0a5
commit 05d6bbb782
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 398 additions and 114 deletions

View File

@ -1,6 +1,6 @@
{
"name": "@aspnet/signalr-protocol-msgpack",
"version": "1.0.0-preview3-t000",
"version": "1.0.0-rc1-t000",
"lockfileVersion": 1,
"requires": true,
"dependencies": {

View File

@ -1,6 +1,6 @@
{
"name": "@aspnet/signalr",
"version": "1.0.0-preview3-t000",
"version": "1.0.0-rc1-t000",
"lockfileVersion": 1,
"requires": true,
"dependencies": {

View File

@ -288,13 +288,73 @@ describe("HttpConnection", () => {
}
});
it("sets inherentKeepAlive feature when using LongPolling", async (done) => {
it("authorization header removed when token factory returns null and using LongPolling", async (done) => {
const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] };
var httpClientGetCount = 0;
var accessTokenFactoryCount = 0;
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient: new TestHttpClient()
.on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })),
.on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] }))
.on("GET", (r) => {
httpClientGetCount++;
const authorizationValue = r.headers["Authorization"];
if (httpClientGetCount == 1) {
if (authorizationValue) {
fail("First long poll request should have a authorization header.");
}
// First long polling request must succeed so start completes
return "";
} else {
// Check second long polling request has its header removed
if (authorizationValue) {
fail("Second long poll request should have no authorization header.");
}
throw new Error("fail");
}
}),
accessTokenFactory: () => {
accessTokenFactoryCount++;
if (accessTokenFactoryCount == 1) {
return "A token value";
} else {
// Return a null value after the first call to test the header being removed
return null;
}
},
} as IHttpConnectionOptions;
const connection = new HttpConnection("http://tempuri.org", options);
try {
await connection.start(TransferFormat.Text);
expect(httpClientGetCount).toBeGreaterThanOrEqual(2);
expect(accessTokenFactoryCount).toBeGreaterThanOrEqual(2);
done();
} catch (e) {
fail(e);
done();
}
});
it("sets inherentKeepAlive feature when using LongPolling", async (done) => {
const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] };
var httpClientGetCount = 0;
const options: IHttpConnectionOptions = {
...commonOptions,
httpClient: new TestHttpClient()
.on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] }))
.on("GET", (r) => {
httpClientGetCount++;
if (httpClientGetCount == 1) {
// First long polling request must succeed so start completes
return "";
} else {
throw new Error("fail");
}
}),
} as IHttpConnectionOptions;
const connection = new HttpConnection("http://tempuri.org", options);

View File

@ -31,7 +31,7 @@ export class LongPollingTransport implements ITransport {
this.logMessageContent = logMessageContent;
}
public connect(url: string, transferFormat: TransferFormat): Promise<void> {
public async connect(url: string, transferFormat: TransferFormat): Promise<void> {
Arg.isRequired(url, "url");
Arg.isRequired(transferFormat, "transferFormat");
Arg.isIn(transferFormat, TransferFormat, "transferFormat");
@ -45,13 +45,6 @@ export class LongPollingTransport implements ITransport {
throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported.");
}
this.poll(this.url, transferFormat);
return Promise.resolve();
}
private async poll(url: string, transferFormat: TransferFormat): Promise<void> {
this.running = true;
const pollOptions: HttpRequest = {
abortSignal: this.pollAbort.signal,
headers: {},
@ -62,15 +55,49 @@ export class LongPollingTransport implements ITransport {
pollOptions.responseType = "arraybuffer";
}
const token = await this.accessTokenFactory();
this.updateHeaderToken(pollOptions, token);
let closeError: Error;
// Make initial long polling request
// Server uses first long polling request to finish initializing connection and it returns without data
const pollUrl = `${url}&_=${Date.now()}`;
this.logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}`);
const response = await this.httpClient.get(pollUrl, pollOptions);
if (response.statusCode !== 200) {
this.logger.log(LogLevel.Error, `(LongPolling transport) Unexpected response code: ${response.statusCode}`);
// Mark running as false so that the poll immediately ends and runs the close logic
closeError = new HttpError(response.statusText, response.statusCode);
this.running = false;
} else {
this.running = true;
}
this.poll(this.url, pollOptions, closeError);
return Promise.resolve();
}
private updateHeaderToken(request: HttpRequest, token: string) {
if (token) {
// tslint:disable-next-line:no-string-literal
request.headers["Authorization"] = `Bearer ${token}`;
return;
}
// tslint:disable-next-line:no-string-literal
if (request.headers["Authorization"]) {
// tslint:disable-next-line:no-string-literal
delete request.headers["Authorization"];
}
}
private async poll(url: string, pollOptions: HttpRequest, closeError: Error): Promise<void> {
try {
while (this.running) {
// We have to get the access token on each poll, in case it changes
const token = await this.accessTokenFactory();
if (token) {
// tslint:disable-next-line:no-string-literal
pollOptions.headers["Authorization"] = `Bearer ${token}`;
}
this.updateHeaderToken(pollOptions, token);
try {
const pollUrl = `${url}&_=${Date.now()}`;
@ -142,14 +169,11 @@ export class LongPollingTransport implements ITransport {
this.running = false;
this.logger.log(LogLevel.Trace, `(LongPolling transport) sending DELETE request to ${this.url}.`);
const deleteOptions: HttpRequest = {};
const deleteOptions: HttpRequest = {
headers: {},
};
const token = await this.accessTokenFactory();
if (token) {
// tslint:disable-next-line:no-string-literal
deleteOptions.headers = {
["Authorization"]: `Bearer ${token}`,
};
}
this.updateHeaderToken(deleteOptions, token);
const response = await this.httpClient.delete(this.url, deleteOptions);
this.logger.log(LogLevel.Trace, "(LongPolling transport) DELETE request accepted.");

View File

@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<LongPollingTransport>();
}
public Task StartAsync(Uri url, TransferFormat transferFormat)
public async Task StartAsync(Uri url, TransferFormat transferFormat)
{
if (transferFormat != TransferFormat.Binary && transferFormat != TransferFormat.Text)
{
@ -49,6 +49,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
Log.StartTransport(_logger, transferFormat);
// Make initial long polling request
// Server uses first long polling request to finish initializing connection and it returns without data
var request = new HttpRequestMessage(HttpMethod.Get, url);
using (var response = await _httpClient.SendAsync(request))
{
response.EnsureSuccessStatusCode();
}
// Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa)
var options = ClientPipeOptions.DefaultOptions;
var pair = DuplexPipe.CreateConnectionPair(options, options);
@ -57,8 +65,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
_application = pair.Application;
Running = ProcessAsync(url);
return Task.CompletedTask;
}
private async Task ProcessAsync(Uri url)
@ -105,6 +111,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
Log.TransportStopping(_logger);
if (_application == null)
{
// We never started
return;
}
_application.Input.CancelPendingRead();
try

View File

@ -207,6 +207,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
Log.TransportStopping(_logger);
if (_application == null)
{
// We never started
return;
}
_transport.Output.Complete();
_transport.Input.Complete();

View File

@ -373,6 +373,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal
{
Log.TransportStopping(_logger);
if (_application == null)
{
// We never started
return;
}
_transport.Output.Complete();
_transport.Input.Complete();

View File

@ -210,7 +210,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
using (connection.Cancellation)
{
// Cancel the previous request
connection.Cancellation.Cancel();
connection.Cancellation?.Cancel();
// Wait for the previous request to drain
await connection.TransportTask;
@ -228,29 +228,38 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
Log.EstablishedConnection(_logger);
connection.ApplicationTask = ExecuteApplication(connectionDelegate, connection);
context.Response.ContentType = "application/octet-stream";
// This request has no content
context.Response.ContentLength = 0;
// On the first poll, we flush the response immediately to mark the poll as "initialized" so future
// requests can be made safely
connection.TransportTask = context.Response.Body.FlushAsync();
}
else
{
Log.ResumingConnection(_logger);
// REVIEW: Performance of this isn't great as this does a bunch of per request allocations
connection.Cancellation = new CancellationTokenSource();
var timeoutSource = new CancellationTokenSource();
var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token);
// Dispose these tokens when the request is over
context.Response.RegisterForDispose(timeoutSource);
context.Response.RegisterForDispose(tokenSource);
var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory);
// Start the transport
connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
// Start the timeout after we return from creating the transport task
timeoutSource.CancelAfter(options.LongPolling.PollTimeout);
}
// REVIEW: Performance of this isn't great as this does a bunch of per request allocations
connection.Cancellation = new CancellationTokenSource();
var timeoutSource = new CancellationTokenSource();
var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token);
// Dispose these tokens when the request is over
context.Response.RegisterForDispose(timeoutSource);
context.Response.RegisterForDispose(tokenSource);
var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory);
// Start the transport
connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token);
// Start the timeout after we return from creating the transport task
timeoutSource.CancelAfter(options.LongPolling.PollTimeout);
}
finally
{
@ -302,7 +311,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
connection.Status = HttpConnectionStatus.Inactive;
// Dispose the cancellation token
connection.Cancellation.Dispose();
connection.Cancellation?.Dispose();
connection.Cancellation = null;
}
@ -474,7 +483,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
context.Response.ContentType = "text/plain";
return;
}
await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize);
Log.ReceivedBytes(_logger, connection.ApplicationStream.Length);

View File

@ -636,6 +636,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
// Start a poll
var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.True(task.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
// Send to the application
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -745,7 +749,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
}
[Theory]
[InlineData(HttpTransportType.LongPolling, 204)]
[InlineData(HttpTransportType.LongPolling, 200)]
[InlineData(HttpTransportType.WebSockets, 404)]
[InlineData(HttpTransportType.ServerSentEvents, 404)]
public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(HttpTransportType transportType, int status)
@ -869,6 +873,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var app = builder.Build();
// First poll will 200
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
@ -998,6 +1006,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
var request1 = dispatcher.ExecuteAsync(context1, options, app);
Assert.True(request1.IsCompleted);
request1 = dispatcher.ExecuteAsync(context1, options, app);
var request2 = dispatcher.ExecuteAsync(context2, options, app);
await request1;
@ -1132,7 +1143,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
builder.UseConnectionHandler<BlockingConnectionHandler>();
var app = builder.Build();
var options = new HttpConnectionDispatcherOptions();
// Initial poll
var task = dispatcher.ExecuteAsync(context, options, app);
Assert.True(task.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
// Real long running poll
task = dispatcher.ExecuteAsync(context, options, app);
var buffer = Encoding.UTF8.GetBytes("Hello World");
@ -1166,7 +1184,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var options = new HttpConnectionDispatcherOptions();
var context1 = MakeRequest("/foo", connection);
// This is the initial poll to make sure things are setup
var task1 = dispatcher.ExecuteAsync(context1, options, app);
Assert.True(task1.IsCompleted);
task1 = dispatcher.ExecuteAsync(context1, options, app);
var context2 = MakeRequest("/foo", connection);
var task2 = dispatcher.ExecuteAsync(context2, options, app);
@ -1363,10 +1384,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
Assert.Equal("Hello, World", GetContentAsString(context.Response.Body));
}
@ -1444,7 +1468,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW")
}));
// First poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
@ -1502,7 +1531,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
// "authorize" user
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") }));
// Initial poll
var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(connectionHandlerTask.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app);
await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout();
await connectionHandlerTask.OrTimeout();
@ -1660,6 +1694,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
var options = new HttpConnectionDispatcherOptions();
var pollTask = dispatcher.ExecuteAsync(context, options, app);
Assert.True(pollTask.IsCompleted);
// Now send the second poll
pollTask = dispatcher.ExecuteAsync(context, options, app);
// Issue the delete request and make sure the poll completes
var deleteContext = new DefaultHttpContext();

View File

@ -13,14 +13,20 @@ using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Connections;
using Microsoft.AspNetCore.Http.Connections.Client;
using Microsoft.AspNetCore.Http.Connections.Client.Internal;
using Microsoft.AspNetCore.SignalR.Tests;
using Xunit;
using Xunit.Abstractions;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
public partial class HttpConnectionTests
{
public class Transport
public class Transport : VerifiableLoggedTest
{
public Transport(ITestOutputHelper output) : base(output)
{
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
@ -30,11 +36,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var requestsExecuted = false;
var callCount = 0;
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
testHttpHandler.OnNegotiate((_, cancellationToken) =>
{
return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent());
@ -52,6 +53,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return await next();
});
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
Task<string> AccessTokenProvider()
{
callCount++;
@ -75,28 +81,25 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
[InlineData(HttpTransportType.ServerSentEvents, false)]
public async Task HttpConnectionSetsInherentKeepAliveFeature(HttpTransportType transportType, bool expectedValue)
{
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
testHttpHandler.OnRequest((request, next, token) =>
using (StartVerifableLog(out var loggerFactory, testName: $"HttpConnectionSetsInherentKeepAliveFeature_{transportType}_{expectedValue}"))
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
testHttpHandler.OnNegotiate((_, cancellationToken) =>
{
return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent());
});
testHttpHandler.OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()));
await WithConnectionAsync(
CreateConnection(testHttpHandler, transportType: transportType),
async (connection) =>
{
await connection.StartAsync(TransferFormat.Text).OrTimeout();
testHttpHandler.OnRequest((request, next, token) => Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)));
var feature = connection.Features.Get<IConnectionInherentKeepAliveFeature>();
Assert.NotNull(feature);
Assert.Equal(expectedValue, feature.HasInherentKeepAlive);
});
await WithConnectionAsync(
CreateConnection(testHttpHandler, transportType: transportType, loggerFactory: loggerFactory),
async (connection) =>
{
await connection.StartAsync(TransferFormat.Text).OrTimeout();
var feature = connection.Features.Get<IConnectionInherentKeepAliveFeature>();
Assert.NotNull(feature);
Assert.Equal(expectedValue, feature.HasInherentKeepAlive);
});
}
}
[Theory]
@ -107,10 +110,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
var requestsExecuted = false;
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
testHttpHandler.OnNegotiate((_, cancellationToken) =>
{
@ -135,6 +134,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return await next();
});
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
await WithConnectionAsync(
CreateConnection(testHttpHandler, transportType: transportType),
async (connection) =>
@ -154,11 +158,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false);
var requestsExecuted = false;
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
testHttpHandler.OnNegotiate((_, cancellationToken) =>
{
return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent());
@ -175,6 +174,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return await next();
});
testHttpHandler.OnRequest((request, next, token) =>
{
return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent));
});
await WithConnectionAsync(
CreateConnection(testHttpHandler, transportType: transportType),
async (connection) =>

View File

@ -109,15 +109,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
if (requests == 0)
{
requests++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello");
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
else if (requests == 1)
{
requests++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello");
}
else if (requests == 2)
{
requests++;
// Time out
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
else if (requests == 2)
else if (requests == 3)
{
requests++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK, "World");
@ -147,7 +152,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
[Fact]
public async Task LongPollingTransportStopsWhenPollRequestFails()
public async Task LongPollingTransportStartAsyncFailsIfFirstRequestFails()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
@ -158,6 +163,39 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError);
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var longPollingTransport = new LongPollingTransport(httpClient);
try
{
var exception = await Assert.ThrowsAsync<HttpRequestException>(() => longPollingTransport.StartAsync(TestUri, TransferFormat.Binary));
Assert.Contains(" 500 ", exception.Message);
}
finally
{
await longPollingTransport.StopAsync();
}
}
}
[Fact]
public async Task LongPollingTransportStopsWhenPollRequestFails()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
var firstPoll = true;
mockHttpHandler.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>(async (request, cancellationToken) =>
{
await Task.Yield();
if (firstPoll)
{
firstPoll = false;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError);
});
using (var httpClient = new HttpClient(mockHttpHandler.Object))
{
var longPollingTransport = new LongPollingTransport(httpClient);
@ -314,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var message1Payload = new[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o' };
var firstCall = true;
var requests = 0;
var mockHttpHandler = new Mock<HttpMessageHandler>();
var sentRequests = new List<HttpRequestMessage>();
mockHttpHandler.Protected()
@ -325,9 +363,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await Task.Yield();
if (firstCall)
if (requests == 0)
{
firstCall = false;
requests++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
else if (requests == 1)
{
requests++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK, message1Payload);
}
@ -349,7 +392,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var message = await longPollingTransport.Input.ReadAllAsync();
// Check the provided request
Assert.Equal(2, sentRequests.Count);
Assert.Equal(3, sentRequests.Count);
// Check the messages received
Assert.Equal(message1Payload, message);
@ -366,6 +409,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var sentRequests = new List<byte[]>();
var tcs = new TaskCompletionSource<HttpResponseMessage>();
var firstPoll = true;
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
@ -380,6 +424,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
else if (request.Method == HttpMethod.Get)
{
// First poll completes immediately
if (firstPoll)
{
firstPoll = false;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken));
// This is the poll task
return await tcs.Task;
@ -426,6 +477,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
var sentRequests = new List<byte[]>();
var pollTcs = new TaskCompletionSource<HttpResponseMessage>();
var deleteTcs = new TaskCompletionSource<object>();
var firstPoll = true;
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
@ -440,6 +492,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
else if (request.Method == HttpMethod.Get)
{
// First poll completes immediately
if (firstPoll)
{
firstPoll = false;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
cancellationToken.Register(() => pollTcs.TrySetCanceled(cancellationToken));
// This is the poll task
return await pollTcs.Task;
@ -538,7 +597,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
await Task.Yield();
if (Interlocked.Increment(ref numPolls) < 3)
if (numPolls == 0)
{
numPolls++;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
if (numPolls++ < 3)
{
throw new OperationCanceledException();
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using Microsoft.AspNetCore.Connections;
@ -39,6 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public static bool IsLongPollRequest(HttpRequestMessage request)
{
return request.Method == HttpMethod.Get &&
!IsServerSentEventsRequest(request) &&
(request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id="));
}
@ -48,6 +50,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
(request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id="));
}
public static bool IsServerSentEventsRequest(HttpRequestMessage request)
{
return request.Method == HttpMethod.Get && request.Headers.Accept.Any(h => h.MediaType == "text/event-stream");
}
public static bool IsSocketSendRequest(HttpRequestMessage request)
{
return request.Method == HttpMethod.Post &&

View File

@ -7,10 +7,14 @@ using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
delegate Task<HttpResponseMessage> RequestDelegate(HttpRequestMessage requestMessage, CancellationToken cancellationToken);
public class TestHttpMessageHandler : HttpMessageHandler
{
private List<HttpRequestMessage> _receivedRequests = new List<HttpRequestMessage>();
private Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> _handler;
private RequestDelegate _app;
private List<Func<RequestDelegate, RequestDelegate>> _middleware = new List<Func<RequestDelegate, RequestDelegate>>();
public IReadOnlyList<HttpRequestMessage> ReceivedRequests
{
@ -23,14 +27,29 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
public TestHttpMessageHandler(bool autoNegotiate = true)
public TestHttpMessageHandler(bool autoNegotiate = true, bool handleFirstPoll = true)
{
_handler = BaseHandler;
if (autoNegotiate)
{
OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()));
}
if (handleFirstPoll)
{
var firstPoll = true;
OnRequest(async (request, next, cancellationToken) =>
{
if (ResponseUtils.IsLongPollRequest(request) && firstPoll)
{
firstPoll = false;
return ResponseUtils.CreateResponse(HttpStatusCode.OK);
}
else
{
return await next();
}
});
}
}
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
@ -40,9 +59,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
lock (_receivedRequests)
{
_receivedRequests.Add(request);
if (_app == null)
{
_middleware.Reverse();
RequestDelegate handler = BaseHandler;
foreach (var middleware in _middleware)
{
handler = middleware(handler);
}
_app = handler;
}
}
return await _handler(request, cancellationToken);
return await _app(request, cancellationToken);
}
public static TestHttpMessageHandler CreateDefault()
@ -80,8 +111,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public void OnRequest(Func<HttpRequestMessage, Func<Task<HttpResponseMessage>>, CancellationToken, Task<HttpResponseMessage>> handler)
{
var nextHandler = _handler;
_handler = (request, cancellationToken) => handler(request, () => nextHandler(request, cancellationToken), cancellationToken);
void OnRequestCore(Func<RequestDelegate, RequestDelegate> middleware)
{
_middleware.Add(middleware);
}
OnRequestCore(next =>
{
return (request, cancellationToken) =>
{
return handler(request, () => next(request, cancellationToken), cancellationToken);
};
});
}
public void OnGet(string pathAndQuery, Func<HttpRequestMessage, CancellationToken, Task<HttpResponseMessage>> handler) => OnRequest(HttpMethod.Get, pathAndQuery, handler);

View File

@ -51,32 +51,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests
[Fact]
public async Task CanStartAndStopConnectionUsingDefaultTransport()
{
var url = _serverFixture.Url + "/echo";
// The test should connect to the server using WebSockets transport on Windows 8 and newer.
// On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server.
var connection = new HttpConnection(new Uri(url));
await connection.StartAsync(TransferFormat.Binary).OrTimeout();
await connection.DisposeAsync().OrTimeout();
using (StartVerifableLog(out var loggerFactory))
{
var url = _serverFixture.Url + "/echo";
// The test should connect to the server using WebSockets transport on Windows 8 and newer.
// On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server.
var connection = new HttpConnection(new Uri(url), HttpTransports.All, loggerFactory);
await connection.StartAsync(TransferFormat.Binary).OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task TransportThatFallsbackCreatesNewConnection()
{
var url = _serverFixture.Url + "/echo";
// The test should connect to the server using WebSockets transport on Windows 8 and newer.
// On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server.
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnection).FullName &&
writeContext.EventId.Name == "ErrorStartingTransport";
}
// The test logic lives in the TestTransportFactory and FakeTransport.
var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, null, new TestTransportFactory());
await connection.StartAsync(TransferFormat.Text).OrTimeout();
await connection.DisposeAsync().OrTimeout();
using (StartVerifableLog(out var loggerFactory, expectedErrorsFilter: ExpectedErrors))
{
var url = _serverFixture.Url + "/echo";
// The test should connect to the server using WebSockets transport on Windows 8 and newer.
// On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server.
// The test logic lives in the TestTransportFactory and FakeTransport.
var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, loggerFactory, new TestTransportFactory());
await connection.StartAsync(TransferFormat.Text).OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Theory(Skip = "https://github.com/aspnet/SignalR/issues/2031")]
[Theory]
[MemberData(nameof(TransportTypes))]
public async Task CanStartAndStopConnectionUsingGivenTransport(HttpTransportType transportType)
{
using (StartVerifableLog(out var loggerFactory, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}"))
using (StartVerifableLog(out var loggerFactory, minLogLevel: LogLevel.Trace, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}"))
{
var url = _serverFixture.Url + "/echo";
var connection = new HttpConnection(new Uri(url), transportType, loggerFactory);
@ -532,7 +544,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
if (_tries < availableTransports)
{
throw new Exception();
return Task.FromException(new Exception());
}
else
{