From 755ba7613ecb87497bb01c36b1c3654778e59def Mon Sep 17 00:00:00 2001 From: Andrew Stanton-Nurse Date: Tue, 21 Feb 2017 15:27:52 -0800 Subject: [PATCH] Fix #215 and restore tests (#218) * fix #215 by properly handling pipe closure * pr feedback * pr feedback --- .../Transports/WebSocketsTransport.cs | 4 +- .../IWebSocketConnection.cs | 5 + ...soft.Extensions.WebSockets.Internal.csproj | 3 +- .../PipeReaderExtensions.cs | 22 +- .../WebSocketConnection.cs | 492 ++++++++++-------- .../Microsoft.AspNetCore.Sockets.Tests.csproj | 1 + .../WebSocketsTests.cs | 7 +- ...cketConnectionTests.ConnectionLifecycle.cs | 2 +- ...WebSocketConnectionTests.ProtocolErrors.cs | 3 +- .../WebSocketConnectionTests.ReceiveAsync.cs | 3 +- .../WebSocketConnectionTests.SendAsync.cs | 15 +- .../WebSocketPair.cs | 4 +- 12 files changed, 304 insertions(+), 257 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index f73205b8d0..9382612e37 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -75,7 +75,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports { if (receiving.IsCanceled || receiving.IsFaulted) { - // The receiver faulted or cancelled. This means the client is probably broken. Just propagate the exception and exit + // The receiver faulted or cancelled. This means the socket is probably broken. Abort the socket and propagate the exception receiving.GetAwaiter().GetResult(); // Should never get here because GetResult above will throw diff --git a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs index 9f74bc95c4..88a88c07b7 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/IWebSocketConnection.cs @@ -64,6 +64,11 @@ namespace Microsoft.Extensions.WebSockets.Internal /// A state parameter that will be passed to each invocation of /// A that will complete when the client has sent a close frame, or the connection has been terminated Task ExecuteAsync(Func messageHandler, object state); + + /// + /// Forcibly terminates the socket, cleaning up the necessary resources. + /// + void Abort(); } public static class WebSocketConnectionExtensions diff --git a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj index 45c314cb42..5d3aca9e14 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj +++ b/src/Microsoft.Extensions.WebSockets.Internal/Microsoft.Extensions.WebSockets.Internal.csproj @@ -13,9 +13,10 @@ + - + diff --git a/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs b/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs index 758920ab79..d3d54f68c3 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/PipeReaderExtensions.cs @@ -1,31 +1,29 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.Threading; -using System.Threading.Tasks; using System.IO.Pipelines; +using System.Threading.Tasks; namespace Microsoft.Extensions.WebSockets.Internal { public static class PipeReaderExtensions { - public static ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes) => ReadAtLeastAsync(input, minimumRequiredBytes, CancellationToken.None); - // TODO: Pull this up to Channels. We should be able to do it there without allocating a Task in any case (rather than here where we can avoid allocation // only if the buffer is already ready and has enough data) - public static ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken) + public static async ValueTask ReadAtLeastAsync(this IPipeReader input, int minimumRequiredBytes) { var awaiter = input.ReadAsync(/* cancellationToken */); // Short-cut path! + ReadResult result; if (awaiter.IsCompleted) { // We have a buffer, is it big enough? - var result = awaiter.GetResult(); + result = awaiter.GetResult(); if (result.IsCompleted || result.Buffer.Length >= minimumRequiredBytes) { - return new ValueTask(result); + return result; } // Buffer wasn't big enough, mark it as examined and continue to the "slow" path below @@ -33,15 +31,9 @@ namespace Microsoft.Extensions.WebSockets.Internal consumed: result.Buffer.Start, examined: result.Buffer.End); } - return new ValueTask(ReadAtLeastSlowAsync(awaiter, input, minimumRequiredBytes, cancellationToken)); - } - - private static async Task ReadAtLeastSlowAsync(ReadableBufferAwaitable awaitable, IPipeReader input, int minimumRequiredBytes, CancellationToken cancellationToken) - { - var result = await awaitable; - while (!result.IsCompleted && result.Buffer.Length < minimumRequiredBytes) + result = await awaiter; + while (!result.IsCancelled && !result.IsCompleted && result.Buffer.Length < minimumRequiredBytes) { - cancellationToken.ThrowIfCancellationRequested(); input.Advance( consumed: result.Buffer.Start, examined: result.Buffer.End); diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs index d6dcca1bf6..b04673642f 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs @@ -31,7 +31,6 @@ namespace Microsoft.Extensions.WebSockets.Internal private readonly byte[] _maskingKeyBuffer; private readonly IPipeReader _inbound; private readonly IPipeWriter _outbound; - private readonly CancellationTokenSource _terminateReceiveCts = new CancellationTokenSource(); private readonly Timer _pinger; private readonly CancellationTokenSource _timerCts = new CancellationTokenSource(); private Utf8Validator _validator = new Utf8Validator(); @@ -113,8 +112,7 @@ namespace Microsoft.Extensions.WebSockets.Internal { // We don't need to wait for this task to complete, we're "tail calling" and // we are in a Timer thread-pool thread. -#pragma warning disable 4014 - connection.SendCoreLockAcquiredAsync( + var ignore = connection.SendCoreLockAcquiredAsync( fin: true, opcode: WebSocketOpcode.Ping, payloadAllocLength: 28, @@ -122,7 +120,6 @@ namespace Microsoft.Extensions.WebSockets.Internal payloadWriter: PingPayloadWriter, payload: DateTime.UtcNow, cancellationToken: connection._timerCts.Token); -#pragma warning restore 4014 } } @@ -131,7 +128,6 @@ namespace Microsoft.Extensions.WebSockets.Internal State = WebSocketConnectionState.Closed; _pinger?.Dispose(); _timerCts.Cancel(); - _terminateReceiveCts.Cancel(); _inbound.Complete(); _outbound.Complete(); } @@ -148,7 +144,7 @@ namespace Microsoft.Extensions.WebSockets.Internal throw new InvalidOperationException("Connection is already running."); } State = WebSocketConnectionState.Connected; - return ReceiveLoop(messageHandler, state, _terminateReceiveCts.Token); + return ReceiveLoop(messageHandler, state); } /// @@ -249,269 +245,322 @@ namespace Microsoft.Extensions.WebSockets.Internal _maskingKeyBuffer.CopyTo(buffer); } - private async Task ReceiveLoop(Func messageHandler, object state, CancellationToken cancellationToken) + /// + /// Terminates the socket abruptly. + /// + public void Abort() { - while (!cancellationToken.IsCancellationRequested) + // We duplicate some work from Dispose here, but that's OK. + _timerCts.Cancel(); + _inbound.CancelPendingRead(); + _outbound.Complete(); + } + + private async ValueTask<(bool Success, byte OpcodeByte, bool Masked, bool Fin, int Length, uint MaskingKey)> ReadHeaderAsync() + { + // Read at least 2 bytes + var readResult = await _inbound.ReadAtLeastAsync(2); + if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < 2)) { - // WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2): - // 0 1 2 3 - // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - // +-+-+-+-+-------+-+-------------+-------------------------------+ - // |F|R|R|R| opcode|M| Payload len | Extended payload length | - // |I|S|S|S| (4) |A| (7) | (16/64) | - // |N|V|V|V| |S| | (if payload len==126/127) | - // | |1|2|3| |K| | | - // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - // | Extended payload length continued, if payload len == 127 | - // + - - - - - - - - - - - - - - - +-------------------------------+ - // | |Masking-key, if MASK set to 1 | - // +-------------------------------+-------------------------------+ - // | Masking-key (continued) | Payload Data | - // +-------------------------------- - - - - - - - - - - - - - - - + - // : Payload Data continued ... : - // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - // | Payload Data continued ... | - // +---------------------------------------------------------------+ + _inbound.Advance(readResult.Buffer.End); + return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0); + } + var buffer = readResult.Buffer; - // Read at least 2 bytes - var result = await _inbound.ReadAtLeastAsync(2, cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); - if (result.IsCompleted && result.Buffer.Length < 2) + // Read the opcode and length + var opcodeByte = buffer.ReadBigEndian(); + buffer = buffer.Slice(1); + + // Read the first byte of the payload length + var lengthByte = buffer.ReadBigEndian(); + buffer = buffer.Slice(1); + + _inbound.Advance(buffer.Start); + + // Determine how much header there still is to read + var fin = (opcodeByte & 0x80) != 0; + var masked = (lengthByte & 0x80) != 0; + var length = lengthByte & 0x7F; + + // Calculate the rest of the header length + var headerLength = masked ? 4 : 0; + if (length == 126) + { + headerLength += 2; + } + else if (length == 127) + { + headerLength += 8; + } + + // Read the next set of header data + uint maskingKey = 0; + if (headerLength > 0) + { + readResult = await _inbound.ReadAtLeastAsync(headerLength); + if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < headerLength)) { - return WebSocketCloseResult.AbnormalClosure; + _inbound.Advance(readResult.Buffer.End); + return (Success: false, OpcodeByte: 0, Masked: false, Fin: false, Length: 0, MaskingKey: 0); } - var buffer = result.Buffer; + buffer = readResult.Buffer; - // Read the opcode - var opcodeByte = buffer.ReadBigEndian(); - buffer = buffer.Slice(1); - - var fin = (opcodeByte & 0x80) != 0; - var opcodeNum = opcodeByte & 0x0F; - var opcode = (WebSocketOpcode)opcodeNum; - - if ((opcodeByte & 0x70) != 0) + // Read extended payload length (if any) + if (length == 126) { - // Reserved bits set, this frame is invalid, close our side and terminate immediately - return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), "Reserved bits, which are required to be zero, were set."); + length = buffer.ReadBigEndian(); + buffer = buffer.Slice(sizeof(ushort)); } - else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F)) + else if (length == 127) { - // Reserved opcode - return await CloseFromProtocolError(cancellationToken, 0, default(ReadableBuffer), $"Received frame using reserved opcode: 0x{opcodeNum:X}"); + var longLen = buffer.ReadBigEndian(); + buffer = buffer.Slice(sizeof(ulong)); + if (longLen > int.MaxValue) + { + throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes"); + } + length = (int)longLen; } - // Read the first byte of the payload length - var lenByte = buffer.ReadBigEndian(); - buffer = buffer.Slice(1); + // Read masking key + if (masked) + { + var maskingKeyStart = buffer.Start; + maskingKey = buffer.Slice(0, sizeof(uint)).ReadBigEndian(); + buffer = buffer.Slice(sizeof(uint)); + } - var masked = (lenByte & 0x80) != 0; - var payloadLen = (lenByte & 0x7F); - - // Mark what we've got so far as consumed + // Mark the length and masking key consumed _inbound.Advance(buffer.Start); + } - // Calculate the rest of the header length - var headerLength = masked ? 4 : 0; - if (payloadLen == 126) + return (Success: true, opcodeByte, masked, fin, length, maskingKey); + } + + private async ValueTask<(bool Success, ReadableBuffer Buffer)> ReadPayloadAsync(int length, bool masked, uint maskingKey) + { + var payload = default(ReadableBuffer); + if (length > 0) + { + var readResult = await _inbound.ReadAtLeastAsync(length); + if (readResult.IsCancelled || (readResult.IsCompleted && readResult.Buffer.Length < length)) { - headerLength += 2; + return (Success: false, Buffer: readResult.Buffer); } - else if (payloadLen == 127) + var buffer = readResult.Buffer; + + payload = buffer.Slice(0, length); + + if (masked) { - headerLength += 8; + // Unmask + MaskingUtilities.ApplyMask(ref payload, maskingKey); } + } + return (Success: true, Buffer: payload); + } - uint maskingKey = 0; - - if (headerLength > 0) + private async Task ReceiveLoop(Func messageHandler, object state) + { + try + { + while (true) { - result = await _inbound.ReadAtLeastAsync(headerLength, cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); - if (result.IsCompleted && result.Buffer.Length < headerLength) - { - return WebSocketCloseResult.AbnormalClosure; - } - buffer = result.Buffer; + // WebSocket Frame layout (https://tools.ietf.org/html/rfc6455#section-5.2): + // 0 1 2 3 + // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + // +-+-+-+-+-------+-+-------------+-------------------------------+ + // |F|R|R|R| opcode|M| Payload len | Extended payload length | + // |I|S|S|S| (4) |A| (7) | (16/64) | + // |N|V|V|V| |S| | (if payload len==126/127) | + // | |1|2|3| |K| | | + // +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + // | Extended payload length continued, if payload len == 127 | + // + - - - - - - - - - - - - - - - +-------------------------------+ + // | |Masking-key, if MASK set to 1 | + // +-------------------------------+-------------------------------+ + // | Masking-key (continued) | Payload Data | + // +-------------------------------- - - - - - - - - - - - - - - - + + // : Payload Data continued ... : + // + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + // | Payload Data continued ... | + // +---------------------------------------------------------------+ - // Read extended payload length (if any) - if (payloadLen == 126) + var header = await ReadHeaderAsync(); + if (!header.Success) { - payloadLen = buffer.ReadBigEndian(); - buffer = buffer.Slice(sizeof(ushort)); + break; } - else if (payloadLen == 127) + + // Validate Opcode + var opcodeNum = header.OpcodeByte & 0x0F; + + if ((header.OpcodeByte & 0x70) != 0) { - var longLen = buffer.ReadBigEndian(); - buffer = buffer.Slice(sizeof(ulong)); - if (longLen > int.MaxValue) + // Reserved bits set, this frame is invalid, close our side and terminate immediately + await CloseFromProtocolError("Reserved bits, which are required to be zero, were set."); + break; + } + else if ((opcodeNum >= 0x03 && opcodeNum <= 0x07) || (opcodeNum >= 0x0B && opcodeNum <= 0x0F)) + { + // Reserved opcode + await CloseFromProtocolError($"Received frame using reserved opcode: 0x{opcodeNum:X}"); + break; + } + var opcode = (WebSocketOpcode)opcodeNum; + + var payload = await ReadPayloadAsync(header.Length, header.Masked, header.MaskingKey); + if (!payload.Success) + { + _inbound.Advance(payload.Buffer.End); + break; + } + + var frame = new WebSocketFrame(header.Fin, opcode, payload.Buffer); + + // Start a try-finally because we may get an exception while closing, if there's an error + // And we need to advance the buffer even if that happens. It wasn't needed above because + // we had already parsed the buffer before we verified it, so we had already advanced the + // buffer, if we encountered an error while closing we didn't have to advance the buffer. + // Side Note: Look at this gloriously aligned comment. You have anurse and brecon to thank for it. Oh wait, I ruined it. + try + { + if (frame.Opcode.IsControl() && !frame.EndOfMessage) { - throw new WebSocketException($"Frame is too large. Maximum frame size is {int.MaxValue} bytes"); + // Control frames cannot be fragmented. + await CloseFromProtocolError("Control frames may not be fragmented"); + break; } - payloadLen = (int)longLen; - } - - // Read masking key - if (masked) - { - var maskingKeyStart = buffer.Start; - maskingKey = buffer.Slice(0, 4).ReadBigEndian(); - buffer = buffer.Slice(4); - } - - // Mark the length and masking key consumed - _inbound.Advance(buffer.Start); - } - - var payload = default(ReadableBuffer); - if (payloadLen > 0) - { - result = await _inbound.ReadAtLeastAsync(payloadLen, cancellationToken); - cancellationToken.ThrowIfCancellationRequested(); - if (result.IsCompleted && result.Buffer.Length < payloadLen) - { - return WebSocketCloseResult.AbnormalClosure; - } - buffer = result.Buffer; - - payload = buffer.Slice(0, payloadLen); - - if (masked) - { - // Unmask - MaskingUtilities.ApplyMask(ref payload, maskingKey); - } - } - - // Run the callback, if we're not cancelled. - cancellationToken.ThrowIfCancellationRequested(); - - var frame = new WebSocketFrame(fin, opcode, payload); - - if (frame.Opcode.IsControl() && !frame.EndOfMessage) - { - // Control frames cannot be fragmented. - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Control frames may not be fragmented"); - } - else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0) - { - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Received non-continuation frame during a fragmented message"); - } - else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation) - { - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Continuation Frame was received when expecting a new message"); - } - - if (frame.Opcode == WebSocketOpcode.Close) - { - // Allowed frame lengths: - // 0 - No body - // 2 - Code with no reason phrase - // >2 - Code and reason phrase (must be valid UTF-8) - if (frame.Payload.Length > 125) - { - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload too long. Maximum size is 125 bytes"); - } - else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(payload.Slice(2)))) - { - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Close frame payload invalid"); - } - - ushort? actualStatusCode; - var closeResult = HandleCloseFrame(payload, frame, out actualStatusCode); - - // Verify the close result - if (actualStatusCode != null) - { - var statusCode = actualStatusCode.Value; - if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000)) + else if (_currentMessageType != WebSocketOpcode.Continuation && opcode.IsMessage() && opcode != 0) { - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, $"Invalid close status: {statusCode}."); + await CloseFromProtocolError("Received non-continuation frame during a fragmented message"); + break; + } + else if (_currentMessageType == WebSocketOpcode.Continuation && frame.Opcode == WebSocketOpcode.Continuation) + { + await CloseFromProtocolError("Continuation Frame was received when expecting a new message"); + break; + } + + if (frame.Opcode == WebSocketOpcode.Close) + { + return await ProcessCloseFrameAsync(frame); + } + else + { + if (frame.Opcode == WebSocketOpcode.Ping) + { + // Check the ping payload length + if (frame.Payload.Length > 125) + { + // Payload too long + await CloseFromProtocolError("Ping frame exceeded maximum size of 125 bytes"); + break; + } + + await SendCoreAsync( + frame.EndOfMessage, + WebSocketOpcode.Pong, + payloadAllocLength: 0, + payloadLength: frame.Payload.Length, + payloadWriter: AppendPayloadWriter, + payload: frame.Payload, + cancellationToken: CancellationToken.None); + } + var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode; + if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage)) + { + // Drop the frame and immediately close with InvalidPayload + await CloseFromProtocolError("An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData); + break; + } + else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong)) + { + await messageHandler(frame, state); + } + } + } + finally + { + if (frame.Payload.Length > 0) + { + _inbound.Advance(frame.Payload.End); } } - // Make the payload as consumed - if (payloadLen > 0) + if (header.Fin) { - _inbound.Advance(payload.End); - } + // Reset the UTF8 validator + _validator.Reset(); - return closeResult; - } - else - { - if (frame.Opcode == WebSocketOpcode.Ping) - { - // Check the ping payload length - if (frame.Payload.Length > 125) + // If it's a non-control frame, reset the message type tracker + if (opcode.IsMessage()) { - // Payload too long - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "Ping frame exceeded maximum size of 125 bytes"); + _currentMessageType = WebSocketOpcode.Continuation; } - - await SendCoreAsync( - frame.EndOfMessage, - WebSocketOpcode.Pong, - payloadAllocLength: 0, - payloadLength: payload.Length, - payloadWriter: AppendPayloadWriter, - payload: payload, - cancellationToken: cancellationToken); } - var effectiveOpcode = opcode == WebSocketOpcode.Continuation ? _currentMessageType : opcode; - if (effectiveOpcode == WebSocketOpcode.Text && !_validator.ValidateUtf8Frame(frame.Payload, frame.EndOfMessage)) + // If there isn't a current message type, and this was a fragmented message frame, set the current message type + else if (!header.Fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage()) { - // Drop the frame and immediately close with InvalidPayload - return await CloseFromProtocolError(cancellationToken, payloadLen, payload, "An invalid Text frame payload was received", statusCode: WebSocketCloseStatus.InvalidPayloadData); - } - else if (_options.PassAllFramesThrough || (frame.Opcode != WebSocketOpcode.Ping && frame.Opcode != WebSocketOpcode.Pong)) - { - await messageHandler(frame, state); + _currentMessageType = opcode; } } - - if (fin) - { - // Reset the UTF8 validator - _validator.Reset(); - - // If it's a non-control frame, reset the message type tracker - if (opcode.IsMessage()) - { - _currentMessageType = WebSocketOpcode.Continuation; - } - } - // If there isn't a current message type, and this was a fragmented message frame, set the current message type - else if (!fin && _currentMessageType == WebSocketOpcode.Continuation && opcode.IsMessage()) - { - _currentMessageType = opcode; - } - - // Mark the payload as consumed - if (payloadLen > 0) - { - _inbound.Advance(payload.End); - } + } + catch + { + // Abort the socket and rethrow + Abort(); + throw; } return WebSocketCloseResult.AbnormalClosure; } - private async Task CloseFromProtocolError(CancellationToken cancellationToken, int payloadLen, ReadableBuffer payload, string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError) + private async ValueTask ProcessCloseFrameAsync(WebSocketFrame frame) { - // Non-continuation non-control message during fragmented message - if (payloadLen > 0) + // Allowed frame lengths: + // 0 - No body + // 2 - Code with no reason phrase + // >2 - Code and reason phrase (must be valid UTF-8) + if (frame.Payload.Length > 125) { - _inbound.Advance(payload.End); + await CloseFromProtocolError("Close frame payload too long. Maximum size is 125 bytes"); + return WebSocketCloseResult.AbnormalClosure; } - var closeResult = new WebSocketCloseResult( - statusCode, - reason); - await CloseAsync(closeResult, cancellationToken); - Dispose(); + else if ((frame.Payload.Length == 1) || (frame.Payload.Length > 2 && !Utf8Validator.ValidateUtf8(frame.Payload.Slice(2)))) + { + await CloseFromProtocolError("Close frame payload invalid"); + return WebSocketCloseResult.AbnormalClosure; + } + + ushort? actualStatusCode; + var closeResult = ParseCloseFrame(frame.Payload, frame, out actualStatusCode); + + // Verify the close result + if (actualStatusCode != null) + { + var statusCode = actualStatusCode.Value; + if (statusCode < 1000 || statusCode == 1004 || statusCode == 1005 || statusCode == 1006 || (statusCode > 1011 && statusCode < 3000)) + { + await CloseFromProtocolError($"Invalid close status: {statusCode}."); + return WebSocketCloseResult.AbnormalClosure; + } + } + return closeResult; } - private WebSocketCloseResult HandleCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode) + private async Task CloseFromProtocolError(string reason, WebSocketCloseStatus statusCode = WebSocketCloseStatus.ProtocolError) + { + var closeResult = new WebSocketCloseResult( + statusCode, + reason); + await CloseAsync(closeResult, CancellationToken.None); + + // We can now terminate our connection, according to the spec. + Abort(); + } + + private WebSocketCloseResult ParseCloseFrame(ReadableBuffer payload, WebSocketFrame frame, out ushort? actualStatusCode) { // Update state if (State == WebSocketConnectionState.CloseSent) @@ -529,6 +578,7 @@ namespace Microsoft.Extensions.WebSockets.Internal { closeResult = WebSocketCloseResult.Empty; } + return closeResult; } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj index 1df0f57c42..06731c2136 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj +++ b/test/Microsoft.AspNetCore.Sockets.Tests/Microsoft.AspNetCore.Sockets.Tests.csproj @@ -12,6 +12,7 @@ + diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index f79d06a48a..e38542f111 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -6,6 +6,7 @@ using System.IO.Pipelines; using System.Text; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Sockets.Transports; using Microsoft.Extensions.Logging; @@ -258,7 +259,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } - [Fact(Skip="Fails after updating to new Pipelines")] + [Fact] public async Task TransportFailsWhenClientDisconnectsAbnormally() { var transportToApplication = Channel.CreateUnbounded(); @@ -286,7 +287,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests } } - [Fact(Skip="Fails after updating to new Pipelines")] + [Fact] public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails() { var transportToApplication = Channel.CreateUnbounded(); @@ -313,6 +314,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests // Close from the client await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + + await transport.OrTimeout(); } } } diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs index 1c344a9239..cad82468ea 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ConnectionLifecycle.cs @@ -65,7 +65,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests } } - [Fact(Skip="Fails after updating to new Pipelines")] + [Fact] public async Task AbnormalTerminationOfInboundChannelCausesExecuteToThrow() { using (var pair = WebSocketPair.Create()) diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs index 67ff16041a..18dc4dd42a 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ProtocolErrors.cs @@ -12,8 +12,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { public partial class WebSocketConnectionTests { - // Skipping tests after failures caused by updating to newer Pipelines - private class ProtocolErrors + public class ProtocolErrors { [Theory] [InlineData(new byte[] { 0x11, 0x00 })] diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs index 4b6bdb6802..5559d40510 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.ReceiveAsync.cs @@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { public partial class WebSocketConnectionTests { - // Skipping tests after failures caused by updating to newer Pipelines - private class TheReceiveAsyncMethod + public class TheReceiveAsyncMethod { [Theory] [InlineData(new byte[] { 0x81, 0x00 }, "", true)] diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs index bc5ba8c255..5ddae6662c 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs @@ -13,8 +13,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { public partial class WebSocketConnectionTests { - // Skipping tests after failures caused by updating to newer Pipelines - private class TheSendAsyncMethod + public class TheSendAsyncMethod { // No auto-pinging for us! private readonly static WebSocketOptions DefaultTestOptions = new WebSocketOptions().WithAllFramesPassedThrough(); @@ -179,20 +178,18 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests var outbound = factory.Create(); var inbound = factory.Create(); - Task executeTask; using (var connection = new WebSocketConnection(inbound.Reader, outbound.Writer, options)) { - executeTask = connection.ExecuteAsync(f => - { - Assert.False(true, "Did not expect to receive any messages"); - return TaskCache.CompletedTask; - }); + var executeTask = connection.ExecuteAndCaptureFramesAsync(); await producer(connection).OrTimeout(); + connection.Abort(); inbound.Writer.Complete(); await executeTask.OrTimeout(); } - var data = (await outbound.Reader.ReadToEndAsync()).ToArray(); + var buffer = await outbound.Reader.ReadToEndAsync(); + var data = buffer.ToArray(); + outbound.Reader.Advance(buffer.End); inbound.Reader.Complete(); CompleteChannels(outbound); return data; diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs index eaa78294c1..1a74a1c7e1 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs @@ -8,8 +8,8 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { internal class WebSocketPair : IDisposable { - private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(); - private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough(); + private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(); + private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough(); private PipeFactory _factory; private readonly bool _ownFactory;