diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index cae096cc94..b5fceae352 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -33,18 +33,28 @@ namespace Microsoft.Net.WebSockets private FrameHeader _frameInProgress; private long _frameBytesRemaining = 0; - public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize) + public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) { _stream = stream; _subProtocl = subProtocol; _state = WebSocketState.Open; _receiveBuffer = new byte[receiveBufferSize]; - _maskOutput = true; // TODO: client only. - _useZeroMask = false; // TODO: make optional - _unmaskInput = false; // TODO: server only + _maskOutput = maskOutput; + _useZeroMask = useZeroMask; + _unmaskInput = unmaskInput; _writeLock = new SemaphoreSlim(1); } + public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool useZeroMask) + { + return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false); + } + + public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, int receiveBufferSize) + { + return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true); + } + public override WebSocketCloseStatus? CloseStatus { get { return _closeStatus; } @@ -149,6 +159,11 @@ namespace Microsoft.Net.WebSockets _receiveCount -= frameHeaderSize; _frameBytesRemaining = _frameInProgress.DataLength; + if (_unmaskInput != _frameInProgress.Masked) + { + throw new InvalidOperationException("Unmasking settings out of sync with data."); + } + // Ping or Pong frames if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) { @@ -171,13 +186,26 @@ namespace Microsoft.Net.WebSockets if (_frameInProgress.OpCode == Constants.OpCodes.CloseFrame) { - // TOOD: This assumes the close message fits in the buffer. - // TODO: Assert at least two bytes remaining for the close status code. + // The close message should be less than 125 bytes and fit in the buffer. await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); - // TODO: Unmask (server only) - // TODO: Throw if the client detects an incoming masked frame. - _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]); - _closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveOffset + 2, _receiveCount - 2) ?? string.Empty; + + // Status code and message are optional + if (_frameBytesRemaining >= 2) + { + ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveOffset + 2, (int)_frameBytesRemaining - 2); + if (_unmaskInput) + { + // In place + Utilities.Mask(_frameInProgress.MaskKey, dataSegment); + } + _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]); + _closeStatusDescription = Encoding.UTF8.GetString(dataSegment.Array, dataSegment.Offset, dataSegment.Count) ?? string.Empty; + } + else + { + _closeStatus = _closeStatus ?? WebSocketCloseStatus.NormalClosure; + _closeStatusDescription = _closeStatusDescription ?? string.Empty; + } result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription); if (State == WebSocketState.Open) @@ -192,49 +220,44 @@ namespace Microsoft.Net.WebSockets return result; } - // Make sure there's at least some data in the buffer - if (_frameBytesRemaining > 0) - { - await EnsureDataAvailableOrReadAsync(1, cancellationToken); - } - - // Copy buffered data to the users buffer - int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining); - if (_receiveCount > 0) - { - // TODO: Unmask - int bytesToCopy = Math.Min(bytesToRead, _receiveCount); - Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy); - if (bytesToCopy == _frameBytesRemaining) - { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), _frameInProgress.Fin); - _frameInProgress = null; - } - else - { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), false); - } - _frameBytesRemaining -= bytesToCopy; - _receiveCount -= bytesToCopy; - _receiveOffset += bytesToCopy; - } - else + if (_frameBytesRemaining == 0) { // End of an empty frame? result = new WebSocketReceiveResult(0, GetMessageType(_frameInProgress.OpCode), true); + _frameInProgress = null; + return result; } + // Make sure there's at least some data in the buffer + await EnsureDataAvailableOrReadAsync(1, cancellationToken); + // Copy buffered data to the users buffer + int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining); + int bytesToCopy = Math.Min(bytesToRead, _receiveCount); + Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy); + if (_unmaskInput) + { + // TODO: mask alignment may be off between reads. + Utilities.Mask(_frameInProgress.MaskKey, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); + } + if (bytesToCopy == _frameBytesRemaining) + { + result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), _frameInProgress.Fin); + _frameInProgress = null; + } + else + { + result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), false); + } + _frameBytesRemaining -= bytesToCopy; + _receiveCount -= bytesToCopy; + _receiveOffset += bytesToCopy; + return result; } // We received a ping, send a pong in reply private async Task SendPongReply(CancellationToken cancellationToken) { - // TODO: Unmask data - if (_unmaskInput != _frameInProgress.Masked) - { - throw new InvalidOperationException("Unmasking settings out of sync with data."); - } ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveOffset, (int)_frameBytesRemaining); if (_unmaskInput) { @@ -359,8 +382,13 @@ namespace Microsoft.Net.WebSockets fullData[1] = (byte)closeStatus; Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length); - // TODO: Masking - FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, true, 0, fullData.Length); + int mask = GetNextMask(); + if (_maskOutput) + { + Utilities.Mask(mask, new ArraySegment(fullData)); + } + + FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, _maskOutput, mask, fullData.Length); ArraySegment segment = frameHeader.Buffer; await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken); diff --git a/src/Microsoft.Net.WebSockets/WebSocketClient.cs b/src/Microsoft.Net.WebSockets/WebSocketClient.cs index 1a32277d8f..a78b2f5e17 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketClient.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketClient.cs @@ -49,7 +49,7 @@ namespace Microsoft.Net.WebSockets.Client Stream stream = response.GetResponseStream(); // Console.WriteLine(stream.CanWrite + " " + stream.CanRead); - return new CommonWebSocket(stream, null, ReceiveBufferSize); + return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: false); } } }