Make MemoryBufferWriter a Stream (#1907)

- Get rid of LimitArrayPoolWriteStream and use MemoryBufferWriter in its place in the MessagePackProtocol implementation.
- Added tests for MemoryPoolBufferWriter and fixed a bug in CopyToAsync
- Added CopyTo(`IBufferWriter<byte>`)
- Changed MemoryBufferWriter to fill the underlying arrays that back segments, the segment size is now a minimum.
This commit is contained in:
David Fowler 2018-04-08 16:11:17 -07:00 committed by GitHub
parent 27d18578d0
commit 9fd713c73a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 410 additions and 220 deletions

View File

@ -5,11 +5,13 @@ using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Internal
{
internal sealed class MemoryBufferWriter : IBufferWriter<byte>
internal sealed class MemoryBufferWriter : Stream, IBufferWriter<byte>
{
[ThreadStatic]
private static MemoryBufferWriter _cachedInstance;
@ -18,19 +20,27 @@ namespace Microsoft.AspNetCore.Internal
private bool _inUse;
#endif
private readonly int _segmentSize;
private readonly int _minimumSegmentSize;
private int _bytesWritten;
private List<byte[]> _fullSegments;
private byte[] _currentSegment;
private int _position;
public MemoryBufferWriter(int segmentSize = 2048)
public MemoryBufferWriter(int minimumSegmentSize = 4096)
{
_segmentSize = segmentSize;
_minimumSegmentSize = minimumSegmentSize;
}
public int Length => _bytesWritten;
public override long Length => _bytesWritten;
public override bool CanRead => false;
public override bool CanSeek => false;
public override bool CanWrite => true;
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public static MemoryBufferWriter Get()
{
@ -39,9 +49,11 @@ namespace Microsoft.AspNetCore.Internal
{
writer = new MemoryBufferWriter();
}
// Taken off the thread static
_cachedInstance = null;
else
{
// Taken off the thread static
_cachedInstance = null;
}
#if DEBUG
if (writer._inUse)
{
@ -93,54 +105,87 @@ namespace Microsoft.AspNetCore.Internal
public Memory<byte> GetMemory(int sizeHint = 0)
{
// TODO: Use sizeHint
if (_currentSegment == null)
{
_currentSegment = ArrayPool<byte>.Shared.Rent(_segmentSize);
_position = 0;
}
else if (_position == _segmentSize)
{
if (_fullSegments == null)
{
_fullSegments = new List<byte[]>();
}
_fullSegments.Add(_currentSegment);
_currentSegment = ArrayPool<byte>.Shared.Rent(_segmentSize);
_position = 0;
}
EnsureCapacity(sizeHint);
return _currentSegment.AsMemory(_position, _currentSegment.Length - _position);
}
public Span<byte> GetSpan(int sizeHint = 0)
{
return GetMemory(sizeHint).Span;
EnsureCapacity(sizeHint);
return _currentSegment.AsSpan(_position, _currentSegment.Length - _position);
}
public Task CopyToAsync(Stream stream)
{
if (_fullSegments == null)
{
// There is only one segment so write without async
return stream.WriteAsync(_currentSegment, 0, _position);
}
return CopyToSlowAsync(stream);
}
private async Task CopyToSlowAsync(Stream stream)
public void CopyTo(IBufferWriter<byte> destination)
{
if (_fullSegments != null)
{
// Copy full segments
for (var i = 0; i < _fullSegments.Count - 1; i++)
var count = _fullSegments.Count;
for (var i = 0; i < count; i++)
{
await stream.WriteAsync(_fullSegments[i], 0, _segmentSize);
destination.Write(_fullSegments[i]);
}
}
await stream.WriteAsync(_currentSegment, 0, _position);
destination.Write(_currentSegment.AsSpan(0, _position));
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
if (_fullSegments == null)
{
// There is only one segment so write without async
return destination.WriteAsync(_currentSegment, 0, _position);
}
return CopyToSlowAsync(destination);
}
private void EnsureCapacity(int sizeHint)
{
// TODO: Use sizeHint
if (_currentSegment != null && _position < _currentSegment.Length)
{
// We have capacity in the current segment
return;
}
AddSegment();
}
private void AddSegment()
{
if (_currentSegment != null)
{
// We're adding a segment to the list
if (_fullSegments == null)
{
_fullSegments = new List<byte[]>();
}
_fullSegments.Add(_currentSegment);
}
_currentSegment = ArrayPool<byte>.Shared.Rent(_minimumSegmentSize);
_position = 0;
}
private async Task CopyToSlowAsync(Stream destination)
{
if (_fullSegments != null)
{
// Copy full segments
var count = _fullSegments.Count;
for (var i = 0; i < count; i++)
{
var segment = _fullSegments[i];
await destination.WriteAsync(segment, 0, segment.Length);
}
}
await destination.WriteAsync(_currentSegment, 0, _position);
}
public byte[] ToArray()
@ -157,11 +202,12 @@ namespace Microsoft.AspNetCore.Internal
if (_fullSegments != null)
{
// Copy full segments
for (var i = 0; i < _fullSegments.Count; i++)
var count = _fullSegments.Count;
for (var i = 0; i < count; i++)
{
_fullSegments[i].CopyTo(result, totalWritten);
totalWritten += _segmentSize;
var segment = _fullSegments[i];
segment.CopyTo(result, totalWritten);
totalWritten += segment.Length;
}
}
@ -170,5 +216,66 @@ namespace Microsoft.AspNetCore.Internal
return result;
}
public override void Flush() { }
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
public override int Read(byte[] buffer, int offset, int count) => throw new NotSupportedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotSupportedException();
public override void SetLength(long value) => throw new NotSupportedException();
public override void WriteByte(byte value)
{
if (_currentSegment != null && (uint)_position < (uint)_currentSegment.Length)
{
_currentSegment[_position] = value;
}
else
{
AddSegment();
_currentSegment[0] = value;
}
_position++;
_bytesWritten++;
}
public override void Write(byte[] buffer, int offset, int count)
{
var position = _position;
if (_currentSegment != null && position < _currentSegment.Length - count)
{
Buffer.BlockCopy(buffer, offset, _currentSegment, position, count);
_position = position + count;
_bytesWritten += count;
}
else
{
BuffersExtensions.Write(this, buffer.AsSpan(offset, count));
}
}
#if NETCOREAPP2_1
public override void Write(ReadOnlySpan<byte> span)
{
if (_currentSegment != null && span.TryCopyTo(_currentSegment.AsSpan().Slice(_position)))
{
_position += span.Length;
_bytesWritten += span.Length;
}
else
{
BuffersExtensions.Write(this, span);
}
}
#endif
protected override void Dispose(bool disposing)
{
if (disposing)
{
Reset();
}
}
}
}

View File

@ -12,11 +12,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Formatters
// will not occur (is not a valid character) and therefore it is safe to not escape it
public static readonly byte RecordSeparator = 0x1e;
public static void WriteRecordSeparator(Stream output)
{
output.WriteByte(RecordSeparator);
}
public static void WriteRecordSeparator(IBufferWriter<byte> output)
{
var buffer = output.GetSpan(1);

View File

@ -1,163 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public sealed class LimitArrayPoolWriteStream : Stream
{
private const int MaxByteArrayLength = 0x7FFFFFC7;
private const int InitialLength = 256;
private readonly int _maxBufferSize;
private byte[] _buffer;
private int _length;
public LimitArrayPoolWriteStream() : this(MaxByteArrayLength) { }
public LimitArrayPoolWriteStream(int maxBufferSize) : this(maxBufferSize, InitialLength) { }
public LimitArrayPoolWriteStream(int maxBufferSize, long capacity)
{
if (capacity < InitialLength)
{
capacity = InitialLength;
}
else if (capacity > maxBufferSize)
{
throw CreateOverCapacityException(maxBufferSize);
}
_maxBufferSize = maxBufferSize;
_buffer = ArrayPool<byte>.Shared.Rent((int)capacity);
}
protected override void Dispose(bool disposing)
{
if (_buffer != null)
{
ArrayPool<byte>.Shared.Return(_buffer);
_buffer = null;
}
base.Dispose(disposing);
}
public ArraySegment<byte> GetBuffer() => new ArraySegment<byte>(_buffer, 0, _length);
public byte[] ToArray()
{
var arr = new byte[_length];
Buffer.BlockCopy(_buffer, 0, arr, 0, _length);
return arr;
}
private void EnsureCapacity(int value)
{
if ((uint)value > (uint)_maxBufferSize) // value cast handles overflow to negative as well
{
throw CreateOverCapacityException(_maxBufferSize);
}
else if (value > _buffer.Length)
{
Grow(value);
}
}
private void Grow(int value)
{
Debug.Assert(value > _buffer.Length);
// Extract the current buffer to be replaced.
var currentBuffer = _buffer;
_buffer = null;
// Determine the capacity to request for the new buffer. It should be
// at least twice as long as the current one, if not more if the requested
// value is more than that. If the new value would put it longer than the max
// allowed byte array, than shrink to that (and if the required length is actually
// longer than that, we'll let the runtime throw).
var twiceLength = 2 * (uint)currentBuffer.Length;
var newCapacity = twiceLength > MaxByteArrayLength ?
(value > MaxByteArrayLength ? value : MaxByteArrayLength) :
Math.Max(value, (int)twiceLength);
// Get a new buffer, copy the current one to it, return the current one, and
// set the new buffer as current.
var newBuffer = ArrayPool<byte>.Shared.Rent(newCapacity);
Buffer.BlockCopy(currentBuffer, 0, newBuffer, 0, _length);
ArrayPool<byte>.Shared.Return(currentBuffer);
_buffer = newBuffer;
}
public override void Write(byte[] buffer, int offset, int count)
{
Debug.Assert(buffer != null);
Debug.Assert(offset >= 0);
Debug.Assert(count >= 0);
EnsureCapacity(_length + count);
Buffer.BlockCopy(buffer, offset, _buffer, _length, count);
_length += count;
}
#if NETCOREAPP2_1
public override void Write(ReadOnlySpan<byte> source)
{
EnsureCapacity(_length + source.Length);
source.CopyTo(new Span<byte>(_buffer, _length, source.Length));
_length += source.Length;
}
#endif
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
Write(buffer, offset, count);
return Task.CompletedTask;
}
#if NETCOREAPP2_1
public override ValueTask WriteAsync(ReadOnlyMemory<byte> source, CancellationToken cancellationToken = default)
{
Write(source.Span);
return default;
}
#endif
public override void WriteByte(byte value)
{
var newLength = _length + 1;
EnsureCapacity(newLength);
_buffer[_length] = value;
_length = newLength;
}
public override void Flush() { }
public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
public override long Length => _length;
public override bool CanWrite => true;
public override bool CanRead => false;
public override bool CanSeek => false;
public override long Position
{
get => throw new NotSupportedException();
set => throw new NotSupportedException();
}
public override int Read(byte[] buffer, int offset, int count) { throw new NotSupportedException(); }
public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); }
public override void SetLength(long value) { throw new NotSupportedException(); }
private static Exception CreateOverCapacityException(int maxBufferSize)
{
return new InvalidOperationException($"Buffer size of {maxBufferSize} exceeded.");
}
}
}

View File

@ -9,6 +9,7 @@ using System.IO;
using System.Runtime.ExceptionServices;
using System.Runtime.InteropServices;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Formatters;
using Microsoft.Extensions.Options;
using MsgPack;
@ -263,15 +264,20 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
public void WriteMessage(HubMessage message, IBufferWriter<byte> output)
{
using (var stream = new LimitArrayPoolWriteStream())
var writer = MemoryBufferWriter.Get();
try
{
// Write message to a buffer so we can get its length
WriteMessageCore(message, stream);
var buffer = stream.GetBuffer();
WriteMessageCore(message, writer);
// Write length then message to output
BinaryMessageFormatter.WriteLengthPrefix(buffer.Count, output);
output.Write(buffer);
BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output);
writer.CopyTo(output);
}
finally
{
MemoryBufferWriter.Return(writer);
}
}

View File

@ -7,6 +7,10 @@
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\MemoryBufferWriter.cs" Link="Internal\MemoryBufferWriter.cs" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="MsgPack.Cli" Version="$(MsgPackCliPackageVersion)" />
</ItemGroup>

View File

@ -170,7 +170,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
var output = new MemoryStream();
output.Write(message, 0, message.Length);
TextMessageFormatter.WriteRecordSeparator(output);
output.WriteByte(TextMessageFormatter.RecordSeparator);
return output.ToArray();
}
}

View File

@ -18,7 +18,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Formatters
{
var buffer = Encoding.UTF8.GetBytes("ABC");
ms.Write(buffer, 0, buffer.Length);
TextMessageFormatter.WriteRecordSeparator(ms);
ms.WriteByte(TextMessageFormatter.RecordSeparator);
Assert.Equal("ABC\u001e", Encoding.UTF8.GetString(ms.ToArray()));
}
}

View File

@ -271,7 +271,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
var output = new MemoryStream();
output.Write(message, 0, message.Length);
TextMessageFormatter.WriteRecordSeparator(output);
output.WriteByte(TextMessageFormatter.RecordSeparator);
return output.ToArray();
}

View File

@ -0,0 +1,241 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Buffers;
using System.IO;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Internal;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
{
public class MemoryBufferWriterTests
{
private static int MinimumSegmentSize;
static MemoryBufferWriterTests()
{
var buffer = ArrayPool<byte>.Shared.Rent(1);
// Compute the minimum segment size of the array pool
MinimumSegmentSize = buffer.Length;
ArrayPool<byte>.Shared.Return(buffer);
}
[Fact]
public void WritingNotingGivesEmptyData()
{
using (var bufferWriter = new MemoryBufferWriter())
{
Assert.Equal(0, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Empty(data);
}
}
[Fact]
public void WriteByteWorksAsFirstCall()
{
using (var bufferWriter = new MemoryBufferWriter())
{
bufferWriter.WriteByte(234);
var data = bufferWriter.ToArray();
Assert.Equal(1, bufferWriter.Length);
Assert.Single(data);
Assert.Equal(234, data[0]);
}
}
[Fact]
public void WriteByteWorksIfFirstByteInNewSegment()
{
var inputSize = MinimumSegmentSize;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(16, bufferWriter.Length);
bufferWriter.WriteByte(16);
Assert.Equal(17, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Equal(input, data.Take(16));
Assert.Equal(16, data[16]);
}
}
[Fact]
public void WriteByteWorksIfSegmentHasSpace()
{
var input = new byte[] { 11, 12, 13 };
using (var bufferWriter = new MemoryBufferWriter())
{
bufferWriter.Write(input, 0, input.Length);
bufferWriter.WriteByte(14);
Assert.Equal(4, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Equal(4, data.Length);
Assert.Equal(11, data[0]);
Assert.Equal(12, data[1]);
Assert.Equal(13, data[2]);
Assert.Equal(14, data[3]);
}
}
[Fact]
public void ToArrayWithExactlyFullSegmentsWorks()
{
var inputSize = MinimumSegmentSize * 2;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Equal(input, data);
}
}
[Fact]
public void ToArrayWithSomeFullSegmentsWorks()
{
var inputSize = (MinimumSegmentSize * 2) + 1;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Equal(input, data);
}
}
[Fact]
public async Task CopyToAsyncWithExactlyFullSegmentsWorks()
{
var inputSize = MinimumSegmentSize * 2;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
var ms = new MemoryStream();
await bufferWriter.CopyToAsync(ms);
var data = ms.ToArray();
Assert.Equal(input, data);
}
}
[Fact]
public async Task CopyToAsyncWithSomeFullSegmentsWorks()
{
// 2 segments + 1 extra byte
var inputSize = (MinimumSegmentSize * 2) + 1;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
var ms = new MemoryStream();
await bufferWriter.CopyToAsync(ms);
var data = ms.ToArray();
Assert.Equal(input, data);
}
}
[Fact]
public void CopyToWithExactlyFullSegmentsWorks()
{
var inputSize = MinimumSegmentSize * 2;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
using (var destination = new MemoryBufferWriter())
{
bufferWriter.CopyTo(destination);
var data = destination.ToArray();
Assert.Equal(input, data);
}
}
}
[Fact]
public void CopyToWithSomeFullSegmentsWorks()
{
var inputSize = (MinimumSegmentSize * 2) + 1;
var input = Enumerable.Range(0, inputSize).Select(i => (byte)i).ToArray();
using (var bufferWriter = new MemoryBufferWriter(MinimumSegmentSize))
{
bufferWriter.Write(input, 0, input.Length);
Assert.Equal(input.Length, bufferWriter.Length);
using (var destination = new MemoryBufferWriter())
{
bufferWriter.CopyTo(destination);
var data = destination.ToArray();
Assert.Equal(input, data);
}
}
}
#if NETCOREAPP2_1
[Fact]
public void WriteSpanWorksAtNonZeroOffset()
{
using (var bufferWriter = new MemoryBufferWriter())
{
bufferWriter.WriteByte(1);
bufferWriter.Write(new byte[] { 2, 3, 4 }.AsSpan());
Assert.Equal(4, bufferWriter.Length);
var data = bufferWriter.ToArray();
Assert.Equal(4, data.Length);
Assert.Equal(1, data[0]);
Assert.Equal(2, data[1]);
Assert.Equal(3, data[2]);
Assert.Equal(4, data[3]);
}
}
#endif
[Fact]
public void ResetResetsTheMemoryBufferWriter()
{
var bufferWriter = new MemoryBufferWriter();
bufferWriter.WriteByte(1);
Assert.Equal(1, bufferWriter.Length);
bufferWriter.Reset();
Assert.Equal(0, bufferWriter.Length);
}
[Fact]
public void DisposeResetsTheMemoryBufferWriter()
{
var bufferWriter = new MemoryBufferWriter();
bufferWriter.WriteByte(1);
Assert.Equal(1, bufferWriter.Length);
bufferWriter.Dispose();
Assert.Equal(0, bufferWriter.Length);
}
}
}