diff --git a/src/Microsoft.Net.WebSockets/ClientWebSocket.cs b/src/Microsoft.Net.WebSockets/ClientWebSocket.cs index 3334508bbd..d91bf1cffd 100644 --- a/src/Microsoft.Net.WebSockets/ClientWebSocket.cs +++ b/src/Microsoft.Net.WebSockets/ClientWebSocket.cs @@ -15,6 +15,9 @@ namespace Microsoft.Net.WebSockets private readonly string _subProtocl; private WebSocketState _state; + private WebSocketCloseStatus? _closeStatus; + private string _closeStatusDescription; + private byte[] _receiveBuffer; private int _receiveOffset; private int _receiveCount; @@ -32,12 +35,12 @@ namespace Microsoft.Net.WebSockets public override WebSocketCloseStatus? CloseStatus { - get { throw new NotImplementedException(); } + get { return _closeStatus; } } public override string CloseStatusDescription { - get { throw new NotImplementedException(); } + get { return _closeStatusDescription; } } public override WebSocketState State @@ -94,9 +97,31 @@ namespace Microsoft.Net.WebSockets } WebSocketReceiveResult result; - // TODO: Close frame + // TODO: Ping or Pong frames + if (_frameInProgress.OpCode == Constants.OpCodes.CloseFrame) + { + // TOOD: This assumes the close message fits in the buffer. + // TODO: Assert at least two bytes remaining for the close status code. + await EnsureDataAvailableOrReadAsync((int)_frameBytesRemaining, CancellationToken.None); + + _closeStatus = (WebSocketCloseStatus)((_receiveBuffer[_receiveOffset] << 8) | _receiveBuffer[_receiveOffset + 1]); + _closeStatusDescription = Encoding.UTF8.GetString(_receiveBuffer, _receiveOffset + 2, _receiveCount - 2) ?? string.Empty; + result = new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, (WebSocketCloseStatus)_closeStatus, _closeStatusDescription); + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseReceived; + } + else if (State == WebSocketState.CloseSent) + { + _state = WebSocketState.Closed; + _stream.Dispose(); + } + return result; + } + // Make sure there's at least some data in the buffer if (_frameBytesRemaining > 0) { @@ -163,9 +188,41 @@ namespace Microsoft.Net.WebSockets } } - public override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) + public async override Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { - throw new NotImplementedException(); + // TODO: Validate arguments + // TODO: Check state + // TODO: Check concurrent writes + // TODO: Check ping/pong state + + if (State >= WebSocketState.Closed) + { + throw new InvalidOperationException("Already closed."); + } + + if (State == WebSocketState.Open || State == WebSocketState.CloseReceived) + { + // Send a close message. + await CloseOutputAsync(closeStatus, statusDescription, cancellationToken); + } + + if (State == WebSocketState.CloseSent) + { + // Do a receiving drain + byte[] data = new byte[1024]; + WebSocketReceiveResult result; + do + { + result = await ReceiveAsync(new ArraySegment(data), cancellationToken); + } + while (result.MessageType != WebSocketMessageType.Close); + + _closeStatus = result.CloseStatus; + _closeStatusDescription = result.CloseStatusDescription; + } + + _state = WebSocketState.Closed; + _stream.Dispose(); } public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) @@ -174,14 +231,32 @@ namespace Microsoft.Net.WebSockets // TODO: Check state // TODO: Check concurrent writes // TODO: Check ping/pong state - _state = WebSocketState.CloseSent; + + if (State == WebSocketState.CloseSent || State >= WebSocketState.Closed) + { + throw new InvalidOperationException("Already closed."); + } + + if (State == WebSocketState.Open) + { + _state = WebSocketState.CloseSent; + } + else if (State == WebSocketState.CloseReceived) + { + _state = WebSocketState.Closed; + } + + byte[] descriptionBytes = Encoding.UTF8.GetBytes(statusDescription ?? string.Empty); + byte[] fullData = new byte[descriptionBytes.Length + 2]; + fullData[0] = (byte)((int)closeStatus >> 8); + fullData[1] = (byte)closeStatus; + Array.Copy(descriptionBytes, 0, fullData, 2, descriptionBytes.Length); // TODO: Masking - byte[] buffer = Encoding.UTF8.GetBytes(statusDescription); - FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, true, 0, buffer.Length); + FrameHeader frameHeader = new FrameHeader(true, Constants.OpCodes.CloseFrame, true, 0, fullData.Length); ArraySegment segment = frameHeader.Buffer; await _stream.WriteAsync(segment.Array, segment.Offset, segment.Count, cancellationToken); - await _stream.WriteAsync(buffer, 0, buffer.Length, cancellationToken); + await _stream.WriteAsync(fullData, 0, fullData.Length, cancellationToken); } public override void Abort() diff --git a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs index 7ac37e8bc3..7443cc0c7b 100644 --- a/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs +++ b/test/Microsoft.Net.WebSockets.Test/WebSocketClientTests.cs @@ -280,5 +280,193 @@ namespace Microsoft.Net.WebSockets.Test clientSocket.Dispose(); } } + + [Fact] + public async Task SendClose_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + + WebSocket clientSocket = await clientConnect; + + string closeDescription = "Test Closed"; + await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + + byte[] serverBuffer = new byte[1024]; + WebSocketReceiveResult result = await serverWebSocketContext.WebSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(0, result.Count); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(closeDescription, result.CloseStatusDescription); + + Assert.Equal(WebSocketState.CloseSent, clientSocket.State); + + clientSocket.Dispose(); + } + } + + [Fact] + public async Task ReceiveClose_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + + WebSocket clientSocket = await clientConnect; + + string closeDescription = "Test Closed"; + await serverWebSocketContext.WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + + byte[] serverBuffer = new byte[1024]; + WebSocketReceiveResult result = await clientSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(0, result.Count); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(closeDescription, result.CloseStatusDescription); + + Assert.Equal(WebSocketState.CloseReceived, clientSocket.State); + + clientSocket.Dispose(); + } + } + + [Fact] + public async Task CloseFromOpen_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + + WebSocket clientSocket = await clientConnect; + + string closeDescription = "Test Closed"; + Task closeTask = clientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + + byte[] serverBuffer = new byte[1024]; + WebSocketReceiveResult result = await serverWebSocketContext.WebSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(0, result.Count); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(closeDescription, result.CloseStatusDescription); + + await serverWebSocketContext.WebSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + + await closeTask; + + Assert.Equal(WebSocketState.Closed, clientSocket.State); + + clientSocket.Dispose(); + } + } + + [Fact] + public async Task CloseFromCloseSent_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + + WebSocket clientSocket = await clientConnect; + + string closeDescription = "Test Closed"; + await clientSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + Assert.Equal(WebSocketState.CloseSent, clientSocket.State); + + byte[] serverBuffer = new byte[1024]; + WebSocketReceiveResult result = await serverWebSocketContext.WebSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(0, result.Count); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(closeDescription, result.CloseStatusDescription); + + await serverWebSocketContext.WebSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + + await clientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + + Assert.Equal(WebSocketState.Closed, clientSocket.State); + + clientSocket.Dispose(); + } + } + + [Fact] + public async Task CloseFromCloseReceived_Success() + { + using (HttpListener listener = new HttpListener()) + { + listener.Prefixes.Add(ServerAddress); + listener.Start(); + Task serverAccept = listener.GetContextAsync(); + + WebSocketClient client = new WebSocketClient(); + Task clientConnect = client.ConnectAsync(new Uri(ClientAddress), CancellationToken.None); + + HttpListenerContext serverContext = await serverAccept; + Assert.True(serverContext.Request.IsWebSocketRequest); + HttpListenerWebSocketContext serverWebSocketContext = await serverContext.AcceptWebSocketAsync(null); + + WebSocket clientSocket = await clientConnect; + + string closeDescription = "Test Closed"; + await serverWebSocketContext.WebSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, closeDescription, CancellationToken.None); + + byte[] serverBuffer = new byte[1024]; + WebSocketReceiveResult result = await clientSocket.ReceiveAsync(new ArraySegment(serverBuffer), CancellationToken.None); + Assert.True(result.EndOfMessage); + Assert.Equal(0, result.Count); + Assert.Equal(WebSocketMessageType.Close, result.MessageType); + Assert.Equal(WebSocketCloseStatus.NormalClosure, result.CloseStatus); + Assert.Equal(closeDescription, result.CloseStatusDescription); + + Assert.Equal(WebSocketState.CloseReceived, clientSocket.State); + + await clientSocket.CloseAsync(result.CloseStatus.Value, result.CloseStatusDescription, CancellationToken.None); + + Assert.Equal(WebSocketState.Closed, clientSocket.State); + + clientSocket.Dispose(); + } + } } }