diff --git a/clients/ts/signalr-protocol-msgpack/package-lock.json b/clients/ts/signalr-protocol-msgpack/package-lock.json index 28f546b661..84c50c6f60 100644 --- a/clients/ts/signalr-protocol-msgpack/package-lock.json +++ b/clients/ts/signalr-protocol-msgpack/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr-protocol-msgpack", - "version": "1.0.0-preview3-t000", + "version": "1.0.0-rc1-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/clients/ts/signalr/package-lock.json b/clients/ts/signalr/package-lock.json index 0447edc39b..e7e2e130b4 100644 --- a/clients/ts/signalr/package-lock.json +++ b/clients/ts/signalr/package-lock.json @@ -1,6 +1,6 @@ { "name": "@aspnet/signalr", - "version": "1.0.0-preview3-t000", + "version": "1.0.0-rc1-t000", "lockfileVersion": 1, "requires": true, "dependencies": { diff --git a/clients/ts/signalr/spec/HttpConnection.spec.ts b/clients/ts/signalr/spec/HttpConnection.spec.ts index 15184bd686..09bd65c4f7 100644 --- a/clients/ts/signalr/spec/HttpConnection.spec.ts +++ b/clients/ts/signalr/spec/HttpConnection.spec.ts @@ -288,13 +288,73 @@ describe("HttpConnection", () => { } }); - it("sets inherentKeepAlive feature when using LongPolling", async (done) => { + it("authorization header removed when token factory returns null and using LongPolling", async (done) => { const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; + var httpClientGetCount = 0; + var accessTokenFactoryCount = 0; const options: IHttpConnectionOptions = { ...commonOptions, httpClient: new TestHttpClient() - .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })), + .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })) + .on("GET", (r) => { + httpClientGetCount++; + const authorizationValue = r.headers["Authorization"]; + if (httpClientGetCount == 1) { + if (authorizationValue) { + fail("First long poll request should have a authorization header."); + } + // First long polling request must succeed so start completes + return ""; + } else { + // Check second long polling request has its header removed + if (authorizationValue) { + fail("Second long poll request should have no authorization header."); + } + throw new Error("fail"); + } + }), + accessTokenFactory: () => { + accessTokenFactoryCount++; + if (accessTokenFactoryCount == 1) { + return "A token value"; + } else { + // Return a null value after the first call to test the header being removed + return null; + } + }, + } as IHttpConnectionOptions; + + const connection = new HttpConnection("http://tempuri.org", options); + + try { + await connection.start(TransferFormat.Text); + expect(httpClientGetCount).toBeGreaterThanOrEqual(2); + expect(accessTokenFactoryCount).toBeGreaterThanOrEqual(2); + done(); + } catch (e) { + fail(e); + done(); + } + }); + + it("sets inherentKeepAlive feature when using LongPolling", async (done) => { + const availableTransport = { transport: "LongPolling", transferFormats: ["Text"] }; + + var httpClientGetCount = 0; + const options: IHttpConnectionOptions = { + ...commonOptions, + httpClient: new TestHttpClient() + .on("POST", (r) => ({ connectionId: "42", availableTransports: [availableTransport] })) + .on("GET", (r) => { + httpClientGetCount++; + if (httpClientGetCount == 1) { + // First long polling request must succeed so start completes + return ""; + } else { + throw new Error("fail"); + } + }), } as IHttpConnectionOptions; const connection = new HttpConnection("http://tempuri.org", options); diff --git a/clients/ts/signalr/src/LongPollingTransport.ts b/clients/ts/signalr/src/LongPollingTransport.ts index 5f3e453277..86d3e8376f 100644 --- a/clients/ts/signalr/src/LongPollingTransport.ts +++ b/clients/ts/signalr/src/LongPollingTransport.ts @@ -31,7 +31,7 @@ export class LongPollingTransport implements ITransport { this.logMessageContent = logMessageContent; } - public connect(url: string, transferFormat: TransferFormat): Promise { + public async connect(url: string, transferFormat: TransferFormat): Promise { Arg.isRequired(url, "url"); Arg.isRequired(transferFormat, "transferFormat"); Arg.isIn(transferFormat, TransferFormat, "transferFormat"); @@ -45,13 +45,6 @@ export class LongPollingTransport implements ITransport { throw new Error("Binary protocols over XmlHttpRequest not implementing advanced features are not supported."); } - this.poll(this.url, transferFormat); - return Promise.resolve(); - } - - private async poll(url: string, transferFormat: TransferFormat): Promise { - this.running = true; - const pollOptions: HttpRequest = { abortSignal: this.pollAbort.signal, headers: {}, @@ -62,15 +55,49 @@ export class LongPollingTransport implements ITransport { pollOptions.responseType = "arraybuffer"; } + const token = await this.accessTokenFactory(); + this.updateHeaderToken(pollOptions, token); + let closeError: Error; + + // Make initial long polling request + // Server uses first long polling request to finish initializing connection and it returns without data + const pollUrl = `${url}&_=${Date.now()}`; + this.logger.log(LogLevel.Trace, `(LongPolling transport) polling: ${pollUrl}`); + const response = await this.httpClient.get(pollUrl, pollOptions); + if (response.statusCode !== 200) { + this.logger.log(LogLevel.Error, `(LongPolling transport) Unexpected response code: ${response.statusCode}`); + + // Mark running as false so that the poll immediately ends and runs the close logic + closeError = new HttpError(response.statusText, response.statusCode); + this.running = false; + } else { + this.running = true; + } + + this.poll(this.url, pollOptions, closeError); + return Promise.resolve(); + } + + private updateHeaderToken(request: HttpRequest, token: string) { + if (token) { + // tslint:disable-next-line:no-string-literal + request.headers["Authorization"] = `Bearer ${token}`; + return; + } + // tslint:disable-next-line:no-string-literal + if (request.headers["Authorization"]) { + // tslint:disable-next-line:no-string-literal + delete request.headers["Authorization"]; + } + } + + private async poll(url: string, pollOptions: HttpRequest, closeError: Error): Promise { try { while (this.running) { // We have to get the access token on each poll, in case it changes const token = await this.accessTokenFactory(); - if (token) { - // tslint:disable-next-line:no-string-literal - pollOptions.headers["Authorization"] = `Bearer ${token}`; - } + this.updateHeaderToken(pollOptions, token); try { const pollUrl = `${url}&_=${Date.now()}`; @@ -142,14 +169,11 @@ export class LongPollingTransport implements ITransport { this.running = false; this.logger.log(LogLevel.Trace, `(LongPolling transport) sending DELETE request to ${this.url}.`); - const deleteOptions: HttpRequest = {}; + const deleteOptions: HttpRequest = { + headers: {}, + }; const token = await this.accessTokenFactory(); - if (token) { - // tslint:disable-next-line:no-string-literal - deleteOptions.headers = { - ["Authorization"]: `Bearer ${token}`, - }; - } + this.updateHeaderToken(deleteOptions, token); const response = await this.httpClient.delete(this.url, deleteOptions); this.logger.log(LogLevel.Trace, "(LongPolling transport) DELETE request accepted."); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs index 4059049a96..860f54b3ee 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/LongPollingTransport.cs @@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, TransferFormat transferFormat) + public async Task StartAsync(Uri url, TransferFormat transferFormat) { if (transferFormat != TransferFormat.Binary && transferFormat != TransferFormat.Text) { @@ -49,6 +49,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.StartTransport(_logger, transferFormat); + // Make initial long polling request + // Server uses first long polling request to finish initializing connection and it returns without data + var request = new HttpRequestMessage(HttpMethod.Get, url); + using (var response = await _httpClient.SendAsync(request)) + { + response.EnsureSuccessStatusCode(); + } + // Create the pipe pair (Application's writer is connected to Transport's reader, and vice versa) var options = ClientPipeOptions.DefaultOptions; var pair = DuplexPipe.CreateConnectionPair(options, options); @@ -57,8 +65,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _application = pair.Application; Running = ProcessAsync(url); - - return Task.CompletedTask; } private async Task ProcessAsync(Uri url) @@ -105,6 +111,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _application.Input.CancelPendingRead(); try diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs index 59ca9d7cbf..b92931a4d0 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs @@ -207,6 +207,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _transport.Output.Complete(); _transport.Input.Complete(); diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs index 4520162c72..c6ebeabce3 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs @@ -373,6 +373,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { Log.TransportStopping(_logger); + if (_application == null) + { + // We never started + return; + } + _transport.Output.Complete(); _transport.Input.Complete(); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs index fb1b7a8701..0f4f88551f 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/HttpConnectionDispatcher.cs @@ -210,7 +210,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal using (connection.Cancellation) { // Cancel the previous request - connection.Cancellation.Cancel(); + connection.Cancellation?.Cancel(); // Wait for the previous request to drain await connection.TransportTask; @@ -228,29 +228,38 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal Log.EstablishedConnection(_logger); connection.ApplicationTask = ExecuteApplication(connectionDelegate, connection); + + context.Response.ContentType = "application/octet-stream"; + + // This request has no content + context.Response.ContentLength = 0; + + // On the first poll, we flush the response immediately to mark the poll as "initialized" so future + // requests can be made safely + connection.TransportTask = context.Response.Body.FlushAsync(); } else { Log.ResumingConnection(_logger); + + // REVIEW: Performance of this isn't great as this does a bunch of per request allocations + connection.Cancellation = new CancellationTokenSource(); + + var timeoutSource = new CancellationTokenSource(); + var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token); + + // Dispose these tokens when the request is over + context.Response.RegisterForDispose(timeoutSource); + context.Response.RegisterForDispose(tokenSource); + + var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); + + // Start the transport + connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); + + // Start the timeout after we return from creating the transport task + timeoutSource.CancelAfter(options.LongPolling.PollTimeout); } - - // REVIEW: Performance of this isn't great as this does a bunch of per request allocations - connection.Cancellation = new CancellationTokenSource(); - - var timeoutSource = new CancellationTokenSource(); - var tokenSource = CancellationTokenSource.CreateLinkedTokenSource(connection.Cancellation.Token, context.RequestAborted, timeoutSource.Token); - - // Dispose these tokens when the request is over - context.Response.RegisterForDispose(timeoutSource); - context.Response.RegisterForDispose(tokenSource); - - var longPolling = new LongPollingTransport(timeoutSource.Token, connection.Application.Input, _loggerFactory); - - // Start the transport - connection.TransportTask = longPolling.ProcessRequestAsync(context, tokenSource.Token); - - // Start the timeout after we return from creating the transport task - timeoutSource.CancelAfter(options.LongPolling.PollTimeout); } finally { @@ -302,7 +311,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal connection.Status = HttpConnectionStatus.Inactive; // Dispose the cancellation token - connection.Cancellation.Dispose(); + connection.Cancellation?.Dispose(); connection.Cancellation = null; } @@ -474,7 +483,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal context.Response.ContentType = "text/plain"; return; } - + await context.Request.Body.CopyToAsync(connection.ApplicationStream, bufferSize); Log.ReceivedBytes(_logger, connection.ApplicationStream.Length); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs index e0fbf43bd2..9443099b05 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/HttpConnectionDispatcherTests.cs @@ -636,6 +636,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // Start a poll var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); + Assert.True(task.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); // Send to the application var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -745,7 +749,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } [Theory] - [InlineData(HttpTransportType.LongPolling, 204)] + [InlineData(HttpTransportType.LongPolling, 200)] [InlineData(HttpTransportType.WebSockets, 404)] [InlineData(HttpTransportType.ServerSentEvents, 404)] public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(HttpTransportType transportType, int status) @@ -869,6 +873,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var builder = new ConnectionBuilder(services.BuildServiceProvider()); builder.UseConnectionHandler(); var app = builder.Build(); + // First poll will 200 + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app); Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode); @@ -998,6 +1006,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); var request1 = dispatcher.ExecuteAsync(context1, options, app); + Assert.True(request1.IsCompleted); + + request1 = dispatcher.ExecuteAsync(context1, options, app); var request2 = dispatcher.ExecuteAsync(context2, options, app); await request1; @@ -1132,7 +1143,14 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests builder.UseConnectionHandler(); var app = builder.Build(); var options = new HttpConnectionDispatcherOptions(); + + // Initial poll var task = dispatcher.ExecuteAsync(context, options, app); + Assert.True(task.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // Real long running poll + task = dispatcher.ExecuteAsync(context, options, app); var buffer = Encoding.UTF8.GetBytes("Hello World"); @@ -1166,7 +1184,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var options = new HttpConnectionDispatcherOptions(); var context1 = MakeRequest("/foo", connection); + // This is the initial poll to make sure things are setup var task1 = dispatcher.ExecuteAsync(context1, options, app); + Assert.True(task1.IsCompleted); + task1 = dispatcher.ExecuteAsync(context1, options, app); var context2 = MakeRequest("/foo", connection); var task2 = dispatcher.ExecuteAsync(context2, options, app); @@ -1363,10 +1384,13 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); - await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); - await connectionHandlerTask.OrTimeout(); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); + await connectionHandlerTask.OrTimeout(); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); Assert.Equal("Hello, World", GetContentAsString(context.Response.Body)); } @@ -1444,7 +1468,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests new Claim(ClaimTypes.StreetAddress, "12345 123rd St. NW") })); + // First poll var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(connectionHandlerTask.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); await connectionHandlerTask.OrTimeout(); @@ -1502,7 +1531,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests // "authorize" user context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.NameIdentifier, "name") })); + // Initial poll var connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(connectionHandlerTask.IsCompleted); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + connectionHandlerTask = dispatcher.ExecuteAsync(context, options, app); await connection.Transport.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello, World")).AsTask().OrTimeout(); await connectionHandlerTask.OrTimeout(); @@ -1660,6 +1694,10 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests var options = new HttpConnectionDispatcherOptions(); var pollTask = dispatcher.ExecuteAsync(context, options, app); + Assert.True(pollTask.IsCompleted); + + // Now send the second poll + pollTask = dispatcher.ExecuteAsync(context, options, app); // Issue the delete request and make sure the poll completes var deleteContext = new DefaultHttpContext(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs index 8357e9e693..42839f7472 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.Transport.cs @@ -13,14 +13,20 @@ using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.Http.Connections.Client; using Microsoft.AspNetCore.Http.Connections.Client.Internal; +using Microsoft.AspNetCore.SignalR.Tests; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.SignalR.Client.Tests { public partial class HttpConnectionTests { - public class Transport + public class Transport : VerifiableLoggedTest { + public Transport(ITestOutputHelper output) : base(output) + { + } + [Theory] [InlineData(HttpTransportType.LongPolling)] [InlineData(HttpTransportType.ServerSentEvents)] @@ -30,11 +36,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var requestsExecuted = false; var callCount = 0; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); - testHttpHandler.OnNegotiate((_, cancellationToken) => { return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); @@ -52,6 +53,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + Task AccessTokenProvider() { callCount++; @@ -75,28 +81,25 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests [InlineData(HttpTransportType.ServerSentEvents, false)] public async Task HttpConnectionSetsInherentKeepAliveFeature(HttpTransportType transportType, bool expectedValue) { - var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); - - testHttpHandler.OnRequest((request, next, token) => + using (StartVerifableLog(out var loggerFactory, testName: $"HttpConnectionSetsInherentKeepAliveFeature_{transportType}_{expectedValue}")) { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); + var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); - testHttpHandler.OnNegotiate((_, cancellationToken) => - { - return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); - }); + testHttpHandler.OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent())); - await WithConnectionAsync( - CreateConnection(testHttpHandler, transportType: transportType), - async (connection) => - { - await connection.StartAsync(TransferFormat.Text).OrTimeout(); + testHttpHandler.OnRequest((request, next, token) => Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent))); - var feature = connection.Features.Get(); - Assert.NotNull(feature); - Assert.Equal(expectedValue, feature.HasInherentKeepAlive); - }); + await WithConnectionAsync( + CreateConnection(testHttpHandler, transportType: transportType, loggerFactory: loggerFactory), + async (connection) => + { + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + + var feature = connection.Features.Get(); + Assert.NotNull(feature); + Assert.Equal(expectedValue, feature.HasInherentKeepAlive); + }); + } } [Theory] @@ -107,10 +110,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); var requestsExecuted = false; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); testHttpHandler.OnNegotiate((_, cancellationToken) => { @@ -135,6 +134,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + await WithConnectionAsync( CreateConnection(testHttpHandler, transportType: transportType), async (connection) => @@ -154,11 +158,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var testHttpHandler = new TestHttpMessageHandler(autoNegotiate: false); var requestsExecuted = false; - testHttpHandler.OnRequest((request, next, token) => - { - return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); - }); - testHttpHandler.OnNegotiate((_, cancellationToken) => { return ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent()); @@ -175,6 +174,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return await next(); }); + testHttpHandler.OnRequest((request, next, token) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.NoContent)); + }); + await WithConnectionAsync( CreateConnection(testHttpHandler, transportType: transportType), async (connection) => diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 5e203256f0..ee4c60daac 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -109,15 +109,20 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests if (requests == 0) { requests++; - return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello"); + return ResponseUtils.CreateResponse(HttpStatusCode.OK); } else if (requests == 1) + { + requests++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK, "Hello"); + } + else if (requests == 2) { requests++; // Time out return ResponseUtils.CreateResponse(HttpStatusCode.OK); } - else if (requests == 2) + else if (requests == 3) { requests++; return ResponseUtils.CreateResponse(HttpStatusCode.OK, "World"); @@ -147,7 +152,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } [Fact] - public async Task LongPollingTransportStopsWhenPollRequestFails() + public async Task LongPollingTransportStartAsyncFailsIfFirstRequestFails() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -158,6 +163,39 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); }); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + try + { + var exception = await Assert.ThrowsAsync(() => longPollingTransport.StartAsync(TestUri, TransferFormat.Binary)); + Assert.Contains(" 500 ", exception.Message); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } + + [Fact] + public async Task LongPollingTransportStopsWhenPollRequestFails() + { + var mockHttpHandler = new Mock(); + var firstPoll = true; + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + return ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError); + }); + using (var httpClient = new HttpClient(mockHttpHandler.Object)) { var longPollingTransport = new LongPollingTransport(httpClient); @@ -314,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var message1Payload = new[] { (byte)'H', (byte)'e', (byte)'l', (byte)'l', (byte)'o' }; - var firstCall = true; + var requests = 0; var mockHttpHandler = new Mock(); var sentRequests = new List(); mockHttpHandler.Protected() @@ -325,9 +363,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await Task.Yield(); - if (firstCall) + if (requests == 0) { - firstCall = false; + requests++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + else if (requests == 1) + { + requests++; return ResponseUtils.CreateResponse(HttpStatusCode.OK, message1Payload); } @@ -349,7 +392,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var message = await longPollingTransport.Input.ReadAllAsync(); // Check the provided request - Assert.Equal(2, sentRequests.Count); + Assert.Equal(3, sentRequests.Count); // Check the messages received Assert.Equal(message1Payload, message); @@ -366,6 +409,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var sentRequests = new List(); var tcs = new TaskCompletionSource(); + var firstPoll = true; var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -380,6 +424,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } else if (request.Method == HttpMethod.Get) { + // First poll completes immediately + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + cancellationToken.Register(() => tcs.TrySetCanceled(cancellationToken)); // This is the poll task return await tcs.Task; @@ -426,6 +477,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var sentRequests = new List(); var pollTcs = new TaskCompletionSource(); var deleteTcs = new TaskCompletionSource(); + var firstPoll = true; var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -440,6 +492,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } else if (request.Method == HttpMethod.Get) { + // First poll completes immediately + if (firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + cancellationToken.Register(() => pollTcs.TrySetCanceled(cancellationToken)); // This is the poll task return await pollTcs.Task; @@ -538,7 +597,13 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { await Task.Yield(); - if (Interlocked.Increment(ref numPolls) < 3) + if (numPolls == 0) + { + numPolls++; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + + if (numPolls++ < 3) { throw new OperationCanceledException(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs index 23ab6d73b5..58d5fbb53c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ResponseUtils.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.Linq; using System.Net; using System.Net.Http; using Microsoft.AspNetCore.Connections; @@ -39,6 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public static bool IsLongPollRequest(HttpRequestMessage request) { return request.Method == HttpMethod.Get && + !IsServerSentEventsRequest(request) && (request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id=")); } @@ -48,6 +50,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests (request.RequestUri.PathAndQuery.Contains("?id=") || request.RequestUri.PathAndQuery.Contains("&id=")); } + public static bool IsServerSentEventsRequest(HttpRequestMessage request) + { + return request.Method == HttpMethod.Get && request.Headers.Accept.Any(h => h.MediaType == "text/event-stream"); + } + public static bool IsSocketSendRequest(HttpRequestMessage request) { return request.Method == HttpMethod.Post && diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs index fc9a7e780b..6bde5d0af2 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestHttpMessageHandler.cs @@ -7,10 +7,14 @@ using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Client.Tests { + delegate Task RequestDelegate(HttpRequestMessage requestMessage, CancellationToken cancellationToken); + public class TestHttpMessageHandler : HttpMessageHandler { private List _receivedRequests = new List(); - private Func> _handler; + private RequestDelegate _app; + + private List> _middleware = new List>(); public IReadOnlyList ReceivedRequests { @@ -23,14 +27,29 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } } - public TestHttpMessageHandler(bool autoNegotiate = true) + public TestHttpMessageHandler(bool autoNegotiate = true, bool handleFirstPoll = true) { - _handler = BaseHandler; - if (autoNegotiate) { OnNegotiate((_, cancellationToken) => ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationContent())); } + + if (handleFirstPoll) + { + var firstPoll = true; + OnRequest(async (request, next, cancellationToken) => + { + if (ResponseUtils.IsLongPollRequest(request) && firstPoll) + { + firstPoll = false; + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + } + else + { + return await next(); + } + }); + } } protected override async Task SendAsync(HttpRequestMessage request, CancellationToken cancellationToken) @@ -40,9 +59,21 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests lock (_receivedRequests) { _receivedRequests.Add(request); + + if (_app == null) + { + _middleware.Reverse(); + RequestDelegate handler = BaseHandler; + foreach (var middleware in _middleware) + { + handler = middleware(handler); + } + + _app = handler; + } } - return await _handler(request, cancellationToken); + return await _app(request, cancellationToken); } public static TestHttpMessageHandler CreateDefault() @@ -80,8 +111,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public void OnRequest(Func>, CancellationToken, Task> handler) { - var nextHandler = _handler; - _handler = (request, cancellationToken) => handler(request, () => nextHandler(request, cancellationToken), cancellationToken); + void OnRequestCore(Func middleware) + { + _middleware.Add(middleware); + } + + OnRequestCore(next => + { + return (request, cancellationToken) => + { + return handler(request, () => next(request, cancellationToken), cancellationToken); + }; + }); } public void OnGet(string pathAndQuery, Func> handler) => OnRequest(HttpMethod.Get, pathAndQuery, handler); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index b9f3e63c0c..d77ba5e2ff 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -51,32 +51,44 @@ namespace Microsoft.AspNetCore.SignalR.Tests [Fact] public async Task CanStartAndStopConnectionUsingDefaultTransport() { - var url = _serverFixture.Url + "/echo"; - // The test should connect to the server using WebSockets transport on Windows 8 and newer. - // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. - var connection = new HttpConnection(new Uri(url)); - await connection.StartAsync(TransferFormat.Binary).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + using (StartVerifableLog(out var loggerFactory)) + { + var url = _serverFixture.Url + "/echo"; + // The test should connect to the server using WebSockets transport on Windows 8 and newer. + // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + var connection = new HttpConnection(new Uri(url), HttpTransports.All, loggerFactory); + await connection.StartAsync(TransferFormat.Binary).OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } } [Fact] public async Task TransportThatFallsbackCreatesNewConnection() { - var url = _serverFixture.Url + "/echo"; - // The test should connect to the server using WebSockets transport on Windows 8 and newer. - // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + bool ExpectedErrors(WriteContext writeContext) + { + return writeContext.LoggerName == typeof(HttpConnection).FullName && + writeContext.EventId.Name == "ErrorStartingTransport"; + } - // The test logic lives in the TestTransportFactory and FakeTransport. - var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, null, new TestTransportFactory()); - await connection.StartAsync(TransferFormat.Text).OrTimeout(); - await connection.DisposeAsync().OrTimeout(); + using (StartVerifableLog(out var loggerFactory, expectedErrorsFilter: ExpectedErrors)) + { + var url = _serverFixture.Url + "/echo"; + // The test should connect to the server using WebSockets transport on Windows 8 and newer. + // On Windows 7/2008R2 it should use ServerSentEvents transport to connect to the server. + + // The test logic lives in the TestTransportFactory and FakeTransport. + var connection = new HttpConnection(new HttpConnectionOptions { Url = new Uri(url) }, loggerFactory, new TestTransportFactory()); + await connection.StartAsync(TransferFormat.Text).OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } } - [Theory(Skip = "https://github.com/aspnet/SignalR/issues/2031")] + [Theory] [MemberData(nameof(TransportTypes))] public async Task CanStartAndStopConnectionUsingGivenTransport(HttpTransportType transportType) { - using (StartVerifableLog(out var loggerFactory, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}")) + using (StartVerifableLog(out var loggerFactory, minLogLevel: LogLevel.Trace, testName: $"CanStartAndStopConnectionUsingGivenTransport_{transportType}")) { var url = _serverFixture.Url + "/echo"; var connection = new HttpConnection(new Uri(url), transportType, loggerFactory); @@ -532,7 +544,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests if (_tries < availableTransports) { - throw new Exception(); + return Task.FromException(new Exception()); } else {