diff --git a/scripts/UpdateCoreFxCode.ps1 b/scripts/UpdateCoreFxCode.ps1 index d94e04a161..09cfc3e448 100644 --- a/scripts/UpdateCoreFxCode.ps1 +++ b/scripts/UpdateCoreFxCode.ps1 @@ -3,7 +3,7 @@ param([string]$CoreFxRepoRoot) $RepoRoot = Split-Path -Parent $PSScriptRoot $FilesToCopy = @( - "src\System.Net.WebSockets.Client\src\System\Net\WebSockets\ManagedWebSocket.cs", + "src\Common\src\System\Net\WebSockets\ManagedWebSocket.cs", "src\Common\src\System\Net\WebSockets\WebSocketValidate.cs" ) diff --git a/src/Microsoft.AspNetCore.WebSockets/Internal/WebSocketFactory.cs b/src/Microsoft.AspNetCore.WebSockets/Internal/WebSocketFactory.cs index 9946d5e9b2..d9f13b6d3d 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Internal/WebSocketFactory.cs +++ b/src/Microsoft.AspNetCore.WebSockets/Internal/WebSocketFactory.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.WebSockets.Internal stream, isServer: false, subprotocol: subProtocol, - keepAliveIntervalSeconds: (int)keepAliveInterval.TotalSeconds, + keepAliveInterval: keepAliveInterval, receiveBufferSize: receiveBufferSize); } @@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.WebSockets.Internal stream, isServer: true, subprotocol: subProtocol, - keepAliveIntervalSeconds: (int)keepAliveInterval.TotalSeconds, + keepAliveInterval: keepAliveInterval, receiveBufferSize: receiveBufferSize); } } diff --git a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs similarity index 86% rename from src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs rename to src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs index 8253aebde5..494b6a786d 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/System.Net.WebSockets.Client/src/System/Net/WebSockets/ManagedWebSocket.cs +++ b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/ManagedWebSocket.cs @@ -2,8 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. +using System.Buffers; using System.Diagnostics; using System.IO; +using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using System.Security.Cryptography; @@ -30,14 +32,14 @@ namespace System.Net.WebSockets /// The connected Stream. /// true if this is the server-side of the connection; false if this is the client-side of the connection. /// The agreed upon subprotocol for the connection. - /// The interval to use for keep-alive pings. + /// The interval to use for keep-alive pings. /// The buffer size to use for received data. + /// Optional buffer to use for receives. /// The created instance. public static ManagedWebSocket CreateFromConnectedStream( - Stream stream, bool isServer, string subprotocol, - int keepAliveIntervalSeconds = 30, int receiveBufferSize = 0x1000) + Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize, ArraySegment? receiveBuffer = null) { - return new ManagedWebSocket(stream, isServer, subprotocol, TimeSpan.FromSeconds(keepAliveIntervalSeconds), receiveBufferSize); + return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval, receiveBufferSize, receiveBuffer); } /// Per-thread cached 4-byte mask byte array. @@ -81,7 +83,9 @@ namespace System.Net.WebSockets /// CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs. private readonly CancellationTokenSource _abortSource = new CancellationTokenSource(); /// Buffer used for reading data from the network. - private readonly byte[] _receiveBuffer; + private byte[] _receiveBuffer; + /// Gets whether the receive buffer came from the ArrayPool. + private readonly bool _receiveBufferFromPool; /// /// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments. /// @@ -132,9 +136,10 @@ namespace System.Net.WebSockets /// private int _receivedMaskOffsetOffset = 0; /// - /// Buffer used to store the complete message to be sent to the stream. This is needed - /// rather than just sending a header and then the user's buffer, as we need to mutate the - /// buffered data with the mask, and we don't want to change the data in the user's buffer. + /// Temporary send buffer. This should be released back to the ArrayPool once it's + /// no longer needed for the current send operation. It is stored as an instance + /// field to minimize needing to pass it around and to avoid it becoming a field on + /// various async state machine objects. /// private byte[] _sendBuffer; /// @@ -168,7 +173,8 @@ namespace System.Net.WebSockets /// The agreed upon subprotocol for the connection. /// The interval to use for keep-alive pings. /// The buffer size to use for received data. - private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize) + /// Optional buffer to use for receives + private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize, ArraySegment? receiveBuffer) { Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null"); Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null"); @@ -183,7 +189,20 @@ namespace System.Net.WebSockets _stream = stream; _isServer = isServer; _subprotocol = subprotocol; - _receiveBuffer = new byte[Math.Max(receiveBufferSize, MaxMessageHeaderLength)]; + + // If we were provided with a buffer to use, use it, as long as it's big enough for our needs, and for simplicity + // as long as we're not supposed to use only a portion of it. If it doesn't meet our criteria, just create a new one. + if (receiveBuffer.HasValue && + receiveBuffer.Value.Offset == 0 && receiveBuffer.Value.Count == receiveBuffer.Value.Array.Length && + receiveBuffer.Value.Count >= MaxMessageHeaderLength) + { + _receiveBuffer = receiveBuffer.Value.Array; + } + else + { + _receiveBufferFromPool = true; + _receiveBuffer = ArrayPool.Shared.Rent(Math.Max(receiveBufferSize, MaxMessageHeaderLength)); + } // Set up the abort source so that if it's triggered, we transition the instance appropriately. _abortSource.Token.Register(s => @@ -225,6 +244,12 @@ namespace System.Net.WebSockets _disposed = true; _keepAliveTimer?.Dispose(); _stream?.Dispose(); + if (_receiveBufferFromPool) + { + byte[] old = _receiveBuffer; + _receiveBuffer = null; + ArrayPool.Shared.Return(old); + } if (_state < WebSocketState.Aborted) { _state = WebSocketState.Closed; @@ -253,7 +278,7 @@ namespace System.Net.WebSockets try { - WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates); + WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates); ThrowIfOperationInProgress(_lastSendAsync); } catch (Exception exc) @@ -369,7 +394,7 @@ namespace System.Net.WebSockets // If we get here, the cancellation token is not cancelable so we don't have to worry about it, // and we own the semaphore, so we don't need to asynchronously wait for it. Task writeTask = null; - bool releaseSemaphore = true; + bool releaseSemaphoreAndSendBuffer = true; try { // Write the payload synchronously to the buffer, then write that buffer out to the network. @@ -386,9 +411,9 @@ namespace System.Net.WebSockets } // Up until this point, if an exception occurred (such as when accessing _stream or when - // calling GetResult), we want to release the semaphore. After this point, the semaphore needs - // to remain held until writeTask completes. - releaseSemaphore = false; + // calling GetResult), we want to release the semaphore and the send buffer. After this point, + // both need to be held until writeTask completes. + releaseSemaphoreAndSendBuffer = false; } catch (Exception exc) { @@ -398,9 +423,10 @@ namespace System.Net.WebSockets } finally { - if (releaseSemaphore) + if (releaseSemaphoreAndSendBuffer) { _sendFrameAsyncLock.Release(); + ReleaseSendBuffer(); } } @@ -410,6 +436,7 @@ namespace System.Net.WebSockets { var thisRef = (ManagedWebSocket)s; thisRef._sendFrameAsyncLock.Release(); + thisRef.ReleaseSendBuffer(); try { t.GetAwaiter().GetResult(); } catch (Exception exc) @@ -441,14 +468,15 @@ namespace System.Net.WebSockets finally { _sendFrameAsyncLock.Release(); + ReleaseSendBuffer(); } } /// Writes a frame into the send buffer, which can then be sent over the network. private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ArraySegment payloadBuffer) { - // Grow our send buffer as needed. We reuse the buffer for all messages, with it protected by the send frame lock. - EnsureBufferLength(ref _sendBuffer, payloadBuffer.Count + MaxMessageHeaderLength); + // Ensure we have a _sendBuffer. + AllocateSendBuffer(payloadBuffer.Count + MaxMessageHeaderLength); // Write the message header data to the buffer. int headerLength; @@ -542,7 +570,7 @@ namespace System.Net.WebSockets { sendBuffer[1] = 126; sendBuffer[2] = (byte)(payload.Count / 256); - sendBuffer[3] = (byte)payload.Count; + sendBuffer[3] = unchecked((byte)payload.Count); maskOffset = 2 + sizeof(ushort); // additional 2 bytes for 16-bit length } else @@ -551,7 +579,7 @@ namespace System.Net.WebSockets int length = payload.Count; for (int i = 9; i >= 2; i--) { - sendBuffer[i] = (byte)length; + sendBuffer[i] = unchecked((byte)length); length = length / 256; } maskOffset = 2 + sizeof(ulong); // additional 8 bytes for 64-bit length @@ -983,37 +1011,44 @@ namespace System.Net.WebSockets $"Unexpected state {State}."); // Wait until we've received a close response - byte[] closeBuffer = new byte[MaxMessageHeaderLength + MaxControlPayloadLength]; - while (!_receivedCloseFrame) + byte[] closeBuffer = ArrayPool.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength); + try { - Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); - Task receiveTask; - lock (ReceiveAsyncLock) + while (!_receivedCloseFrame) { - // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. - // It could have been received between our check above and now due to a concurrent receive completing. - if (_receivedCloseFrame) + Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}"); + Task receiveTask; + lock (ReceiveAsyncLock) { - break; + // Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame. + // It could have been received between our check above and now due to a concurrent receive completing. + if (_receivedCloseFrame) + { + break; + } + + // We've not yet processed a received close frame, which means we need to wait for a received close to complete. + // There may already be one in flight, in which case we want to just wait for that one rather than kicking off + // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've + // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is + // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst + // case is we then await it, find that it's not what we need, and try again. + receiveTask = _lastReceiveAsync; + if (receiveTask == null || + (receiveTask.Status == TaskStatus.RanToCompletion && receiveTask.Result.MessageType != WebSocketMessageType.Close)) + { + _lastReceiveAsync = receiveTask = ReceiveAsyncPrivate(new ArraySegment(closeBuffer), cancellationToken); + } } - // We've not yet processed a received close frame, which means we need to wait for a received close to complete. - // There may already be one in flight, in which case we want to just wait for that one rather than kicking off - // another (we don't support concurrent receive operations). We need to kick off a new receive if either we've - // never issued a receive or if the last issued receive completed for reasons other than a close frame. There is - // a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst - // case is we then await it, find that it's not what we need, and try again. - receiveTask = _lastReceiveAsync; - if (receiveTask == null || - (receiveTask.Status == TaskStatus.RanToCompletion && receiveTask.Result.MessageType != WebSocketMessageType.Close)) - { - _lastReceiveAsync = receiveTask = ReceiveAsyncPrivate(new ArraySegment(closeBuffer), cancellationToken); - } + // Wait for whatever receive task we have. We'll then loop around again to re-check our state. + Debug.Assert(receiveTask != null); + await receiveTask.ConfigureAwait(false); } - - // Wait for whatever receive task we have. We'll then loop around again to re-check our state. - Debug.Assert(receiveTask != null); - await receiveTask.ConfigureAwait(false); + } + finally + { + ArrayPool.Shared.Return(closeBuffer); } // We're closed. Close the connection and update the status. @@ -1035,24 +1070,36 @@ namespace System.Net.WebSockets { // Close payload is two bytes containing the close status followed by a UTF8-encoding of the status description, if it exists. - byte[] buffer; - if (string.IsNullOrEmpty(closeStatusDescription)) + byte[] buffer = null; + try { - buffer = new byte[2]; + int count = 2; + if (string.IsNullOrEmpty(closeStatusDescription)) + { + buffer = ArrayPool.Shared.Rent(count); + } + else + { + count += s_textEncoding.GetByteCount(closeStatusDescription); + buffer = ArrayPool.Shared.Rent(count); + int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2); + Debug.Assert(count - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match"); + } + + ushort closeStatusValue = (ushort)closeStatus; + buffer[0] = (byte)(closeStatusValue >> 8); + buffer[1] = (byte)(closeStatusValue & 0xFF); + + await SendFrameAsync(MessageOpcode.Close, true, new ArraySegment(buffer, 0, count), cancellationToken).ConfigureAwait(false); } - else + finally { - buffer = new byte[2 + s_textEncoding.GetByteCount(closeStatusDescription)]; - int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2); - Debug.Assert(buffer.Length - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match"); + if (buffer != null) + { + ArrayPool.Shared.Return(buffer); + } } - ushort closeStatusValue = (ushort)closeStatus; - buffer[0] = (byte)(closeStatusValue >> 8); - buffer[1] = (byte)(closeStatusValue & 0xFF); - - await SendFrameAsync(MessageOpcode.Close, true, new ArraySegment(buffer), cancellationToken).ConfigureAwait(false); - lock (StateUpdateLock) { _sentCloseFrame = true; @@ -1111,15 +1158,21 @@ namespace System.Net.WebSockets } } - /// - /// Grows the specified buffer if it's not at least the specified minimum length. - /// Data is not copied if the buffer is grown. - /// - private static void EnsureBufferLength(ref byte[] buffer, int minLength) + /// Gets a send buffer from the pool. + private void AllocateSendBuffer(int minLength) { - if (buffer == null || buffer.Length < minLength) + Debug.Assert(_sendBuffer == null); // would only fail if had some catastrophic error previously that prevented cleaning up + _sendBuffer = ArrayPool.Shared.Rent(minLength); + } + + /// Releases the send buffer to the pool. + private void ReleaseSendBuffer() + { + byte[] old = _sendBuffer; + if (old != null) { - buffer = new byte[minLength]; + _sendBuffer = null; + ArrayPool.Shared.Return(old); } } @@ -1150,21 +1203,68 @@ namespace System.Net.WebSockets /// The next index into the mask to be used for future applications of the mask. private static unsafe int ApplyMask(byte[] toMask, int toMaskOffset, int mask, int maskIndex, long count) { - Debug.Assert(toMaskOffset <= toMask.Length - count, $"Unexpected inputs: {toMaskOffset}, {toMask.Length}, {count}"); - Debug.Assert(maskIndex < sizeof(int), $"Unexpected {nameof(maskIndex)}: {maskIndex}"); + int maskShift = maskIndex * 8; + int shiftedMask = (int)(((uint)mask >> maskShift) | ((uint)mask << (32 - maskShift))); - byte* maskPtr = (byte*)&mask; - fixed (byte* toMaskPtr = toMask) + // Try to use SIMD. We can if the number of bytes we're trying to mask is at least as much + // as the width of a vector and if the width is an even multiple of the mask. + if (Vector.IsHardwareAccelerated && + Vector.Count % sizeof(int) == 0 && + count >= Vector.Count) { - byte* p = toMaskPtr + toMaskOffset; - byte* end = p + count; - while (p < end) + // Mask bytes a vector at a time. + Vector maskVector = Vector.AsVectorByte(new Vector(shiftedMask)); + while (count >= Vector.Count) { - *p++ ^= maskPtr[maskIndex]; - maskIndex = (maskIndex + 1) & 3; // & 3 == faster % MaskLength + count -= Vector.Count; + (maskVector ^ new Vector(toMask, toMaskOffset)).CopyTo(toMask, toMaskOffset); + toMaskOffset += Vector.Count; } - return maskIndex; + + // Fall through to processing any remaining bytes that were less than a vector width. + // Since we processed full masks at a time, we don't need to update maskIndex, and + // toMaskOffset has already been updated to point to the correct location. } + + // If there are any bytes remaining (either we couldn't use vectors, or the count wasn't + // an even multiple of the vector width), process them without vectors. + if (count > 0) + { + fixed (byte* toMaskPtr = toMask) + { + // Get the location in the target array to continue processing. + byte* p = toMaskPtr + toMaskOffset; + + // Try to go an int at a time if the remaining data is 4-byte aligned and there's enough remaining. + if (((long)p % sizeof(int)) == 0) + { + while (count >= sizeof(int)) + { + count -= sizeof(int); + *((int*)p) ^= shiftedMask; + p += sizeof(int); + } + + // We don't need to update the maskIndex, as its mod-4 value won't have changed. + // `p` points to the remainder. + } + + // Process any remaining data a byte at a time. + if (count > 0) + { + byte* maskPtr = (byte*)&mask; + byte* end = p + count; + while (p < end) + { + *p++ ^= maskPtr[maskIndex]; + maskIndex = (maskIndex + 1) & 3; + } + } + } + } + + // Return the updated index. + return maskIndex; } /// Aborts the websocket and throws an exception if an existing operation is in progress. diff --git a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/WebSocketValidate.cs b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/WebSocketValidate.cs index 06e07f29dd..4a1018610e 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/WebSocketValidate.cs +++ b/src/Microsoft.AspNetCore.WebSockets/Internal/fx/src/Common/src/System/Net/WebSockets/WebSocketValidate.cs @@ -7,7 +7,7 @@ using System.Text; namespace System.Net.WebSockets { - internal static class WebSocketValidate + internal static partial class WebSocketValidate { internal const int MaxControlFramePayloadLength = 123; private const int CloseStatusCodeAbort = 1006; @@ -16,6 +16,34 @@ namespace System.Net.WebSockets private const int InvalidCloseStatusCodesTo = 999; private const string Separators = "()<>@,;:\\\"/[]?={} "; + internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) + { + string validStatesText = string.Empty; + + if (validStates != null && validStates.Length > 0) + { + foreach (WebSocketState validState in validStates) + { + if (currentState == validState) + { + // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. + if (isDisposed) + { + throw new ObjectDisposedException("ClientWebSocket"); + } + + return; + } + } + + validStatesText = string.Join(", ", validStates); + } + + throw new WebSocketException( + WebSocketError.InvalidState, + SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText)); + } + internal static void ValidateSubprotocol(string subProtocol) { if (string.IsNullOrWhiteSpace(subProtocol)) @@ -101,32 +129,22 @@ namespace System.Net.WebSockets } } - internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates) + internal static void ValidateBuffer(byte[] buffer, int offset, int count) { - string validStatesText = string.Empty; - - if (validStates != null && validStates.Length > 0) + if (buffer == null) { - foreach (WebSocketState validState in validStates) - { - if (currentState == validState) - { - // Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior. - if (isDisposed) - { - throw new ObjectDisposedException("ClientWebSocket"); - } - - return; - } - } - - validStatesText = string.Join(", ", validStates); + throw new ArgumentNullException(nameof(buffer)); } - throw new WebSocketException( - WebSocketError.InvalidState, - SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText)); + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException(nameof(offset)); + } + + if (count < 0 || count > (buffer.Length - offset)) + { + throw new ArgumentOutOfRangeException(nameof(count)); + } } } } diff --git a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj index ed0bcaab59..40923e0978 100644 --- a/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj +++ b/src/Microsoft.AspNetCore.WebSockets/Microsoft.AspNetCore.WebSockets.csproj @@ -14,6 +14,7 @@ +