Implement KeepAliveInterval, send pings.

This commit is contained in:
Chris Ross 2014-03-08 15:14:54 -08:00
parent 30ca12933e
commit 678af7c22f
2 changed files with 84 additions and 7 deletions

View File

@ -12,6 +12,7 @@ namespace Microsoft.Net.WebSockets
public class CommonWebSocket : WebSocket
{
private readonly static Random Random = new Random();
private readonly static byte[] PingBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz");
private readonly Stream _stream;
private readonly string _subProtocl;
@ -19,6 +20,7 @@ namespace Microsoft.Net.WebSockets
private readonly bool _unmaskInput;
private readonly bool _useZeroMask;
private readonly SemaphoreSlim _writeLock;
private readonly Timer _keepAliveTimer;
private WebSocketState _state;
@ -35,7 +37,7 @@ namespace Microsoft.Net.WebSockets
private long _frameBytesRemaining;
private int? _firstDataOpCode;
public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput)
public CommonWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput)
{
_stream = stream;
_subProtocl = subProtocol;
@ -45,16 +47,20 @@ namespace Microsoft.Net.WebSockets
_useZeroMask = useZeroMask;
_unmaskInput = unmaskInput;
_writeLock = new SemaphoreSlim(1);
if (keepAliveInterval != Timeout.InfiniteTimeSpan)
{
_keepAliveTimer = new Timer(SendKeepAlive, this, keepAliveInterval, keepAliveInterval);
}
}
public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool useZeroMask)
public static CommonWebSocket CreateClientWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool useZeroMask)
{
return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false);
return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: true, useZeroMask: useZeroMask, unmaskInput: false);
}
public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, int receiveBufferSize)
public static CommonWebSocket CreateServerWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize)
{
return new CommonWebSocket(stream, subProtocol, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true);
return new CommonWebSocket(stream, subProtocol, keepAliveInterval, receiveBufferSize, maskOutput: false, useZeroMask: false, unmaskInput: true);
}
public override WebSocketCloseStatus? CloseStatus
@ -140,6 +146,58 @@ namespace Microsoft.Net.WebSockets
}
}
private static void SendKeepAlive(object state)
{
CommonWebSocket websocket = (CommonWebSocket)state;
websocket.SendKeepAliveAsync();
}
private async void SendKeepAliveAsync()
{
// Check concurrent writes, pings & pongs, or closes
bool lockAquired = await _writeLock.WaitAsync(TimeSpan.FromMinutes(1)); // TODO: Wait up to KeepAliveInterval?
if (!lockAquired)
{
// Pings aren't that important, discard them if we can't take the lock.
return;
}
try
{
if (State == WebSocketState.CloseSent || State >= WebSocketState.Closed)
{
_keepAliveTimer.Dispose();
return;
}
int mask = GetNextMask();
FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.PingFrame, _maskOutput, mask, PingBuffer.Length);
ArraySegment<byte> headerSegment = frameHeader.Buffer;
// TODO: CancelationToken / timeout?
if (_maskOutput && mask != 0)
{
byte[] maskedFrame = Utilities.MergeAndMask(mask, headerSegment, new ArraySegment<byte>(PingBuffer));
await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length);
}
else
{
await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count);
await _stream.WriteAsync(PingBuffer, 0, PingBuffer.Length);
}
}
catch (Exception)
{
// TODO: Log exception, this is a background thread.
// Shut down, we must be in a faulted state;
Abort();
}
finally
{
_writeLock.Release();
}
}
public async override Task<WebSocketReceiveResult> ReceiveAsync(ArraySegment<byte> buffer, CancellationToken cancellationToken)
{
ThrowIfDisposed();
@ -376,6 +434,10 @@ namespace Microsoft.Net.WebSockets
{
ThrowIfDisposed();
ThrowIfOutputClosed();
if (_keepAliveTimer != null)
{
_keepAliveTimer.Dispose();
}
byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty);
byte[] fullData = new byte[descriptionBytes.Length + 2];
@ -419,6 +481,10 @@ namespace Microsoft.Net.WebSockets
}
_state = WebSocketState.Aborted;
if (_keepAliveTimer != null)
{
_keepAliveTimer.Dispose();
}
_stream.Dispose();
}
@ -430,6 +496,10 @@ namespace Microsoft.Net.WebSockets
}
_state = WebSocketState.Closed;
if (_keepAliveTimer != null)
{
_keepAliveTimer.Dispose();
}
_stream.Dispose();
}

View File

@ -24,7 +24,14 @@ namespace Microsoft.Net.WebSockets.Client
public WebSocketClient()
{
ReceiveBufferSize = 1024 * 64;
ReceiveBufferSize = 1024 * 16;
KeepAliveInterval = TimeSpan.FromMinutes(2);
}
public TimeSpan KeepAliveInterval
{
get;
set;
}
public int ReceiveBufferSize
@ -85,7 +92,7 @@ namespace Microsoft.Net.WebSockets.Client
Stream stream = response.GetResponseStream();
return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: UseZeroMask);
return CommonWebSocket.CreateClientWebSocket(stream, null, KeepAliveInterval, ReceiveBufferSize, useZeroMask: UseZeroMask);
}
}
}