diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs index e221a64cf2..ba76cc74d0 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs @@ -221,12 +221,7 @@ namespace Microsoft.AspNet.WebSockets.Protocol { if (!_firstDataOpCode.HasValue) { - if (State == WebSocketState.Open) - { - await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Invalid continuation frame", cancellationToken); - Abort(); - } - throw new InvalidOperationException("A continuation can't be the first frame"); // TODO: WebSocketException + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid continuation frame", cancellationToken); } opCode = _firstDataOpCode.Value; } @@ -261,12 +256,7 @@ namespace Microsoft.AspNet.WebSockets.Protocol if (messageType == WebSocketMessageType.Text && !Utilities.TryValidateUtf8(new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy), _frameInProgress.Fin, _incomingUtf8MessageState)) { - if (State == WebSocketState.Open) - { - await CloseOutputAsync(WebSocketCloseStatus.InvalidPayloadData, "Invalid UTF-8", cancellationToken); - Abort(); - } - throw new InvalidOperationException("An invalid UTF-8 payload was received."); // TODO: WebSocketException + await SendErrorAbortAndThrow(WebSocketCloseStatus.InvalidPayloadData, "Invalid UTF-8", cancellationToken); } if (bytesToCopy == _frameBytesRemaining) @@ -303,47 +293,41 @@ namespace Microsoft.AspNet.WebSockets.Protocol if (_frameInProgress.AreReservedSet()) { - if (State == WebSocketState.Open) - { - await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Unexpected reserved bits set", cancellationToken); - Abort(); - } - throw new InvalidOperationException("Unexpected reserved bits are set."); // TODO: WebSocketException + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Unexpected reserved bits set", cancellationToken); } if (_unmaskInput != _frameInProgress.Masked) { - if (State == WebSocketState.Open) - { - await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Incorrect masking", cancellationToken); - Abort(); - } - throw new InvalidOperationException("Unmasking settings out of sync with data."); // TODO: WebSocketException + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Incorrect masking", cancellationToken); } - if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) + if (_frameInProgress.IsControlFrame) { if (_frameBytesRemaining > 125) { - if (State == WebSocketState.Open) - { - await CloseOutputAsync(WebSocketCloseStatus.ProtocolError, "Invalid control frame size", cancellationToken); - Abort(); - } - throw new InvalidOperationException("Control frame too large."); // TODO: WebSocketException + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid control frame size", cancellationToken); } - // Drain it, should be less than 125 bytes - await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); - if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame) + if (!_frameInProgress.Fin) { - await SendPongReplyAsync(cancellationToken); + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Fragmented control frame", cancellationToken); } - _receiveBufferOffset += (int)_frameBytesRemaining; - _receiveBufferBytes -= (int)_frameBytesRemaining; - _frameBytesRemaining = 0; - _frameInProgress = null; + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) + { + // Drain it, should be less than 125 bytes + await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); + + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame) + { + await SendPongReplyAsync(cancellationToken); + } + + _receiveBufferOffset += (int)_frameBytesRemaining; + _receiveBufferBytes -= (int)_frameBytesRemaining; + _frameBytesRemaining = 0; + _frameInProgress = null; + } } } @@ -587,5 +571,15 @@ namespace Microsoft.AspNet.WebSockets.Protocol throw new ArgumentOutOfRangeException("buffer.Count", buffer.Count, string.Empty); } } + + private async Task SendErrorAbortAndThrow(WebSocketCloseStatus error, string message, CancellationToken cancellationToken) + { + if (State == WebSocketState.Open) + { + await CloseOutputAsync(error, message, cancellationToken); + } + Abort(); + throw new InvalidOperationException(message); // TODO: WebSocketException + } } } diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs b/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs index 7c1cb0ab51..5d9d693763 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/FrameHeader.cs @@ -209,6 +209,14 @@ namespace Microsoft.AspNet.WebSockets.Protocol } } + public bool IsControlFrame + { + get + { + return OpCode >= Constants.OpCodes.CloseFrame; + } + } + // bits 1-3. internal bool AreReservedSet() {