diff --git a/src/Common/PipeWriterStream.cs b/src/Common/PipeWriterStream.cs index 8c294b95b6..d68a03f26c 100644 --- a/src/Common/PipeWriterStream.cs +++ b/src/Common/PipeWriterStream.cs @@ -69,6 +69,11 @@ namespace System.IO.Pipelines private ValueTask WriteCoreAsync(ReadOnlyMemory source, CancellationToken cancellationToken = default) { + if (cancellationToken.IsCancellationRequested) + { + return new ValueTask(Task.FromCanceled(cancellationToken)); + } + _length += source.Length; var task = _pipeWriter.WriteAsync(source); if (!task.IsCompletedSuccessfully) diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/PipeReaderFactory.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/PipeReaderFactory.cs new file mode 100644 index 0000000000..8604457be1 --- /dev/null +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/PipeReaderFactory.cs @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft. All rights reserved. +// Licensed under the MIT license. See LICENSE file in the project root for full license information. + +using System.Threading; +using System.Threading.Tasks; + +namespace System.IO.Pipelines +{ + internal class PipeReaderFactory + { + private static readonly Action _cancelReader = state => ((PipeReader)state).CancelPendingRead(); + + public static PipeReader CreateFromStream(PipeOptions options, Stream stream, CancellationToken cancellationToken) + { + if (!stream.CanRead) + { + throw new NotSupportedException(); + } + + var pipe = new Pipe(options); + _ = CopyToAsync(stream, pipe, cancellationToken); + + return pipe.Reader; + } + + private static async Task CopyToAsync(Stream stream, Pipe pipe, CancellationToken cancellationToken) + { + // We manually register for cancellation here in case the Stream implementation ignores it + using (var registration = cancellationToken.Register(_cancelReader, pipe.Reader)) + { + try + { + // REVIEW: Should we use the default buffer size here? + // 81920 is the default bufferSize, there is no stream.CopyToAsync overload that takes only a cancellationToken + await stream.CopyToAsync(new PipeWriterStream(pipe.Writer), bufferSize: 81920, cancellationToken); + } + catch (OperationCanceledException) + { + // Ignore the cancellation signal (the pipe reader is already wired up for cancellation when the token trips) + } + catch (Exception ex) + { + pipe.Writer.Complete(ex); + return; + } + pipe.Writer.Complete(); + } + } + } +} diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.Log.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.Log.cs index 03edd6c455..71b1aa2d88 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.Log.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.Log.cs @@ -29,8 +29,8 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal private static readonly Action _transportStopping = LoggerMessage.Define(LogLevel.Information, new EventId(6, "TransportStopping"), "Transport is stopping."); - private static readonly Action _messageToApp = - LoggerMessage.Define(LogLevel.Debug, new EventId(7, "MessageToApp"), "Passing message to application. Payload size: {Count}."); + private static readonly Action _messageToApplication = + LoggerMessage.Define(LogLevel.Debug, new EventId(7, "MessageToApplication"), "Passing message to application. Payload size: {Count}."); private static readonly Action _eventStreamEnded = LoggerMessage.Define(LogLevel.Debug, new EventId(8, "EventStreamEnded"), "Server-Sent Event Stream ended."); @@ -60,9 +60,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _transportStopping(logger, null); } - public static void MessageToApp(ILogger logger, int count) + public static void MessageToApplication(ILogger logger, int count) { - _messageToApp(logger, count, null); + _messageToApplication(logger, count, null); } public static void ReceiveCanceled(ILogger logger) diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs index 94536e8778..8b930015fe 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/ServerSentEventsTransport.cs @@ -42,7 +42,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, IDuplexPipe application, TransferFormat transferFormat, IConnection connection) + public async Task StartAsync(Uri url, IDuplexPipe application, TransferFormat transferFormat, IConnection connection) { if (transferFormat != TransferFormat.Text) { @@ -53,17 +53,32 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal Log.StartTransport(_logger, transferFormat); - var startTcs = new TaskCompletionSource(TaskContinuationOptions.RunContinuationsAsynchronously); + var request = new HttpRequestMessage(HttpMethod.Get, url); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - Running = ProcessAsync(url, startTcs); + HttpResponseMessage response = null; - return startTcs.Task; + try + { + response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, CancellationToken.None); + response.EnsureSuccessStatusCode(); + } + catch + { + response?.Dispose(); + + Log.TransportStopping(_logger); + + throw; + } + + Running = ProcessAsync(url, response); } - private async Task ProcessAsync(Uri url, TaskCompletionSource startTcs) + private async Task ProcessAsync(Uri url, HttpResponseMessage response) { // Start sending and polling (ask for binary if the server supports it) - var receiving = OpenConnection(_application, url, startTcs, _transportCts.Token); + var receiving = ProcessEventStream(_application, response, _transportCts.Token); var sending = SendUtils.SendMessages(url, _application, _httpClient, _logger); // Wait for send or receive to complete @@ -90,90 +105,75 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal } } - private async Task OpenConnection(IDuplexPipe application, Uri url, TaskCompletionSource startTcs, CancellationToken cancellationToken) + private async Task ProcessEventStream(IDuplexPipe application, HttpResponseMessage response, CancellationToken cancellationToken) { Log.StartReceive(_logger); - var request = new HttpRequestMessage(HttpMethod.Get, url); - request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("text/event-stream")); - - HttpResponseMessage response = null; - - try - { - response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); - response.EnsureSuccessStatusCode(); - startTcs.TrySetResult(null); - } - catch (Exception ex) - { - response?.Dispose(); - Log.TransportStopping(_logger); - startTcs.TrySetException(ex); - return; - } - using (response) using (var stream = await response.Content.ReadAsStreamAsync()) { - var pipeOptions = new PipeOptions(pauseWriterThreshold: 0, resumeWriterThreshold: 0); - var pipelineReader = StreamPipeConnection.CreateReader(pipeOptions, stream); - var readCancellationRegistration = cancellationToken.Register( - reader => ((PipeReader)reader).CancelPendingRead(), pipelineReader); + var options = new PipeOptions(pauseWriterThreshold: 0, resumeWriterThreshold: 0); + var reader = PipeReaderFactory.CreateFromStream(options, stream, cancellationToken); + try { while (true) { - var result = await pipelineReader.ReadAsync(); - var input = result.Buffer; - if (result.IsCanceled || (input.IsEmpty && result.IsCompleted)) - { - Log.EventStreamEnded(_logger); - break; - } + var result = await reader.ReadAsync(); + var buffer = result.Buffer; + var consumed = buffer.Start; + var examined = buffer.End; - var consumed = input.Start; - var examined = input.End; try { - Log.ParsingSSE(_logger, input.Length); - var parseResult = _parser.ParseMessage(input, out consumed, out examined, out var buffer); - FlushResult flushResult = default; - - switch (parseResult) + if (result.IsCanceled) { - case ServerSentEventsMessageParser.ParseResult.Completed: - Log.MessageToApp(_logger, buffer.Length); - - flushResult = await _application.Output.WriteAsync(buffer); - - _parser.Reset(); - break; - case ServerSentEventsMessageParser.ParseResult.Incomplete: - if (result.IsCompleted) - { - throw new FormatException("Incomplete message."); - } - break; + Log.ReceiveCanceled(_logger); + break; } - // We canceled in the middle of applying back pressure - // or if the consumer is done - if (flushResult.IsCanceled || flushResult.IsCompleted) + if (!buffer.IsEmpty) + { + Log.ParsingSSE(_logger, buffer.Length); + + var parseResult = _parser.ParseMessage(buffer, out consumed, out examined, out var message); + FlushResult flushResult = default; + + switch (parseResult) + { + case ServerSentEventsMessageParser.ParseResult.Completed: + Log.MessageToApplication(_logger, message.Length); + + flushResult = await _application.Output.WriteAsync(message); + + _parser.Reset(); + break; + case ServerSentEventsMessageParser.ParseResult.Incomplete: + if (result.IsCompleted) + { + throw new FormatException("Incomplete message."); + } + break; + } + + // We canceled in the middle of applying back pressure + // or if the consumer is done + if (flushResult.IsCanceled || flushResult.IsCompleted) + { + break; + } + } + else if (result.IsCompleted) { break; } } finally { - pipelineReader.AdvanceTo(consumed, examined); + reader.AdvanceTo(consumed, examined); } } } - catch (OperationCanceledException) - { - Log.ReceiveCanceled(_logger); - } catch (Exception ex) { _error = ex; @@ -182,9 +182,9 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { _application.Output.Complete(_error); - readCancellationRegistration.Dispose(); - Log.ReceiveStopped(_logger); + + reader.Complete(); } } } diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamExtensions.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamExtensions.cs deleted file mode 100644 index 6b3df653d9..0000000000 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamExtensions.cs +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -using System.Threading; -using System.Threading.Tasks; - -namespace System.IO.Pipelines -{ - internal static class StreamExtensions - { - public static async Task CopyToEndAsync(this Stream stream, PipeWriter writer, CancellationToken cancellationToken = default) - { - try - { - // REVIEW: Should we use the default buffer size here? - // 81920 is the default bufferSize, there is no stream.CopyToAsync overload that takes only a cancellationToken - await stream.CopyToAsync(new PipelineWriterStream(writer), bufferSize: 81920, cancellationToken: cancellationToken); - } - catch (Exception ex) - { - writer.Complete(ex); - return; - } - writer.Complete(); - } - - private class PipelineWriterStream : Stream - { - private readonly PipeWriter _writer; - - public PipelineWriterStream(PipeWriter writer) - { - _writer = writer; - } - - public override bool CanRead => false; - - public override bool CanSeek => false; - - public override bool CanWrite => true; - - public override long Length => throw new NotSupportedException(); - - public override long Position - { - get => throw new NotSupportedException(); - set => throw new NotSupportedException(); - } - - public override void Flush() - { - throw new NotSupportedException(); - } - - public override int Read(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } - - public override long Seek(long offset, SeekOrigin origin) - { - throw new NotSupportedException(); - } - - public override void SetLength(long value) - { - throw new NotSupportedException(); - } - - public override void Write(byte[] buffer, int offset, int count) - { - throw new NotSupportedException(); - } - - public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) - { - cancellationToken.ThrowIfCancellationRequested(); - - await _writer.WriteAsync(new ReadOnlyMemory(buffer, offset, count), cancellationToken); - } - } - } -} diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamPipeConnection.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamPipeConnection.cs deleted file mode 100644 index 69534a2e8a..0000000000 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/StreamPipeConnection.cs +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright (c) Microsoft. All rights reserved. -// Licensed under the MIT license. See LICENSE file in the project root for full license information. - -namespace System.IO.Pipelines -{ - internal class StreamPipeConnection - { - public static PipeReader CreateReader(PipeOptions options, Stream stream) - { - if (!stream.CanRead) - { - throw new NotSupportedException(); - } - - var pipe = new Pipe(options); - _ = stream.CopyToEndAsync(pipe.Writer); - - return pipe.Reader; - } - } -}