diff --git a/src/Common/PipeWriterStream.cs b/src/Common/PipeWriterStream.cs index 56472af59e..df6f85aaa1 100644 --- a/src/Common/PipeWriterStream.cs +++ b/src/Common/PipeWriterStream.cs @@ -67,7 +67,7 @@ namespace System.IO.Pipelines { _pipeWriter.Write(source.Span); _length += source.Length; - return new ValueTask(Task.CompletedTask); + return default; } #endif } diff --git a/src/Common/StreamExtensions.cs b/src/Common/StreamExtensions.cs index ad801470c2..60da892475 100644 --- a/src/Common/StreamExtensions.cs +++ b/src/Common/StreamExtensions.cs @@ -11,10 +11,27 @@ namespace System.IO { internal static class StreamExtensions { - public static async Task WriteAsync(this Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken = default) + public static ValueTask WriteAsync(this Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken = default) { - // REVIEW: Should we special case IsSingleSegment here? - foreach (var segment in buffer) + if (buffer.IsSingleSegment) + { +#if NETCOREAPP2_1 + return stream.WriteAsync(buffer.First, cancellationToken); +#else + var isArray = MemoryMarshal.TryGetArray(buffer.First, out var arraySegment); + // We're using the managed memory pool which is backed by managed buffers + Debug.Assert(isArray); + return new ValueTask(stream.WriteAsync(arraySegment.Array, arraySegment.Offset, arraySegment.Count, cancellationToken)); +#endif + } + + return WriteMultiSegmentAsync(stream, buffer, cancellationToken); + } + + private static async ValueTask WriteMultiSegmentAsync(Stream stream, ReadOnlySequence buffer, CancellationToken cancellationToken) + { + var position = buffer.Start; + while (buffer.TryGet(ref position, out var segment)) { #if NETCOREAPP2_1 await stream.WriteAsync(segment, cancellationToken); diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs index 06094a8df8..8e3d4feb50 100644 --- a/src/Common/WebSocketExtensions.cs +++ b/src/Common/WebSocketExtensions.cs @@ -1,12 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Buffers; -using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -14,29 +11,50 @@ namespace System.Net.WebSockets { internal static class WebSocketExtensions { - public static Task SendAsync(this WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + public static ValueTask SendAsync(this WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) { - // TODO: Consider chunking writes here if we get a multi segment buffer #if NETCOREAPP2_1 if (buffer.IsSingleSegment) { - return webSocket.SendAsync(buffer.First, webSocketMessageType, endOfMessage: true, cancellationToken).AsTask(); + return webSocket.SendAsync(buffer.First, webSocketMessageType, endOfMessage: true, cancellationToken); } else { - return webSocket.SendAsync(buffer.ToArray(), webSocketMessageType, endOfMessage: true, cancellationToken); + return SendMultiSegmentAsync(webSocket, buffer, webSocketMessageType, cancellationToken); } #else if (buffer.IsSingleSegment) { var isArray = MemoryMarshal.TryGetArray(buffer.First, out var segment); Debug.Assert(isArray); - return webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: true, cancellationToken); + return new ValueTask(webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: true, cancellationToken)); } else { - return webSocket.SendAsync(new ArraySegment(buffer.ToArray()), webSocketMessageType, true, cancellationToken); + return SendMultiSegmentAsync(webSocket, buffer, webSocketMessageType, cancellationToken); } +#endif + } + + private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) + { + var position = buffer.Start; + while (buffer.TryGet(ref position, out var segment)) + { +#if NETCOREAPP2_1 + await webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: false, cancellationToken); +#else + var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment); + Debug.Assert(isArray); + await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken); +#endif + } + + // Empty end of message frame +#if NETCOREAPP2_1 + await webSocket.SendAsync(Memory.Empty, webSocketMessageType, endOfMessage: true, cancellationToken); +#else + await webSocket.SendAsync(new ArraySegment(Array.Empty()), webSocketMessageType, endOfMessage: true, cancellationToken); #endif } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs index 8adf707c49..6b3df653d9 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/StreamExtensions.cs @@ -12,7 +12,9 @@ namespace System.IO.Pipelines { try { - await stream.CopyToAsync(writer, cancellationToken); + // REVIEW: Should we use the default buffer size here? + // 81920 is the default bufferSize, there is no stream.CopyToAsync overload that takes only a cancellationToken + await stream.CopyToAsync(new PipelineWriterStream(writer), bufferSize: 81920, cancellationToken: cancellationToken); } catch (Exception ex) { @@ -22,19 +24,6 @@ namespace System.IO.Pipelines writer.Complete(); } - /// - /// Copies the content of a into a . - /// - /// - /// - /// - /// - private static Task CopyToAsync(this Stream stream, PipeWriter writer, CancellationToken cancellationToken = default) - { - // 81920 is the default bufferSize, there is not stream.CopyToAsync overload that takes only a cancellationToken - return stream.CopyToAsync(new PipelineWriterStream(writer), bufferSize: 81920, cancellationToken: cancellationToken); - } - private class PipelineWriterStream : Stream { private readonly PipeWriter _writer; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs index d113feec6e..b8b8e31060 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/SendUtils.cs @@ -111,7 +111,7 @@ namespace Microsoft.AspNetCore.Sockets.Client protected override Task SerializeToStreamAsync(Stream stream, TransportContext context) { - return stream.WriteAsync(_buffer); + return stream.WriteAsync(_buffer).AsTask(); } protected override bool TryComputeLength(out long length)