From 678af7c22fd21d2f35b57e138408c78c32c5766b Mon Sep 17 00:00:00 2001 From: Chris Ross Date: Sat, 8 Mar 2014 15:14:54 -0800 Subject: [PATCH] Implement KeepAliveInterval, send pings. --- .../CommonWebSocket.cs | 80 +++++++++++++++++-- .../WebSocketClient.cs | 11 ++- 2 files changed, 84 insertions(+), 7 deletions(-) diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index 979ae18606..e1795dbef2 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -12,6 +12,7 @@ namespace Microsoft.Net.WebSockets public class CommonWebSocket : WebSocket { private readonly static Random Random = new Random(); + private readonly static byte[] PingBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz"); private readonly Stream _stream; private readonly string _subProtocl; @@ -19,6 +20,7 @@ namespace Microsoft.Net.WebSockets private readonly bool _unmaskInput; private readonly bool _useZeroMask; private readonly SemaphoreSlim _writeLock; + private readonly Timer _keepAliveTimer; private WebSocketState _state; @@ -35,7 +37,7 @@ namespace Microsoft.Net.WebSockets private long _frameBytesRemaining; private int? _firstDataOpCode; - public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) + public CommonWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) { _stream = stream; _subProtocl = subProtocol; @@ -45,16 +47,20 @@ namespace Microsoft.Net.WebSockets _useZeroMask = useZeroMask; _unmaskInput = unmaskInput; _writeLock = new SemaphoreSlim(1); + if (keepAliveInterval != Timeout.InfiniteTimeSpan) + { + _keepAliveTimer = new Timer(SendKeepAlive, this, keepAliveInterval, keepAliveInterval); + } } - public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool useZeroMask) + public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool useZeroMask) { - return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false); + return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false); } - public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, int receiveBufferSize) + public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize) { - return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true); + return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true); } public override WebSocketCloseStatus? CloseStatus @@ -140,6 +146,58 @@ namespace Microsoft.Net.WebSockets } } + private static void SendKeepAlive(object state) + { + CommonWebSocket websocket = (CommonWebSocket)state; + websocket.SendKeepAliveAsync(); + } + + private async void SendKeepAliveAsync() + { + // Check concurrent writes, pings & pongs, or closes + bool lockAquired = await _writeLock.WaitAsync(TimeSpan.FromMinutes(1)); // TODO: Wait up to KeepAliveInterval? + if (!lockAquired) + { + // Pings aren't that important, discard them if we can't take the lock. + return; + } + try + { + if (State == WebSocketState.CloseSent || State >= WebSocketState.Closed) + { + _keepAliveTimer.Dispose(); + return; + } + + int mask = GetNextMask(); + FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.PingFrame, _maskOutput, mask, PingBuffer.Length); + ArraySegment headerSegment = frameHeader.Buffer; + + // TODO: CancelationToken / timeout? + if (_maskOutput && mask != 0) + { + byte[] maskedFrame = Utilities.MergeAndMask(mask, headerSegment, new ArraySegment(PingBuffer)); + await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length); + } + else + { + await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count); + await _stream.WriteAsync(PingBuffer, 0, PingBuffer.Length); + } + } + catch (Exception) + { + // TODO: Log exception, this is a background thread. + + // Shut down, we must be in a faulted state; + Abort(); + } + finally + { + _writeLock.Release(); + } + } + public async override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { ThrowIfDisposed(); @@ -376,6 +434,10 @@ namespace Microsoft.Net.WebSockets { ThrowIfDisposed(); ThrowIfOutputClosed(); + if (_keepAliveTimer != null) + { + _keepAliveTimer.Dispose(); + } byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty); byte[] fullData = new byte[descriptionBytes.Length + 2]; @@ -419,6 +481,10 @@ namespace Microsoft.Net.WebSockets } _state = WebSocketState.Aborted; + if (_keepAliveTimer != null) + { + _keepAliveTimer.Dispose(); + } _stream.Dispose(); } @@ -430,6 +496,10 @@ namespace Microsoft.Net.WebSockets } _state = WebSocketState.Closed; + if (_keepAliveTimer != null) + { + _keepAliveTimer.Dispose(); + } _stream.Dispose(); } diff --git a/src/Microsoft.Net.WebSockets/WebSocketClient.cs b/src/Microsoft.Net.WebSockets/WebSocketClient.cs index a5c5c666a4..60636a05f0 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketClient.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketClient.cs @@ -24,7 +24,14 @@ namespace Microsoft.Net.WebSockets.Client public WebSocketClient() { - ReceiveBufferSize = 1024 * 64; + ReceiveBufferSize = 1024 * 16; + KeepAliveInterval = TimeSpan.FromMinutes(2); + } + + public TimeSpan KeepAliveInterval + { + get; + set; } public int ReceiveBufferSize @@ -85,7 +92,7 @@ namespace Microsoft.Net.WebSockets.Client Stream stream = response.GetResponseStream(); - return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: UseZeroMask); + return CommonWebSocket.CreateClientWebSocket(stream, null, KeepAliveInterval, ReceiveBufferSize, useZeroMask: UseZeroMask); } } }