diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs index 481e5a9393..c6d82a573b 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/CommonWebSocket.cs @@ -4,6 +4,7 @@ using System; using System.Diagnostics.Contracts; using System.IO; +using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; @@ -302,6 +303,11 @@ namespace Microsoft.AspNet.WebSockets.Protocol await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Incorrect masking", cancellationToken); } + if (!ValidateOpCode(_frameInProgress.OpCode)) + { + await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid opcode: " + _frameInProgress.OpCode, cancellationToken); + } + if (_frameInProgress.IsControlFrame) { if (_frameBytesRemaining > 125) @@ -455,29 +461,6 @@ namespace Microsoft.AspNet.WebSockets.Protocol return result; } - private static bool ValidateCloseStatus(WebSocketCloseStatus closeStatus) - { - if (closeStatus < (WebSocketCloseStatus)1000 || closeStatus >= (WebSocketCloseStatus)5000) - { - return false; - } - else if (closeStatus >= (WebSocketCloseStatus)3000) - { - // 3000-3999 - Reserved for frameworks - // 4000-4999 - Reserved for private usage - return true; - } - int[] validCodes = new[] { 1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011 }; - foreach (var validCode in validCodes) - { - if (closeStatus == (WebSocketCloseStatus)validCode) - { - return true; - } - } - return false; - } - public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -617,6 +600,34 @@ namespace Microsoft.AspNet.WebSockets.Protocol } } + private bool ValidateOpCode(int opCode) + { + return Constants.OpCodes.ValidOpCodes.Contains(opCode); + } + + private static bool ValidateCloseStatus(WebSocketCloseStatus closeStatus) + { + if (closeStatus < (WebSocketCloseStatus)1000 || closeStatus >= (WebSocketCloseStatus)5000) + { + return false; + } + else if (closeStatus >= (WebSocketCloseStatus)3000) + { + // 3000-3999 - Reserved for frameworks + // 4000-4999 - Reserved for private usage + return true; + } + int[] validCodes = new[] { 1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011 }; + foreach (var validCode in validCodes) + { + if (closeStatus == (WebSocketCloseStatus)validCode) + { + return true; + } + } + return false; + } + private async Task SendErrorAbortAndThrow(WebSocketCloseStatus error, string message, CancellationToken cancellationToken) { if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) diff --git a/src/Microsoft.AspNet.WebSockets.Protocol/Constants.cs b/src/Microsoft.AspNet.WebSockets.Protocol/Constants.cs index 07756e49e3..60a9340b67 100644 --- a/src/Microsoft.AspNet.WebSockets.Protocol/Constants.cs +++ b/src/Microsoft.AspNet.WebSockets.Protocol/Constants.cs @@ -25,7 +25,17 @@ namespace Microsoft.AspNet.WebSockets.Protocol public const int BinaryFrame = 0x2; public const int CloseFrame = 0x8; public const int PingFrame = 0x9; - public const int PongFrame = 0xA; + public const int PongFrame = 0xA; + + internal static readonly int[] ValidOpCodes = new int[] + { + ContinuationFrame, + TextFrame, + BinaryFrame, + CloseFrame, + PingFrame, + PongFrame, + }; } } }