diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index ff6b9f44a9..6930e38c4e 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -1,7 +1,6 @@ using System; -using System.Collections.Generic; +using System.Diagnostics.Contracts; using System.IO; -using System.Linq; using System.Net.WebSockets; using System.Text; using System.Threading; @@ -26,14 +25,14 @@ namespace Microsoft.Net.WebSockets private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; - private bool _outgoingMessageInProgress; + private bool _isOutgoingMessageInProgress; private byte[] _receiveBuffer; - private int _receiveOffset; - private int _receiveCount; + private int _receiveBufferOffset; + private int _receiveBufferBytes; private FrameHeader _frameInProgress; - private long _frameBytesRemaining = 0; + private long _frameBytesRemaining; private int? _firstDataOpCode; public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) @@ -102,31 +101,38 @@ namespace Microsoft.Net.WebSockets public override async Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { - // TODO: Validate arguments - // TODO: Check state - // TODO: Check concurrent writes - // TODO: Check ping/pong state - // TODO: Masking - // TODO: Block close frame? + ValidateSegment(buffer); + if (messageType != WebSocketMessageType.Binary && messageType != WebSocketMessageType.Text) + { + // Block control frames + throw new ArgumentOutOfRangeException("messageType", messageType, string.Empty); + } + // Check concurrent writes, pings & pongs, or closes await _writeLock.WaitAsync(cancellationToken); - try { + ThrowIfDisposed(); + ThrowIfOutputClosed(); + int mask = GetNextMask(); - FrameHeader frameHeader = new FrameHeader(endOfMessage, _outgoingMessageInProgress ? Constants.OpCodes.ContinuationFrame : GetOpCode(messageType), _maskOutput, mask, buffer.Count); - ArraySegment segment = frameHeader.Buffer; + int opcode = _isOutgoingMessageInProgress ? Constants.OpCodes.ContinuationFrame : Utilities.GetOpCode(messageType); + FrameHeader frameHeader = new FrameHeader(endOfMessage, opcode, _maskOutput, mask, buffer.Count); + ArraySegment headerSegment = frameHeader.Buffer; + if (_maskOutput && mask != 0) { - byte[] maskedFrame = Utilities.MergeAndMask(mask, segment, buffer); + // TODO: For larger messages consider using a limited size buffer and masking & sending in segments. + byte[] maskedFrame = Utilities.MergeAndMask(mask, headerSegment, buffer); await _stream.WriteAsync(maskedFrame, 0, maskedFrame.Length, cancellationToken); } else { - await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count, cancellationToken); await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); } - _outgoingMessageInProgress = !endOfMessage; + + _isOutgoingMessageInProgress = !endOfMessage; } finally { @@ -134,60 +140,21 @@ namespace Microsoft.Net.WebSockets } } - private int GetOpCode(WebSocketMessageType messageType) - { - switch (messageType) - { - case WebSocketMessageType.Text: return Constants.OpCodes.TextFrame; - case WebSocketMessageType.Binary: return Constants.OpCodes.BinaryFrame; - case WebSocketMessageType.Close: return Constants.OpCodes.CloseFrame; - default: throw new NotImplementedException(messageType.ToString()); - } - } - public async override Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { - // TODO: Validate arguments - // TODO: Check state - // TODO: Check concurrent reads - // TODO: Check ping/pong state + ThrowIfDisposed(); + ThrowIfInputClosed(); + ValidateSegment(buffer); + // TODO: InvalidOperationException if any receives are currently in progress. - // No active frame + // No active frame. Loop because we may be discarding ping/pong frames. while (_frameInProgress == null) { - await EnsureDataAvailableOrReadAsync(2, cancellationToken); - int frameHeaderSize = FrameHeader.CalculateFrameHeaderSize(_receiveBuffer[_receiveOffset + 1]); - await EnsureDataAvailableOrReadAsync(frameHeaderSize, cancellationToken); - _frameInProgress = new FrameHeader(new ArraySegment(_receiveBuffer, _receiveOffset, frameHeaderSize)); - _receiveOffset += frameHeaderSize; - _receiveCount -= frameHeaderSize; - _frameBytesRemaining = _frameInProgress.DataLength; - - if (_unmaskInput != _frameInProgress.Masked) - { - throw new InvalidOperationException("Unmasking settings out of sync with data."); - } - - // Ping or Pong frames - if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) - { - // Drain it, should be less than 125 bytes - await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); - - if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame && State == WebSocketState.Open) - { - await SendPongReply(cancellationToken); - } - - _receiveOffset += (int)_frameBytesRemaining; - _receiveCount -= (int)_frameBytesRemaining; - _frameBytesRemaining = 0; - _frameInProgress = null; - } + await ReadNextFrameAsync(cancellationToken); } // Handle fragmentation, remember the first frame type - int opCode = 0; + int opCode = Constants.OpCodes.ContinuationFrame; if (_frameInProgress.OpCode == Constants.OpCodes.BinaryFrame || _frameInProgress.OpCode == Constants.OpCodes.TextFrame || _frameInProgress.OpCode == Constants.OpCodes.CloseFrame) @@ -209,48 +176,18 @@ namespace Microsoft.Net.WebSockets _firstDataOpCode = null; } - WebSocketReceiveResult result; - if (opCode == Constants.OpCodes.CloseFrame) { - // The close message should be less than 125 bytes and fit in the buffer. - await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); - - // Status code and message are optional - if (_frameBytesRemaining >= 2) - { - ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveOffset + 2, (int)_frameBytesRemaining - 2); - if (_unmaskInput) - { - // In place - Utilities.Mask(_frameInProgress.MaskKey, dataSegment); - } - _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]); - _closeStatusDescription = Encoding.UTF8.GetString(dataSegment.Array, dataSegment.Offset, dataSegment.Count) ?? string.Empty; - } - else - { - _closeStatus = _closeStatus ?? WebSocketCloseStatus.NormalClosure; - _closeStatusDescription = _closeStatusDescription ?? string.Empty; - } - result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription); - - if (State == WebSocketState.Open) - { - _state = WebSocketState.CloseReceived; - } - else if (State == WebSocketState.CloseSent) - { - _state = WebSocketState.Closed; - _stream.Dispose(); - } - return result; + return await ProcessCloseFrameAsync(cancellationToken); } + WebSocketReceiveResult result; + + WebSocketMessageType messageType = Utilities.GetMessageType(opCode); if (_frameBytesRemaining == 0) { // End of an empty frame? - result = new WebSocketReceiveResult(0, GetMessageType(opCode), _frameInProgress.Fin); + result = new WebSocketReceiveResult(0, messageType, _frameInProgress.Fin); _frameInProgress = null; return result; } @@ -259,50 +196,116 @@ namespace Microsoft.Net.WebSockets await EnsureDataAvailableOrReadAsync(1, cancellationToken); // Copy buffered data to the users buffer int bytesToRead = (int)Math.Min((long)buffer.Count, _frameBytesRemaining); - int bytesToCopy = Math.Min(bytesToRead, _receiveCount); - Array.Copy(_receiveBuffer, _receiveOffset, buffer.Array, buffer.Offset, bytesToCopy); + int bytesToCopy = Math.Min(bytesToRead, _receiveBufferBytes); + Array.Copy(_receiveBuffer, _receiveBufferOffset, buffer.Array, buffer.Offset, bytesToCopy); if (_unmaskInput) { // TODO: mask alignment may be off between reads. - Utilities.Mask(_frameInProgress.MaskKey, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); + // _frameInProgress.Masked == _unmaskInput already verified + Utilities.MaskInPlace(_frameInProgress.MaskKey, new ArraySegment(buffer.Array, buffer.Offset, bytesToCopy)); } if (bytesToCopy == _frameBytesRemaining) { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(opCode), _frameInProgress.Fin); + result = new WebSocketReceiveResult(bytesToCopy, messageType, _frameInProgress.Fin); _frameInProgress = null; } else { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(opCode), false); + result = new WebSocketReceiveResult(bytesToCopy, messageType, false); } _frameBytesRemaining -= bytesToCopy; - _receiveCount -= bytesToCopy; - _receiveOffset += bytesToCopy; + _receiveBufferBytes -= bytesToCopy; + _receiveBufferOffset += bytesToCopy; return result; } - // We received a ping, send a pong in reply - private async Task SendPongReply(CancellationToken cancellationToken) + private async Task ReadNextFrameAsync(CancellationToken cancellationToken) { - ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveOffset, (int)_frameBytesRemaining); - if (_unmaskInput) + await EnsureDataAvailableOrReadAsync(2, cancellationToken); + int frameHeaderSize = FrameHeader.CalculateFrameHeaderSize(_receiveBuffer[_receiveBufferOffset + 1]); + await EnsureDataAvailableOrReadAsync(frameHeaderSize, cancellationToken); + _frameInProgress = new FrameHeader(new ArraySegment(_receiveBuffer, _receiveBufferOffset, frameHeaderSize)); + _receiveBufferOffset += frameHeaderSize; + _receiveBufferBytes -= frameHeaderSize; + _frameBytesRemaining = _frameInProgress.DataLength; + + if (_unmaskInput != _frameInProgress.Masked) { - // In place - Utilities.Mask(_frameInProgress.MaskKey, dataSegment); + throw new InvalidOperationException("Unmasking settings out of sync with data."); } - int mask = GetNextMask(); - FrameHeader header = new FrameHeader(true, Constants.OpCodes.PongFrame, _maskOutput, mask, _frameBytesRemaining); - if (_maskOutput) + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame || _frameInProgress.OpCode == Constants.OpCodes.PongFrame) { - // In place - Utilities.Mask(_frameInProgress.MaskKey, dataSegment); - } + // Drain it, should be less than 125 bytes + await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, cancellationToken); + if (_frameInProgress.OpCode == Constants.OpCodes.PingFrame) + { + await SendPongReplyAsync(cancellationToken); + } + + _receiveBufferOffset += (int)_frameBytesRemaining; + _receiveBufferBytes -= (int)_frameBytesRemaining; + _frameBytesRemaining = 0; + _frameInProgress = null; + } + } + + private async Task EnsureDataAvailableOrReadAsync(int bytesNeeded, CancellationToken cancellationToken) + { + // Adequate buffer space? + Contract.Assert(bytesNeeded <= _receiveBuffer.Length); + + // Insufficient buffered data + while (_receiveBufferBytes < bytesNeeded) + { + cancellationToken.ThrowIfCancellationRequested(); + + int spaceRemaining = _receiveBuffer.Length - (_receiveBufferOffset + _receiveBufferBytes); + if (_receiveBufferOffset > 0 && bytesNeeded > spaceRemaining) + { + // Some data in the buffer, shift down to make room + Array.Copy(_receiveBuffer, _receiveBufferOffset, _receiveBuffer, 0, _receiveBufferBytes); + _receiveBufferOffset = 0; + spaceRemaining = _receiveBuffer.Length - _receiveBufferBytes; + } + // Add to the end + int read = await _stream.ReadAsync(_receiveBuffer, _receiveBufferOffset + _receiveBufferBytes, spaceRemaining, cancellationToken); + if (read == 0) + { + throw new IOException("Unexpected end of stream"); + } + _receiveBufferBytes += read; + } + } + + // We received a ping, send a pong in reply + private async Task SendPongReplyAsync(CancellationToken cancellationToken) + { await _writeLock.WaitAsync(cancellationToken); try { + if (State != WebSocketState.Open) + { + // Output closed, discard the pong. + return; + } + + ArraySegment dataSegment = new ArraySegment(_receiveBuffer, _receiveBufferOffset, (int)_frameBytesRemaining); + if (_unmaskInput) + { + // _frameInProgress.Masked == _unmaskInput already verified + Utilities.MaskInPlace(_frameInProgress.MaskKey, dataSegment); + } + + int mask = GetNextMask(); + FrameHeader header = new FrameHeader(true, Constants.OpCodes.PongFrame, _maskOutput, mask, _frameBytesRemaining); + if (_maskOutput) + { + Utilities.MaskInPlace(mask, dataSegment); + } + ArraySegment headerSegment = header.Buffer; await _stream.WriteAsync(headerSegment.Array, headerSegment.Offset, headerSegment.Count, cancellationToken); await _stream.WriteAsync(dataSegment.Array, dataSegment.Offset, dataSegment.Count, cancellationToken); @@ -313,49 +316,47 @@ namespace Microsoft.Net.WebSockets } } - private async Task EnsureDataAvailableOrReadAsync(int bytes, CancellationToken cancellationToken) + private async Task ProcessCloseFrameAsync(CancellationToken cancellationToken) { - // Insufficient data - while (_receiveCount < bytes && bytes <= _receiveBuffer.Length) - { - // Some data in the buffer, shift down to make room - if (_receiveCount > 0 && _receiveOffset > 0) - { - Array.Copy(_receiveBuffer, _receiveOffset, _receiveBuffer, 0, _receiveCount); - } - _receiveOffset = 0; - // Add to the end - int read = await _stream.ReadAsync(_receiveBuffer, _receiveCount, _receiveBuffer.Length - (_receiveCount), cancellationToken); - if (read == 0) - { - throw new IOException("Unexpected end of stream"); - } - _receiveCount += read; - } - } + // The close message should be less than 125 bytes and fit in the buffer. + await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); - private WebSocketMessageType GetMessageType(int opCode) - { - switch (opCode) + // Status code and message are optional + if (_frameBytesRemaining >= 2) { - case Constants.OpCodes.TextFrame: return WebSocketMessageType.Text; - case Constants.OpCodes.BinaryFrame: return WebSocketMessageType.Binary; - case Constants.OpCodes.CloseFrame: return WebSocketMessageType.Close; - default: throw new NotImplementedException(opCode.ToString()); + if (_unmaskInput) + { + Utilities.MaskInPlace(_frameInProgress.MaskKey, new ArraySegment(_receiveBuffer, _receiveBufferOffset, (int)_frameBytesRemaining)); + } + _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveBufferOffset] << 8) | _receiveBuffer[_receiveBufferOffset + 1]); + _closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveBufferOffset + 2, (int)_frameBytesRemaining - 2) ?? string.Empty; } + else + { + _closeStatus = _closeStatus ?? WebSocketCloseStatus.NormalClosure; + _closeStatusDescription = _closeStatusDescription ?? string.Empty; + } + + Contract.Assert(_frameInProgress.Fin); + WebSocketReceiveResult result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, _frameInProgress.Fin, + _closeStatus.Value, _closeStatusDescription); + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseReceived; + } + else if (State == WebSocketState.CloseSent) + { + _state = WebSocketState.Closed; + _stream.Dispose(); + } + + return result; } public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { - // TODO: Validate arguments - // TODO: Check state - // TODO: Check concurrent writes - // TODO: Check ping/pong state - - if (State >= WebSocketState.Closed) - { - throw new InvalidOperationException("Already closed."); - } + ThrowIfDisposed(); if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) { @@ -373,52 +374,49 @@ namespace Microsoft.Net.WebSockets result = await ReceiveAsync(new ArraySegment(data), cancellationToken); } while (result.MessageType != WebSocketMessageType.Close); - - _closeStatus = result.CloseStatus; - _closeStatusDescription = result.CloseStatusDescription; } - - _state = WebSocketState.Closed; - _stream.Dispose(); } public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { - // TODO: Validate arguments - // TODO: Check state - // TODO: Check concurrent writes - // TODO: Check ping/pong state - - if (State == WebSocketState.CloseSent || State >= WebSocketState.Closed) + await _writeLock.WaitAsync(cancellationToken); + try { - throw new InvalidOperationException("Already closed."); - } + ThrowIfDisposed(); + ThrowIfOutputClosed(); - if (State == WebSocketState.Open) + byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty); + byte[] fullData = new byte[descriptionBytes.Length + 2]; + fullData[0] = (byte)((int)closeStatus >> 8); + fullData[1] = (byte)closeStatus; + Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length); + + int mask = GetNextMask(); + if (_maskOutput) + { + Utilities.MaskInPlace(mask, new ArraySegment(fullData)); + } + + FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, _maskOutput, mask, fullData.Length); + + ArraySegment segment = frameHeader.Buffer; + await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); + await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken); + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseSent; + } + else if (State == WebSocketState.CloseReceived) + { + _state = WebSocketState.Closed; + _stream.Dispose(); + } + } + finally { - _state = WebSocketState.CloseSent; + _writeLock.Release(); } - else if (State == WebSocketState.CloseReceived) - { - _state = WebSocketState.Closed; - } - - byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty); - byte[] fullData = new byte[descriptionBytes.Length + 2]; - fullData[0] = (byte)((int)closeStatus >> 8); - fullData[1] = (byte)closeStatus; - Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length); - - int mask = GetNextMask(); - if (_maskOutput) - { - Utilities.Mask(mask, new ArraySegment(fullData)); - } - - FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, _maskOutput, mask, fullData.Length); - ArraySegment segment = frameHeader.Buffer; - await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken); } public override void Abort() @@ -442,5 +440,45 @@ namespace Microsoft.Net.WebSockets _state = WebSocketState.Closed; _stream.Dispose(); } + + private void ThrowIfDisposed() + { + if (_state >= WebSocketState.Closed) // or Aborted + { + throw new ObjectDisposedException(typeof(CommonWebSocket).FullName); + } + } + + private void ThrowIfOutputClosed() + { + if (State == WebSocketState.CloseSent) + { + throw new InvalidOperationException("Close already sent."); + } + } + + private void ThrowIfInputClosed() + { + if (State == WebSocketState.CloseReceived) + { + throw new InvalidOperationException("Close already received."); + } + } + + private void ValidateSegment(ArraySegment buffer) + { + if (buffer.Array == null) + { + throw new ArgumentNullException("buffer"); + } + if (buffer.Offset < 0 || buffer.Offset >= buffer.Array.Length) + { + throw new ArgumentOutOfRangeException("buffer.Offset", buffer.Offset, string.Empty); + } + if (buffer.Count <= 0 || buffer.Count > buffer.Array.Length - buffer.Offset) + { + throw new ArgumentOutOfRangeException("buffer.Count", buffer.Count, string.Empty); + } + } } } diff --git a/src/Microsoft.Net.WebSockets/Utilities.cs b/src/Microsoft.Net.WebSockets/Utilities.cs index 6dbc83630a..35be77a857 100644 --- a/src/Microsoft.Net.WebSockets/Utilities.cs +++ b/src/Microsoft.Net.WebSockets/Utilities.cs @@ -1,8 +1,5 @@ using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; +using System.Net.WebSockets; namespace Microsoft.Net.WebSockets { @@ -15,12 +12,12 @@ namespace Microsoft.Net.WebSockets Array.Copy(header.Array, header.Offset, frame, 0, header.Count); Array.Copy(data.Array, data.Offset, frame, header.Count, data.Count); - Mask(mask, new ArraySegment(frame, header.Count, data.Count)); + MaskInPlace(mask, new ArraySegment(frame, header.Count, data.Count)); return frame; } // Un/Masks the data in place - public static void Mask(int mask, ArraySegment data) + public static void MaskInPlace(int mask, ArraySegment data) { if (mask == 0) { @@ -42,5 +39,27 @@ namespace Microsoft.Net.WebSockets maskOffset = (maskOffset + 1) % 4; } } + + public static int GetOpCode(WebSocketMessageType messageType) + { + switch (messageType) + { + case WebSocketMessageType.Text: return Constants.OpCodes.TextFrame; + case WebSocketMessageType.Binary: return Constants.OpCodes.BinaryFrame; + case WebSocketMessageType.Close: return Constants.OpCodes.CloseFrame; + default: throw new NotImplementedException(messageType.ToString()); + } + } + + public static WebSocketMessageType GetMessageType(int opCode) + { + switch (opCode) + { + case Constants.OpCodes.TextFrame: return WebSocketMessageType.Text; + case Constants.OpCodes.BinaryFrame: return WebSocketMessageType.Binary; + case Constants.OpCodes.CloseFrame: return WebSocketMessageType.Close; + default: throw new NotImplementedException(opCode.ToString()); + } + } } } diff --git a/src/Microsoft.Net.WebSockets/WebSocketClient.cs b/src/Microsoft.Net.WebSockets/WebSocketClient.cs index a78b2f5e17..d1ec49b4d5 100644 --- a/src/Microsoft.Net.WebSockets/WebSocketClient.cs +++ b/src/Microsoft.Net.WebSockets/WebSocketClient.cs @@ -1,10 +1,7 @@ using System; -using System.Collections.Generic; using System.IO; -using System.Linq; using System.Net; using System.Net.WebSockets; -using System.Text; using System.Threading; using System.Threading.Tasks; @@ -36,20 +33,59 @@ namespace Microsoft.Net.WebSockets.Client set; } + public bool UseZeroMask + { + get; + set; + } + + public Action ConfigureRequest + { + get; + set; + } + + public Action InspectResponse + { + get; + set; + } + public async Task ConnectAsync(Uri uri, CancellationToken cancellationToken) { HttpWebRequest request = (HttpWebRequest)WebRequest.Create(uri); + CancellationTokenRegistration cancellation = cancellationToken.Register(() => request.Abort()); + request.Headers[Constants.Headers.WebSocketVersion] = Constants.Headers.SupportedVersion; // TODO: Sub-protocols - WebResponse response = await request.GetResponseAsync(); + if (ConfigureRequest != null) + { + ConfigureRequest(request); + } + + HttpWebResponse response = (HttpWebResponse)await request.GetResponseAsync(); + + cancellation.Dispose(); + + if (InspectResponse != null) + { + InspectResponse(response); + } + // TODO: Validate handshake + if (response.StatusCode != HttpStatusCode.SwitchingProtocols) + { + response.Dispose(); + throw new InvalidOperationException("Incomplete handshake"); + } + + // TODO: Sub protocol Stream stream = response.GetResponseStream(); - // Console.WriteLine(stream.CanWrite + " " + stream.CanRead); - return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: false); + return CommonWebSocket.CreateClientWebSocket(stream, null, ReceiveBufferSize, useZeroMask: UseZeroMask); } } } diff --git a/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs b/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs index ff9c8f4601..0a218308c8 100644 --- a/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs +++ b/test/Microsoft.Net.WebSockets.Test/UtilitiesTests.cs @@ -1,8 +1,5 @@ using System; -using System.Collections.Generic; -using System.Linq; using System.Text; -using System.Threading.Tasks; using Xunit; namespace Microsoft.Net.WebSockets.Test @@ -14,8 +11,8 @@ namespace Microsoft.Net.WebSockets.Test { byte[] data = Encoding.UTF8.GetBytes("Hello World"); byte[] orriginal = Encoding.UTF8.GetBytes("Hello World"); - Utilities.Mask(16843009, new ArraySegment(data)); - Utilities.Mask(16843009, new ArraySegment(data)); + Utilities.MaskInPlace(16843009, new ArraySegment(data)); + Utilities.MaskInPlace(16843009, new ArraySegment(data)); Assert.Equal(orriginal, data); } } diff --git a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs index a466271278..91d1045e4c 100644 --- a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs +++ b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs @@ -1,7 +1,5 @@ using Microsoft.Net.WebSockets.Client; using System; -using System.Collections.Generic; -using System.Linq; using System.Net; using System.Net.WebSockets; using System.Text;