From c852bdcc332ffb998ec6a5b226e35d5e74d24009 Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Wed, 21 Nov 2018 11:47:39 -0800 Subject: [PATCH] Avoid zero-byte send in WebSockets (#3326) --- src/Common/WebSocketExtensions.cs | 16 +++++++--- .../WebSocketsTests.cs | 31 +++++++++++++++++++ 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/src/Common/WebSocketExtensions.cs b/src/Common/WebSocketExtensions.cs index a15ad78891..fedb954296 100644 --- a/src/Common/WebSocketExtensions.cs +++ b/src/Common/WebSocketExtensions.cs @@ -39,22 +39,28 @@ namespace System.Net.WebSockets private static async ValueTask SendMultiSegmentAsync(WebSocket webSocket, ReadOnlySequence buffer, WebSocketMessageType webSocketMessageType, CancellationToken cancellationToken = default) { var position = buffer.Start; + // Get a segment before the loop so we can be one segment behind while writing + // This allows us to do a non-zero byte write for the endOfMessage = true send + buffer.TryGet(ref position, out var prevSegment); while (buffer.TryGet(ref position, out var segment)) { #if NETCOREAPP3_0 - await webSocket.SendAsync(segment, webSocketMessageType, endOfMessage: false, cancellationToken); + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: false, cancellationToken); #else - var isArray = MemoryMarshal.TryGetArray(segment, out var arraySegment); + var isArray = MemoryMarshal.TryGetArray(prevSegment, out var arraySegment); Debug.Assert(isArray); await webSocket.SendAsync(arraySegment, webSocketMessageType, endOfMessage: false, cancellationToken); #endif + prevSegment = segment; } - // Empty end of message frame + // End of message frame #if NETCOREAPP3_0 - await webSocket.SendAsync(Memory.Empty, webSocketMessageType, endOfMessage: true, cancellationToken); + await webSocket.SendAsync(prevSegment, webSocketMessageType, endOfMessage: true, cancellationToken); #else - await webSocket.SendAsync(new ArraySegment(Array.Empty()), webSocketMessageType, endOfMessage: true, cancellationToken); + var isArrayEnd = MemoryMarshal.TryGetArray(prevSegment, out var arraySegmentEnd); + Debug.Assert(isArrayEnd); + await webSocket.SendAsync(arraySegmentEnd, webSocketMessageType, endOfMessage: true, cancellationToken); #endif } } diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs index 0af2f65812..8068853f17 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/WebSocketsTests.cs @@ -396,5 +396,36 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests } } } + + [Fact] + public async Task MultiSegmentSendWillNotSendEmptyEndOfMessageFrame() + { + using (var feature = new TestWebSocketConnectionFeature()) + { + var serverSocket = await feature.AcceptAsync(); + var sequence = ReadOnlySequenceFactory.CreateSegments(new byte[] { 1 }, new byte[] { 15 }); + Assert.False(sequence.IsSingleSegment); + + await serverSocket.SendAsync(sequence, WebSocketMessageType.Text); + + // Run the client socket + var client = feature.Client.ExecuteAndCaptureFramesAsync(); + + await serverSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "", default); + + var messages = await client.OrTimeout(); + Assert.Equal(2, messages.Received.Count); + + // First message: 1 byte, endOfMessage false + Assert.Single(messages.Received[0].Buffer); + Assert.Equal(1, messages.Received[0].Buffer[0]); + Assert.False(messages.Received[0].EndOfMessage); + + // Second message: 1 byte, endOfMessage true + Assert.Single(messages.Received[1].Buffer); + Assert.Equal(15, messages.Received[1].Buffer[0]); + Assert.True(messages.Received[1].EndOfMessage); + } + } } }