// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; using System.Diagnostics.Contracts; using System.IO; using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; using System.Threading.Tasks; namespace Microsoft.AspNet.WebSockets.Protocol { // https://tools.ietf.org/html/rfc6455 public class CommonWebSocket : WebSocket { private readonly static Random Random = new Random(); private readonly static byte[] PingBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz"); private readonly Stream _stream; private readonly string _subProtocl; private readonly bool _maskOutput; private readonly bool _unmaskInput; private readonly bool _useZeroMask; private readonly SemaphoreSlim _writeLock; private readonly Timer _keepAliveTimer; private WebSocketState _state; private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private bool _isOutgoingMessageInProgress; private byte[] _receiveBuffer; private int _receiveBufferOffset; private int _receiveBufferBytes; private FrameHeader _frameInProgress; private long _frameBytesRemaining; private int? _firstDataOpCode; private int _dataUnmaskOffset; private Utilities.Utf8MessageState _incomingUtf8MessageState = new Utilities.Utf8MessageState(); public CommonWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) { _stream = stream; _subProtocl = subProtocol; _state = WebSocketState.Open; _receiveBuffer = new byte[receiveBufferSize]; _maskOutput = maskOutput; _useZeroMask = useZeroMask; _unmaskInput = unmaskInput; _writeLock = new SemaphoreSlim(1); if (keepAliveInterval != Timeout.InfiniteTimeSpan) { _keepAliveTimer = new Timer(SendKeepAlive, this, keepAliveInterval, keepAliveInterval); } } public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool useZeroMask) { return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false); } public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize) { return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true); } public override WebSocketCloseStatus? CloseStatus { get { return _closeStatus; } } public override string CloseStatusDescription { get { return _closeStatusDescription; } } public override WebSocketState State { get { return _state; } } public override string SubProtocol { get { return _subProtocl; } } // https://tools.ietf.org/html/rfc6455#section-5.3 // The masking key is a 32-bit value chosen at random by the client. // When preparing a masked frame, the client MUST pick a fresh masking // key from the set of allowed 32-bit values. The masking key needs to // be unpredictable; thus, the masking key MUST be derived from a strong // source of entropy, and the masking key for a given frame MUST NOT // make it simple for a server/proxy to predict the masking key for a // subsequent frame. The unpredictability of the masking key is // essential to prevent authors of malicious applications from selecting // the bytes that appear on the wire. RFC 4086 [RFC4086] discusses what // entails a suitable source of entropy for security-sensitive // applications. private int GetNextMask() { if (_useZeroMask) { return 0; } // TODO: Doesn't include negative numbers so it's only 31 bits, not 32. return Random.Next(); } public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { ValidateSegment(buffer); if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text) { // Block control frames throw new ArgumentOutOfRangeException("messageType", messageType, string.Empty); } // Check concurrent writes, pings & pongs, or closes await _writeLock.WaitAsync(cancellationToken); try { ThrowIfDisposed(); ThrowIfOutputClosed(); int mask = GetNextMask(); int opcode = _isOutgoingMessageInProgress ? Constants.OpCodes.ContinuationFrame : Utilities.GetOpCode(messageType); FrameHeader frameHeader = new FrameHeader(endOfMessage, opcode, _maskOutput, mask, buffer.Count); ArraySegment headerSegment = frameHeader.Buffer; if (_maskOutput && mask != 0) { // TODO: For larger messages consider using a limited size buffer and masking & sending in segments. byte[] maskedFrame = Utilities.MergeAndMask(mask, headerSegment, buffer); await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken); } else { await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count, cancellationToken); await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); } _isOutgoingMessageInProgress = !endOfMessage; } finally { _writeLock.Release(); } } private static void SendKeepAlive(object state) { CommonWebSocket websocket = (CommonWebSocket)state; websocket.SendKeepAliveAsync(); } private async void SendKeepAliveAsync() { // Check concurrent writes, pings & pongs, or closes if (!_writeLock.Wait(0)) { // Sending real data is better than a ping, discard it. return; } try { if (State == WebSocketState.CloseSent || State >= WebSocketState.Closed) { _keepAliveTimer.Dispose(); return; } int mask = GetNextMask(); FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.PingFrame, _maskOutput, mask, PingBuffer.Length); ArraySegment headerSegment = frameHeader.Buffer; // TODO: CancelationToken / timeout? if (_maskOutput && mask != 0) { byte[] maskedFrame = Utilities.MergeAndMask(mask, headerSegment, new ArraySegment(PingBuffer)); await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length); } else { await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count); await _stream.WriteAsync(PingBuffer, 0, PingBuffer.Length); } } catch (Exception) { // TODO: Log exception, this is a background thread. // Shut down, we must be in a faulted state; Abort(); } finally { _writeLock.Release(); } } public async override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { ThrowIfDisposed(); ThrowIfInputClosed(); ValidateSegment(buffer); // TODO: InvalidOperationException if any receives are currently in progress. // No active frame. Loop because we may be discarding ping/pong frames. while (_frameInProgress == null) { await ReadNextFrameAsync(cancellationToken); } int opCode = _frameInProgress.OpCode; if (opCode == Constants.OpCodes.CloseFrame) { return await ProcessCloseFrameAsync(cancellationToken); } // Handle fragmentation, remember the first frame type if (opCode == Constants.OpCodes.ContinuationFrame) { if (!_firstDataOpCode.HasValue) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid continuation frame", cancellationToken); } opCode = _firstDataOpCode.Value; } else { _firstDataOpCode = opCode; } // Make sure there's at least some data in the buffer int bytesToBuffer = (int)Math.Min((long)_receiveBuffer.Length, _frameBytesRemaining); await EnsureDataAvailableOrReadAsync(bytesToBuffer, cancellationToken); // Copy buffered data to the users buffer int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining); int bytesToCopy = Math.Min(bytesToRead, _receiveBufferBytes); Array.Copy(_receiveBuffer, _receiveBufferOffset, buffer.Array, buffer.Offset, bytesToCopy); if (_unmaskInput) { // _frameInProgress.Masked == _unmaskInput already verified Utilities.MaskInPlace(_frameInProgress.MaskKey, ref _dataUnmaskOffset, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); } WebSocketReceiveResult result; WebSocketMessageType messageType = Utilities.GetMessageType(opCode); if (messageType == WebSocketMessageType.Text && !Utilities.TryValidateUtf8(new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy), _frameInProgress.Fin, _incomingUtf8MessageState)) { await SendErrorAbortAndThrow(WebSocketCloseStatus.InvalidPayloadData, "Invalid UTF-8", cancellationToken); } if (bytesToCopy == _frameBytesRemaining) { result = new WebSocketReceiveResult(bytesToCopy, messageType, _frameInProgress.Fin); if (_frameInProgress.Fin) { _firstDataOpCode = null; } _frameInProgress = null; _dataUnmaskOffset = 0; } else { result = new WebSocketReceiveResult(bytesToCopy, messageType, false); } _frameBytesRemaining -= bytesToCopy; _receiveBufferBytes -= bytesToCopy; _receiveBufferOffset += bytesToCopy; return result; } private async Task ReadNextFrameAsync(CancellationToken cancellationToken) { await EnsureDataAvailableOrReadAsync(2, cancellationToken); int frameHeaderSize = FrameHeader.CalculateFrameHeaderSize(_receiveBuffer[_receiveBufferOffset + 1]); await EnsureDataAvailableOrReadAsync(frameHeaderSize, cancellationToken); _frameInProgress = new FrameHeader(new ArraySegment(_receiveBuffer, _receiveBufferOffset, frameHeaderSize)); _receiveBufferOffset += frameHeaderSize; _receiveBufferBytes -= frameHeaderSize; _frameBytesRemaining = _frameInProgress.DataLength; if (_frameInProgress.AreReservedSet()) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Unexpected reserved bits set", cancellationToken); } if (_unmaskInput != _frameInProgress.Masked) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Incorrect masking", cancellationToken); } if (!ValidateOpCode(_frameInProgress.OpCode)) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid opcode: " + _frameInProgress.OpCode, cancellationToken); } if (_frameInProgress.IsControlFrame) { if (_frameBytesRemaining > 125) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid control frame size", cancellationToken); } if (!_frameInProgress.Fin) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Fragmented control frame", cancellationToken); } if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) { // Drain it, should be less than 125 bytes await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame) { await SendPongReplyAsync(cancellationToken); } _receiveBufferOffset += (int)_frameBytesRemaining; _receiveBufferBytes -= (int)_frameBytesRemaining; _frameBytesRemaining = 0; _frameInProgress = null; } } else if (_firstDataOpCode.HasValue && _frameInProgress.OpCode != Constants.OpCodes.ContinuationFrame) { // A data frame is already in progress, but this new frame is not a continuation frame. await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Expected a continuation frame: " + _frameInProgress.OpCode, cancellationToken); } } private async Task EnsureDataAvailableOrReadAsync(int bytesNeeded, CancellationToken cancellationToken) { // Adequate buffer space? Contract.Assert(bytesNeeded <= _receiveBuffer.Length); // Insufficient buffered data while (_receiveBufferBytes < bytesNeeded) { cancellationToken.ThrowIfCancellationRequested(); int spaceRemaining = _receiveBuffer.Length - (_receiveBufferOffset + _receiveBufferBytes); if (_receiveBufferOffset > 0 && bytesNeeded > spaceRemaining) { // Some data in the buffer, shift down to make room Array.Copy(_receiveBuffer, _receiveBufferOffset, _receiveBuffer, 0, _receiveBufferBytes); _receiveBufferOffset = 0; spaceRemaining = _receiveBuffer.Length - _receiveBufferBytes; } // Add to the end int read = await _stream.ReadAsync(_receiveBuffer, _receiveBufferOffset + _receiveBufferBytes, spaceRemaining, cancellationToken); if (read == 0) { throw new IOException("Unexpected end of stream"); } _receiveBufferBytes += read; } } // We received a ping, send a pong in reply private async Task SendPongReplyAsync(CancellationToken cancellationToken) { await _writeLock.WaitAsync(cancellationToken); try { if (State != WebSocketState.Open) { // Output closed, discard the pong. return; } ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveBufferOffset, (int)_frameBytesRemaining); if (_unmaskInput) { // _frameInProgress.Masked == _unmaskInput already verified Utilities.MaskInPlace(_frameInProgress.MaskKey, dataSegment); } int mask = GetNextMask(); FrameHeader header = new FrameHeader(true, Constants.OpCodes.PongFrame, _maskOutput, mask, _frameBytesRemaining); if (_maskOutput) { Utilities.MaskInPlace(mask, dataSegment); } ArraySegment headerSegment = header.Buffer; await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count, cancellationToken); await _stream.WriteAsync(dataSegment.Array, dataSegment.Offset, dataSegment.Count, cancellationToken); } finally { _writeLock.Release(); } } private async Task ProcessCloseFrameAsync(CancellationToken cancellationToken) { // The close message should be less than 125 bytes and fit in the buffer. await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); // Status code and message are optional if (_frameBytesRemaining >= 2) { if (_unmaskInput) { Utilities.MaskInPlace(_frameInProgress.MaskKey, new ArraySegment(_receiveBuffer, _receiveBufferOffset, (int)_frameBytesRemaining)); } _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveBufferOffset] << 8) | _receiveBuffer[_receiveBufferOffset + 1]); if (!ValidateCloseStatus(_closeStatus.Value)) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid close status code.", cancellationToken); } try { var encoding = new UTF8Encoding(encoderShouldEmitUTF8Identifier: false, throwOnInvalidBytes: true); _closeStatusDescription = encoding.GetString(_receiveBuffer, _receiveBufferOffset + 2, (int)_frameBytesRemaining - 2) ?? string.Empty; } catch (DecoderFallbackException) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid UTF-8 close message.", cancellationToken); } } else if (_frameBytesRemaining == 1) { await SendErrorAbortAndThrow(WebSocketCloseStatus.ProtocolError, "Invalid close body.", cancellationToken); } else { _closeStatus = _closeStatus ?? WebSocketCloseStatus.NormalClosure; _closeStatusDescription = _closeStatusDescription ?? string.Empty; } Contract.Assert(_frameInProgress.Fin); WebSocketReceiveResult result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, _frameInProgress.Fin, _closeStatus.Value, _closeStatusDescription); if (State == WebSocketState.Open) { _state = WebSocketState.CloseReceived; } else if (State == WebSocketState.CloseSent) { _state = WebSocketState.Closed; _stream.Dispose(); } return result; } public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { ThrowIfDisposed(); if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) { // Send a close message. await CloseOutputAsync(closeStatus, statusDescription, cancellationToken); } if (State == WebSocketState.CloseSent) { // Do a receiving drain byte[] data = new byte[_receiveBuffer.Length]; WebSocketReceiveResult result; do { result = await ReceiveAsync(new ArraySegment(data), cancellationToken); } while (result.MessageType != WebSocketMessageType.Close); } } public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { await _writeLock.WaitAsync(cancellationToken); try { ThrowIfDisposed(); ThrowIfOutputClosed(); if (_keepAliveTimer != null) { _keepAliveTimer.Dispose(); } byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty); byte[] fullData = new byte[descriptionBytes.Length + 2]; fullData[0] = (byte)((int)closeStatus >> 8); fullData[1] = (byte)closeStatus; Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length); int mask = GetNextMask(); if (_maskOutput) { Utilities.MaskInPlace(mask, new ArraySegment(fullData)); } FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, _maskOutput, mask, fullData.Length); ArraySegment segment = frameHeader.Buffer; await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken); if (State == WebSocketState.Open) { _state = WebSocketState.CloseSent; } else if (State == WebSocketState.CloseReceived) { _state = WebSocketState.Closed; _stream.Dispose(); } } finally { _writeLock.Release(); } } public override void Abort() { if (_state >= WebSocketState.Closed) // or Aborted { return; } _state = WebSocketState.Aborted; if (_keepAliveTimer != null) { _keepAliveTimer.Dispose(); } _stream.Dispose(); } public override void Dispose() { if (_state >= WebSocketState.Closed) // or Aborted { return; } _state = WebSocketState.Closed; if (_keepAliveTimer != null) { _keepAliveTimer.Dispose(); } _stream.Dispose(); } private void ThrowIfDisposed() { if (_state >= WebSocketState.Closed) // or Aborted { throw new ObjectDisposedException(typeof(CommonWebSocket).FullName); } } private void ThrowIfOutputClosed() { if (State == WebSocketState.CloseSent) { throw new InvalidOperationException("Close already sent."); } } private void ThrowIfInputClosed() { if (State == WebSocketState.CloseReceived) { throw new InvalidOperationException("Close already received."); } } private void ValidateSegment(ArraySegment buffer) { if (buffer.Array == null) { throw new ArgumentNullException("buffer"); } if (buffer.Offset < 0 || buffer.Offset > buffer.Array.Length) { throw new ArgumentOutOfRangeException("buffer.Offset", buffer.Offset, string.Empty); } if (buffer.Count < 0 || buffer.Count > buffer.Array.Length - buffer.Offset) { throw new ArgumentOutOfRangeException("buffer.Count", buffer.Count, string.Empty); } } private bool ValidateOpCode(int opCode) { return Constants.OpCodes.ValidOpCodes.Contains(opCode); } private static bool ValidateCloseStatus(WebSocketCloseStatus closeStatus) { if (closeStatus < (WebSocketCloseStatus)1000 || closeStatus >= (WebSocketCloseStatus)5000) { return false; } else if (closeStatus >= (WebSocketCloseStatus)3000) { // 3000-3999 - Reserved for frameworks // 4000-4999 - Reserved for private usage return true; } int[] validCodes = new[] { 1000, 1001, 1002, 1003, 1007, 1008, 1009, 1010, 1011 }; foreach (var validCode in validCodes) { if (closeStatus == (WebSocketCloseStatus)validCode) { return true; } } return false; } private async Task SendErrorAbortAndThrow(WebSocketCloseStatus error, string message, CancellationToken cancellationToken) { if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) { await CloseOutputAsync(error, message, cancellationToken); } Abort(); throw new InvalidOperationException(message); // TODO: WebSocketException } } }