diff --git a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs index b5fceae352..ff6b9f44a9 100644 --- a/src/Microsoft.Net.WebSockets/CommonWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/CommonWebSocket.cs @@ -26,12 +26,15 @@ namespace Microsoft.Net.WebSockets private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; + private bool _outgoingMessageInProgress; + private byte[] _receiveBuffer; private int _receiveOffset; private int _receiveCount; private FrameHeader _frameInProgress; private long _frameBytesRemaining = 0; + private int? _firstDataOpCode; public CommonWebSocket(Stream stream, string subProtocol, int receiveBufferSize, bool maskOutput, bool useZeroMask, bool unmaskInput) { @@ -111,7 +114,7 @@ namespace Microsoft.Net.WebSockets try { int mask = GetNextMask(); - FrameHeader frameHeader = new FrameHeader(endOfMessage, GetOpCode(messageType), _maskOutput, mask, buffer.Count); + FrameHeader frameHeader = new FrameHeader(endOfMessage, _outgoingMessageInProgress ? Constants.OpCodes.ContinuationFrame : GetOpCode(messageType), _maskOutput, mask, buffer.Count); ArraySegment segment = frameHeader.Buffer; if (_maskOutput && mask != 0) { @@ -123,6 +126,7 @@ namespace Microsoft.Net.WebSockets await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); await _stream.WriteAsync(buffer.Array, buffer.Offset, buffer.Count, cancellationToken); } + _outgoingMessageInProgress = !endOfMessage; } finally { @@ -182,9 +186,32 @@ namespace Microsoft.Net.WebSockets } } + // Handle fragmentation, remember the first frame type + int opCode = 0; + if (_frameInProgress.OpCode == Constants.OpCodes.BinaryFrame + || _frameInProgress.OpCode == Constants.OpCodes.TextFrame + || _frameInProgress.OpCode == Constants.OpCodes.CloseFrame) + { + opCode = _frameInProgress.OpCode; + _firstDataOpCode = opCode; + } + else if (_frameInProgress.OpCode == Constants.OpCodes.ContinuationFrame) + { + if (!_firstDataOpCode.HasValue) + { + throw new InvalidOperationException("A continuation can't be the first frame"); + } + opCode = _firstDataOpCode.Value; + } + + if (_frameInProgress.Fin) + { + _firstDataOpCode = null; + } + WebSocketReceiveResult result; - if (_frameInProgress.OpCode == Constants.OpCodes.CloseFrame) + if (opCode == Constants.OpCodes.CloseFrame) { // The close message should be less than 125 bytes and fit in the buffer. await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); @@ -223,7 +250,7 @@ namespace Microsoft.Net.WebSockets if (_frameBytesRemaining == 0) { // End of an empty frame? - result = new WebSocketReceiveResult(0, GetMessageType(_frameInProgress.OpCode), true); + result = new WebSocketReceiveResult(0, GetMessageType(opCode), _frameInProgress.Fin); _frameInProgress = null; return result; } @@ -241,12 +268,12 @@ namespace Microsoft.Net.WebSockets } if (bytesToCopy == _frameBytesRemaining) { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), _frameInProgress.Fin); + result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(opCode), _frameInProgress.Fin); _frameInProgress = null; } else { - result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(_frameInProgress.OpCode), false); + result = new WebSocketReceiveResult(bytesToCopy, GetMessageType(opCode), false); } _frameBytesRemaining -= bytesToCopy; _receiveCount -= bytesToCopy; diff --git a/src/Microsoft.Net.WebSockets/Constants.cs b/src/Microsoft.Net.WebSockets/Constants.cs index 740ea9ab41..8fb002913e 100644 --- a/src/Microsoft.Net.WebSockets/Constants.cs +++ b/src/Microsoft.Net.WebSockets/Constants.cs @@ -16,6 +16,7 @@ namespace Microsoft.Net.WebSockets public static class OpCodes { + public const int ContinuationFrame = 0x0; public const int TextFrame = 0x1; public const int BinaryFrame = 0x2; public const int CloseFrame = 0x8; diff --git a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs index 7443cc0c7b..a466271278 100644 --- a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs +++ b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs @@ -145,6 +145,57 @@ namespace Microsoft.Net.WebSockets.Test } } + [Fact] + public async Task SendFragmentedData_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + WebSocket serverSocket = serverWebSocketContext.WebSocket; + + WebSocket clientSocket = await clientConnect; + + byte[] orriginalData = Encoding.UTF8.GetBytes("Hello World"); + await clientSocket.SendAsync(new ArraySegment(orriginalData, 0, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await clientSocket.SendAsync(new ArraySegment(orriginalData, 2, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await clientSocket.SendAsync(new ArraySegment(orriginalData, 4, 7), WebSocketMessageType.Binary, true, CancellationToken.None); + + byte[] serverBuffer = new byte[orriginalData.Length]; + WebSocketReceiveResult result = await serverSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.False(result.EndOfMessage); + Assert.Equal(2, result.Count); + int totalReceived = result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + result = await serverSocket.ReceiveAsync( + new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); + Assert.False(result.EndOfMessage); + Assert.Equal(2, result.Count); + totalReceived += result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + result = await serverSocket.ReceiveAsync( + new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(7, result.Count); + totalReceived += result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + Assert.Equal(orriginalData, serverBuffer); + + clientSocket.Dispose(); + } + } + [Fact] public async Task ReceiveShortData_Success() { @@ -281,6 +332,57 @@ namespace Microsoft.Net.WebSockets.Test } } + [Fact] + public async Task ReceiveFragmentedData_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + WebSocket serverSocket = serverWebSocketContext.WebSocket; + + WebSocket clientSocket = await clientConnect; + + byte[] orriginalData = Encoding.UTF8.GetBytes("Hello World"); + await serverSocket.SendAsync(new ArraySegment(orriginalData, 0, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await serverSocket.SendAsync(new ArraySegment(orriginalData, 2, 2), WebSocketMessageType.Binary, false, CancellationToken.None); + await serverSocket.SendAsync(new ArraySegment(orriginalData, 4, 7), WebSocketMessageType.Binary, true, CancellationToken.None); + + byte[] serverBuffer = new byte[orriginalData.Length]; + WebSocketReceiveResult result = await clientSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.False(result.EndOfMessage); + Assert.Equal(2, result.Count); + int totalReceived = result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + result = await clientSocket.ReceiveAsync( + new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); + Assert.False(result.EndOfMessage); + Assert.Equal(2, result.Count); + totalReceived += result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + result = await clientSocket.ReceiveAsync( + new ArraySegment(serverBuffer, totalReceived, serverBuffer.Length - totalReceived), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(7, result.Count); + totalReceived += result.Count; + Assert.Equal(WebSocketMessageType.Binary, result.MessageType); + + Assert.Equal(orriginalData, serverBuffer); + + clientSocket.Dispose(); + } + } + [Fact] public async Task SendClose_Success() {