diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs index 3576f7b13a..a9c1f93f48 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/ServerSentEventsBenchmark.cs @@ -1,10 +1,12 @@ using System; using System.Buffers; using System.IO; +using System.Threading.Tasks; using BenchmarkDotNet.Attributes; using Microsoft.AspNetCore.Http.Connections.Client.Internal; using Microsoft.AspNetCore.Http.Connections.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { @@ -12,15 +14,32 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { private ServerSentEventsMessageParser _parser; private byte[] _sseFormattedData; - private byte[] _rawData; + private ReadOnlySequence _rawData; [Params(Message.NoArguments, Message.FewArguments, Message.ManyArguments, Message.LargeArguments)] public Message Input { get; set; } + [Params("json", "json-formatted")] + public string Protocol { get; set; } + [GlobalSetup] public void GlobalSetup() { - var hubProtocol = new JsonHubProtocol(); + IHubProtocol protocol; + + if (Protocol == "json") + { + protocol = new JsonHubProtocol(); + } + else + { + // New line in result to trigger SSE formatting + protocol = new JsonHubProtocol + { + PayloadSerializer = { Formatting = Formatting.Indented } + }; + } + HubMessage hubMessage = null; switch (Input) { @@ -39,9 +58,9 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } _parser = new ServerSentEventsMessageParser(); - _rawData = hubProtocol.GetMessageBytes(hubMessage); + _rawData = new ReadOnlySequence(protocol.GetMessageBytes(hubMessage)); var ms = new MemoryStream(); - ServerSentEventsMessageFormatter.WriteMessage(_rawData, ms); + ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, ms).GetAwaiter().GetResult(); _sseFormattedData = ms.ToArray(); } @@ -59,9 +78,9 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks } [Benchmark] - public void WriteSingleMessage() + public Task WriteSingleMessage() { - ServerSentEventsMessageFormatter.WriteMessage(_rawData, Stream.Null); + return ServerSentEventsMessageFormatter.WriteMessageAsync(_rawData, Stream.Null); } public enum Message diff --git a/build/dependencies.props b/build/dependencies.props index a73b8b05b6..80daeb9094 100644 --- a/build/dependencies.props +++ b/build/dependencies.props @@ -35,6 +35,7 @@ 2.1.0-preview3-32170 2.1.0-preview3-32170 2.1.0-preview3-32170 + 2.1.0-preview3-32170 2.1.0-preview3-32170 2.1.0-preview3-32170 2.1.0-preview3-32170 diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/ServerSentEventsMessageFormatter.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/ServerSentEventsMessageFormatter.cs index 6347cba900..970a88b8c4 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/ServerSentEventsMessageFormatter.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/ServerSentEventsMessageFormatter.cs @@ -12,113 +12,101 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal { public static class ServerSentEventsMessageFormatter { - private static readonly byte[] DataPrefix = new[] { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' }; - private static readonly byte[] Newline = new[] { (byte)'\r', (byte)'\n' }; + private static readonly byte[] DataPrefix = { (byte)'d', (byte)'a', (byte)'t', (byte)'a', (byte)':', (byte)' ' }; + private static readonly byte[] Newline = { (byte)'\r', (byte)'\n' }; private const byte LineFeed = (byte)'\n'; - public static Task WriteMessageAsync(in ReadOnlySequence payload, Stream output) + public static async Task WriteMessageAsync(ReadOnlySequence payload, Stream output) { - var ms = new MemoryStream(); - - // TODO: There are 2 improvements to be made here - // 1. Don't convert the entire payload into an array if if's multi-segmented. - // 2. Don't allocate the memory stream unless the payload contains \n. If it doesn't we can just write the buffers directly - // to the stream without modification. While it does mean that there will be smaller writes, should be fine for the most part - // since we're using reasonably sized buffers. - - if (payload.IsSingleSegment) + // Payload does not contain a line feed so write it directly to output + if (payload.PositionOf(LineFeed) == null) { - WriteMessage(payload.First, ms); - } - else - { - WriteMessage(payload.ToArray(), ms); - } + if (payload.Length > 0) + { + await output.WriteAsync(DataPrefix, 0, DataPrefix.Length); + await output.WriteAsync(payload); + await output.WriteAsync(Newline, 0, Newline.Length); + } - ms.Position = 0; - - return ms.CopyToAsync(output); - } - - public static void WriteMessage(ReadOnlyMemory payload, Stream output) - { - // Write the payload - WritePayload(payload, output); - - // Write new \r\n - output.Write(Newline, 0, Newline.Length); - } - - private static void WritePayload(ReadOnlyMemory payload, Stream output) - { - // Short-cut for empty payload - if (payload.Length == 0) - { + await output.WriteAsync(Newline, 0, Newline.Length); return; } - // We can't just use while(payload.Length > 0) because we need to write a blank final "data: " line - // if the payload ends in a newline. For example, consider the following payload: - // "Hello\n" - // It needs to be written as: - // data: Hello\r\n - // data: \r\n - // \r\n - // Since we slice past the newline when we find it, after writing "Hello" in the previous example, we'll - // end up with an empty payload buffer, BUT we need to write it as an empty 'data:' line, so we need - // to use a condition that ensure the only time we stop writing is when we write the slice after the final - // newline. + var ms = new MemoryStream(); + + // Parse payload and write formatted output to memory + await WriteMessageToMemory(ms, payload); + ms.Position = 0; + + await ms.CopyToAsync(output); + } + + /// + /// Gets the last memory segment in a sequence. + /// + /// Source sequence. + /// The offset the segment starts at. + /// The last memory segment in a sequence. + private static ReadOnlyMemory GetLastSegment(in ReadOnlySequence source, out long offset) + { + offset = 0; + + var totalLength = source.Length; + var position = source.Start; + while (source.TryGet(ref position, out ReadOnlyMemory memory)) + { + // Last segment + if (offset + memory.Length >= totalLength) + { + return memory; + } + + offset += memory.Length; + } + + throw new InvalidOperationException("Could not get last segment from sequence."); + } + + private static async Task WriteMessageToMemory(Stream output, ReadOnlySequence payload) + { var keepWriting = true; while (keepWriting) { - var span = payload.Span; - // Seek to the end of buffer or newline - var sliceEnd = span.IndexOf(LineFeed); - var nextSliceStart = sliceEnd + 1; - if (sliceEnd < 0) - { - sliceEnd = payload.Length; - nextSliceStart = sliceEnd + 1; + var sliceEnd = payload.PositionOf(LineFeed); - // This is the last span + ReadOnlySequence lineSegment; + if (sliceEnd == null) + { + lineSegment = payload; + payload = ReadOnlySequence.Empty; keepWriting = false; } - if (sliceEnd > 0 && span[sliceEnd - 1] == '\r') - { - sliceEnd--; - } - - var slice = payload.Slice(0, sliceEnd); - - if (nextSliceStart >= payload.Length) - { - payload = ReadOnlyMemory.Empty; - } else { - payload = payload.Slice(nextSliceStart); + lineSegment = payload.Slice(payload.Start, sliceEnd.Value); + + if (lineSegment.Length > 1) + { + // Check if the line ended in \r\n. If it did then trim the \r + var memory = GetLastSegment(lineSegment, out var offset); + if (memory.Span[memory.Length - 1] == '\r') + { + lineSegment = lineSegment.Slice(lineSegment.Start, offset + memory.Length - 1); + } + } + + // Update payload to remove \n + payload = payload.Slice(payload.GetPosition(1, sliceEnd.Value)); } - WriteLine(slice, output); + // Write line + await output.WriteAsync(DataPrefix, 0, DataPrefix.Length); + await output.WriteAsync(lineSegment); + await output.WriteAsync(Newline, 0, Newline.Length); } - } - private static void WriteLine(ReadOnlyMemory payload, Stream output) - { - output.Write(DataPrefix, 0, DataPrefix.Length); - -#if NETCOREAPP2_1 - output.Write(payload.Span); -#else - if (payload.Length > 0) - { - var isArray = MemoryMarshal.TryGetArray(payload, out var segment); - Debug.Assert(isArray); - output.Write(segment.Array, segment.Offset, segment.Count); - } -#endif - output.Write(Newline, 0, Newline.Length); + await output.WriteAsync(Newline, 0, Newline.Length); } } } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/Microsoft.AspNetCore.Http.Connections.Tests.csproj b/test/Microsoft.AspNetCore.Http.Connections.Tests/Microsoft.AspNetCore.Http.Connections.Tests.csproj index 6ff6487c12..b161cd535b 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/Microsoft.AspNetCore.Http.Connections.Tests.csproj +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/Microsoft.AspNetCore.Http.Connections.Tests.csproj @@ -19,6 +19,7 @@ + diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsMessageFormatterTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsMessageFormatterTests.cs index c70a219c1e..2a58e8d4dd 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsMessageFormatterTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsMessageFormatterTests.cs @@ -1,8 +1,11 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Buffers; +using System.Collections.Generic; using System.IO; using System.Text; +using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Connections.Internal; using Xunit; @@ -11,18 +14,37 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests public class ServerSentEventsMessageFormatterTests { [Theory] - [InlineData("\r\n", "")] - [InlineData("data: Hello, World\r\n\r\n", "Hello, World")] - [InlineData("data: Hello\r\ndata: World\r\n\r\n", "Hello\r\nWorld")] - [InlineData("data: Hello\r\ndata: World\r\n\r\n", "Hello\nWorld")] - [InlineData("data: Hello\r\ndata: \r\n\r\n", "Hello\n")] - [InlineData("data: Hello\r\ndata: \r\n\r\n", "Hello\r\n")] - public void WriteTextMessage(string encoded, string payload) + [MemberData(nameof(PayloadData))] + public async Task WriteTextMessageFromSingleSegment(string encoded, string payload) { + var buffer = new ReadOnlySequence(Encoding.UTF8.GetBytes(payload)); + var output = new MemoryStream(); - ServerSentEventsMessageFormatter.WriteMessage(Encoding.UTF8.GetBytes(payload), output); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output); Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); } + + [Theory] + [MemberData(nameof(PayloadData))] + public async Task WriteTextMessageFromMultipleSegments(string encoded, string payload) + { + var buffer = ReadOnlySequenceFactory.SegmentPerByteFactory.CreateWithContent(Encoding.UTF8.GetBytes(payload)); + + var output = new MemoryStream(); + await ServerSentEventsMessageFormatter.WriteMessageAsync(buffer, output); + + Assert.Equal(encoded, Encoding.UTF8.GetString(output.ToArray())); + } + + public static IEnumerable PayloadData => new List + { + new object[] { "\r\n", "" }, + new object[] { "data: Hello, World\r\n\r\n", "Hello, World" }, + new object[] { "data: Hello\r\ndata: World\r\n\r\n", "Hello\r\nWorld" }, + new object[] { "data: Hello\r\ndata: World\r\n\r\n", "Hello\nWorld" }, + new object[] { "data: Hello\r\ndata: \r\n\r\n", "Hello\n" }, + new object[] { "data: Hello\r\ndata: \r\n\r\n", "Hello\r\n" }, + }; } } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsTests.cs index 32ebbcf208..ab0863c5e4 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/ServerSentEventsTests.cs @@ -70,6 +70,28 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests Assert.Equal(":\r\ndata: Hello\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); } + [Fact] + public async Task SSEWritesVeryLargeMessages() + { + var pair = DuplexPipe.CreateConnectionPair(PipeOptions.Default, new PipeOptions(readerScheduler: PipeScheduler.Inline)); + var connection = new DefaultConnectionContext("foo", pair.Transport, pair.Application); + var context = new DefaultHttpContext(); + + var ms = new MemoryStream(); + context.Response.Body = ms; + var sse = new ServerSentEventsTransport(connection.Application.Input, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + + var task = sse.ProcessRequestAsync(context, context.RequestAborted); + + string hText = new string('H', 60000); + string wText = new string('W', 60000); + + await connection.Transport.Output.WriteAsync(Encoding.ASCII.GetBytes(hText + wText)); + connection.Transport.Output.Complete(); + await task.OrTimeout(); + Assert.Equal(":\r\ndata: " + hText + wText + "\r\n\r\n", Encoding.ASCII.GetString(ms.ToArray())); + } + [Theory] [InlineData("Hello World", ":\r\ndata: Hello World\r\n\r\n")] [InlineData("Hello\nWorld", ":\r\ndata: Hello\r\ndata: World\r\n\r\n")]