SSE formatting refactor (#1916)

This commit is contained in:
James Newton-King 2018-04-11 17:13:15 +12:00 committed by GitHub
parent 83821a028d
commit b30c2fecbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 153 additions and 100 deletions

View File

@ -1,10 +1,12 @@
using System;
using System.Buffers;
using System.IO;
using System.Threading.Tasks;
using BenchmarkDotNet.Attributes;
using Microsoft.AspNetCore.Http.Connections.Client.Internal;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
@ -12,15 +14,32 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
{
private ServerSentEventsMessageParser _parser;
private byte[] _sseFormattedData;
private byte[] _rawData;
private ReadOnlySequence<byte> _rawData;
[Params(Message.NoArguments, Message.FewArguments, Message.ManyArguments, Message.LargeArguments)]
public Message Input { get; set; }
[Params("json", "json-formatted")]
public string Protocol { get; set; }
[GlobalSetup]
public void GlobalSetup()
{
var hubProtocol = new JsonHubProtocol();
IHubProtocol protocol;
if (Protocol == "json")
{
protocol = new JsonHubProtocol();
}
else
{
// New line in result to trigger SSE formatting
protocol = new JsonHubProtocol
{
PayloadSerializer = { Formatting = Formatting.Indented }
};
}
HubMessage hubMessage = null;
switch (Input)
{
@ -39,9 +58,9 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
}
_parser = new ServerSentEventsMessageParser();
_rawData = hubProtocol.GetMessageBytes(hubMessage);
_rawData = new ReadOnlySequence<byte>(protocol.GetMessageBytes(hubMessage));
var ms = new MemoryStream();
ServerSentEventsMessageFormatter.WriteMessage(_rawData, ms);
ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, ms).GetAwaiter().GetResult();
_sseFormattedData = ms.ToArray();
}
@ -59,9 +78,9 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
}
[Benchmark]
public void WriteSingleMessage()
public Task WriteSingleMessage()
{
ServerSentEventsMessageFormatter.WriteMessage(_rawData, Stream.Null);
return ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, Stream.Null);
}
public enum Message

View File

@ -35,6 +35,7 @@
<MicrosoftEntityFrameworkCoreDesignPackageVersion>2.1.0-preview3-32170</MicrosoftEntityFrameworkCoreDesignPackageVersion>
<MicrosoftEntityFrameworkCoreSqlServerPackageVersion>2.1.0-preview3-32170</MicrosoftEntityFrameworkCoreSqlServerPackageVersion>
<MicrosoftEntityFrameworkCoreToolsPackageVersion>2.1.0-preview3-32170</MicrosoftEntityFrameworkCoreToolsPackageVersion>
<MicrosoftExtensionsBuffersTestingSourcesPackageVersion>2.1.0-preview3-32170</MicrosoftExtensionsBuffersTestingSourcesPackageVersion>
<MicrosoftExtensionsClosedGenericMatcherSourcesPackageVersion>2.1.0-preview3-32170</MicrosoftExtensionsClosedGenericMatcherSourcesPackageVersion>
<MicrosoftExtensionsCommandLineUtilsSourcesPackageVersion>2.1.0-preview3-32170</MicrosoftExtensionsCommandLineUtilsSourcesPackageVersion>
<MicrosoftExtensionsConfigurationCommandLinePackageVersion>2.1.0-preview3-32170</MicrosoftExtensionsConfigurationCommandLinePackageVersion>

View File

@ -12,113 +12,101 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal
{
public static class ServerSentEventsMessageFormatter
{
private static readonly byte[] DataPrefix = new[] { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' };
private static readonly byte[] Newline = new[] { (byte)'\r', (byte)'\n' };
private static readonly byte[] DataPrefix = { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' };
private static readonly byte[] Newline = { (byte)'\r', (byte)'\n' };
private const byte LineFeed = (byte)'\n';
public static Task WriteMessageAsync(in ReadOnlySequence<byte> payload, Stream output)
public static async Task WriteMessageAsync(ReadOnlySequence<byte> payload, Stream output)
{
var ms = new MemoryStream();
// TODO: There are 2 improvements to be made here
// 1. Don't convert the entire payload into an array if if's multi-segmented.
// 2. Don't allocate the memory stream unless the payload contains \n. If it doesn't we can just write the buffers directly
// to the stream without modification. While it does mean that there will be smaller writes, should be fine for the most part
// since we're using reasonably sized buffers.
if (payload.IsSingleSegment)
// Payload does not contain a line feed so write it directly to output
if (payload.PositionOf(LineFeed) == null)
{
WriteMessage(payload.First, ms);
}
else
{
WriteMessage(payload.ToArray(), ms);
}
if (payload.Length > 0)
{
await output.WriteAsync(DataPrefix, 0, DataPrefix.Length);
await output.WriteAsync(payload);
await output.WriteAsync(Newline, 0, Newline.Length);
}
ms.Position = 0;
return ms.CopyToAsync(output);
}
public static void WriteMessage(ReadOnlyMemory<byte> payload, Stream output)
{
// Write the payload
WritePayload(payload, output);
// Write new \r\n
output.Write(Newline, 0, Newline.Length);
}
private static void WritePayload(ReadOnlyMemory<byte> payload, Stream output)
{
// Short-cut for empty payload
if (payload.Length == 0)
{
await output.WriteAsync(Newline, 0, Newline.Length);
return;
}
// We can't just use while(payload.Length > 0) because we need to write a blank final "data: " line
// if the payload ends in a newline. For example, consider the following payload:
// "Hello\n"
// It needs to be written as:
// data: Hello\r\n
// data: \r\n
// \r\n
// Since we slice past the newline when we find it, after writing "Hello" in the previous example, we'll
// end up with an empty payload buffer, BUT we need to write it as an empty 'data:' line, so we need
// to use a condition that ensure the only time we stop writing is when we write the slice after the final
// newline.
var ms = new MemoryStream();
// Parse payload and write formatted output to memory
await WriteMessageToMemory(ms, payload);
ms.Position = 0;
await ms.CopyToAsync(output);
}
/// <summary>
/// Gets the last memory segment in a sequence.
/// </summary>
/// <param name="source">Source sequence.</param>
/// <param name="offset">The offset the segment starts at.</param>
/// <returns>The last memory segment in a sequence.</returns>
private static ReadOnlyMemory<byte> GetLastSegment(in ReadOnlySequence<byte> source, out long offset)
{
offset = 0;
var totalLength = source.Length;
var position = source.Start;
while (source.TryGet(ref position, out ReadOnlyMemory<byte> memory))
{
// Last segment
if (offset + memory.Length >= totalLength)
{
return memory;
}
offset += memory.Length;
}
throw new InvalidOperationException("Could not get last segment from sequence.");
}
private static async Task WriteMessageToMemory(Stream output, ReadOnlySequence<byte> payload)
{
var keepWriting = true;
while (keepWriting)
{
var span = payload.Span;
// Seek to the end of buffer or newline
var sliceEnd = span.IndexOf(LineFeed);
var nextSliceStart = sliceEnd + 1;
if (sliceEnd < 0)
{
sliceEnd = payload.Length;
nextSliceStart = sliceEnd + 1;
var sliceEnd = payload.PositionOf(LineFeed);
// This is the last span
ReadOnlySequence<byte> lineSegment;
if (sliceEnd == null)
{
lineSegment = payload;
payload = ReadOnlySequence<byte>.Empty;
keepWriting = false;
}
if (sliceEnd > 0 && span[sliceEnd - 1] == '\r')
{
sliceEnd--;
}
var slice = payload.Slice(0, sliceEnd);
if (nextSliceStart >= payload.Length)
{
payload = ReadOnlyMemory<byte>.Empty;
}
else
{
payload = payload.Slice(nextSliceStart);
lineSegment = payload.Slice(payload.Start, sliceEnd.Value);
if (lineSegment.Length > 1)
{
// Check if the line ended in \r\n. If it did then trim the \r
var memory = GetLastSegment(lineSegment, out var offset);
if (memory.Span[memory.Length - 1] == '\r')
{
lineSegment = lineSegment.Slice(lineSegment.Start, offset + memory.Length - 1);
}
}
// Update payload to remove \n
payload = payload.Slice(payload.GetPosition(1, sliceEnd.Value));
}
WriteLine(slice, output);
// Write line
await output.WriteAsync(DataPrefix, 0, DataPrefix.Length);
await output.WriteAsync(lineSegment);
await output.WriteAsync(Newline, 0, Newline.Length);
}
}
private static void WriteLine(ReadOnlyMemory<byte> payload, Stream output)
{
output.Write(DataPrefix, 0, DataPrefix.Length);
#if NETCOREAPP2_1
output.Write(payload.Span);
#else
if (payload.Length > 0)
{
var isArray = MemoryMarshal.TryGetArray(payload, out var segment);
Debug.Assert(isArray);
output.Write(segment.Array, segment.Offset, segment.Count);
}
#endif
output.Write(Newline, 0, Newline.Length);
await output.WriteAsync(Newline, 0, Newline.Length);
}
}
}

View File

@ -19,6 +19,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Authentication.Core" Version="$(MicrosoftAspNetCoreAuthenticationCorePackageVersion)" />
<PackageReference Include="Microsoft.AspNetCore.Http" Version="$(MicrosoftAspNetCoreHttpPackageVersion)" />
<PackageReference Include="Microsoft.Extensions.Buffers.Testing.Sources" Version="$(MicrosoftExtensionsBuffersTestingSourcesPackageVersion)" />
<PackageReference Include="Newtonsoft.Json" Version="$(NewtonsoftJsonPackageVersion)" />
</ItemGroup>

View File

@ -1,8 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Xunit;
@ -11,18 +14,37 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
public class ServerSentEventsMessageFormatterTests
{
[Theory]
[InlineData("\r\n", "")]
[InlineData("data: Hello, World\r\n\r\n", "Hello, World")]
[InlineData("data: Hello\r\ndata: World\r\n\r\n", "Hello\r\nWorld")]
[InlineData("data: Hello\r\ndata: World\r\n\r\n", "Hello\nWorld")]
[InlineData("data: Hello\r\ndata: \r\n\r\n", "Hello\n")]
[InlineData("data: Hello\r\ndata: \r\n\r\n", "Hello\r\n")]
public void WriteTextMessage(string encoded, string payload)
[MemberData(nameof(PayloadData))]
public async Task WriteTextMessageFromSingleSegment(string encoded, string payload)
{
var buffer = new ReadOnlySequence<byte>(Encoding.UTF8.GetBytes(payload));
var output = new MemoryStream();
ServerSentEventsMessageFormatter.WriteMessage(Encoding.UTF8.GetBytes(payload), output);
await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output);
Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray()));
}
[Theory]
[MemberData(nameof(PayloadData))]
public async Task WriteTextMessageFromMultipleSegments(string encoded, string payload)
{
var buffer = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent(Encoding.UTF8.GetBytes(payload));
var output = new MemoryStream();
await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output);
Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray()));
}
public static IEnumerable<object[]> PayloadData => new List<object[]>
{
new object[] { "\r\n", "" },
new object[] { "data: Hello, World\r\n\r\n", "Hello, World" },
new object[] { "data: Hello\r\ndata: World\r\n\r\n", "Hello\r\nWorld" },
new object[] { "data: Hello\r\ndata: World\r\n\r\n", "Hello\nWorld" },
new object[] { "data: Hello\r\ndata: \r\n\r\n", "Hello\n" },
new object[] { "data: Hello\r\ndata: \r\n\r\n", "Hello\r\n" },
};
}
}

View File

@ -70,6 +70,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests
Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray()));
}
[Fact]
public async Task SSEWritesVeryLargeMessages()
{
var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, new PipeOptions(readerScheduler: PipeScheduler.Inline));
var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application);
var context = new DefaultHttpContext();
var ms = new MemoryStream();
context.Response.Body = ms;
var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory());
var task = sse.ProcessRequestAsync(context, context.RequestAborted);
string hText = new string('H', 60000);
string wText = new string('W', 60000);
await connection.Transport.Output.WriteAsync(Encoding.ASCII.GetBytes(hText + wText));
connection.Transport.Output.Complete();
await task.OrTimeout();
Assert.Equal(":\r\ndata: " + hText + wText + "\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray()));
}
[Theory]
[InlineData("Hello World", ":\r\ndata: Hello World\r\n\r\n")]
[InlineData("Hello\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")]