diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs index aa3c41bbed..ffbe81a8e7 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/ServerSentEventsTransportTests.cs @@ -43,7 +43,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) .Returns(copyToAsyncTcs.Task); mockStream.Setup(s => s.CanRead).Returns(true); - return new HttpResponseMessage {Content = new StreamContent(mockStream.Object)}; + return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); try @@ -76,15 +76,17 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var mockStream = new Mock(); mockStream - .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (stream, bufferSize, t) => + .Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())) + .Returns, CancellationToken>(async (data, t) => { - var buffer = Encoding.ASCII.GetBytes("data: 3:abc\r\n\r\n"); - while (!t.IsCancellationRequested) + if (t.IsCancellationRequested) { - await stream.WriteAsync(buffer, 0, buffer.Length).OrTimeout(); - await Task.Delay(100); + return 0; } + + int count = Encoding.ASCII.GetBytes("data: 3:abc\r\n\r\n", data.Span); + await Task.Delay(100); + return count; }); mockStream.Setup(s => s.CanRead).Returns(true); @@ -120,6 +122,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task SSETransportStopsWithErrorIfServerSendsIncompleteResults() { var mockHttpHandler = new Mock(); + var calls = 0; mockHttpHandler.Protected() .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) .Returns(async (request, cancellationToken) => @@ -128,11 +131,15 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var mockStream = new Mock(); mockStream - .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (stream, bufferSize, t) => + .Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())) + .Returns, CancellationToken>((data, t) => { - var buffer = Encoding.ASCII.GetBytes("data: 3:a"); - await stream.WriteAsync(buffer, 0, buffer.Length); + if (calls == 0) + { + calls++; + return new ValueTask(Encoding.ASCII.GetBytes("data: 3:a", data.Span)); + } + return new ValueTask(0); }); mockStream.Setup(s => s.CanRead).Returns(true); @@ -165,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests } var eventStreamTcs = new TaskCompletionSource(); - var copyToAsyncTcs = new TaskCompletionSource(); + var readTcs = new TaskCompletionSource(); var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -182,8 +189,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // returns unfinished task to block pipelines var mockStream = new Mock(); mockStream - .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(copyToAsyncTcs.Task); + .Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())) + .Returns, CancellationToken>(async (data, ct) => + { + using (ct.Register(() => readTcs.TrySetCanceled())) + { + return await readTcs.Task; + } + }); mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; } @@ -214,7 +227,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task SSETransportStopsIfChannelClosed() { var eventStreamTcs = new TaskCompletionSource(); - var copyToAsyncTcs = new TaskCompletionSource(); + var readTcs = new TaskCompletionSource(); var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -229,8 +242,14 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // returns unfinished task to block pipelines var mockStream = new Mock(); mockStream - .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(copyToAsyncTcs.Task); + .Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())) + .Returns, CancellationToken>(async (data, ct) => + { + using (ct.Register(() => readTcs.TrySetCanceled())) + { + return await readTcs.Task; + } + }); mockStream.Setup(s => s.CanRead).Returns(true); return new HttpResponseMessage { Content = new StreamContent(mockStream.Object) }; }); @@ -281,7 +300,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task SSETransportCancelsSendOnStop() { var eventStreamTcs = new TaskCompletionSource(); - var copyToAsyncTcs = new TaskCompletionSource(); + var readTcs = new TaskCompletionSource(); var sendSyncPoint = new SyncPoint(); var mockHttpHandler = new Mock(); @@ -299,10 +318,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // returns unfinished task to block pipelines var mockStream = new Mock(); mockStream - .Setup(s => s.CopyToAsync(It.IsAny(), It.IsAny(), It.IsAny())) - .Returns(async (stream, bufferSize, t) => + .Setup(s => s.ReadAsync(It.IsAny>(), It.IsAny())) + .Returns(async () => { - await copyToAsyncTcs.Task; + await readTcs.Task; throw new TaskCanceledException(); }); @@ -332,7 +351,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var stopTask = sseTransport.StopAsync(); - copyToAsyncTcs.SetResult(null); + readTcs.SetResult(null); sendSyncPoint.Continue(); await stopTask; diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ClientPipeOptions.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ClientPipeOptions.cs index e2c546a192..4839ac7f9e 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ClientPipeOptions.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ClientPipeOptions.cs @@ -7,6 +7,6 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { internal static class ClientPipeOptions { - public static PipeOptions DefaultOptions = new PipeOptions(writerScheduler: PipeScheduler.ThreadPool, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false, pauseWriterThreshold: 0, resumeWriterThreshold: 0); + public static PipeOptions DefaultOptions = new PipeOptions(writerScheduler: PipeScheduler.ThreadPool, readerScheduler: PipeScheduler.ThreadPool, useSynchronizationContext: false); } } diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/PipeReaderFactory.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/PipeReaderFactory.cs deleted file mode 100644 index a96b756176..0000000000 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/PipeReaderFactory.cs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Threading; -using System.Threading.Tasks; - -namespace System.IO.Pipelines -{ - internal class PipeReaderFactory - { - private static readonly Action _cancelReader = state => ((PipeReader)state).CancelPendingRead(); - - public static PipeReader CreateFromStream(PipeOptions options, Stream stream, CancellationToken cancellationToken) - { - if (!stream.CanRead) - { - throw new NotSupportedException(); - } - - var pipe = new Pipe(options); - _ = CopyToAsync(stream, pipe, cancellationToken); - - return pipe.Reader; - } - - private static async Task CopyToAsync(Stream stream, Pipe pipe, CancellationToken cancellationToken) - { - // We manually register for cancellation here in case the Stream implementation ignores it - using (var registration = cancellationToken.Register(_cancelReader, pipe.Reader)) - { - try - { - await stream.CopyToAsync(new PipeWriterStream(pipe.Writer), bufferSize: 4096, cancellationToken); - } - catch (OperationCanceledException) - { - // Ignore the cancellation signal (the pipe reader is already wired up for cancellation when the token trips) - } - catch (Exception ex) - { - pipe.Writer.Complete(ex); - return; - } - pipe.Writer.Complete(); - } - } - } -} diff --git a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs index 66ab83bbca..2f42cc9630 100644 --- a/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs +++ b/src/SignalR/clients/csharp/Http.Connections.Client/src/Internal/ServerSentEventsTransport.cs @@ -129,12 +129,15 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal private async Task ProcessEventStream(HttpResponseMessage response, CancellationToken cancellationToken) { Log.StartReceive(_logger); + + static void CancelReader(object state) => ((PipeReader)state).CancelPendingRead(); using (response) using (var stream = await response.Content.ReadAsStreamAsync()) { - var options = new PipeOptions(pauseWriterThreshold: 0, resumeWriterThreshold: 0); - var reader = PipeReaderFactory.CreateFromStream(options, stream, cancellationToken); + var reader = PipeReader.Create(stream); + + using var registration = cancellationToken.Register(CancelReader, reader); try { diff --git a/src/SignalR/server/SignalR/test/EndToEndTests.cs b/src/SignalR/server/SignalR/test/EndToEndTests.cs index c9f85fefbd..675d939a3a 100644 --- a/src/SignalR/server/SignalR/test/EndToEndTests.cs +++ b/src/SignalR/server/SignalR/test/EndToEndTests.cs @@ -329,15 +329,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests logger.LogInformation("Started connection to {url}", url); var bytes = Encoding.UTF8.GetBytes(message); - logger.LogInformation("Sending {length} byte message", bytes.Length); - await connection.Transport.Output.WriteAsync(bytes).OrTimeout(); - logger.LogInformation("Sent message"); - logger.LogInformation("Receiving message"); - // Big timeout here because it can take a while to receive all the bytes - var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).OrTimeout(TimeSpan.FromMinutes(2)); - Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); - logger.LogInformation("Completed receive"); + async Task SendMessage() + { + logger.LogInformation("Sending {length} byte message", bytes.Length); + await connection.Transport.Output.WriteAsync(bytes).OrTimeout(); + logger.LogInformation("Sent message"); + } + + async Task ReceiveMessage() + { + logger.LogInformation("Receiving message"); + // Big timeout here because it can take a while to receive all the bytes + var receivedData = await connection.Transport.Input.ReadAsync(bytes.Length).OrTimeout(TimeSpan.FromMinutes(2)); + Assert.Equal(message, Encoding.UTF8.GetString(receivedData)); + logger.LogInformation("Completed receive"); + } + + // Send the receive concurrently so that back pressure is released + // for server -> client sends + var sendingTask = SendMessage(); + var receivingTask = ReceiveMessage(); + + await sendingTask; + await receivingTask; } catch (Exception ex) {