From b5c46f35b3550936655c94c9ae8925d1fd4256be Mon Sep 17 00:00:00 2001 From: Mikael Mengistu Date: Thu, 22 Mar 2018 19:03:48 +0000 Subject: [PATCH] Check for actual start in SSE (#1681) --- .../ServerSentEventsTransport.cs | 22 ++++++-- ...HttpConnectionTests.ConnectionLifecycle.cs | 53 +++++++++++++++++++ 2 files changed, 71 insertions(+), 4 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index f8784756e2..8d819d6cbd 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -54,8 +54,9 @@ namespace Microsoft.AspNetCore.Sockets.Client Log.StartTransport(_logger, transferFormat); + var startTcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); var sendTask = SendUtils.SendMessages(url, _application, _httpClient, _httpOptions, _transportCts, _logger); - var receiveTask = OpenConnection(_application, url, _transportCts.Token); + var receiveTask = OpenConnection(_application, url, startTcs, _transportCts.Token); Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => { @@ -66,17 +67,30 @@ namespace Microsoft.AspNetCore.Sockets.Client return t; }).Unwrap(); - return Task.CompletedTask; + return startTcs.Task; } - private async Task OpenConnection(IDuplexPipe application, Uri url, CancellationToken cancellationToken) + private async Task OpenConnection(IDuplexPipe application, Uri url, TaskCompletionSource startTcs, CancellationToken cancellationToken) { Log.StartReceive(_logger); var request = new HttpRequestMessage(HttpMethod.Get, url); SendUtils.PrepareHttpRequest(request, _httpOptions); request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + + HttpResponseMessage response; + try + { + response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + response.EnsureSuccessStatusCode(); + startTcs.TrySetResult(null); + } + catch (Exception ex) + { + Log.TransportStopping(_logger); + startTcs.TrySetException(ex); + return; + } using (var stream = await response.Content.ReadAsStreamAsync()) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs index db2cb3879b..bff6c5b0af 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.ConnectionLifecycle.cs @@ -370,6 +370,59 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }); } + [Fact] + public async Task SSEWontStartIfSuccessfulConnectionIsNotEstablished() + { + using (StartLog(out var loggerFactory)) + { + var httpHandler = new TestHttpMessageHandler(); + + httpHandler.OnGet("/?id=00000000-0000-0000-0000-000000000000", (_, __) => + { + return Task.FromResult(ResponseUtils.CreateResponse(HttpStatusCode.InternalServerError)); + }); + + var sse = new ServerSentEventsTransport(new HttpClient(httpHandler)); + + await WithConnectionAsync( + CreateConnection(httpHandler, loggerFactory: loggerFactory, url: null, transport: sse), + async (connection, closed) => + { + await Assert.ThrowsAsync( + () => connection.StartAsync(TransferFormat.Text).OrTimeout()); + }); + } + } + + [Fact] + public async Task SSEWaitsForResponseToStart() + { + using (StartLog(out var loggerFactory)) + { + var httpHandler = new TestHttpMessageHandler(); + + var connectResponseTcs = new TaskCompletionSource(); + httpHandler.OnGet("/?id=00000000-0000-0000-0000-000000000000", async (_, __) => + { + await connectResponseTcs.Task; + return ResponseUtils.CreateResponse(HttpStatusCode.Accepted); + }); + + var sse = new ServerSentEventsTransport(new HttpClient(httpHandler)); + + await WithConnectionAsync( + CreateConnection(httpHandler, loggerFactory: loggerFactory, url: null, transport: sse), + async (connection, closed) => + { + var startTask = connection.StartAsync(TransferFormat.Text).OrTimeout(); + Assert.False(connectResponseTcs.Task.IsCompleted); + Assert.False(startTask.IsCompleted); + connectResponseTcs.TrySetResult(null); + await startTask; + }); + } + } + [Fact] public async Task TransportIsStoppedWhenConnectionIsStopped() {