diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
index f73205b8d0..9382612e37 100644
--- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
+++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs
@@ -1,4 +1,4 @@
-// Copyright (c) .NET Foundation. All rights reserved.
+// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
@@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports
{
if (receiving.IsCanceled || receiving.IsFaulted)
{
- // The receiver faulted or cancelled. This means the client is probably broken. Just propagate the exception and exit
+ // The receiver faulted or cancelled. This means the socket is probably broken. Abort the socket and propagate the exception
receiving.GetAwaiter().GetResult();
// Should never get here because GetResult above will throw
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
index 9f74bc95c4..88a88c07b7 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
+++ b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs
@@ -64,6 +64,11 @@ namespace Microsoft.Extensions.WebSockets.Internal
/// A state parameter that will be passed to each invocation of
/// A that will complete when the client has sent a close frame, or the connection has been terminated
Task ExecuteAsync(Func messageHandler, object state);
+
+ ///
+ /// Forcibly terminates the socket, cleaning up the necessary resources.
+ ///
+ void Abort();
}
public static class WebSocketConnectionExtensions
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
index 45c314cb42..5d3aca9e14 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
+++ b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj
@@ -13,9 +13,10 @@
+
-
+
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs b/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs
index 758920ab79..d3d54f68c3 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs
+++ b/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs
@@ -1,31 +1,29 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
-using System.Threading;
-using System.Threading.Tasks;
using System.IO.Pipelines;
+using System.Threading.Tasks;
namespace Microsoft.Extensions.WebSockets.Internal
{
public static class PipeReaderExtensions
{
- public static ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes) => ReadAtLeastAsync(input, minimumRequiredBytes, CancellationToken.None);
-
// TODO: Pull this up to Channels. We should be able to do it there without allocating a Task in any case (rather than here where we can avoid allocation
// only if the buffer is already ready and has enough data)
- public static ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken)
+ public static async ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes)
{
var awaiter = input.ReadAsync(/* cancellationToken */);
// Short-cut path!
+ ReadResult result;
if (awaiter.IsCompleted)
{
// We have a buffer, is it big enough?
- var result = awaiter.GetResult();
+ result = awaiter.GetResult();
if (result.IsCompleted || result.Buffer.Length >= minimumRequiredBytes)
{
- return new ValueTask(result);
+ return result;
}
// Buffer wasn't big enough, mark it as examined and continue to the "slow" path below
@@ -33,15 +31,9 @@ namespace Microsoft.Extensions.WebSockets.Internal
consumed: result.Buffer.Start,
examined: result.Buffer.End);
}
- return new ValueTask(ReadAtLeastSlowAsync(awaiter, input, minimumRequiredBytes, cancellationToken));
- }
-
- private static async Task ReadAtLeastSlowAsync(ReadableBufferAwaitable awaitable, IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken)
- {
- var result = await awaitable;
- while (!result.IsCompleted && result.Buffer.Length < minimumRequiredBytes)
+ result = await awaiter;
+ while (!result.IsCancelled && !result.IsCompleted && result.Buffer.Length < minimumRequiredBytes)
{
- cancellationToken.ThrowIfCancellationRequested();
input.Advance(
consumed: result.Buffer.Start,
examined: result.Buffer.End);
diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs
index d6dcca1bf6..b04673642f 100644
--- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs
+++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs
@@ -31,7 +31,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
private readonly byte[] _maskingKeyBuffer;
private readonly IPipeReader _inbound;
private readonly IPipeWriter _outbound;
- private readonly CancellationTokenSource _terminateReceiveCts = new CancellationTokenSource();
private readonly Timer _pinger;
private readonly CancellationTokenSource _timerCts = new CancellationTokenSource();
private Utf8Validator _validator = new Utf8Validator();
@@ -113,8 +112,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
{
// We don't need to wait for this task to complete, we're "tail calling" and
// we are in a Timer thread-pool thread.
-#pragma warning disable 4014
- connection.SendCoreLockAcquiredAsync(
+ var ignore = connection.SendCoreLockAcquiredAsync(
fin: true,
opcode: WebSocketOpcode.Ping,
payloadAllocLength: 28,
@@ -122,7 +120,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
payloadWriter: PingPayloadWriter,
payload: DateTime.UtcNow,
cancellationToken: connection._timerCts.Token);
-#pragma warning restore 4014
}
}
@@ -131,7 +128,6 @@ namespace Microsoft.Extensions.WebSockets.Internal
State = WebSocketConnectionState.Closed;
_pinger?.Dispose();
_timerCts.Cancel();
- _terminateReceiveCts.Cancel();
_inbound.Complete();
_outbound.Complete();
}
@@ -148,7 +144,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
throw new InvalidOperationException("Connection is already running.");
}
State = WebSocketConnectionState.Connected;
- return ReceiveLoop(messageHandler, state, _terminateReceiveCts.Token);
+ return ReceiveLoop(messageHandler, state);
}
///
@@ -249,269 +245,322 @@ namespace Microsoft.Extensions.WebSockets.Internal
_maskingKeyBuffer.CopyTo(buffer);
}
- private async Task ReceiveLoop(Func messageHandler, object state, CancellationToken cancellationToken)
+ ///
+ /// Terminates the socket abruptly.
+ ///
+ public void Abort()
{
- while (!cancellationToken.IsCancellationRequested)
+ // We duplicate some work from Dispose here, but that's OK.
+ _timerCts.Cancel();
+ _inbound.CancelPendingRead();
+ _outbound.Complete();
+ }
+
+ private async ValueTask<(bool Success, byte OpcodeByte, bool Masked, bool Fin, int Length, uint MaskingKey)> ReadHeaderAsync()
+ {
+ // Read at least 2 bytes
+ var readResult = await _inbound.ReadAtLeastAsync(2);
+ if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < 2))
{
- // WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2):
- // 0 1 2 3
- // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- // +-+-+-+-+-------+-+-------------+-------------------------------+
- // |F|R|R|R| opcode|M| Payload len | Extended payload length |
- // |I|S|S|S| (4) |A| (7) | (16/64) |
- // |N|V|V|V| |S| | (if payload len==126/127) |
- // | |1|2|3| |K| | |
- // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
- // | Extended payload length continued, if payload len == 127 |
- // + - - - - - - - - - - - - - - - +-------------------------------+
- // | |Masking-key, if MASK set to 1 |
- // +-------------------------------+-------------------------------+
- // | Masking-key (continued) | Payload Data |
- // +-------------------------------- - - - - - - - - - - - - - - - +
- // : Payload Data continued ... :
- // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
- // | Payload Data continued ... |
- // +---------------------------------------------------------------+
+ _inbound.Advance(readResult.Buffer.End);
+ return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
+ }
+ var buffer = readResult.Buffer;
- // Read at least 2 bytes
- var result = await _inbound.ReadAtLeastAsync(2, cancellationToken);
- cancellationToken.ThrowIfCancellationRequested();
- if (result.IsCompleted && result.Buffer.Length < 2)
+ // Read the opcode and length
+ var opcodeByte = buffer.ReadBigEndian();
+ buffer = buffer.Slice(1);
+
+ // Read the first byte of the payload length
+ var lengthByte = buffer.ReadBigEndian();
+ buffer = buffer.Slice(1);
+
+ _inbound.Advance(buffer.Start);
+
+ // Determine how much header there still is to read
+ var fin = (opcodeByte & 0x80) != 0;
+ var masked = (lengthByte & 0x80) != 0;
+ var length = lengthByte & 0x7F;
+
+ // Calculate the rest of the header length
+ var headerLength = masked ? 4 : 0;
+ if (length == 126)
+ {
+ headerLength += 2;
+ }
+ else if (length == 127)
+ {
+ headerLength += 8;
+ }
+
+ // Read the next set of header data
+ uint maskingKey = 0;
+ if (headerLength > 0)
+ {
+ readResult = await _inbound.ReadAtLeastAsync(headerLength);
+ if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < headerLength))
{
- return WebSocketCloseResult.AbnormalClosure;
+ _inbound.Advance(readResult.Buffer.End);
+ return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0);
}
- var buffer = result.Buffer;
+ buffer = readResult.Buffer;
- // Read the opcode
- var opcodeByte = buffer.ReadBigEndian();
- buffer = buffer.Slice(1);
-
- var fin = (opcodeByte & 0x80) != 0;
- var opcodeNum = opcodeByte & 0x0F;
- var opcode = (WebSocketOpcode)opcodeNum;
-
- if ((opcodeByte & 0x70) != 0)
+ // Read extended payload length (if any)
+ if (length == 126)
{
- // Reserved bits set, this frame is invalid, close our side and terminate immediately
- return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), "Reserved bits, which are required to be zero, were set.");
+ length = buffer.ReadBigEndian();
+ buffer = buffer.Slice(sizeof(ushort));
}
- else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F))
+ else if (length == 127)
{
- // Reserved opcode
- return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), $"Received frame using reserved opcode: 0x{opcodeNum:X}");
+ var longLen = buffer.ReadBigEndian();
+ buffer = buffer.Slice(sizeof(ulong));
+ if (longLen > int.MaxValue)
+ {
+ throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes");
+ }
+ length = (int)longLen;
}
- // Read the first byte of the payload length
- var lenByte = buffer.ReadBigEndian();
- buffer = buffer.Slice(1);
+ // Read masking key
+ if (masked)
+ {
+ var maskingKeyStart = buffer.Start;
+ maskingKey = buffer.Slice(0, sizeof(uint)).ReadBigEndian();
+ buffer = buffer.Slice(sizeof(uint));
+ }
- var masked = (lenByte & 0x80) != 0;
- var payloadLen = (lenByte & 0x7F);
-
- // Mark what we've got so far as consumed
+ // Mark the length and masking key consumed
_inbound.Advance(buffer.Start);
+ }
- // Calculate the rest of the header length
- var headerLength = masked ? 4 : 0;
- if (payloadLen == 126)
+ return (Success: true, opcodeByte, masked, fin, length, maskingKey);
+ }
+
+ private async ValueTask<(bool Success, ReadableBuffer Buffer)> ReadPayloadAsync(int length, bool masked, uint maskingKey)
+ {
+ var payload = default(ReadableBuffer);
+ if (length > 0)
+ {
+ var readResult = await _inbound.ReadAtLeastAsync(length);
+ if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < length))
{
- headerLength += 2;
+ return (Success: false, Buffer: readResult.Buffer);
}
- else if (payloadLen == 127)
+ var buffer = readResult.Buffer;
+
+ payload = buffer.Slice(0, length);
+
+ if (masked)
{
- headerLength += 8;
+ // Unmask
+ MaskingUtilities.ApplyMask(ref payload, maskingKey);
}
+ }
+ return (Success: true, Buffer: payload);
+ }
- uint maskingKey = 0;
-
- if (headerLength > 0)
+ private async Task ReceiveLoop(Func messageHandler, object state)
+ {
+ try
+ {
+ while (true)
{
- result = await _inbound.ReadAtLeastAsync(headerLength, cancellationToken);
- cancellationToken.ThrowIfCancellationRequested();
- if (result.IsCompleted && result.Buffer.Length < headerLength)
- {
- return WebSocketCloseResult.AbnormalClosure;
- }
- buffer = result.Buffer;
+ // WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2):
+ // 0 1 2 3
+ // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+ // +-+-+-+-+-------+-+-------------+-------------------------------+
+ // |F|R|R|R| opcode|M| Payload len | Extended payload length |
+ // |I|S|S|S| (4) |A| (7) | (16/64) |
+ // |N|V|V|V| |S| | (if payload len==126/127) |
+ // | |1|2|3| |K| | |
+ // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
+ // | Extended payload length continued, if payload len == 127 |
+ // + - - - - - - - - - - - - - - - +-------------------------------+
+ // | |Masking-key, if MASK set to 1 |
+ // +-------------------------------+-------------------------------+
+ // | Masking-key (continued) | Payload Data |
+ // +-------------------------------- - - - - - - - - - - - - - - - +
+ // : Payload Data continued ... :
+ // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
+ // | Payload Data continued ... |
+ // +---------------------------------------------------------------+
- // Read extended payload length (if any)
- if (payloadLen == 126)
+ var header = await ReadHeaderAsync();
+ if (!header.Success)
{
- payloadLen = buffer.ReadBigEndian();
- buffer = buffer.Slice(sizeof(ushort));
+ break;
}
- else if (payloadLen == 127)
+
+ // Validate Opcode
+ var opcodeNum = header.OpcodeByte & 0x0F;
+
+ if ((header.OpcodeByte & 0x70) != 0)
{
- var longLen = buffer.ReadBigEndian();
- buffer = buffer.Slice(sizeof(ulong));
- if (longLen > int.MaxValue)
+ // Reserved bits set, this frame is invalid, close our side and terminate immediately
+ await CloseFromProtocolError("Reserved bits, which are required to be zero, were set.");
+ break;
+ }
+ else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F))
+ {
+ // Reserved opcode
+ await CloseFromProtocolError($"Received frame using reserved opcode: 0x{opcodeNum:X}");
+ break;
+ }
+ var opcode = (WebSocketOpcode)opcodeNum;
+
+ var payload = await ReadPayloadAsync(header.Length, header.Masked, header.MaskingKey);
+ if (!payload.Success)
+ {
+ _inbound.Advance(payload.Buffer.End);
+ break;
+ }
+
+ var frame = new WebSocketFrame(header.Fin, opcode, payload.Buffer);
+
+ // Start a try-finally because we may get an exception while closing, if there's an error
+ // And we need to advance the buffer even if that happens. It wasn't needed above because
+ // we had already parsed the buffer before we verified it, so we had already advanced the
+ // buffer, if we encountered an error while closing we didn't have to advance the buffer.
+ // Side Note: Look at this gloriously aligned comment. You have anurse and brecon to thank for it. Oh wait, I ruined it.
+ try
+ {
+ if (frame.Opcode.IsControl() && !frame.EndOfMessage)
{
- throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes");
+ // Control frames cannot be fragmented.
+ await CloseFromProtocolError("Control frames may not be fragmented");
+ break;
}
- payloadLen = (int)longLen;
- }
-
- // Read masking key
- if (masked)
- {
- var maskingKeyStart = buffer.Start;
- maskingKey = buffer.Slice(0, 4).ReadBigEndian();
- buffer = buffer.Slice(4);
- }
-
- // Mark the length and masking key consumed
- _inbound.Advance(buffer.Start);
- }
-
- var payload = default(ReadableBuffer);
- if (payloadLen > 0)
- {
- result = await _inbound.ReadAtLeastAsync(payloadLen, cancellationToken);
- cancellationToken.ThrowIfCancellationRequested();
- if (result.IsCompleted && result.Buffer.Length < payloadLen)
- {
- return WebSocketCloseResult.AbnormalClosure;
- }
- buffer = result.Buffer;
-
- payload = buffer.Slice(0, payloadLen);
-
- if (masked)
- {
- // Unmask
- MaskingUtilities.ApplyMask(ref payload, maskingKey);
- }
- }
-
- // Run the callback, if we're not cancelled.
- cancellationToken.ThrowIfCancellationRequested();
-
- var frame = new WebSocketFrame(fin, opcode, payload);
-
- if (frame.Opcode.IsControl() && !frame.EndOfMessage)
- {
- // Control frames cannot be fragmented.
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Control frames may not be fragmented");
- }
- else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0)
- {
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Received non-continuation frame during a fragmented message");
- }
- else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation)
- {
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Continuation Frame was received when expecting a new message");
- }
-
- if (frame.Opcode == WebSocketOpcode.Close)
- {
- // Allowed frame lengths:
- // 0 - No body
- // 2 - Code with no reason phrase
- // >2 - Code and reason phrase (must be valid UTF-8)
- if (frame.Payload.Length > 125)
- {
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload too long. Maximum size is 125 bytes");
- }
- else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(payload.Slice(2))))
- {
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload invalid");
- }
-
- ushort? actualStatusCode;
- var closeResult = HandleCloseFrame(payload, frame, out actualStatusCode);
-
- // Verify the close result
- if (actualStatusCode != null)
- {
- var statusCode = actualStatusCode.Value;
- if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000))
+ else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0)
{
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, $"Invalid close status: {statusCode}.");
+ await CloseFromProtocolError("Received non-continuation frame during a fragmented message");
+ break;
+ }
+ else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation)
+ {
+ await CloseFromProtocolError("Continuation Frame was received when expecting a new message");
+ break;
+ }
+
+ if (frame.Opcode == WebSocketOpcode.Close)
+ {
+ return await ProcessCloseFrameAsync(frame);
+ }
+ else
+ {
+ if (frame.Opcode == WebSocketOpcode.Ping)
+ {
+ // Check the ping payload length
+ if (frame.Payload.Length > 125)
+ {
+ // Payload too long
+ await CloseFromProtocolError("Ping frame exceeded maximum size of 125 bytes");
+ break;
+ }
+
+ await SendCoreAsync(
+ frame.EndOfMessage,
+ WebSocketOpcode.Pong,
+ payloadAllocLength: 0,
+ payloadLength: frame.Payload.Length,
+ payloadWriter: AppendPayloadWriter,
+ payload: frame.Payload,
+ cancellationToken: CancellationToken.None);
+ }
+ var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode;
+ if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage))
+ {
+ // Drop the frame and immediately close with InvalidPayload
+ await CloseFromProtocolError("An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData);
+ break;
+ }
+ else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong))
+ {
+ await messageHandler(frame, state);
+ }
+ }
+ }
+ finally
+ {
+ if (frame.Payload.Length > 0)
+ {
+ _inbound.Advance(frame.Payload.End);
}
}
- // Make the payload as consumed
- if (payloadLen > 0)
+ if (header.Fin)
{
- _inbound.Advance(payload.End);
- }
+ // Reset the UTF8 validator
+ _validator.Reset();
- return closeResult;
- }
- else
- {
- if (frame.Opcode == WebSocketOpcode.Ping)
- {
- // Check the ping payload length
- if (frame.Payload.Length > 125)
+ // If it's a non-control frame, reset the message type tracker
+ if (opcode.IsMessage())
{
- // Payload too long
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Ping frame exceeded maximum size of 125 bytes");
+ _currentMessageType = WebSocketOpcode.Continuation;
}
-
- await SendCoreAsync(
- frame.EndOfMessage,
- WebSocketOpcode.Pong,
- payloadAllocLength: 0,
- payloadLength: payload.Length,
- payloadWriter: AppendPayloadWriter,
- payload: payload,
- cancellationToken: cancellationToken);
}
- var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode;
- if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage))
+ // If there isn't a current message type, and this was a fragmented message frame, set the current message type
+ else if (!header.Fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage())
{
- // Drop the frame and immediately close with InvalidPayload
- return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData);
- }
- else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong))
- {
- await messageHandler(frame, state);
+ _currentMessageType = opcode;
}
}
-
- if (fin)
- {
- // Reset the UTF8 validator
- _validator.Reset();
-
- // If it's a non-control frame, reset the message type tracker
- if (opcode.IsMessage())
- {
- _currentMessageType = WebSocketOpcode.Continuation;
- }
- }
- // If there isn't a current message type, and this was a fragmented message frame, set the current message type
- else if (!fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage())
- {
- _currentMessageType = opcode;
- }
-
- // Mark the payload as consumed
- if (payloadLen > 0)
- {
- _inbound.Advance(payload.End);
- }
+ }
+ catch
+ {
+ // Abort the socket and rethrow
+ Abort();
+ throw;
}
return WebSocketCloseResult.AbnormalClosure;
}
- private async Task CloseFromProtocolError(CancellationToken cancellationToken, int payloadLen, ReadableBuffer payload, string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError)
+ private async ValueTask ProcessCloseFrameAsync(WebSocketFrame frame)
{
- // Non-continuation non-control message during fragmented message
- if (payloadLen > 0)
+ // Allowed frame lengths:
+ // 0 - No body
+ // 2 - Code with no reason phrase
+ // >2 - Code and reason phrase (must be valid UTF-8)
+ if (frame.Payload.Length > 125)
{
- _inbound.Advance(payload.End);
+ await CloseFromProtocolError("Close frame payload too long. Maximum size is 125 bytes");
+ return WebSocketCloseResult.AbnormalClosure;
}
- var closeResult = new WebSocketCloseResult(
- statusCode,
- reason);
- await CloseAsync(closeResult, cancellationToken);
- Dispose();
+ else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(frame.Payload.Slice(2))))
+ {
+ await CloseFromProtocolError("Close frame payload invalid");
+ return WebSocketCloseResult.AbnormalClosure;
+ }
+
+ ushort? actualStatusCode;
+ var closeResult = ParseCloseFrame(frame.Payload, frame, out actualStatusCode);
+
+ // Verify the close result
+ if (actualStatusCode != null)
+ {
+ var statusCode = actualStatusCode.Value;
+ if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000))
+ {
+ await CloseFromProtocolError($"Invalid close status: {statusCode}.");
+ return WebSocketCloseResult.AbnormalClosure;
+ }
+ }
+
return closeResult;
}
- private WebSocketCloseResult HandleCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode)
+ private async Task CloseFromProtocolError(string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError)
+ {
+ var closeResult = new WebSocketCloseResult(
+ statusCode,
+ reason);
+ await CloseAsync(closeResult, CancellationToken.None);
+
+ // We can now terminate our connection, according to the spec.
+ Abort();
+ }
+
+ private WebSocketCloseResult ParseCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode)
{
// Update state
if (State == WebSocketConnectionState.CloseSent)
@@ -529,6 +578,7 @@ namespace Microsoft.Extensions.WebSockets.Internal
{
closeResult = WebSocketCloseResult.Empty;
}
+
return closeResult;
}
diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj
index 1df0f57c42..06731c2136 100644
--- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj
+++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj
@@ -12,6 +12,7 @@
+
diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs
index f79d06a48a..e38542f111 100644
--- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs
+++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs
@@ -6,6 +6,7 @@ using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
+using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Sockets.Transports;
using Microsoft.Extensions.Logging;
@@ -258,7 +259,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
- [Fact(Skip="Fails after updating to new Pipelines")]
+ [Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
var transportToApplication = Channel.CreateUnbounded();
@@ -286,7 +287,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests
}
}
- [Fact(Skip="Fails after updating to new Pipelines")]
+ [Fact]
public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
{
var transportToApplication = Channel.CreateUnbounded();
@@ -313,6 +314,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests
// Close from the client
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
+
+ await transport.OrTimeout();
}
}
}
diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs
index 1c344a9239..cad82468ea 100644
--- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs
+++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs
@@ -65,7 +65,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
}
}
- [Fact(Skip="Fails after updating to new Pipelines")]
+ [Fact]
public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow()
{
using (var pair = WebSocketPair.Create())
diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs
index 67ff16041a..18dc4dd42a 100644
--- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs
+++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs
@@ -12,8 +12,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
- // Skipping tests after failures caused by updating to newer Pipelines
- private class ProtocolErrors
+ public class ProtocolErrors
{
[Theory]
[InlineData(new byte[] { 0x11, 0x00 })]
diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs
index 4b6bdb6802..5559d40510 100644
--- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs
+++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs
@@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
- // Skipping tests after failures caused by updating to newer Pipelines
- private class TheReceiveAsyncMethod
+ public class TheReceiveAsyncMethod
{
[Theory]
[InlineData(new byte[] { 0x81, 0x00 }, "", true)]
diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs
index bc5ba8c255..5ddae6662c 100644
--- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs
+++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs
@@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
public partial class WebSocketConnectionTests
{
- // Skipping tests after failures caused by updating to newer Pipelines
- private class TheSendAsyncMethod
+ public class TheSendAsyncMethod
{
// No auto-pinging for us!
private readonly static WebSocketOptions DefaultTestOptions = new WebSocketOptions().WithAllFramesPassedThrough();
@@ -179,20 +178,18 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
var outbound = factory.Create();
var inbound = factory.Create();
- Task executeTask;
using (var connection = new WebSocketConnection(inbound.Reader, outbound.Writer, options))
{
- executeTask = connection.ExecuteAsync(f =>
- {
- Assert.False(true, "Did not expect to receive any messages");
- return TaskCache.CompletedTask;
- });
+ var executeTask = connection.ExecuteAndCaptureFramesAsync();
await producer(connection).OrTimeout();
+ connection.Abort();
inbound.Writer.Complete();
await executeTask.OrTimeout();
}
- var data = (await outbound.Reader.ReadToEndAsync()).ToArray();
+ var buffer = await outbound.Reader.ReadToEndAsync();
+ var data = buffer.ToArray();
+ outbound.Reader.Advance(buffer.End);
inbound.Reader.Complete();
CompleteChannels(outbound);
return data;
diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs
index eaa78294c1..1a74a1c7e1 100644
--- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs
+++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs
@@ -8,8 +8,8 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
internal class WebSocketPair : IDisposable
{
- private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
- private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
+ private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
+ private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
private PipeFactory _factory;
private readonly bool _ownFactory;