diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index e1795dbef2..2edd8f6e47 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -36,6 +36,7 @@ namespace Microsoft.Net.WebSockets private FrameHeader _frameInProgress; private long _frameBytesRemaining; private int? _firstDataOpCode; + private int _dataUnmaskOffset; public CommonWebSocket(Stream stream, string subProtocol, TimeSpan keepAliveInterval, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) { @@ -242,9 +243,8 @@ namespace Microsoft.Net.WebSockets if (_unmaskInput) { - // TODO: mask alignment may be off between reads. // _frameInProgress.Masked == _unmaskInput already verified - Utilities.MaskInPlace(_frameInProgress.MaskKey, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); + Utilities.MaskInPlace(_frameInProgress.MaskKey, ref _dataUnmaskOffset, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); } WebSocketReceiveResult result; @@ -257,6 +257,7 @@ namespace Microsoft.Net.WebSockets _firstDataOpCode = null; } _frameInProgress = null; + _dataUnmaskOffset = 0; } else { diff --git a/src/Microsoft.Net.WebSockets/Utilities.cs b/src/Microsoft.Net.WebSockets/Utilities.cs index a4fdcfda4b..c478b4140e 100644 --- a/src/Microsoft.Net.WebSockets/Utilities.cs +++ b/src/Microsoft.Net.WebSockets/Utilities.cs @@ -16,8 +16,13 @@ namespace Microsoft.Net.WebSockets return frame; } - // Un/Masks the data in place public static void MaskInPlace(int mask, ArraySegment data) + { + int maskOffset = 0; + MaskInPlace(mask, ref maskOffset, data); + } + + public static void MaskInPlace(int mask, ref int maskOffset, ArraySegment data) { if (mask == 0) { @@ -31,7 +36,6 @@ namespace Microsoft.Net.WebSockets (byte)(mask >> 8), (byte)mask, }; - int maskOffset = 0; int end = data.Offset + data.Count; for (int i = data.Offset; i < end; i++) diff --git a/test/Microsoft.Net.WebSockets.Test/BufferStream.cs b/test/Microsoft.Net.WebSockets.Test/BufferStream.cs new file mode 100644 index 0000000000..de2c308f14 --- /dev/null +++ b/test/Microsoft.Net.WebSockets.Test/BufferStream.cs @@ -0,0 +1,337 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +using System; +using System.Collections.Concurrent; +using System.Diagnostics.CodeAnalysis; +using System.Diagnostics.Contracts; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Net.WebSockets.Test +{ + // This steam accepts writes from one side, buffers them internally, and returns the data via Reads + // when requested on the other side. + public class BufferStream : Stream + { + private bool _disposed; + private bool _aborted; + private Exception _abortException; + private ConcurrentQueue _bufferedData; + private ArraySegment _topBuffer; + private SemaphoreSlim _readLock; + private SemaphoreSlim _writeLock; + private TaskCompletionSource _readWaitingForData; + + internal BufferStream() + { + _readLock = new SemaphoreSlim(1, 1); + _writeLock = new SemaphoreSlim(1, 1); + _bufferedData = new ConcurrentQueue(); + _readWaitingForData = new TaskCompletionSource(); + } + + public override bool CanRead + { + get { return true; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanWrite + { + get { return true; } + } + + #region NotSupported + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + #endregion NotSupported + + public override void Flush() + { + CheckDisposed(); + // TODO: Wait for data to drain? + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + if (cancellationToken.IsCancellationRequested) + { + TaskCompletionSource tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; + } + + Flush(); + + // TODO: Wait for data to drain? + + return Task.FromResult(0); + } + + public override int Read(byte[] buffer, int offset, int count) + { + VerifyBuffer(buffer, offset, count, allowEmpty: false); + _readLock.Wait(); + try + { + int totalRead = 0; + do + { + // Don't drain buffered data when signaling an abort. + CheckAborted(); + if (_topBuffer.Count <= 0) + { + byte[] topBuffer = null; + while (!_bufferedData.TryDequeue(out topBuffer)) + { + if (_disposed) + { + CheckAborted(); + // Graceful close + return totalRead; + } + WaitForDataAsync().Wait(); + } + _topBuffer = new ArraySegment(topBuffer); + } + int actualCount = Math.Min(count, _topBuffer.Count); + Buffer.BlockCopy(_topBuffer.Array, _topBuffer.Offset, buffer, offset, actualCount); + _topBuffer = new ArraySegment(_topBuffer.Array, + _topBuffer.Offset + actualCount, + _topBuffer.Count - actualCount); + totalRead += actualCount; + offset += actualCount; + count -= actualCount; + } + while (count > 0 && (_topBuffer.Count > 0 || _bufferedData.Count > 0)); + // Keep reading while there is more data available and we have more space to put it in. + return totalRead; + } + finally + { + _readLock.Release(); + } + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + // TODO: This option doesn't preserve the state object. + // return ReadAsync(buffer, offset, count); + return base.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + // return ((Task)asyncResult).Result; + return base.EndRead(asyncResult); + } + + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + VerifyBuffer(buffer, offset, count, allowEmpty: false); + CancellationTokenRegistration registration = cancellationToken.Register(Abort); + await _readLock.WaitAsync(cancellationToken); + try + { + int totalRead = 0; + do + { + // Don't drained buffered data on abort. + CheckAborted(); + if (_topBuffer.Count <= 0) + { + byte[] topBuffer = null; + while (!_bufferedData.TryDequeue(out topBuffer)) + { + if (_disposed) + { + CheckAborted(); + // Graceful close + return totalRead; + } + await WaitForDataAsync(); + } + _topBuffer = new ArraySegment(topBuffer); + } + int actualCount = Math.Min(count, _topBuffer.Count); + Buffer.BlockCopy(_topBuffer.Array, _topBuffer.Offset, buffer, offset, actualCount); + _topBuffer = new ArraySegment(_topBuffer.Array, + _topBuffer.Offset + actualCount, + _topBuffer.Count - actualCount); + totalRead += actualCount; + offset += actualCount; + count -= actualCount; + } + while (count > 0 && (_topBuffer.Count > 0 || _bufferedData.Count > 0)); + // Keep reading while there is more data available and we have more space to put it in. + return totalRead; + } + finally + { + registration.Dispose(); + _readLock.Release(); + } + } + + // Write with count 0 will still trigger OnFirstWrite + public override void Write(byte[] buffer, int offset, int count) + { + VerifyBuffer(buffer, offset, count, allowEmpty: true); + CheckDisposed(); + + _writeLock.Wait(); + try + { + if (count == 0) + { + return; + } + // Copies are necessary because we don't know what the caller is going to do with the buffer afterwards. + byte[] internalBuffer = new byte[count]; + Buffer.BlockCopy(buffer, offset, internalBuffer, 0, count); + _bufferedData.Enqueue(internalBuffer); + + SignalDataAvailable(); + } + finally + { + _writeLock.Release(); + } + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + Write(buffer, offset, count); + TaskCompletionSource tcs = new TaskCompletionSource(state); + tcs.TrySetResult(null); + IAsyncResult result = tcs.Task; + if (callback != null) + { + callback(result); + } + return result; + } + + public override void EndWrite(IAsyncResult asyncResult) + { + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + VerifyBuffer(buffer, offset, count, allowEmpty: true); + if (cancellationToken.IsCancellationRequested) + { + TaskCompletionSource tcs = new TaskCompletionSource(); + tcs.TrySetCanceled(); + return tcs.Task; + } + + Write(buffer, offset, count); + return Task.FromResult(null); + } + + private static void VerifyBuffer(byte[] buffer, int offset, int count, bool allowEmpty) + { + if (buffer == null) + { + throw new ArgumentNullException("buffer"); + } + if (offset < 0 || offset > buffer.Length) + { + throw new ArgumentOutOfRangeException("offset", offset, string.Empty); + } + if (count < 0 || count > buffer.Length - offset + || (!allowEmpty && count == 0)) + { + throw new ArgumentOutOfRangeException("count", count, string.Empty); + } + } + + private void SignalDataAvailable() + { + // Dispatch, as TrySetResult will synchronously execute the waiters callback and block our Write. + Task.Factory.StartNew(() => _readWaitingForData.TrySetResult(null)); + } + + private Task WaitForDataAsync() + { + _readWaitingForData = new TaskCompletionSource(); + + if (!_bufferedData.IsEmpty || _disposed) + { + // Race, data could have arrived before we created the TCS. + _readWaitingForData.TrySetResult(null); + } + + return _readWaitingForData.Task; + } + + internal void Abort() + { + Abort(new OperationCanceledException()); + } + + internal void Abort(Exception innerException) + { + Contract.Requires(innerException != null); + _aborted = true; + _abortException = innerException; + Dispose(); + } + + private void CheckAborted() + { + if (_aborted) + { + throw new IOException(string.Empty, _abortException); + } + } + + [SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_writeLock", Justification = "ODEs from the locks would mask IOEs from abort.")] + [SuppressMessage("Microsoft.Usage", "CA2213:DisposableFieldsShouldBeDisposed", MessageId = "_readLock", Justification = "Data can still be read unless we get aborted.")] + protected override void Dispose(bool disposing) + { + if (disposing) + { + // Throw for further writes, but not reads. Allow reads to drain the buffered data and then return 0 for further reads. + _disposed = true; + _readWaitingForData.TrySetResult(null); + } + + base.Dispose(disposing); + } + + private void CheckDisposed() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + } + } +} diff --git a/test/Microsoft.Net.WebSockets.Test/DuplexStream.cs b/test/Microsoft.Net.WebSockets.Test/DuplexStream.cs new file mode 100644 index 0000000000..7cd19c17f8 --- /dev/null +++ b/test/Microsoft.Net.WebSockets.Test/DuplexStream.cs @@ -0,0 +1,172 @@ +// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.txt in the project root for license information. + +using System; +using System.IO; +using System.Threading; +using System.Threading.Tasks; + +namespace Microsoft.Net.WebSockets.Test +{ + // A duplex wrapper around a read and write stream. + public class DuplexStream : Stream + { + private readonly Stream _readStream; + private readonly Stream _writeStream; + + public DuplexStream() + : this (new BufferStream(), new BufferStream()) + { + } + + public DuplexStream(Stream readStream, Stream writeStream) + { + _readStream = readStream; + _writeStream = writeStream; + } + + public DuplexStream CreateReverseDuplexStream() + { + return new DuplexStream(_writeStream, _readStream); + } + +#region Properties + + public override bool CanRead + { + get { return _readStream.CanRead; } + } + + public override bool CanSeek + { + get { return false; } + } + + public override bool CanTimeout + { + get { return _readStream.CanTimeout || _writeStream.CanTimeout; } + } + + public override bool CanWrite + { + get { return _writeStream.CanWrite; } + } + + public override long Length + { + get { throw new NotSupportedException(); } + } + + public override long Position + { + get { throw new NotSupportedException(); } + set { throw new NotSupportedException(); } + } + + public override int ReadTimeout + { + get { return _readStream.ReadTimeout; } + set { _readStream.ReadTimeout = value; } + } + + public override int WriteTimeout + { + get { return _writeStream.WriteTimeout; } + set { _writeStream.WriteTimeout = value; } + } + +#endregion Properties + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + +#region Read + + public override int Read(byte[] buffer, int offset, int count) + { + return _readStream.Read(buffer, offset, count); + } + + public override int ReadByte() + { + return _readStream.ReadByte(); + } + + public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _readStream.BeginRead(buffer, offset, count, callback, state); + } + + public override int EndRead(IAsyncResult asyncResult) + { + return _readStream.EndRead(asyncResult); + } + + public override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _readStream.ReadAsync(buffer, offset, count, cancellationToken); + } + + public override Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken) + { + return _readStream.CopyToAsync(destination, bufferSize, cancellationToken); + } + +#endregion Read + +#region Write + + public override void Write(byte[] buffer, int offset, int count) + { + _writeStream.Write(buffer, offset, count); + } + + public override void WriteByte(byte value) + { + _writeStream.WriteByte(value); + } + + public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, AsyncCallback callback, object state) + { + return _writeStream.BeginWrite(buffer, offset, count, callback, state); + } + + public override void EndWrite(IAsyncResult asyncResult) + { + _writeStream.EndWrite(asyncResult); + } + + public override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + return _writeStream.WriteAsync(buffer, offset, count, cancellationToken); + } + + public override void Flush() + { + _writeStream.Flush(); + } + + public override Task FlushAsync(CancellationToken cancellationToken) + { + return _writeStream.FlushAsync(cancellationToken); + } + +#endregion Write + + protected override void Dispose(bool disposing) + { + if (disposing) + { + _readStream.Dispose(); + _writeStream.Dispose(); + } + base.Dispose(disposing); + } + } +} diff --git a/test/Microsoft.Net.WebSockets.Test/DuplexTests.cs b/test/Microsoft.Net.WebSockets.Test/DuplexTests.cs new file mode 100644 index 0000000000..117c233e88 --- /dev/null +++ b/test/Microsoft.Net.WebSockets.Test/DuplexTests.cs @@ -0,0 +1,63 @@ +using System; +using System.Net.WebSockets; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Xunit; + +namespace Microsoft.Net.WebSockets.Test +{ + public class DuplexTests + { + [Fact] + public async Task SendAndReceive() + { + DuplexStream serverStream = new DuplexStream(); + DuplexStream clientStream = serverStream.CreateReverseDuplexStream(); + + WebSocket serverWebSocket = CommonWebSocket.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024); + WebSocket clientWebSocket = CommonWebSocket.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024, false); + + byte[] clientBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz"); + byte[] serverBuffer = new byte[clientBuffer.Length]; + + await clientWebSocket.SendAsync(new ArraySegment(clientBuffer), WebSocketMessageType.Text, true, CancellationToken.None); + WebSocketReceiveResult serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(serverResult.EndOfMessage); + Assert.Equal(clientBuffer.Length, serverResult.Count); + Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType); + Assert.Equal(clientBuffer, serverBuffer); + } + + [Fact] + // Tests server unmasking with offset masks + public async Task ServerReceiveOffsetData() + { + DuplexStream serverStream = new DuplexStream(); + DuplexStream clientStream = serverStream.CreateReverseDuplexStream(); + + WebSocket serverWebSocket = CommonWebSocket.CreateServerWebSocket(serverStream, null, TimeSpan.FromMinutes(2), 1024); + WebSocket clientWebSocket = CommonWebSocket.CreateClientWebSocket(clientStream, null, TimeSpan.FromMinutes(2), 1024, false); + + byte[] clientBuffer = Encoding.ASCII.GetBytes("abcdefghijklmnopqrstuvwxyz"); + byte[] serverBuffer = new byte[clientBuffer.Length]; + + await clientWebSocket.SendAsync(new ArraySegment(clientBuffer), WebSocketMessageType.Text, true, CancellationToken.None); + WebSocketReceiveResult serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment(serverBuffer, 0, 3), CancellationToken.None); + Assert.False(serverResult.EndOfMessage); + Assert.Equal(3, serverResult.Count); + Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType); + + serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment(serverBuffer, 3, 10), CancellationToken.None); + Assert.False(serverResult.EndOfMessage); + Assert.Equal(10, serverResult.Count); + Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType); + + serverResult = await serverWebSocket.ReceiveAsync(new ArraySegment(serverBuffer, 13, 13), CancellationToken.None); + Assert.True(serverResult.EndOfMessage); + Assert.Equal(13, serverResult.Count); + Assert.Equal(WebSocketMessageType.Text, serverResult.MessageType); + Assert.Equal(clientBuffer, serverBuffer); + } + } +} diff --git a/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj b/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj index fcfd99f1a4..25f16b963c 100644 --- a/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj +++ b/test/Microsoft.Net.WebSockets.Test/Microsoft.Net.WebSockets.Test.csproj @@ -45,6 +45,9 @@ + + +