diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index a1a1fd48ba..a6cae86868 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -71,25 +71,26 @@ namespace Microsoft.AspNetCore.Sockets.Client var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); var stream = await response.Content.ReadAsStreamAsync(); - - var pipelineReader = stream.AsPipelineReader(); + var pipelineReader = stream.AsPipelineReader(cancellationToken); + var readCancellationRegistration = cancellationToken.Register( + reader => ((IPipeReader)reader).CancelPendingRead(), pipelineReader); try { while (true) { var result = await pipelineReader.ReadAsync(); var input = result.Buffer; + if (result.IsCancelled || (input.IsEmpty && result.IsCompleted)) + { + _logger.LogDebug("Server-Sent Event Stream ended"); + break; + } + var consumed = input.Start; var examined = input.End; try { - if (input.IsEmpty && result.IsCompleted) - { - _logger.LogDebug("Server-Sent Event Stream ended"); - break; - } - var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer); switch (parseResult) @@ -114,6 +115,7 @@ namespace Microsoft.AspNetCore.Sockets.Client } finally { + readCancellationRegistration.Dispose(); _transportCts.Cancel(); stream.Dispose(); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 22f7a29306..167c56d0fd 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -37,7 +37,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { @@ -74,7 +74,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -127,7 +127,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -163,7 +163,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -199,7 +199,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -240,7 +240,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -289,7 +289,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); @@ -347,7 +347,7 @@ namespace Microsoft.AspNetCore.Client.Tests using (var httpClient = new HttpClient(mockHttpHandler.Object)) { - var longPollingTransport = new LongPollingTransport(httpClient, new LoggerFactory()); + var longPollingTransport = new LongPollingTransport(httpClient); try { var connectionToTransport = Channel.CreateUnbounded(); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs new file mode 100644 index 0000000000..8b0cf074a1 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -0,0 +1,65 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Net.Http; +using System.Threading; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Internal; +using Moq; +using Moq.Protected; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Client.Tests +{ + public class ServerSentEventsTransportTests + { + [Fact] + public async Task CanStartStopSSETransport() + { + var eventStreamTcs = new TaskCompletionSource(); + var copyToAsyncTcs = new TaskCompletionSource(); + + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + // Receive loop started - allow stopping the transport + eventStreamTcs.SetResult(null); + + // returns unfinished task to block pipelines + var mockStream = new Mock(); + mockStream + .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) + .Returns(copyToAsyncTcs.Task); + return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; + }); + + try + { + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var sseTransport = new ServerSentEventsTransport(httpClient); + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection).OrTimeout(); + + await eventStreamTcs.Task.OrTimeout(); + await sseTransport.StopAsync().OrTimeout(); + await sseTransport.Running.OrTimeout(); + } + } + finally + { + copyToAsyncTcs.SetResult(0); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 08ac903e4b..341f08a000 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -249,7 +249,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - public static IEnumerable TransportTypes() => + public static IEnumerable TransportTypes => new[] { new object[] { TransportType.WebSockets },