diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs index 32b39389bd..e221a64cf2 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs @@ -221,7 +221,12 @@ namespace Microsoft.AspNet.WebSockets.Protocol { if (!_firstDataOpCode.HasValue) { - throw new InvalidOperationException("A continuation can't be the first frame"); + if (State == WebSocketState.Open) + { + await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Invalid continuation frame", cancellationToken); + Abort(); + } + throw new InvalidOperationException("A continuation can't be the first frame"); // TODO: WebSocketException } opCode = _firstDataOpCode.Value; } @@ -256,8 +261,12 @@ namespace Microsoft.AspNet.WebSockets.Protocol if (messageType == WebSocketMessageType.Text && !Utilities.TryValidateUtf8(new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy), _frameInProgress.Fin, _incomingUtf8MessageState)) { - await CloseOutputAsync(WebSocketCloseStatus.InvalidPayloadData, string.Empty, cancellationToken); - throw new InvalidOperationException("An invalid UTF-8 payload was received."); + if (State == WebSocketState.Open) + { + await CloseOutputAsync(WebSocketCloseStatus.InvalidPayloadData, "Invalid UTF-8", cancellationToken); + Abort(); + } + throw new InvalidOperationException("An invalid UTF-8 payload was received."); // TODO: WebSocketException } if (bytesToCopy == _frameBytesRemaining) @@ -292,9 +301,24 @@ namespace Microsoft.AspNet.WebSockets.Protocol _receiveBufferBytes -= frameHeaderSize; _frameBytesRemaining = _frameInProgress.DataLength; + if (_frameInProgress.AreReservedSet()) + { + if (State == WebSocketState.Open) + { + await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Unexpected reserved bits set", cancellationToken); + Abort(); + } + throw new InvalidOperationException("Unexpected reserved bits are set."); // TODO: WebSocketException + } + if (_unmaskInput != _frameInProgress.Masked) { - throw new InvalidOperationException("Unmasking settings out of sync with data."); + if (State == WebSocketState.Open) + { + await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Incorrect masking", cancellationToken); + Abort(); + } + throw new InvalidOperationException("Unmasking settings out of sync with data."); // TODO: WebSocketException } if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs b/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs index 1307c69fd7..7c1cb0ab51 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs @@ -209,6 +209,12 @@ namespace Microsoft.AspNet.WebSockets.Protocol } } + // bits 1-3. + internal bool AreReservedSet() + { + return (_header[0] & 0x70) != 0; + } + // Given the second bytes of a frame, calculate how long the whole frame header should be. // Range 2-12 bytes public static int CalculateFrameHeaderSize(byte b2)