diff --git a/src/Microsoft.Net.WebSockets/ClientWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs similarity index 81% rename from src/Microsoft.Net.WebSockets/ClientWebSocket.cs rename to src/Microsoft.Net.WebSockets/CommonWebSocket.cs index d91bf1cffd..6e3f19d41e 100644 --- a/src/Microsoft.Net.WebSockets/ClientWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -9,10 +9,15 @@ using System.Threading.Tasks; namespace Microsoft.Net.WebSockets { - public class ClientWebSocket : WebSocket + // https://tools.ietf.org/html/rfc6455 + public class CommonWebSocket : WebSocket { + private readonly static Random Random = new Random(); + private readonly Stream _stream; private readonly string _subProtocl; + private readonly bool _maskOutput; + private readonly bool _useZeroMask; private WebSocketState _state; private WebSocketCloseStatus? _closeStatus; @@ -25,12 +30,14 @@ namespace Microsoft.Net.WebSockets private FrameHeader _frameInProgress; private long _frameBytesRemaining = 0; - public ClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize) + public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize) { _stream = stream; _subProtocl = subProtocol; _state = WebSocketState.Open; _receiveBuffer = new byte[receiveBufferSize]; + _maskOutput = true; // TODO: make optional for client. Add option to block unmasking from server. + _useZeroMask = false; // TODO: make optional } public override WebSocketCloseStatus? CloseStatus @@ -53,6 +60,28 @@ namespace Microsoft.Net.WebSockets get { return _subProtocl; } } + // https://tools.ietf.org/html/rfc6455#section-5.3 + // The masking key is a 32-bit value chosen at random by the client. + // When preparing a masked frame, the client MUST pick a fresh masking + // key from the set of allowed 32-bit values. The masking key needs to + // be unpredictable; thus, the masking key MUST be derived from a strong + // source of entropy, and the masking key for a given frame MUST NOT + // make it simple for a server/proxy to predict the masking key for a + // subsequent frame. The unpredictability of the masking key is + // essential to prevent authors of malicious applications from selecting + // the bytes that appear on the wire. RFC 4086 [RFC4086] discusses what + // entails a suitable source of entropy for security-sensitive + // applications. + private int GetNextMask() + { + if (_useZeroMask) + { + return 0; + } + // TODO: Doesn't include negative numbers so it's only 31 bits, not 32. + return Random.Next(); + } + public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { // TODO: Validate arguments @@ -60,10 +89,20 @@ namespace Microsoft.Net.WebSockets // TODO: Check concurrent writes // TODO: Check ping/pong state // TODO: Masking - FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), true, 0, buffer.Count); + // TODO: Block close frame? + int mask = GetNextMask(); + FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), _maskOutput, mask, buffer.Count); ArraySegment segment = frameHeader.Buffer; - await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + if (_maskOutput && mask != 0) + { + byte[] maskedFrame = Utilities.MergeAndMask(mask, segment, buffer); + await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken); + } + else + { + await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + } } private int GetOpCode(WebSocketMessageType messageType) @@ -105,7 +144,8 @@ namespace Microsoft.Net.WebSockets // TOOD: This assumes the close message fits in the buffer. // TODO: Assert at least two bytes remaining for the close status code. await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); - + // TODO: Unmask (server only) + // TODO: Throw if the client detects an incoming masked frame. _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]); _closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveOffset + 2, _receiveCount - 2) ?? string.Empty; result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription); @@ -132,6 +172,7 @@ namespace Microsoft.Net.WebSockets int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining); if (_receiveCount > 0) { + // TODO: Unmask int bytesToCopy = Math.Min(bytesToRead, _receiveCount); Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy); if (bytesToCopy == _frameBytesRemaining) diff --git a/src/Microsoft.Net.WebSockets/FrameHeader.cs b/src/Microsoft.Net.WebSockets/FrameHeader.cs index 54d53fc1c7..35e54c7233 100644 --- a/src/Microsoft.Net.WebSockets/FrameHeader.cs +++ b/src/Microsoft.Net.WebSockets/FrameHeader.cs @@ -41,11 +41,11 @@ namespace Microsoft.Net.WebSockets Fin = final; OpCode = opCode; Masked = masked; + DataLength = dataLength; if (masked) { MaskKey = maskKey; } - DataLength = dataLength; } public bool Fin @@ -109,7 +109,7 @@ namespace Microsoft.Net.WebSockets } int offset = ExtendedLengthFieldSize + 2; return (_header[offset] << 24) + (_header[offset + 1] << 16) - + (_header[offset + 2] << 8) + _header[offset + 4]; + + (_header[offset + 2] << 8) + _header[offset + 3]; } private set { diff --git a/src/Microsoft.Net.WebSockets/Microsoft.Net.WebSockets.csproj b/src/Microsoft.Net.WebSockets/Microsoft.Net.WebSockets.csproj index c0283932c2..dec92dc375 100644 --- a/src/Microsoft.Net.WebSockets/Microsoft.Net.WebSockets.csproj +++ b/src/Microsoft.Net.WebSockets/Microsoft.Net.WebSockets.csproj @@ -39,10 +39,11 @@ - + + diff --git a/src/Microsoft.Net.WebSockets/Utilities.cs b/src/Microsoft.Net.WebSockets/Utilities.cs new file mode 100644 index 0000000000..2293dbb328 --- /dev/null +++ b/src/Microsoft.Net.WebSockets/Utilities.cs @@ -0,0 +1,41 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace Microsoft.Net.WebSockets +{ + public static class Utilities + { + // Copies the header and data into a new buffer and masks the data. + public static byte[] MergeAndMask(int mask, ArraySegment header, ArraySegment data) + { + byte[] frame = new byte[header.Count + data.Count]; + Array.Copy(header.Array, header.Offset, frame, 0, header.Count); + Array.Copy(data.Array, data.Offset, frame, header.Count, data.Count); + + Mask(mask, new ArraySegment(frame, header.Count, data.Count)); + return frame; + } + + // Un/Masks the data in place + public static void Mask(int mask, ArraySegment data) + { + byte[] maskBytes = new byte[] + { + (byte)(mask >> 24), + (byte)(mask >> 16), + (byte)(mask >> 8), + (byte)mask, + }; + int maskOffset = 0; + + for (int i = data.Offset; i < data.Offset + data.Count; i++) + { + data.Array[i] = (byte)(data.Array[i] ^ maskBytes[maskOffset]); + maskOffset = (maskOffset + 1) % 4; + } + } + } +} diff --git a/src/Microsoft.Net.WebSockets/WebSocketClient.cs b/src/Microsoft.Net.WebSockets/WebSocketClient.cs index 3cac40fe79..1a32277d8f 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketClient.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketClient.cs @@ -49,7 +49,7 @@ namespace Microsoft.Net.WebSockets.Client Stream stream = response.GetResponseStream(); // Console.WriteLine(stream.CanWrite + " " + stream.CanRead); - return new ClientWebSocket(stream, null, ReceiveBufferSize); + return new CommonWebSocket(stream, null, ReceiveBufferSize); } } } diff --git a/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj b/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj index 23b2f9f0f3..fcfd99f1a4 100644 --- a/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj +++ b/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj @@ -45,6 +45,7 @@ + diff --git a/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs b/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs new file mode 100644 index 0000000000..ff9c8f4601 --- /dev/null +++ b/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs @@ -0,0 +1,22 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Net.WebSockets.Test +{ + public class UtilitiesTests + { + [Fact] + public void MaskDataRoundTrips() + { + byte[] data = Encoding.UTF8.GetBytes("Hello World"); + byte[] orriginal = Encoding.UTF8.GetBytes("Hello World"); + Utilities.Mask(16843009, new ArraySegment(data)); + Utilities.Mask(16843009, new ArraySegment(data)); + Assert.Equal(orriginal, data); + } + } +}