From 2f770ca4d40421d001b497c05367f1dc0e410385 Mon Sep 17 00:00:00 2001 From: Chris Ross Date: Thu, 6 Mar 2014 10:29:54 -0800 Subject: [PATCH] Handle incoming pings and pongs. --- .../CommonWebSocket.cs | 92 ++++++++++++++++--- src/Microsoft.Net.WebSockets/Constants.cs | 2 +- src/Microsoft.Net.WebSockets/Utilities.cs | 5 + 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index 6e3f19d41e..cae096cc94 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -17,7 +17,10 @@ namespace Microsoft.Net.WebSockets private readonly Stream _stream; private readonly string _subProtocl; private readonly bool _maskOutput; + private readonly bool _unmaskInput; private readonly bool _useZeroMask; + private readonly SemaphoreSlim _writeLock; + private WebSocketState _state; private WebSocketCloseStatus? _closeStatus; @@ -36,8 +39,10 @@ namespace Microsoft.Net.WebSockets _subProtocl = subProtocol; _state = WebSocketState.Open; _receiveBuffer = new byte[receiveBufferSize]; - _maskOutput = true; // TODO: make optional for client. Add option to block unmasking from server. + _maskOutput = true; // TODO: client only. _useZeroMask = false; // TODO: make optional + _unmaskInput = false; // TODO: server only + _writeLock = new SemaphoreSlim(1); } public override WebSocketCloseStatus? CloseStatus @@ -90,18 +95,28 @@ namespace Microsoft.Net.WebSockets // TODO: Check ping/pong state // TODO: Masking // TODO: Block close frame? - int mask = GetNextMask(); - FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), _maskOutput, mask, buffer.Count); - ArraySegment segment = frameHeader.Buffer; - if (_maskOutput && mask != 0) + + await _writeLock.WaitAsync(cancellationToken); + + try { - byte[] maskedFrame = Utilities.MergeAndMask(mask, segment, buffer); - await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken); + int mask = GetNextMask(); + FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), _maskOutput, mask, buffer.Count); + ArraySegment segment = frameHeader.Buffer; + if (_maskOutput && mask != 0) + { + byte[] maskedFrame = Utilities.MergeAndMask(mask, segment, buffer); + await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken); + } + else + { + await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + } } - else + finally { - await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); + _writeLock.Release(); } } @@ -124,7 +139,7 @@ namespace Microsoft.Net.WebSockets // TODO: Check ping/pong state // No active frame - if (_frameInProgress == null) + while (_frameInProgress == null) { await EnsureDataAvailableOrReadAsync(2, cancellationToken); int frameHeaderSize = FrameHeader.CalculateFrameHeaderSize(_receiveBuffer[_receiveOffset + 1]); @@ -133,12 +148,27 @@ namespace Microsoft.Net.WebSockets _receiveOffset += frameHeaderSize; _receiveCount -= frameHeaderSize; _frameBytesRemaining = _frameInProgress.DataLength; + + // Ping or Pong frames + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) + { + // Drain it, should be less than 125 bytes + await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); + + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame && State == WebSocketState.Open) + { + await SendPongReply(cancellationToken); + } + + _receiveOffset += (int)_frameBytesRemaining; + _receiveCount -= (int)_frameBytesRemaining; + _frameBytesRemaining = 0; + _frameInProgress = null; + } } WebSocketReceiveResult result; - // TODO: Ping or Pong frames - if (_frameInProgress.OpCode == Constants.OpCodes.CloseFrame) { // TOOD: This assumes the close message fits in the buffer. @@ -197,6 +227,42 @@ namespace Microsoft.Net.WebSockets return result; } + // We received a ping, send a pong in reply + private async Task SendPongReply(CancellationToken cancellationToken) + { + // TODO: Unmask data + if (_unmaskInput != _frameInProgress.Masked) + { + throw new InvalidOperationException("Unmasking settings out of sync with data."); + } + ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveOffset, (int)_frameBytesRemaining); + if (_unmaskInput) + { + // In place + Utilities.Mask(_frameInProgress.MaskKey, dataSegment); + } + + int mask = GetNextMask(); + FrameHeader header = new FrameHeader(true, Constants.OpCodes.PongFrame, _maskOutput, mask, _frameBytesRemaining); + if (_maskOutput) + { + // In place + Utilities.Mask(_frameInProgress.MaskKey, dataSegment); + } + + await _writeLock.WaitAsync(cancellationToken); + try + { + ArraySegment headerSegment = header.Buffer; + await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count, cancellationToken); + await _stream.WriteAsync(dataSegment.Array, dataSegment.Offset, dataSegment.Count, cancellationToken); + } + finally + { + _writeLock.Release(); + } + } + private async Task EnsureDataAvailableOrReadAsync(int bytes, CancellationToken cancellationToken) { // Insufficient data diff --git a/src/Microsoft.Net.WebSockets/Constants.cs b/src/Microsoft.Net.WebSockets/Constants.cs index c7ebabaab7..740ea9ab41 100644 --- a/src/Microsoft.Net.WebSockets/Constants.cs +++ b/src/Microsoft.Net.WebSockets/Constants.cs @@ -19,7 +19,7 @@ namespace Microsoft.Net.WebSockets public const int TextFrame = 0x1; public const int BinaryFrame = 0x2; public const int CloseFrame = 0x8; - public const int PingFrame = 0x8; + public const int PingFrame = 0x9; public const int PongFrame = 0xA; } } diff --git a/src/Microsoft.Net.WebSockets/Utilities.cs b/src/Microsoft.Net.WebSockets/Utilities.cs index 2293dbb328..6dbc83630a 100644 --- a/src/Microsoft.Net.WebSockets/Utilities.cs +++ b/src/Microsoft.Net.WebSockets/Utilities.cs @@ -22,6 +22,11 @@ namespace Microsoft.Net.WebSockets // Un/Masks the data in place public static void Mask(int mask, ArraySegment data) { + if (mask == 0) + { + return; + } + byte[] maskBytes = new byte[] { (byte)(mask >> 24),