Implement client masking.

This commit is contained in:
Chris Ross 2014-03-05 16:55:32 -08:00
parent 7004026b5e
commit 14685821a6
7 changed files with 116 additions and 10 deletions

View File

@ -9,10 +9,15 @@ using System.Threading.Tasks;
namespace Microsoft.Net.WebSockets
{
public class ClientWebSocket : WebSocket
// https://tools.ietf.org/html/rfc6455
public class CommonWebSocket : WebSocket
{
private readonly static Random Random = new Random();
private readonly Stream _stream;
private readonly string _subProtocl;
private readonly bool _maskOutput;
private readonly bool _useZeroMask;
private WebSocketState _state;
private WebSocketCloseStatus? _closeStatus;
@ -25,12 +30,14 @@ namespace Microsoft.Net.WebSockets
private FrameHeader _frameInProgress;
private long _frameBytesRemaining = 0;
public ClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize)
public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize)
{
_stream = stream;
_subProtocl = subProtocol;
_state = WebSocketState.Open;
_receiveBuffer = new byte[receiveBufferSize];
_maskOutput = true; // TODO: make optional for client. Add option to block unmasking from server.
_useZeroMask = false; // TODO: make optional
}
public override WebSocketCloseStatus? CloseStatus
@ -53,6 +60,28 @@ namespace Microsoft.Net.WebSockets
get { return _subProtocl; }
}
// https://tools.ietf.org/html/rfc6455#section-5.3
// The masking key is a 32-bit value chosen at random by the client.
// When preparing a masked frame, the client MUST pick a fresh masking
// key from the set of allowed 32-bit values. The masking key needs to
// be unpredictable; thus, the masking key MUST be derived from a strong
// source of entropy, and the masking key for a given frame MUST NOT
// make it simple for a server/proxy to predict the masking key for a
// subsequent frame. The unpredictability of the masking key is
// essential to prevent authors of malicious applications from selecting
// the bytes that appear on the wire. RFC 4086 [RFC4086] discusses what
// entails a suitable source of entropy for security-sensitive
// applications.
private int GetNextMask()
{
if (_useZeroMask)
{
return 0;
}
// TODO: Doesn't include negative numbers so it's only 31 bits, not 32.
return Random.Next();
}
public override async Task SendAsync(ArraySegment<byte> buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken)
{
// TODO: Validate arguments
@ -60,10 +89,20 @@ namespace Microsoft.Net.WebSockets
// TODO: Check concurrent writes
// TODO: Check ping/pong state
// TODO: Masking
FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), true, 0, buffer.Count);
// TODO: Block close frame?
int mask = GetNextMask();
FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), _maskOutput, mask, buffer.Count);
ArraySegment<byte> segment = frameHeader.Buffer;
await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken);
if (_maskOutput && mask != 0)
{
byte[] maskedFrame = Utilities.MergeAndMask(mask, segment, buffer);
await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken);
}
else
{
await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken);
}
}
private int GetOpCode(WebSocketMessageType messageType)
@ -105,7 +144,8 @@ namespace Microsoft.Net.WebSockets
// TOOD: This assumes the close message fits in the buffer.
// TODO: Assert at least two bytes remaining for the close status code.
await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None);
// TODO: Unmask (server only)
// TODO: Throw if the client detects an incoming masked frame.
_closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]);
_closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveOffset + 2, _receiveCount - 2) ?? string.Empty;
result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription);
@ -132,6 +172,7 @@ namespace Microsoft.Net.WebSockets
int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining);
if (_receiveCount > 0)
{
// TODO: Unmask
int bytesToCopy = Math.Min(bytesToRead, _receiveCount);
Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy);
if (bytesToCopy == _frameBytesRemaining)

View File

@ -41,11 +41,11 @@ namespace Microsoft.Net.WebSockets
Fin = final;
OpCode = opCode;
Masked = masked;
DataLength = dataLength;
if (masked)
{
MaskKey = maskKey;
}
DataLength = dataLength;
}
public bool Fin
@ -109,7 +109,7 @@ namespace Microsoft.Net.WebSockets
}
int offset = ExtendedLengthFieldSize + 2;
return (_header[offset] << 24) + (_header[offset + 1] << 16)
+ (_header[offset + 2] << 8) + _header[offset + 4];
+ (_header[offset + 2] << 8) + _header[offset + 3];
}
private set
{

View File

@ -39,10 +39,11 @@
<Reference Include="System.Xml" />
</ItemGroup>
<ItemGroup>
<Compile Include="ClientWebSocket.cs" />
<Compile Include="CommonWebSocket.cs" />
<Compile Include="Constants.cs" />
<Compile Include="FrameHeader.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
<Compile Include="Utilities.cs" />
<Compile Include="WebSocketClient.cs" />
</ItemGroup>
<Import Project="$(MSBuildToolsPath)\Microsoft.CSharp.targets" />

View File

@ -0,0 +1,41 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
namespace Microsoft.Net.WebSockets
{
public static class Utilities
{
// Copies the header and data into a new buffer and masks the data.
public static byte[] MergeAndMask(int mask, ArraySegment<byte> header, ArraySegment<byte> data)
{
byte[] frame = new byte[header.Count + data.Count];
Array.Copy(header.Array, header.Offset, frame, 0, header.Count);
Array.Copy(data.Array, data.Offset, frame, header.Count, data.Count);
Mask(mask, new ArraySegment<byte>(frame, header.Count, data.Count));
return frame;
}
// Un/Masks the data in place
public static void Mask(int mask, ArraySegment<byte> data)
{
byte[] maskBytes = new byte[]
{
(byte)(mask >> 24),
(byte)(mask >> 16),
(byte)(mask >> 8),
(byte)mask,
};
int maskOffset = 0;
for (int i = data.Offset; i < data.Offset + data.Count; i++)
{
data.Array[i] = (byte)(data.Array[i] ^ maskBytes[maskOffset]);
maskOffset = (maskOffset + 1) % 4;
}
}
}
}

View File

@ -49,7 +49,7 @@ namespace Microsoft.Net.WebSockets.Client
Stream stream = response.GetResponseStream();
// Console.WriteLine(stream.CanWrite + " " + stream.CanRead);
return new ClientWebSocket(stream, null, ReceiveBufferSize);
return new CommonWebSocket(stream, null, ReceiveBufferSize);
}
}
}

View File

@ -45,6 +45,7 @@
</Reference>
</ItemGroup>
<ItemGroup>
<Compile Include="UtilitiesTests.cs" />
<Compile Include="WebSocketClientTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />
</ItemGroup>

View File

@ -0,0 +1,22 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Net.WebSockets.Test
{
public class UtilitiesTests
{
[Fact]
public void MaskDataRoundTrips()
{
byte[] data = Encoding.UTF8.GetBytes("Hello World");
byte[] orriginal = Encoding.UTF8.GetBytes("Hello World");
Utilities.Mask(16843009, new ArraySegment<byte>(data));
Utilities.Mask(16843009, new ArraySegment<byte>(data));
Assert.Equal(orriginal, data);
}
}
}