port dotnet/corefx#17049 (#151)

This commit is contained in:
Andrew Stanton-Nurse 2017-03-14 12:04:59 -07:00 committed by GitHub
parent 92b37f85c2
commit cb150de808
5 changed files with 221 additions and 102 deletions

View File

@ -3,7 +3,7 @@ param([string]$CoreFxRepoRoot)
$RepoRoot = Split-Path -Parent $PSScriptRoot
$FilesToCopy = @(
"src\System.Net.WebSockets.Client\src\System\Net\WebSockets\ManagedWebSocket.cs",
"src\Common\src\System\Net\WebSockets\ManagedWebSocket.cs",
"src\Common\src\System\Net\WebSockets\WebSocketValidate.cs"
)

View File

@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@ -15,7 +15,7 @@ namespace Microsoft.AspNetCore.WebSockets.Internal
stream,
isServer: false,
subprotocol: subProtocol,
keepAliveIntervalSeconds: (int)keepAliveInterval.TotalSeconds,
keepAliveInterval: keepAliveInterval,
receiveBufferSize: receiveBufferSize);
}
@ -25,7 +25,7 @@ namespace Microsoft.AspNetCore.WebSockets.Internal
stream,
isServer: true,
subprotocol: subProtocol,
keepAliveIntervalSeconds: (int)keepAliveInterval.TotalSeconds,
keepAliveInterval: keepAliveInterval,
receiveBufferSize: receiveBufferSize);
}
}

View File

@ -2,8 +2,10 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Buffers;
using System.Diagnostics;
using System.IO;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security.Cryptography;
@ -30,14 +32,14 @@ namespace System.Net.WebSockets
/// <param name="stream">The connected Stream.</param>
/// <param name="isServer">true if this is the server-side of the connection; false if this is the client-side of the connection.</param>
/// <param name="subprotocol">The agreed upon subprotocol for the connection.</param>
/// <param name="keepAliveIntervalSeconds">The interval to use for keep-alive pings.</param>
/// <param name="keepAliveInterval">The interval to use for keep-alive pings.</param>
/// <param name="receiveBufferSize">The buffer size to use for received data.</param>
/// <param name="receiveBuffer">Optional buffer to use for receives.</param>
/// <returns>The created <see cref="ManagedWebSocket"/> instance.</returns>
public static ManagedWebSocket CreateFromConnectedStream(
Stream stream, bool isServer, string subprotocol,
int keepAliveIntervalSeconds = 30, int receiveBufferSize = 0x1000)
Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize, ArraySegment<byte>? receiveBuffer = null)
{
return new ManagedWebSocket(stream, isServer, subprotocol, TimeSpan.FromSeconds(keepAliveIntervalSeconds), receiveBufferSize);
return new ManagedWebSocket(stream, isServer, subprotocol, keepAliveInterval, receiveBufferSize, receiveBuffer);
}
/// <summary>Per-thread cached 4-byte mask byte array.</summary>
@ -81,7 +83,9 @@ namespace System.Net.WebSockets
/// <summary>CancellationTokenSource used to abort all current and future operations when anything is canceled or any error occurs.</summary>
private readonly CancellationTokenSource _abortSource = new CancellationTokenSource();
/// <summary>Buffer used for reading data from the network.</summary>
private readonly byte[] _receiveBuffer;
private byte[] _receiveBuffer;
/// <summary>Gets whether the receive buffer came from the ArrayPool.</summary>
private readonly bool _receiveBufferFromPool;
/// <summary>
/// Tracks the state of the validity of the UTF8 encoding of text payloads. Text may be split across fragments.
/// </summary>
@ -132,9 +136,10 @@ namespace System.Net.WebSockets
/// </summary>
private int _receivedMaskOffsetOffset = 0;
/// <summary>
/// Buffer used to store the complete message to be sent to the stream. This is needed
/// rather than just sending a header and then the user's buffer, as we need to mutate the
/// buffered data with the mask, and we don't want to change the data in the user's buffer.
/// Temporary send buffer. This should be released back to the ArrayPool once it's
/// no longer needed for the current send operation. It is stored as an instance
/// field to minimize needing to pass it around and to avoid it becoming a field on
/// various async state machine objects.
/// </summary>
private byte[] _sendBuffer;
/// <summary>
@ -168,7 +173,8 @@ namespace System.Net.WebSockets
/// <param name="subprotocol">The agreed upon subprotocol for the connection.</param>
/// <param name="keepAliveInterval">The interval to use for keep-alive pings.</param>
/// <param name="receiveBufferSize">The buffer size to use for received data.</param>
private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize)
/// <param name="receiveBuffer">Optional buffer to use for receives</param>
private ManagedWebSocket(Stream stream, bool isServer, string subprotocol, TimeSpan keepAliveInterval, int receiveBufferSize, ArraySegment<byte>? receiveBuffer)
{
Debug.Assert(StateUpdateLock != null, $"Expected {nameof(StateUpdateLock)} to be non-null");
Debug.Assert(ReceiveAsyncLock != null, $"Expected {nameof(ReceiveAsyncLock)} to be non-null");
@ -183,7 +189,20 @@ namespace System.Net.WebSockets
_stream = stream;
_isServer = isServer;
_subprotocol = subprotocol;
_receiveBuffer = new byte[Math.Max(receiveBufferSize, MaxMessageHeaderLength)];
// If we were provided with a buffer to use, use it, as long as it's big enough for our needs, and for simplicity
// as long as we're not supposed to use only a portion of it. If it doesn't meet our criteria, just create a new one.
if (receiveBuffer.HasValue &&
receiveBuffer.Value.Offset == 0 && receiveBuffer.Value.Count == receiveBuffer.Value.Array.Length &&
receiveBuffer.Value.Count >= MaxMessageHeaderLength)
{
_receiveBuffer = receiveBuffer.Value.Array;
}
else
{
_receiveBufferFromPool = true;
_receiveBuffer = ArrayPool<byte>.Shared.Rent(Math.Max(receiveBufferSize, MaxMessageHeaderLength));
}
// Set up the abort source so that if it's triggered, we transition the instance appropriately.
_abortSource.Token.Register(s =>
@ -225,6 +244,12 @@ namespace System.Net.WebSockets
_disposed = true;
_keepAliveTimer?.Dispose();
_stream?.Dispose();
if (_receiveBufferFromPool)
{
byte[] old = _receiveBuffer;
_receiveBuffer = null;
ArrayPool<byte>.Shared.Return(old);
}
if (_state < WebSocketState.Aborted)
{
_state = WebSocketState.Closed;
@ -253,7 +278,7 @@ namespace System.Net.WebSockets
try
{
WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates);
WebSocketValidate.ThrowIfInvalidState(_state, _disposed, s_validSendStates);
ThrowIfOperationInProgress(_lastSendAsync);
}
catch (Exception exc)
@ -369,7 +394,7 @@ namespace System.Net.WebSockets
// If we get here, the cancellation token is not cancelable so we don't have to worry about it,
// and we own the semaphore, so we don't need to asynchronously wait for it.
Task writeTask = null;
bool releaseSemaphore = true;
bool releaseSemaphoreAndSendBuffer = true;
try
{
// Write the payload synchronously to the buffer, then write that buffer out to the network.
@ -386,9 +411,9 @@ namespace System.Net.WebSockets
}
// Up until this point, if an exception occurred (such as when accessing _stream or when
// calling GetResult), we want to release the semaphore. After this point, the semaphore needs
// to remain held until writeTask completes.
releaseSemaphore = false;
// calling GetResult), we want to release the semaphore and the send buffer. After this point,
// both need to be held until writeTask completes.
releaseSemaphoreAndSendBuffer = false;
}
catch (Exception exc)
{
@ -398,9 +423,10 @@ namespace System.Net.WebSockets
}
finally
{
if (releaseSemaphore)
if (releaseSemaphoreAndSendBuffer)
{
_sendFrameAsyncLock.Release();
ReleaseSendBuffer();
}
}
@ -410,6 +436,7 @@ namespace System.Net.WebSockets
{
var thisRef = (ManagedWebSocket)s;
thisRef._sendFrameAsyncLock.Release();
thisRef.ReleaseSendBuffer();
try { t.GetAwaiter().GetResult(); }
catch (Exception exc)
@ -441,14 +468,15 @@ namespace System.Net.WebSockets
finally
{
_sendFrameAsyncLock.Release();
ReleaseSendBuffer();
}
}
/// <summary>Writes a frame into the send buffer, which can then be sent over the network.</summary>
private int WriteFrameToSendBuffer(MessageOpcode opcode, bool endOfMessage, ArraySegment<byte> payloadBuffer)
{
// Grow our send buffer as needed. We reuse the buffer for all messages, with it protected by the send frame lock.
EnsureBufferLength(ref _sendBuffer, payloadBuffer.Count + MaxMessageHeaderLength);
// Ensure we have a _sendBuffer.
AllocateSendBuffer(payloadBuffer.Count + MaxMessageHeaderLength);
// Write the message header data to the buffer.
int headerLength;
@ -542,7 +570,7 @@ namespace System.Net.WebSockets
{
sendBuffer[1] = 126;
sendBuffer[2] = (byte)(payload.Count / 256);
sendBuffer[3] = (byte)payload.Count;
sendBuffer[3] = unchecked((byte)payload.Count);
maskOffset = 2 + sizeof(ushort); // additional 2 bytes for 16-bit length
}
else
@ -551,7 +579,7 @@ namespace System.Net.WebSockets
int length = payload.Count;
for (int i = 9; i >= 2; i--)
{
sendBuffer[i] = (byte)length;
sendBuffer[i] = unchecked((byte)length);
length = length / 256;
}
maskOffset = 2 + sizeof(ulong); // additional 8 bytes for 64-bit length
@ -983,37 +1011,44 @@ namespace System.Net.WebSockets
$"Unexpected state {State}.");
// Wait until we've received a close response
byte[] closeBuffer = new byte[MaxMessageHeaderLength + MaxControlPayloadLength];
while (!_receivedCloseFrame)
byte[] closeBuffer = ArrayPool<byte>.Shared.Rent(MaxMessageHeaderLength + MaxControlPayloadLength);
try
{
Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
Task<WebSocketReceiveResult> receiveTask;
lock (ReceiveAsyncLock)
while (!_receivedCloseFrame)
{
// Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
// It could have been received between our check above and now due to a concurrent receive completing.
if (_receivedCloseFrame)
Debug.Assert(!Monitor.IsEntered(StateUpdateLock), $"{nameof(StateUpdateLock)} must never be held when acquiring {nameof(ReceiveAsyncLock)}");
Task<WebSocketReceiveResult> receiveTask;
lock (ReceiveAsyncLock)
{
break;
// Now that we're holding the ReceiveAsyncLock, double-check that we've not yet received the close frame.
// It could have been received between our check above and now due to a concurrent receive completing.
if (_receivedCloseFrame)
{
break;
}
// We've not yet processed a received close frame, which means we need to wait for a received close to complete.
// There may already be one in flight, in which case we want to just wait for that one rather than kicking off
// another (we don't support concurrent receive operations). We need to kick off a new receive if either we've
// never issued a receive or if the last issued receive completed for reasons other than a close frame. There is
// a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
// case is we then await it, find that it's not what we need, and try again.
receiveTask = _lastReceiveAsync;
if (receiveTask == null ||
(receiveTask.Status == TaskStatus.RanToCompletion && receiveTask.Result.MessageType != WebSocketMessageType.Close))
{
_lastReceiveAsync = receiveTask = ReceiveAsyncPrivate(new ArraySegment<byte>(closeBuffer), cancellationToken);
}
}
// We've not yet processed a received close frame, which means we need to wait for a received close to complete.
// There may already be one in flight, in which case we want to just wait for that one rather than kicking off
// another (we don't support concurrent receive operations). We need to kick off a new receive if either we've
// never issued a receive or if the last issued receive completed for reasons other than a close frame. There is
// a race condition here, e.g. if there's a in-flight receive that completes after we check, but that's fine: worst
// case is we then await it, find that it's not what we need, and try again.
receiveTask = _lastReceiveAsync;
if (receiveTask == null ||
(receiveTask.Status == TaskStatus.RanToCompletion && receiveTask.Result.MessageType != WebSocketMessageType.Close))
{
_lastReceiveAsync = receiveTask = ReceiveAsyncPrivate(new ArraySegment<byte>(closeBuffer), cancellationToken);
}
// Wait for whatever receive task we have. We'll then loop around again to re-check our state.
Debug.Assert(receiveTask != null);
await receiveTask.ConfigureAwait(false);
}
// Wait for whatever receive task we have. We'll then loop around again to re-check our state.
Debug.Assert(receiveTask != null);
await receiveTask.ConfigureAwait(false);
}
finally
{
ArrayPool<byte>.Shared.Return(closeBuffer);
}
// We're closed. Close the connection and update the status.
@ -1035,24 +1070,36 @@ namespace System.Net.WebSockets
{
// Close payload is two bytes containing the close status followed by a UTF8-encoding of the status description, if it exists.
byte[] buffer;
if (string.IsNullOrEmpty(closeStatusDescription))
byte[] buffer = null;
try
{
buffer = new byte[2];
int count = 2;
if (string.IsNullOrEmpty(closeStatusDescription))
{
buffer = ArrayPool<byte>.Shared.Rent(count);
}
else
{
count += s_textEncoding.GetByteCount(closeStatusDescription);
buffer = ArrayPool<byte>.Shared.Rent(count);
int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2);
Debug.Assert(count - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match");
}
ushort closeStatusValue = (ushort)closeStatus;
buffer[0] = (byte)(closeStatusValue >> 8);
buffer[1] = (byte)(closeStatusValue & 0xFF);
await SendFrameAsync(MessageOpcode.Close, true, new ArraySegment<byte>(buffer, 0, count), cancellationToken).ConfigureAwait(false);
}
else
finally
{
buffer = new byte[2 + s_textEncoding.GetByteCount(closeStatusDescription)];
int encodedLength = s_textEncoding.GetBytes(closeStatusDescription, 0, closeStatusDescription.Length, buffer, 2);
Debug.Assert(buffer.Length - 2 == encodedLength, $"GetByteCount and GetBytes encoded count didn't match");
if (buffer != null)
{
ArrayPool<byte>.Shared.Return(buffer);
}
}
ushort closeStatusValue = (ushort)closeStatus;
buffer[0] = (byte)(closeStatusValue >> 8);
buffer[1] = (byte)(closeStatusValue & 0xFF);
await SendFrameAsync(MessageOpcode.Close, true, new ArraySegment<byte>(buffer), cancellationToken).ConfigureAwait(false);
lock (StateUpdateLock)
{
_sentCloseFrame = true;
@ -1111,15 +1158,21 @@ namespace System.Net.WebSockets
}
}
/// <summary>
/// Grows the specified buffer if it's not at least the specified minimum length.
/// Data is not copied if the buffer is grown.
/// </summary>
private static void EnsureBufferLength(ref byte[] buffer, int minLength)
/// <summary>Gets a send buffer from the pool.</summary>
private void AllocateSendBuffer(int minLength)
{
if (buffer == null || buffer.Length < minLength)
Debug.Assert(_sendBuffer == null); // would only fail if had some catastrophic error previously that prevented cleaning up
_sendBuffer = ArrayPool<byte>.Shared.Rent(minLength);
}
/// <summary>Releases the send buffer to the pool.</summary>
private void ReleaseSendBuffer()
{
byte[] old = _sendBuffer;
if (old != null)
{
buffer = new byte[minLength];
_sendBuffer = null;
ArrayPool<byte>.Shared.Return(old);
}
}
@ -1150,21 +1203,68 @@ namespace System.Net.WebSockets
/// <returns>The next index into the mask to be used for future applications of the mask.</returns>
private static unsafe int ApplyMask(byte[] toMask, int toMaskOffset, int mask, int maskIndex, long count)
{
Debug.Assert(toMaskOffset <= toMask.Length - count, $"Unexpected inputs: {toMaskOffset}, {toMask.Length}, {count}");
Debug.Assert(maskIndex < sizeof(int), $"Unexpected {nameof(maskIndex)}: {maskIndex}");
int maskShift = maskIndex * 8;
int shiftedMask = (int)(((uint)mask >> maskShift) | ((uint)mask << (32 - maskShift)));
byte* maskPtr = (byte*)&mask;
fixed (byte* toMaskPtr = toMask)
// Try to use SIMD. We can if the number of bytes we're trying to mask is at least as much
// as the width of a vector and if the width is an even multiple of the mask.
if (Vector.IsHardwareAccelerated &&
Vector<byte>.Count % sizeof(int) == 0 &&
count >= Vector<byte>.Count)
{
byte* p = toMaskPtr + toMaskOffset;
byte* end = p + count;
while (p < end)
// Mask bytes a vector at a time.
Vector<byte> maskVector = Vector.AsVectorByte(new Vector<int>(shiftedMask));
while (count >= Vector<byte>.Count)
{
*p++ ^= maskPtr[maskIndex];
maskIndex = (maskIndex + 1) & 3; // & 3 == faster % MaskLength
count -= Vector<byte>.Count;
(maskVector ^ new Vector<byte>(toMask, toMaskOffset)).CopyTo(toMask, toMaskOffset);
toMaskOffset += Vector<byte>.Count;
}
return maskIndex;
// Fall through to processing any remaining bytes that were less than a vector width.
// Since we processed full masks at a time, we don't need to update maskIndex, and
// toMaskOffset has already been updated to point to the correct location.
}
// If there are any bytes remaining (either we couldn't use vectors, or the count wasn't
// an even multiple of the vector width), process them without vectors.
if (count > 0)
{
fixed (byte* toMaskPtr = toMask)
{
// Get the location in the target array to continue processing.
byte* p = toMaskPtr + toMaskOffset;
// Try to go an int at a time if the remaining data is 4-byte aligned and there's enough remaining.
if (((long)p % sizeof(int)) == 0)
{
while (count >= sizeof(int))
{
count -= sizeof(int);
*((int*)p) ^= shiftedMask;
p += sizeof(int);
}
// We don't need to update the maskIndex, as its mod-4 value won't have changed.
// `p` points to the remainder.
}
// Process any remaining data a byte at a time.
if (count > 0)
{
byte* maskPtr = (byte*)&mask;
byte* end = p + count;
while (p < end)
{
*p++ ^= maskPtr[maskIndex];
maskIndex = (maskIndex + 1) & 3;
}
}
}
}
// Return the updated index.
return maskIndex;
}
/// <summary>Aborts the websocket and throws an exception if an existing operation is in progress.</summary>

View File

@ -7,7 +7,7 @@ using System.Text;
namespace System.Net.WebSockets
{
internal static class WebSocketValidate
internal static partial class WebSocketValidate
{
internal const int MaxControlFramePayloadLength = 123;
private const int CloseStatusCodeAbort = 1006;
@ -16,6 +16,34 @@ namespace System.Net.WebSockets
private const int InvalidCloseStatusCodesTo = 999;
private const string Separators = "()<>@,;:\\\"/[]?={} ";
internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
{
string validStatesText = string.Empty;
if (validStates != null && validStates.Length > 0)
{
foreach (WebSocketState validState in validStates)
{
if (currentState == validState)
{
// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
if (isDisposed)
{
throw new ObjectDisposedException("ClientWebSocket");
}
return;
}
}
validStatesText = string.Join(", ", validStates);
}
throw new WebSocketException(
WebSocketError.InvalidState,
SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText));
}
internal static void ValidateSubprotocol(string subProtocol)
{
if (string.IsNullOrWhiteSpace(subProtocol))
@ -101,32 +129,22 @@ namespace System.Net.WebSockets
}
}
internal static void ThrowIfInvalidState(WebSocketState currentState, bool isDisposed, WebSocketState[] validStates)
internal static void ValidateBuffer(byte[] buffer, int offset, int count)
{
string validStatesText = string.Empty;
if (validStates != null && validStates.Length > 0)
if (buffer == null)
{
foreach (WebSocketState validState in validStates)
{
if (currentState == validState)
{
// Ordering is important to maintain .NET 4.5 WebSocket implementation exception behavior.
if (isDisposed)
{
throw new ObjectDisposedException("ClientWebSocket");
}
return;
}
}
validStatesText = string.Join(", ", validStates);
throw new ArgumentNullException(nameof(buffer));
}
throw new WebSocketException(
WebSocketError.InvalidState,
SR.Format(SR.net_WebSockets_InvalidState, currentState, validStatesText));
if (offset < 0 || offset > buffer.Length)
{
throw new ArgumentOutOfRangeException(nameof(offset));
}
if (count < 0 || count > (buffer.Length - offset))
{
throw new ArgumentOutOfRangeException(nameof(count));
}
}
}
}

View File

@ -14,6 +14,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.AspNetCore.Http.Extensions" Version="1.2.0-*" />
<PackageReference Include="Microsoft.Extensions.Options" Version="1.2.0-*" />
<PackageReference Include="System.Numerics.Vectors" Version="$(CoreFxVersion)" />
</ItemGroup>
</Project>