From 05d6bbb78257e2bb93af85421e542f107934ba6f Mon Sep 17 00:00:00 2001 From: David Fowler Date: Tue, 17 Apr 2018 00:49:26 -0700 Subject: [PATCH] Flush first long poll immediately (#2032) There was a race condition between the first poll and any other http request that was sent. In particular, if you called StartAsync then StopAsync it was possible for the delete to happen before the poll started leading to 400 errors. This change fixes that by making the very first poll return immediately so that the client can use that to determine if there was an error connecting. --- .../package-lock.json | 2 +- clients/ts/signalr/package-lock.json | 2 +- .../ts/signalr/spec/HttpConnection.spec.ts | 64 ++++++++++++++- .../ts/signalr/src/LongPollingTransport.ts | 62 +++++++++----- .../Internal/LongPollingTransport.cs | 18 ++++- .../Internal/ServerSentEventsTransport.cs | 6 ++ .../Internal/WebSocketsTransport.cs | 6 ++ .../Internal/HttpConnectionDispatcher.cs | 51 +++++++----- .../HttpConnectionDispatcherTests.cs | 44 +++++++++- .../HttpConnectionTests.Transport.cs | 70 ++++++++-------- .../LongPollingTransportTests.cs | 81 +++++++++++++++++-- .../ResponseUtils.cs | 7 ++ .../TestHttpMessageHandler.cs | 55 +++++++++++-- .../EndToEndTests.cs | 44 ++++++---- 14 files changed, 398 insertions(+), 114 deletions(-) diff --git a/clients/ts/signalr-protocol-msgpack/package-lock.json b/clients/ts/signalr-protocol-msgpack/package-lock.json index 28f546b661..84c50c6f60 100644 --- a/clients/ts/signalr-protocol-msgpack/package-lock.json +++ b/clients/ts/signalr-protocol-msgpack/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr-protocol-msgpack", - "version": "1.0.0-preview3-t000", + "version": "1.0.0-rc1-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/clients/ts/signalr/package-lock.json b/clients/ts/signalr/package-lock.json index 0447edc39b..e7e2e130b4 100644 --- a/clients/ts/signalr/package-lock.json +++ b/clients/ts/signalr/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr", - "version": "1.0.0-preview3-t000", + "version": "1.0.0-rc1-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts index 15184bd686..09bd65c4f7 100644 --- a/clients/ts/signalr/spec/HttpConnection.spec.ts +++ b/clients/ts/signalr/spec/HttpConnection.spec.ts @@ -288,13 +288,73 @@ describe("HttpConnection", () => { } }); - it("sets inherentKeepAlive feature when using LongPolling", async (done) => { + it("authorization header removed when token factory returns null and using LongPolling", async (done) => { const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; + var httpClientGetCount = 0; + var accessTokenFactoryCount = 0; const options: IHttpConnectionOptions = { ...commonOptions, httpClient: new TestHttpClient() - .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })), + .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })) + .on("GET", (r) => { + httpClientGetCount++; + const authorizationValue = r.headers["Authorization"]; + if (httpClientGetCount == 1) { + if (authorizationValue) { + fail("First long poll request should have a authorization header."); + } + // First long polling request must succeed so start completes + return ""; + } else { + // Check second long polling request has its header removed + if (authorizationValue) { + fail("Second long poll request should have no authorization header."); + } + throw new Error("fail"); + } + }), + accessTokenFactory: () => { + accessTokenFactoryCount++; + if (accessTokenFactoryCount == 1) { + return "A token value"; + } else { + // Return a null value after the first call to test the header being removed + return null; + } + }, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + + try { + await connection.start(TransferFormat.Text); + expect(httpClientGetCount).toBeGreaterThanOrEqual(2); + expect(accessTokenFactoryCount).toBeGreaterThanOrEqual(2); + done(); + } catch (e) { + fail(e); + done(); + } + }); + + it("sets inherentKeepAlive feature when using LongPolling", async (done) => { + const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; + + var httpClientGetCount = 0; + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })) + .on("GET", (r) => { + httpClientGetCount++; + if (httpClientGetCount == 1) { + // First long polling request must succeed so start completes + return ""; + } else { + throw new Error("fail"); + } + }), } as IHttpConnectionOptions; const connection = new HttpConnection("http://tempuri.org", options); diff --git a/clients/ts/signalr/src/LongPollingTransport.ts b/clients/ts/signalr/src/LongPollingTransport.ts index 5f3e453277..86d3e8376f 100644 --- a/clients/ts/signalr/src/LongPollingTransport.ts +++ b/clients/ts/signalr/src/LongPollingTransport.ts @@ -31,7 +31,7 @@ export class LongPollingTransport implements ITransport { this.logMessageContent = logMessageContent; } - public connect(url: string, transferFormat: TransferFormat): Promise { + public async connect(url: string, transferFormat: TransferFormat): Promise { Arg.isRequired(url, "url"); Arg.isRequired(transferFormat, "transferFormat"); Arg.isIn(transferFormat, TransferFormat, "transferFormat"); @@ -45,13 +45,6 @@ export class LongPollingTransport implements ITransport { throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported."); } - this.poll(this.url, transferFormat); - return Promise.resolve(); - } - - private async poll(url: string, transferFormat: TransferFormat): Promise { - this.running = true; - const pollOptions: HttpRequest = { abortSignal: this.pollAbort.signal, headers: {}, @@ -62,15 +55,49 @@ export class LongPollingTransport implements ITransport { pollOptions.responseType = "arraybuffer"; } + const token = await this.accessTokenFactory(); + this.updateHeaderToken(pollOptions, token); + let closeError: Error; + + // Make initial long polling request + // Server uses first long polling request to finish initializing connection and it returns without data + const pollUrl = `${url}&_=${Date.now()}`; + this.logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}`); + const response = await this.httpClient.get(pollUrl, pollOptions); + if (response.statusCode !== 200) { + this.logger.log(LogLevel.Error, `(LongPolling transport) Unexpected response code: ${response.statusCode}`); + + // Mark running as false so that the poll immediately ends and runs the close logic + closeError = new HttpError(response.statusText, response.statusCode); + this.running = false; + } else { + this.running = true; + } + + this.poll(this.url, pollOptions, closeError); + return Promise.resolve(); + } + + private updateHeaderToken(request: HttpRequest, token: string) { + if (token) { + // tslint:disable-next-line:no-string-literal + request.headers["Authorization"] = `Bearer ${token}`; + return; + } + // tslint:disable-next-line:no-string-literal + if (request.headers["Authorization"]) { + // tslint:disable-next-line:no-string-literal + delete request.headers["Authorization"]; + } + } + + private async poll(url: string, pollOptions: HttpRequest, closeError: Error): Promise { try { while (this.running) { // We have to get the access token on each poll, in case it changes const token = await this.accessTokenFactory(); - if (token) { - // tslint:disable-next-line:no-string-literal - pollOptions.headers["Authorization"] = `Bearer ${token}`; - } + this.updateHeaderToken(pollOptions, token); try { const pollUrl = `${url}&_=${Date.now()}`; @@ -142,14 +169,11 @@ export class LongPollingTransport implements ITransport { this.running = false; this.logger.log(LogLevel.Trace, `(LongPolling transport) sending DELETE request to ${this.url}.`); - const deleteOptions: HttpRequest = {}; + const deleteOptions: HttpRequest = { + headers: {}, + }; const token = await this.accessTokenFactory(); - if (token) { - // tslint:disable-next-line:no-string-literal - deleteOptions.headers = { - ["Authorization"]: `Bearer ${token}`, - }; - } + this.updateHeaderToken(deleteOptions, token); const response = await this.httpClient.delete(this.url, deleteOptions); this.logger.log(LogLevel.Trace, "(LongPolling transport) DELETE request accepted."); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs index 4059049a96..860f54b3ee 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, TransferFormat transferFormat) + public async Task StartAsync(Uri url, TransferFormat transferFormat) { if (transferFormat != TransferFormat.Binary && transferFormat != TransferFormat.Text) { @@ -49,6 +49,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.StartTransport(_logger, transferFormat); + // Make initial long polling request + // Server uses first long polling request to finish initializing connection and it returns without data + var request = new HttpRequestMessage(HttpMethod.Get, url); + using (var response = await _httpClient.SendAsync(request)) + { + response.EnsureSuccessStatusCode(); + } + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) var options = ClientPipeOptions.DefaultOptions; var pair = DuplexPipe.CreateConnectionPair(options, options); @@ -57,8 +65,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _application = pair.Application; Running = ProcessAsync(url); - - return Task.CompletedTask; } private async Task ProcessAsync(Uri url) @@ -105,6 +111,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _application.Input.CancelPendingRead(); try diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs index 59ca9d7cbf..b92931a4d0 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs @@ -207,6 +207,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _transport.Output.Complete(); _transport.Input.Complete(); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs index 4520162c72..c6ebeabce3 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs @@ -373,6 +373,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _transport.Output.Complete(); _transport.Input.Complete(); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs index fb1b7a8701..0f4f88551f 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs @@ -210,7 +210,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal using (connection.Cancellation) { // Cancel the previous request - connection.Cancellation.Cancel(); + connection.Cancellation?.Cancel(); // Wait for the previous request to drain await connection.TransportTask; @@ -228,29 +228,38 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Log.EstablishedConnection(_logger); connection.ApplicationTask = ExecuteApplication(connectionDelegate, connection); + + context.Response.ContentType = "application/octet-stream"; + + // This request has no content + context.Response.ContentLength = 0; + + // On the first poll, we flush the response immediately to mark the poll as "initialized" so future + // requests can be made safely + connection.TransportTask = context.Response.Body.FlushAsync(); } else { Log.ResumingConnection(_logger); + + // REVIEW: Performance of this isn't great as this does a bunch of per request allocations + connection.Cancellation = new CancellationTokenSource(); + + var timeoutSource = new CancellationTokenSource(); + var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token); + + // Dispose these tokens when the request is over + context.Response.RegisterForDispose(timeoutSource); + context.Response.RegisterForDispose(tokenSource); + + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); + + // Start the transport + connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); + + // Start the timeout after we return from creating the transport task + timeoutSource.CancelAfter(options.LongPolling.PollTimeout); } - - // REVIEW: Performance of this isn't great as this does a bunch of per request allocations - connection.Cancellation = new CancellationTokenSource(); - - var timeoutSource = new CancellationTokenSource(); - var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token); - - // Dispose these tokens when the request is over - context.Response.RegisterForDispose(timeoutSource); - context.Response.RegisterForDispose(tokenSource); - - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); - - // Start the transport - connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); - - // Start the timeout after we return from creating the transport task - timeoutSource.CancelAfter(options.LongPolling.PollTimeout); } finally { @@ -302,7 +311,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal connection.Status = HttpConnectionStatus.Inactive; // Dispose the cancellation token - connection.Cancellation.Dispose(); + connection.Cancellation?.Dispose(); connection.Cancellation = null; } @@ -474,7 +483,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal context.Response.ContentType = "text/plain"; return; } - + await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize); Log.ReceivedBytes(_logger, connection.ApplicationStream.Length); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index e0fbf43bd2..9443099b05 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -636,6 +636,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Start a poll var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); + Assert.True(task.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); // Send to the application var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -745,7 +749,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } [Theory] - [InlineData(HttpTransportType.LongPolling, 204)] + [InlineData(HttpTransportType.LongPolling, 200)] [InlineData(HttpTransportType.WebSockets, 404)] [InlineData(HttpTransportType.ServerSentEvents, 404)] public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(HttpTransportType transportType, int status) @@ -869,6 +873,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); + // First poll will 200 + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -998,6 +1006,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); + Assert.True(request1.IsCompleted); + + request1 = dispatcher.ExecuteAsync(context1, options, app); var request2 = dispatcher.ExecuteAsync(context2, options, app); await request1; @@ -1132,7 +1143,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests builder.UseConnectionHandler(); var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); + + // Initial poll var task = dispatcher.ExecuteAsync(context, options, app); + Assert.True(task.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // Real long running poll + task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -1166,7 +1184,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var options = new HttpConnectionDispatcherOptions(); var context1 = MakeRequest("/foo", connection); + // This is the initial poll to make sure things are setup var task1 = dispatcher.ExecuteAsync(context1, options, app); + Assert.True(task1.IsCompleted); + task1 = dispatcher.ExecuteAsync(context1, options, app); var context2 = MakeRequest("/foo", connection); var task2 = dispatcher.ExecuteAsync(context2, options, app); @@ -1363,10 +1384,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); - await connectionHandlerTask.OrTimeout(); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); + await connectionHandlerTask.OrTimeout(); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); Assert.Equal("Hello, World", GetContentAsString(context.Response.Body)); } @@ -1444,7 +1468,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") })); + // First poll var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(connectionHandlerTask.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); await connectionHandlerTask.OrTimeout(); @@ -1502,7 +1531,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); + // Initial poll var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(connectionHandlerTask.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); await connectionHandlerTask.OrTimeout(); @@ -1660,6 +1694,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var options = new HttpConnectionDispatcherOptions(); var pollTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(pollTask.IsCompleted); + + // Now send the second poll + pollTask = dispatcher.ExecuteAsync(context, options, app); // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs index 8357e9e693..42839f7472 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs @@ -13,14 +13,20 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Client; using Microsoft.AspNetCore.Http.Connections.Client.Internal; +using Microsoft.AspNetCore.SignalR.Tests; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HttpConnectionTests { - public class Transport + public class Transport : VerifiableLoggedTest { + public Transport(ITestOutputHelper output) : base(output) + { + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] @@ -30,11 +36,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var requestsExecuted = false; var callCount = 0; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); - testHttpHandler.OnNegotiate((_, cancellationToken) => { return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); @@ -52,6 +53,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + Task AccessTokenProvider() { callCount++; @@ -75,28 +81,25 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData(HttpTransportType.ServerSentEvents, false)] public async Task HttpConnectionSetsInherentKeepAliveFeature(HttpTransportType transportType, bool expectedValue) { - var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); - - testHttpHandler.OnRequest((request, next, token) => + using (StartVerifableLog(out var loggerFactory, testName: $"HttpConnectionSetsInherentKeepAliveFeature_{transportType}_{expectedValue}")) { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); - testHttpHandler.OnNegotiate((_, cancellationToken) => - { - return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); - }); + testHttpHandler.OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent())); - await WithConnectionAsync( - CreateConnection(testHttpHandler, transportType: transportType), - async (connection) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); + testHttpHandler.OnRequest((request, next, token) => Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); - var feature = connection.Features.Get(); - Assert.NotNull(feature); - Assert.Equal(expectedValue, feature.HasInherentKeepAlive); - }); + await WithConnectionAsync( + CreateConnection(testHttpHandler, transportType: transportType, loggerFactory: loggerFactory), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + + var feature = connection.Features.Get(); + Assert.NotNull(feature); + Assert.Equal(expectedValue, feature.HasInherentKeepAlive); + }); + } } [Theory] @@ -107,10 +110,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); var requestsExecuted = false; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); testHttpHandler.OnNegotiate((_, cancellationToken) => { @@ -135,6 +134,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + await WithConnectionAsync( CreateConnection(testHttpHandler, transportType: transportType), async (connection) => @@ -154,11 +158,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); var requestsExecuted = false; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); - testHttpHandler.OnNegotiate((_, cancellationToken) => { return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); @@ -175,6 +174,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + await WithConnectionAsync( CreateConnection(testHttpHandler, transportType: transportType), async (connection) => diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 5e203256f0..ee4c60daac 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -109,15 +109,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests if (requests == 0) { requests++; - return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello"); + return ResponseUtils.CreateResponse(HttpStatusCode.OK); } else if (requests == 1) + { + requests++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello"); + } + else if (requests == 2) { requests++; // Time out return ResponseUtils.CreateResponse(HttpStatusCode.OK); } - else if (requests == 2) + else if (requests == 3) { requests++; return ResponseUtils.CreateResponse(HttpStatusCode.OK, "World"); @@ -147,7 +152,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task LongPollingTransportStopsWhenPollRequestFails() + public async Task LongPollingTransportStartAsyncFailsIfFirstRequestFails() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -158,6 +163,39 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); }); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + try + { + var exception = await Assert.ThrowsAsync(() => longPollingTransport.StartAsync(TestUri, TransferFormat.Binary)); + Assert.Contains(" 500 ", exception.Message); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } + + [Fact] + public async Task LongPollingTransportStopsWhenPollRequestFails() + { + var mockHttpHandler = new Mock(); + var firstPoll = true; + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); + }); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var longPollingTransport = new LongPollingTransport(httpClient); @@ -314,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var message1Payload = new[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o' }; - var firstCall = true; + var requests = 0; var mockHttpHandler = new Mock(); var sentRequests = new List(); mockHttpHandler.Protected() @@ -325,9 +363,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await Task.Yield(); - if (firstCall) + if (requests == 0) { - firstCall = false; + requests++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + else if (requests == 1) + { + requests++; return ResponseUtils.CreateResponse(HttpStatusCode.OK, message1Payload); } @@ -349,7 +392,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var message = await longPollingTransport.Input.ReadAllAsync(); // Check the provided request - Assert.Equal(2, sentRequests.Count); + Assert.Equal(3, sentRequests.Count); // Check the messages received Assert.Equal(message1Payload, message); @@ -366,6 +409,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sentRequests = new List(); var tcs = new TaskCompletionSource(); + var firstPoll = true; var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -380,6 +424,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } else if (request.Method == HttpMethod.Get) { + // First poll completes immediately + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); // This is the poll task return await tcs.Task; @@ -426,6 +477,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var sentRequests = new List(); var pollTcs = new TaskCompletionSource(); var deleteTcs = new TaskCompletionSource(); + var firstPoll = true; var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -440,6 +492,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } else if (request.Method == HttpMethod.Get) { + // First poll completes immediately + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + cancellationToken.Register(() => pollTcs.TrySetCanceled(cancellationToken)); // This is the poll task return await pollTcs.Task; @@ -538,7 +597,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await Task.Yield(); - if (Interlocked.Increment(ref numPolls) < 3) + if (numPolls == 0) + { + numPolls++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + + if (numPolls++ < 3) { throw new OperationCanceledException(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs index 23ab6d73b5..58d5fbb53c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Net; using System.Net.Http; using Microsoft.AspNetCore.Connections; @@ -39,6 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public static bool IsLongPollRequest(HttpRequestMessage request) { return request.Method == HttpMethod.Get && + !IsServerSentEventsRequest(request) && (request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id=")); } @@ -48,6 +50,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests (request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id=")); } + public static bool IsServerSentEventsRequest(HttpRequestMessage request) + { + return request.Method == HttpMethod.Get && request.Headers.Accept.Any(h => h.MediaType == "text/event-stream"); + } + public static bool IsSocketSendRequest(HttpRequestMessage request) { return request.Method == HttpMethod.Post && diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs index fc9a7e780b..6bde5d0af2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -7,10 +7,14 @@ using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Client.Tests { + delegate Task RequestDelegate(HttpRequestMessage requestMessage, CancellationToken cancellationToken); + public class TestHttpMessageHandler : HttpMessageHandler { private List _receivedRequests = new List(); - private Func> _handler; + private RequestDelegate _app; + + private List> _middleware = new List>(); public IReadOnlyList ReceivedRequests { @@ -23,14 +27,29 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } - public TestHttpMessageHandler(bool autoNegotiate = true) + public TestHttpMessageHandler(bool autoNegotiate = true, bool handleFirstPoll = true) { - _handler = BaseHandler; - if (autoNegotiate) { OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent())); } + + if (handleFirstPoll) + { + var firstPoll = true; + OnRequest(async (request, next, cancellationToken) => + { + if (ResponseUtils.IsLongPollRequest(request) && firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + else + { + return await next(); + } + }); + } } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -40,9 +59,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests lock (_receivedRequests) { _receivedRequests.Add(request); + + if (_app == null) + { + _middleware.Reverse(); + RequestDelegate handler = BaseHandler; + foreach (var middleware in _middleware) + { + handler = middleware(handler); + } + + _app = handler; + } } - return await _handler(request, cancellationToken); + return await _app(request, cancellationToken); } public static TestHttpMessageHandler CreateDefault() @@ -80,8 +111,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public void OnRequest(Func>, CancellationToken, Task> handler) { - var nextHandler = _handler; - _handler = (request, cancellationToken) => handler(request, () => nextHandler(request, cancellationToken), cancellationToken); + void OnRequestCore(Func middleware) + { + _middleware.Add(middleware); + } + + OnRequestCore(next => + { + return (request, cancellationToken) => + { + return handler(request, () => next(request, cancellationToken), cancellationToken); + }; + }); } public void OnGet(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Get, pathAndQuery, handler); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index b9f3e63c0c..d77ba5e2ff 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -51,32 +51,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public async Task CanStartAndStopConnectionUsingDefaultTransport() { - var url = _serverFixture.Url + "/echo"; - // The test should connect to the server using WebSockets transport on Windows 8 and newer. - // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. - var connection = new HttpConnection(new Uri(url)); - await connection.StartAsync(TransferFormat.Binary).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + using (StartVerifableLog(out var loggerFactory)) + { + var url = _serverFixture.Url + "/echo"; + // The test should connect to the server using WebSockets transport on Windows 8 and newer. + // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + var connection = new HttpConnection(new Uri(url), HttpTransports.All, loggerFactory); + await connection.StartAsync(TransferFormat.Binary).OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } } [Fact] public async Task TransportThatFallsbackCreatesNewConnection() { - var url = _serverFixture.Url + "/echo"; - // The test should connect to the server using WebSockets transport on Windows 8 and newer. - // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorStartingTransport"; + } - // The test logic lives in the TestTransportFactory and FakeTransport. - var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, null, new TestTransportFactory()); - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + using (StartVerifableLog(out var loggerFactory, expectedErrorsFilter: ExpectedErrors)) + { + var url = _serverFixture.Url + "/echo"; + // The test should connect to the server using WebSockets transport on Windows 8 and newer. + // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + + // The test logic lives in the TestTransportFactory and FakeTransport. + var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, loggerFactory, new TestTransportFactory()); + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } } - [Theory(Skip = "https://github.com/aspnet/SignalR/issues/2031")] + [Theory] [MemberData(nameof(TransportTypes))] public async Task CanStartAndStopConnectionUsingGivenTransport(HttpTransportType transportType) { - using (StartVerifableLog(out var loggerFactory, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}")) + using (StartVerifableLog(out var loggerFactory, minLogLevel: LogLevel.Trace, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}")) { var url = _serverFixture.Url + "/echo"; var connection = new HttpConnection(new Uri(url), transportType, loggerFactory); @@ -532,7 +544,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests if (_tries < availableTransports) { - throw new Exception(); + return Task.FromException(new Exception()); } else {