Handle unmasking offset data.

This commit is contained in:
Chris Ross 2014-03-08 23:01:53 -08:00
parent 678af7c22f
commit 65532849f6
6 changed files with 584 additions and 4 deletions

View File

@ -36,6 +36,7 @@ namespace Microsoft.Net.WebSockets
private FrameHeader _frameInProgress;
private long _frameBytesRemaining;
private int? _firstDataOpCode;
private int _dataUnmaskOffset;
public CommonWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput)
{
@ -242,9 +243,8 @@ namespace Microsoft.Net.WebSockets
if (_unmaskInput)
{
// TODO: mask alignment may be off between reads.
// _frameInProgress.Masked == _unmaskInput already verified
Utilities.MaskInPlace(_frameInProgress.MaskKey, new ArraySegment<byte>(buffer.Array, buffer.Offset, bytesToCopy));
Utilities.MaskInPlace(_frameInProgress.MaskKey, ref _dataUnmaskOffset, new ArraySegment<byte>(buffer.Array, buffer.Offset, bytesToCopy));
}
WebSocketReceiveResult result;
@ -257,6 +257,7 @@ namespace Microsoft.Net.WebSockets
_firstDataOpCode = null;
}
_frameInProgress = null;
_dataUnmaskOffset = 0;
}
else
{

View File

@ -16,8 +16,13 @@ namespace Microsoft.Net.WebSockets
return frame;
}
// Un/Masks the data in place
public static void MaskInPlace(int mask, ArraySegment<byte> data)
{
int maskOffset = 0;
MaskInPlace(mask, ref maskOffset, data);
}
public static void MaskInPlace(int mask, ref int maskOffset, ArraySegment<byte> data)
{
if (mask == 0)
{
@ -31,7 +36,6 @@ namespace Microsoft.Net.WebSockets
(byte)(mask >> 8),
(byte)mask,
};
int maskOffset = 0;
int end = data.Offset + data.Count;
for (int i = data.Offset; i < end; i++)

View File

@ -0,0 +1,337 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Contracts;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Net.WebSockets.Test
{
// This steam accepts writes from one side, buffers them internally, and returns the data via Reads
// when requested on the other side.
public class BufferStream : Stream
{
private bool _disposed;
private bool _aborted;
private Exception _abortException;
private ConcurrentQueue<byte[]> _bufferedData;
private ArraySegment<byte> _topBuffer;
private SemaphoreSlim _readLock;
private SemaphoreSlim _writeLock;
private TaskCompletionSource<object> _readWaitingForData;
internal BufferStream()
{
_readLock = new SemaphoreSlim(1, 1);
_writeLock = new SemaphoreSlim(1, 1);
_bufferedData = new ConcurrentQueue<byte[]>();
_readWaitingForData = new TaskCompletionSource<object>();
}
public override bool CanRead
{
get { return true; }
}
public override bool CanSeek
{
get { return false; }
}
public override bool CanWrite
{
get { return true; }
}
#region NotSupported
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
#endregion NotSupported
public override void Flush()
{
CheckDisposed();
// TODO: Wait for data to drain?
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
if (cancellationToken.IsCancellationRequested)
{
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
tcs.TrySetCanceled();
return tcs.Task;
}
Flush();
// TODO: Wait for data to drain?
return Task.FromResult(0);
}
public override int Read(byte[] buffer, int offset, int count)
{
VerifyBuffer(buffer, offset, count, allowEmpty: false);
_readLock.Wait();
try
{
int totalRead = 0;
do
{
// Don't drain buffered data when signaling an abort.
CheckAborted();
if (_topBuffer.Count <= 0)
{
byte[] topBuffer = null;
while (!_bufferedData.TryDequeue(out topBuffer))
{
if (_disposed)
{
CheckAborted();
// Graceful close
return totalRead;
}
WaitForDataAsync().Wait();
}
_topBuffer = new ArraySegment<byte>(topBuffer);
}
int actualCount = Math.Min(count, _topBuffer.Count);
Buffer.BlockCopy(_topBuffer.Array, _topBuffer.Offset, buffer, offset, actualCount);
_topBuffer = new ArraySegment<byte>(_topBuffer.Array,
_topBuffer.Offset + actualCount,
_topBuffer.Count - actualCount);
totalRead += actualCount;
offset += actualCount;
count -= actualCount;
}
while (count > 0 && (_topBuffer.Count > 0 || _bufferedData.Count > 0));
// Keep reading while there is more data available and we have more space to put it in.
return totalRead;
}
finally
{
_readLock.Release();
}
}
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
// TODO: This option doesn't preserve the state object.
// return ReadAsync(buffer, offset, count);
return base.BeginRead(buffer, offset, count, callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
// return ((Task<int>)asyncResult).Result;
return base.EndRead(asyncResult);
}
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
VerifyBuffer(buffer, offset, count, allowEmpty: false);
CancellationTokenRegistration registration = cancellationToken.Register(Abort);
await _readLock.WaitAsync(cancellationToken);
try
{
int totalRead = 0;
do
{
// Don't drained buffered data on abort.
CheckAborted();
if (_topBuffer.Count <= 0)
{
byte[] topBuffer = null;
while (!_bufferedData.TryDequeue(out topBuffer))
{
if (_disposed)
{
CheckAborted();
// Graceful close
return totalRead;
}
await WaitForDataAsync();
}
_topBuffer = new ArraySegment<byte>(topBuffer);
}
int actualCount = Math.Min(count, _topBuffer.Count);
Buffer.BlockCopy(_topBuffer.Array, _topBuffer.Offset, buffer, offset, actualCount);
_topBuffer = new ArraySegment<byte>(_topBuffer.Array,
_topBuffer.Offset + actualCount,
_topBuffer.Count - actualCount);
totalRead += actualCount;
offset += actualCount;
count -= actualCount;
}
while (count > 0 && (_topBuffer.Count > 0 || _bufferedData.Count > 0));
// Keep reading while there is more data available and we have more space to put it in.
return totalRead;
}
finally
{
registration.Dispose();
_readLock.Release();
}
}
// Write with count 0 will still trigger OnFirstWrite
public override void Write(byte[] buffer, int offset, int count)
{
VerifyBuffer(buffer, offset, count, allowEmpty: true);
CheckDisposed();
_writeLock.Wait();
try
{
if (count == 0)
{
return;
}
// Copies are necessary because we don't know what the caller is going to do with the buffer afterwards.
byte[] internalBuffer = new byte[count];
Buffer.BlockCopy(buffer, offset, internalBuffer, 0, count);
_bufferedData.Enqueue(internalBuffer);
SignalDataAvailable();
}
finally
{
_writeLock.Release();
}
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
Write(buffer, offset, count);
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>(state);
tcs.TrySetResult(null);
IAsyncResult result = tcs.Task;
if (callback != null)
{
callback(result);
}
return result;
}
public override void EndWrite(IAsyncResult asyncResult)
{
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
VerifyBuffer(buffer, offset, count, allowEmpty: true);
if (cancellationToken.IsCancellationRequested)
{
TaskCompletionSource<object> tcs = new TaskCompletionSource<object>();
tcs.TrySetCanceled();
return tcs.Task;
}
Write(buffer, offset, count);
return Task.FromResult<object>(null);
}
private static void VerifyBuffer(byte[] buffer, int offset, int count, bool allowEmpty)
{
if (buffer == null)
{
throw new ArgumentNullException("buffer");
}
if (offset < 0 || offset > buffer.Length)
{
throw new ArgumentOutOfRangeException("offset", offset, string.Empty);
}
if (count < 0 || count > buffer.Length - offset
|| (!allowEmpty && count == 0))
{
throw new ArgumentOutOfRangeException("count", count, string.Empty);
}
}
private void SignalDataAvailable()
{
// Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write.
Task.Factory.StartNew(() => _readWaitingForData.TrySetResult(null));
}
private Task WaitForDataAsync()
{
_readWaitingForData = new TaskCompletionSource<object>();
if (!_bufferedData.IsEmpty || _disposed)
{
// Race, data could have arrived before we created the TCS.
_readWaitingForData.TrySetResult(null);
}
return _readWaitingForData.Task;
}
internal void Abort()
{
Abort(new OperationCanceledException());
}
internal void Abort(Exception innerException)
{
Contract.Requires(innerException != null);
_aborted = true;
_abortException = innerException;
Dispose();
}
private void CheckAborted()
{
if (_aborted)
{
throw new IOException(string.Empty, _abortException);
}
}
[SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_writeLock", Justification = "ODEs from the locks would mask IOEs from abort.")]
[SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_readLock", Justification = "Data can still be read unless we get aborted.")]
protected override void Dispose(bool disposing)
{
if (disposing)
{
// Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads.
_disposed = true;
_readWaitingForData.TrySetResult(null);
}
base.Dispose(disposing);
}
private void CheckDisposed()
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}
}
}
}

View File

@ -0,0 +1,172 @@
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.Net.WebSockets.Test
{
// A duplex wrapper around a read and write stream.
public class DuplexStream : Stream
{
private readonly Stream _readStream;
private readonly Stream _writeStream;
public DuplexStream()
: this (new BufferStream(), new BufferStream())
{
}
public DuplexStream(Stream readStream, Stream writeStream)
{
_readStream = readStream;
_writeStream = writeStream;
}
public DuplexStream CreateReverseDuplexStream()
{
return new DuplexStream(_writeStream, _readStream);
}
#region Properties
public override bool CanRead
{
get { return _readStream.CanRead; }
}
public override bool CanSeek
{
get { return false; }
}
public override bool CanTimeout
{
get { return _readStream.CanTimeout || _writeStream.CanTimeout; }
}
public override bool CanWrite
{
get { return _writeStream.CanWrite; }
}
public override long Length
{
get { throw new NotSupportedException(); }
}
public override long Position
{
get { throw new NotSupportedException(); }
set { throw new NotSupportedException(); }
}
public override int ReadTimeout
{
get { return _readStream.ReadTimeout; }
set { _readStream.ReadTimeout = value; }
}
public override int WriteTimeout
{
get { return _writeStream.WriteTimeout; }
set { _writeStream.WriteTimeout = value; }
}
#endregion Properties
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
throw new NotSupportedException();
}
#region Read
public override int Read(byte[] buffer, int offset, int count)
{
return _readStream.Read(buffer, offset, count);
}
public override int ReadByte()
{
return _readStream.ReadByte();
}
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _readStream.BeginRead(buffer, offset, count, callback, state);
}
public override int EndRead(IAsyncResult asyncResult)
{
return _readStream.EndRead(asyncResult);
}
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _readStream.ReadAsync(buffer, offset, count, cancellationToken);
}
public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
return _readStream.CopyToAsync(destination, bufferSize, cancellationToken);
}
#endregion Read
#region Write
public override void Write(byte[] buffer, int offset, int count)
{
_writeStream.Write(buffer, offset, count);
}
public override void WriteByte(byte value)
{
_writeStream.WriteByte(value);
}
public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
{
return _writeStream.BeginWrite(buffer, offset, count, callback, state);
}
public override void EndWrite(IAsyncResult asyncResult)
{
_writeStream.EndWrite(asyncResult);
}
public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
return _writeStream.WriteAsync(buffer, offset, count, cancellationToken);
}
public override void Flush()
{
_writeStream.Flush();
}
public override Task FlushAsync(CancellationToken cancellationToken)
{
return _writeStream.FlushAsync(cancellationToken);
}
#endregion Write
protected override void Dispose(bool disposing)
{
if (disposing)
{
_readStream.Dispose();
_writeStream.Dispose();
}
base.Dispose(disposing);
}
}
}

View File

@ -0,0 +1,63 @@
using System;
using System.Net.WebSockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
namespace Microsoft.Net.WebSockets.Test
{
public class DuplexTests
{
[Fact]
public async Task SendAndReceive()
{
DuplexStream serverStream = new DuplexStream();
DuplexStream clientStream = serverStream.CreateReverseDuplexStream();
WebSocket serverWebSocket = CommonWebSocket.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024);
WebSocket clientWebSocket = CommonWebSocket.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024, false);
byte[] clientBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz");
byte[] serverBuffer = new byte[clientBuffer.Length];
await clientWebSocket.SendAsync(new ArraySegment<byte>(clientBuffer), WebSocketMessageType.Text, true, CancellationToken.None);
WebSocketReceiveResult serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment<byte>(serverBuffer), CancellationToken.None);
Assert.True(serverResult.EndOfMessage);
Assert.Equal(clientBuffer.Length, serverResult.Count);
Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType);
Assert.Equal(clientBuffer, serverBuffer);
}
[Fact]
// Tests server unmasking with offset masks
public async Task ServerReceiveOffsetData()
{
DuplexStream serverStream = new DuplexStream();
DuplexStream clientStream = serverStream.CreateReverseDuplexStream();
WebSocket serverWebSocket = CommonWebSocket.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024);
WebSocket clientWebSocket = CommonWebSocket.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024, false);
byte[] clientBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz");
byte[] serverBuffer = new byte[clientBuffer.Length];
await clientWebSocket.SendAsync(new ArraySegment<byte>(clientBuffer), WebSocketMessageType.Text, true, CancellationToken.None);
WebSocketReceiveResult serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment<byte>(serverBuffer, 0, 3), CancellationToken.None);
Assert.False(serverResult.EndOfMessage);
Assert.Equal(3, serverResult.Count);
Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType);
serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment<byte>(serverBuffer, 3, 10), CancellationToken.None);
Assert.False(serverResult.EndOfMessage);
Assert.Equal(10, serverResult.Count);
Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType);
serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment<byte>(serverBuffer, 13, 13), CancellationToken.None);
Assert.True(serverResult.EndOfMessage);
Assert.Equal(13, serverResult.Count);
Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType);
Assert.Equal(clientBuffer, serverBuffer);
}
}
}

View File

@ -45,6 +45,9 @@
</Reference>
</ItemGroup>
<ItemGroup>
<Compile Include="DuplexTests.cs" />
<Compile Include="BufferStream.cs" />
<Compile Include="DuplexStream.cs" />
<Compile Include="UtilitiesTests.cs" />
<Compile Include="WebSocketClientTests.cs" />
<Compile Include="Properties\AssemblyInfo.cs" />