From b492dbe5d18180c7e3d4aef7d2943295010020e0 Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Thu, 3 May 2018 19:10:15 -0700 Subject: [PATCH] fix #2187 by passing min size through to buffer writer when encoding (#2190) --- src/Common/MemoryBufferWriter.cs | 102 +++++++++----- src/Common/Utf8BufferTextWriter.cs | 14 +- .../Protocol/MemoryBufferWriterTests.cs | 29 +++- .../Protocol/Utf8BufferTextWriterTests.cs | 133 +++++++++++++----- 4 files changed, 205 insertions(+), 73 deletions(-) diff --git a/src/Common/MemoryBufferWriter.cs b/src/Common/MemoryBufferWriter.cs index 2b0fa29d62..139c647cbc 100644 --- a/src/Common/MemoryBufferWriter.cs +++ b/src/Common/MemoryBufferWriter.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -6,7 +6,6 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; @@ -24,7 +23,7 @@ namespace Microsoft.AspNetCore.Internal private readonly int _minimumSegmentSize; private int _bytesWritten; - private List _fullSegments; + private List _completedSegments; private byte[] _currentSegment; private int _position; @@ -78,14 +77,14 @@ namespace Microsoft.AspNetCore.Internal public void Reset() { - if (_fullSegments != null) + if (_completedSegments != null) { - for (var i = 0; i < _fullSegments.Count; i++) + for (var i = 0; i < _completedSegments.Count; i++) { - ArrayPool.Shared.Return(_fullSegments[i]); + _completedSegments[i].Return(); } - _fullSegments.Clear(); + _completedSegments.Clear(); } if (_currentSegment != null) @@ -120,13 +119,13 @@ namespace Microsoft.AspNetCore.Internal public void CopyTo(IBufferWriter destination) { - if (_fullSegments != null) + if (_completedSegments != null) { - // Copy full segments - var count = _fullSegments.Count; + // Copy completed segments + var count = _completedSegments.Count; for (var i = 0; i < count; i++) { - destination.Write(_fullSegments[i]); + destination.Write(_completedSegments[i].Span); } } @@ -135,9 +134,9 @@ namespace Microsoft.AspNetCore.Internal public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) { - if (_fullSegments == null) + if (_completedSegments == null) { - // There is only one segment so write without async + // There is only one segment so write without awaiting. return destination.WriteAsync(_currentSegment, 0, _position); } @@ -146,43 +145,52 @@ namespace Microsoft.AspNetCore.Internal private void EnsureCapacity(int sizeHint) { - // TODO: Use sizeHint - if (_currentSegment != null && _position < _currentSegment.Length) + // This does the Right Thing. It only subtracts _position from the current segment length if it's non-null. + // If _currentSegment is null, it returns 0. + var remainingSize = _currentSegment?.Length - _position ?? 0; + + // If the sizeHint is 0, any capacity will do + // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment. + if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint)) { // We have capacity in the current segment return; } - AddSegment(); + AddSegment(sizeHint); } - private void AddSegment() + private void AddSegment(int sizeHint = 0) { if (_currentSegment != null) { // We're adding a segment to the list - if (_fullSegments == null) + if (_completedSegments == null) { - _fullSegments = new List(); + _completedSegments = new List(); } - _fullSegments.Add(_currentSegment); + // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when + // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to + // ignore any empty space in it. + _completedSegments.Add(new CompletedBuffer(_currentSegment, _position)); } - _currentSegment = ArrayPool.Shared.Rent(_minimumSegmentSize); + // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment. + _currentSegment = ArrayPool.Shared.Rent(Math.Max(_minimumSegmentSize, sizeHint)); _position = 0; } private async Task CopyToSlowAsync(Stream destination) { - if (_fullSegments != null) + if (_completedSegments != null) { // Copy full segments - var count = _fullSegments.Count; + var count = _completedSegments.Count; for (var i = 0; i < count; i++) { - var segment = _fullSegments[i]; - await destination.WriteAsync(segment, 0, segment.Length); + var segment = _completedSegments[i]; + await destination.WriteAsync(segment.Buffer, 0, segment.Length); } } @@ -200,15 +208,15 @@ namespace Microsoft.AspNetCore.Internal var totalWritten = 0; - if (_fullSegments != null) + if (_completedSegments != null) { // Copy full segments - var count = _fullSegments.Count; + var count = _completedSegments.Count; for (var i = 0; i < count; i++) { - var segment = _fullSegments[i]; - segment.CopyTo(result, totalWritten); - totalWritten += segment.Length; + var segment = _completedSegments[i]; + segment.Span.CopyTo(result.AsSpan(totalWritten)); + totalWritten += segment.Span.Length; } } @@ -229,15 +237,15 @@ namespace Microsoft.AspNetCore.Internal var totalWritten = 0; - if (_fullSegments != null) + if (_completedSegments != null) { // Copy full segments - var count = _fullSegments.Count; + var count = _completedSegments.Count; for (var i = 0; i < count; i++) { - var segment = _fullSegments[i]; - segment.AsSpan().CopyTo(span.Slice(totalWritten)); - totalWritten += segment.Length; + var segment = _completedSegments[i]; + segment.Span.CopyTo(span.Slice(totalWritten)); + totalWritten += segment.Span.Length; } } @@ -307,5 +315,27 @@ namespace Microsoft.AspNetCore.Internal Reset(); } } + + /// + /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. + /// + private readonly struct CompletedBuffer + { + public byte[] Buffer { get; } + public int Length { get; } + + public ReadOnlySpan Span => Buffer.AsSpan(0, Length); + + public CompletedBuffer(byte[] buffer, int length) + { + Buffer = buffer; + Length = length; + } + + public void Return() + { + ArrayPool.Shared.Return(Buffer); + } + } } -} \ No newline at end of file +} diff --git a/src/Common/Utf8BufferTextWriter.cs b/src/Common/Utf8BufferTextWriter.cs index aed1681649..f0f1a4edc8 100644 --- a/src/Common/Utf8BufferTextWriter.cs +++ b/src/Common/Utf8BufferTextWriter.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -14,6 +14,7 @@ namespace Microsoft.AspNetCore.Internal internal sealed class Utf8BufferTextWriter : TextWriter { private static readonly UTF8Encoding _utf8NoBom = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + private static readonly int MaximumBytesPerUtf8Char = 4; [ThreadStatic] private static Utf8BufferTextWriter _cachedInstance; @@ -139,7 +140,12 @@ namespace Microsoft.AspNetCore.Internal private void EnsureBuffer() { - if (_memoryUsed == _memory.Length) + // We need at least enough bytes to encode a single UTF-8 character, or Encoder.Convert will throw. + // Normally, if there isn't enough space to write every character of a char buffer, Encoder.Convert just + // writes what it can. However, if it can't even write a single character, it throws. So if the buffer has only + // 2 bytes left and the next character to write is 3 bytes in UTF-8, an exception is thrown. + var remaining = _memory.Length - _memoryUsed; + if (remaining < MaximumBytesPerUtf8Char) { // Used up the memory from the buffer writer so advance and get more if (_memoryUsed > 0) @@ -147,7 +153,7 @@ namespace Microsoft.AspNetCore.Internal _bufferWriter.Advance(_memoryUsed); } - _memory = _bufferWriter.GetMemory(); + _memory = _bufferWriter.GetMemory(MaximumBytesPerUtf8Char); _memoryUsed = 0; } } @@ -199,4 +205,4 @@ namespace Microsoft.AspNetCore.Internal } } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs index 10bc0b41fc..ff51d1791e 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/MemoryBufferWriterTests.cs @@ -398,6 +398,33 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol } #endif + [Fact] + public void GetMemoryAllocatesNewSegmentWhenInsufficientSpaceInCurrentSegment() + { + // Have the buffer writer rent only the minimum size segments from the pool. + using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize)) + { + var data = new byte[MinimumSegmentSize]; + new Random().NextBytes(data); + + // Write half the minimum segment size + bufferWriter.Write(data.AsSpan(0, MinimumSegmentSize / 2)); + + // Request a new buffer of MinimumSegmentSize + var buffer = bufferWriter.GetMemory(MinimumSegmentSize); + Assert.Equal(MinimumSegmentSize, buffer.Length); + + // Write to the buffer + bufferWriter.Write(data); + + // Verify the data was all written correctly + var expectedOutput = new byte[MinimumSegmentSize + (MinimumSegmentSize / 2)]; + data.AsSpan(0, MinimumSegmentSize / 2).CopyTo(expectedOutput.AsSpan(0, MinimumSegmentSize / 2)); + data.CopyTo(expectedOutput, MinimumSegmentSize / 2); + Assert.Equal(expectedOutput, bufferWriter.ToArray()); + } + } + [Fact] public void ResetResetsTheMemoryBufferWriter() { @@ -418,4 +445,4 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Equal(0, bufferWriter.Length); } } -} \ No newline at end of file +} diff --git a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs index 629dde3edd..a109bf4719 100644 --- a/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Common.Tests/Internal/Protocol/Utf8BufferTextWriterTests.cs @@ -4,6 +4,7 @@ using System; using System.Buffers; using System.Collections.Generic; +using System.Linq; using System.Text; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.SignalR.Internal; @@ -202,20 +203,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol textWriter.Write(chars); textWriter.Flush(); - Assert.Equal(6, bufferWriter.Segments.Count); + var segments = bufferWriter.GetSegments(); + Assert.Equal(6, segments.Count); Assert.Equal(1, bufferWriter.Position); - Assert.Equal((byte)'H', bufferWriter.Segments[0].Span[0]); - Assert.Equal((byte)'e', bufferWriter.Segments[0].Span[1]); - Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[0]); - Assert.Equal((byte)'l', bufferWriter.Segments[1].Span[1]); - Assert.Equal((byte)'o', bufferWriter.Segments[2].Span[0]); - Assert.Equal((byte)' ', bufferWriter.Segments[2].Span[1]); - Assert.Equal((byte)'w', bufferWriter.Segments[3].Span[0]); - Assert.Equal((byte)'o', bufferWriter.Segments[3].Span[1]); - Assert.Equal((byte)'r', bufferWriter.Segments[4].Span[0]); - Assert.Equal((byte)'l', bufferWriter.Segments[4].Span[1]); - Assert.Equal((byte)'d', bufferWriter.Segments[5].Span[0]); + Assert.Equal((byte)'H', segments[0].Span[0]); + Assert.Equal((byte)'e', segments[0].Span[1]); + Assert.Equal((byte)'l', segments[1].Span[0]); + Assert.Equal((byte)'l', segments[1].Span[1]); + Assert.Equal((byte)'o', segments[2].Span[0]); + Assert.Equal((byte)' ', segments[2].Span[1]); + Assert.Equal((byte)'w', segments[3].Span[0]); + Assert.Equal((byte)'o', segments[3].Span[1]); + Assert.Equal((byte)'r', segments[4].Span[0]); + Assert.Equal((byte)'l', segments[4].Span[1]); + Assert.Equal((byte)'d', segments[5].Span[0]); } [Fact] @@ -242,35 +244,89 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol Assert.Same(textWriter1, textWriter2); } + [Fact] + private void WriteMultiByteCharactersToSmallBuffers() + { + // Test string breakdown (char => UTF-8 hex values): + // a => 61 + // い => E3-81-84 + // b => 62 + // ろ => E3-82-8D + // c => 63 + // d => 64 + // は => E3-81-AF + // に => E3-81-AB + // e => 65 + // ほ => E3-81-BB + // f => 66 + // へ => E3-81-B8 + // ど => E3-81-A9 + // g => 67 + // h => 68 + // i => 69 + // \uD800\uDC00 => F0-90-80-80 (this is a surrogate pair that is represented as a single 4-byte UTF-8 encoding) + const string testString = "aいbろcdはにeほfへどghi\uD800\uDC00"; + + // By mixing single byte and multi-byte characters, we know that there will + // be spaces in the active segment that cannot fit the current character. This + // means we'll be testing the GetMemory(minimumSize) logic. + var bufferWriter = new TestMemoryBufferWriter(segmentSize: 5); + + var writer = new Utf8BufferTextWriter(); + writer.SetWriter(bufferWriter); + writer.Write(testString); + writer.Flush(); + + // Verify the output + var allSegments = bufferWriter.GetSegments().Select(s => s.ToArray()).ToArray(); + Assert.Collection(allSegments, + seg => Assert.Equal(new byte[] { 0x61, 0xE3, 0x81, 0x84, 0x62 }, seg), // "aいb" + seg => Assert.Equal(new byte[] { 0xE3, 0x82, 0x8D, 0x63, 0x64 }, seg), // "ろcd" + seg => Assert.Equal(new byte[] { 0xE3, 0x81, 0xAF }, seg), // "は" + seg => Assert.Equal(new byte[] { 0xE3, 0x81, 0xAB, 0x65 }, seg), // "にe" + seg => Assert.Equal(new byte[] { 0xE3, 0x81, 0xBB, 0x66 }, seg), // "ほf" + seg => Assert.Equal(new byte[] { 0xE3, 0x81, 0xB8 }, seg), // "へ" + seg => Assert.Equal(new byte[] { 0xE3, 0x81, 0xA9, 0x67, 0x68 }, seg), // "どgh" + seg => Assert.Equal(new byte[] { 0x69, 0xF0, 0x90, 0x80, 0x80 }, seg)); // "i\uD800\uDC00" + + Assert.Equal(testString, Encoding.UTF8.GetString(bufferWriter.ToArray())); + } + private sealed class TestMemoryBufferWriter : IBufferWriter { private readonly int _segmentSize; - internal List> Segments { get; } + private List> _completedSegments = new List>(); + private int _totalLength; + + public Memory CurrentSegment { get; private set; } internal int Position { get; private set; } public TestMemoryBufferWriter(int segmentSize = 2048) { _segmentSize = segmentSize; - - Segments = new List>(); + CurrentSegment = Memory.Empty; } - public Memory CurrentSegment => Segments.Count > 0 ? Segments[Segments.Count - 1] : null; - public void Advance(int count) { Position += count; + _totalLength += count; } public Memory GetMemory(int sizeHint = 0) { - // TODO: Use sizeHint - - if (Segments.Count == 0 || Position == _segmentSize) + // Need special handling for sizeHint == 0, because for that we want to enter the if even if there are "sizeHint" (i.e. 0) bytes left :). + if ((sizeHint == 0 && CurrentSegment.Length == Position) || (CurrentSegment.Length - Position < sizeHint)) { - // TODO: Rent memory from a pool - Segments.Add(new Memory(new byte[_segmentSize])); + if (Position > 0) + { + // Complete the current segment + _completedSegments.Add(CurrentSegment.Slice(0, Position)); + } + + // Allocate a new segment and reset the position. + CurrentSegment = new Memory(new byte[_segmentSize]); Position = 0; } @@ -284,31 +340,44 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol public byte[] ToArray() { - if (Segments.Count == 0) + if (CurrentSegment.IsEmpty && _completedSegments.Count == 0) { return Array.Empty(); } - var totalLength = (Segments.Count - 1) * _segmentSize; - totalLength += Position; - - var result = new byte[totalLength]; + var result = new byte[_totalLength]; var totalWritten = 0; - // Copy full segments - for (var i = 0; i < Segments.Count - 1; i++) + // Copy completed segments + foreach (var segment in _completedSegments) { - Segments[i].CopyTo(result.AsMemory(totalWritten, _segmentSize)); + segment.CopyTo(result.AsMemory(totalWritten, segment.Length)); - totalWritten += _segmentSize; + totalWritten += segment.Length; } - // Copy current incomplete segment + // Copy current segment CurrentSegment.Slice(0, Position).CopyTo(result.AsMemory(totalWritten, Position)); return result; } + + public IList> GetSegments() + { + var list = new List>(); + foreach (var segment in _completedSegments) + { + list.Add(segment); + } + + if (CurrentSegment.Length > 0) + { + list.Add(CurrentSegment.Slice(0, Position)); + } + + return list; + } } } }