Cleanup, unmasking.

This commit is contained in:
Chris Ross 2014-03-06 12:32:46 -08:00
parent c34001ee68
commit 1d5b4582f1
2 changed files with 73 additions and 45 deletions

View File

@ -33,18 +33,28 @@ namespace Microsoft.Net.WebSockets
private FrameHeader _frameInProgress;
private long _frameBytesRemaining = 0;
public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize)
public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput)
{
_stream = stream;
_subProtocl = subProtocol;
_state = WebSocketState.Open;
_receiveBuffer = new byte[receiveBufferSize];
_maskOutput = true; // TODO: client only.
_useZeroMask = false; // TODO: make optional
_unmaskInput = false; // TODO: server only
_maskOutput = maskOutput;
_useZeroMask = useZeroMask;
_unmaskInput = unmaskInput;
_writeLock = new SemaphoreSlim(1);
}
public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool useZeroMask)
{
return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false);
}
public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, int receiveBufferSize)
{
return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true);
}
public override WebSocketCloseStatus? CloseStatus
{
get { return _closeStatus; }
@ -149,6 +159,11 @@ namespace Microsoft.Net.WebSockets
_receiveCount -= frameHeaderSize;
_frameBytesRemaining = _frameInProgress.DataLength;
if (_unmaskInput != _frameInProgress.Masked)
{
throw new InvalidOperationException("Unmasking settings out of sync with data.");
}
// Ping or Pong frames
if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame)
{
@ -171,13 +186,26 @@ namespace Microsoft.Net.WebSockets
if (_frameInProgress.OpCode == Constants.OpCodes.CloseFrame)
{
// TOOD: This assumes the close message fits in the buffer.
// TODO: Assert at least two bytes remaining for the close status code.
// The close message should be less than 125 bytes and fit in the buffer.
await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None);
// TODO: Unmask (server only)
// TODO: Throw if the client detects an incoming masked frame.
_closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]);
_closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveOffset + 2, _receiveCount - 2) ?? string.Empty;
// Status code and message are optional
if (_frameBytesRemaining >= 2)
{
ArraySegment<byte> dataSegment = new ArraySegment<byte>(_receiveBuffer, _receiveOffset + 2, (int)_frameBytesRemaining - 2);
if (_unmaskInput)
{
// In place
Utilities.Mask(_frameInProgress.MaskKey, dataSegment);
}
_closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]);
_closeStatusDescription = Encoding.UTF8.GetString(dataSegment.Array, dataSegment.Offset, dataSegment.Count) ?? string.Empty;
}
else
{
_closeStatus = _closeStatus ?? WebSocketCloseStatus.NormalClosure;
_closeStatusDescription = _closeStatusDescription ?? string.Empty;
}
result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription);
if (State == WebSocketState.Open)
@ -192,49 +220,44 @@ namespace Microsoft.Net.WebSockets
return result;
}
// Make sure there's at least some data in the buffer
if (_frameBytesRemaining > 0)
{
await EnsureDataAvailableOrReadAsync(1, cancellationToken);
}
// Copy buffered data to the users buffer
int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining);
if (_receiveCount > 0)
{
// TODO: Unmask
int bytesToCopy = Math.Min(bytesToRead, _receiveCount);
Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy);
if (bytesToCopy == _frameBytesRemaining)
{
result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), _frameInProgress.Fin);
_frameInProgress = null;
}
else
{
result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), false);
}
_frameBytesRemaining -= bytesToCopy;
_receiveCount -= bytesToCopy;
_receiveOffset += bytesToCopy;
}
else
if (_frameBytesRemaining == 0)
{
// End of an empty frame?
result = new WebSocketReceiveResult(0, GetMessageType(_frameInProgress.OpCode), true);
_frameInProgress = null;
return result;
}
// Make sure there's at least some data in the buffer
await EnsureDataAvailableOrReadAsync(1, cancellationToken);
// Copy buffered data to the users buffer
int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining);
int bytesToCopy = Math.Min(bytesToRead, _receiveCount);
Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy);
if (_unmaskInput)
{
// TODO: mask alignment may be off between reads.
Utilities.Mask(_frameInProgress.MaskKey, new ArraySegment<byte>(buffer.Array, buffer.Offset, bytesToCopy));
}
if (bytesToCopy == _frameBytesRemaining)
{
result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), _frameInProgress.Fin);
_frameInProgress = null;
}
else
{
result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), false);
}
_frameBytesRemaining -= bytesToCopy;
_receiveCount -= bytesToCopy;
_receiveOffset += bytesToCopy;
return result;
}
// We received a ping, send a pong in reply
private async Task SendPongReply(CancellationToken cancellationToken)
{
// TODO: Unmask data
if (_unmaskInput != _frameInProgress.Masked)
{
throw new InvalidOperationException("Unmasking settings out of sync with data.");
}
ArraySegment<byte> dataSegment = new ArraySegment<byte>(_receiveBuffer, _receiveOffset, (int)_frameBytesRemaining);
if (_unmaskInput)
{
@ -359,8 +382,13 @@ namespace Microsoft.Net.WebSockets
fullData[1] = (byte)closeStatus;
Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length);
// TODO: Masking
FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, true, 0, fullData.Length);
int mask = GetNextMask();
if (_maskOutput)
{
Utilities.Mask(mask, new ArraySegment<byte>(fullData));
}
FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, _maskOutput, mask, fullData.Length);
ArraySegment<byte> segment = frameHeader.Buffer;
await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken);
await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken);

View File

@ -49,7 +49,7 @@ namespace Microsoft.Net.WebSockets.Client
Stream stream = response.GetResponseStream();
// Console.WriteLine(stream.CanWrite + " " + stream.CanRead);
return new CommonWebSocket(stream, null, ReceiveBufferSize);
return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: false);
}
}
}