From b0c4e9d0f71766340a6285ab0d5c17e01542f6be Mon Sep 17 00:00:00 2001 From: BrennanConroy Date: Thu, 5 Apr 2018 18:50:30 -0700 Subject: [PATCH] 0 byte read in WebSockets (#1878) --- .../Internal/WebSocketsTransport.cs | 24 ++++++++-- .../Transports/WebSocketsTransport.cs | 14 +++++- .../TestWebSocketConnectionFeature.cs | 48 ++++++++++++++----- 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs index db06b7dae9..e3ef3e3309 100644 --- a/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections.Client/Internal/WebSocketsTransport.cs @@ -183,9 +183,25 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal { while (true) { - var memory = _application.Output.GetMemory(); - #if NETCOREAPP2_1 + // Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read + var result = await socket.ReceiveAsync(Memory.Empty, CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Close) + { + Log.WebSocketClosed(_logger, _webSocket.CloseStatus); + + if (_webSocket.CloseStatus != WebSocketCloseStatus.NormalClosure) + { + throw new InvalidOperationException($"Websocket closed with error: {_webSocket.CloseStatus}."); + } + + return; + } +#endif + var memory = _application.Output.GetMemory(); +#if NETCOREAPP2_1 + // Because we checked the CloseStatus from the 0 byte read above, we don't need to check again after reading var receiveResult = await socket.ReceiveAsync(memory, CancellationToken.None); #else var isArray = MemoryMarshal.TryGetArray(memory, out var arraySegment); @@ -193,7 +209,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal // Exceptions are handled above where the send and receive tasks are being run. var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None); -#endif + if (receiveResult.MessageType == WebSocketMessageType.Close) { Log.WebSocketClosed(_logger, _webSocket.CloseStatus); @@ -205,7 +221,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Client.Internal return; } - +#endif Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); _application.Output.Advance(receiveResult.Count); diff --git a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs index 24812e93b8..e8f51781af 100644 --- a/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Http.Connections/Internal/Transports/WebSocketsTransport.cs @@ -140,9 +140,19 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports { while (true) { +#if NETCOREAPP2_1 + // Do a 0 byte read so that idle connections don't allocate a buffer when waiting for a read + var result = await socket.ReceiveAsync(Memory.Empty, CancellationToken.None); + + if (result.MessageType == WebSocketMessageType.Close) + { + return; + } +#endif var memory = _application.Output.GetMemory(); #if NETCOREAPP2_1 + // Because we checked the CloseStatus from the 0 byte read above, we don't need to check again after reading var receiveResult = await socket.ReceiveAsync(memory, CancellationToken.None); #else var isArray = MemoryMarshal.TryGetArray(memory, out var arraySegment); @@ -150,12 +160,12 @@ namespace Microsoft.AspNetCore.Http.Connections.Internal.Transports // Exceptions are handled above where the send and receive tasks are being run. var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None); -#endif + if (receiveResult.MessageType == WebSocketMessageType.Close) { return; } - +#endif Log.MessageReceived(_logger, receiveResult.MessageType, receiveResult.Count, receiveResult.EndOfMessage); _application.Output.Advance(receiveResult.Count); diff --git a/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs b/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs index e9be9faa71..9f6e1a5092 100644 --- a/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs +++ b/test/Microsoft.AspNetCore.Http.Connections.Tests/TestWebSocketConnectionFeature.cs @@ -40,6 +40,7 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private WebSocketState _state; + private WebSocketMessage _internalBuffer = new WebSocketMessage(); public WebSocketChannel(ChannelReader input, ChannelWriter output) { @@ -106,23 +107,44 @@ namespace Microsoft.AspNetCore.Http.Connections.Tests { try { - await _input.WaitToReadAsync(); - - if (_input.TryRead(out var message)) + if (_internalBuffer.Buffer == null || _internalBuffer.Buffer.Length == 0) { - if (message.MessageType == WebSocketMessageType.Close) + await _input.WaitToReadAsync(); + + if (_input.TryRead(out var message)) { - _state = WebSocketState.CloseReceived; - _closeStatus = message.CloseStatus; - _closeStatusDescription = message.CloseStatusDescription; - return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); + if (message.MessageType == WebSocketMessageType.Close) + { + _state = WebSocketState.CloseReceived; + _closeStatus = message.CloseStatus; + _closeStatusDescription = message.CloseStatusDescription; + return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); + } + + _internalBuffer = message; } - - // REVIEW: This assumes the buffer passed in is > the buffer received - Buffer.BlockCopy(message.Buffer, 0, buffer.Array, buffer.Offset, message.Buffer.Length); - - return new WebSocketReceiveResult(message.Buffer.Length, message.MessageType, message.EndOfMessage); } + + var length = _internalBuffer.Buffer.Length; + if (buffer.Count - buffer.Offset < _internalBuffer.Buffer.Length) + { + length = Math.Min(buffer.Count - buffer.Offset, _internalBuffer.Buffer.Length); + Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); + } + else + { + Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); + } + + var endOfMessage = _internalBuffer.EndOfMessage; + if (length > 0) + { + // Remove the sent bytes from the remaining buffer + _internalBuffer.Buffer = _internalBuffer.Buffer.AsMemory().Slice(length).ToArray(); + endOfMessage = _internalBuffer.Buffer.Length == 0 && endOfMessage; + } + + return new WebSocketReceiveResult(length, _internalBuffer.MessageType, endOfMessage); } catch (WebSocketException ex) {