using System; using System.Collections.Generic; using System.Net.WebSockets; using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Http.Connections.Tests { internal class TestWebSocketConnectionFeature : IHttpWebSocketFeature, IDisposable { private readonly TaskCompletionSource _accepted = new TaskCompletionSource(); public bool IsWebSocketRequest => true; public WebSocketChannel Client { get; private set; } public string SubProtocol { get; private set; } public Task Accepted => _accepted.Task; public Task AcceptAsync() => AcceptAsync(new WebSocketAcceptContext()); public Task AcceptAsync(WebSocketAcceptContext context) { var clientToServer = Channel.CreateUnbounded(); var serverToClient = Channel.CreateUnbounded(); var clientSocket = new WebSocketChannel(serverToClient.Reader, clientToServer.Writer); var serverSocket = new WebSocketChannel(clientToServer.Reader, serverToClient.Writer); Client = clientSocket; SubProtocol = context.SubProtocol; _accepted.TrySetResult(null); return Task.FromResult(serverSocket); } public void Dispose() { } public class WebSocketChannel : WebSocket { private readonly ChannelReader _input; private readonly ChannelWriter _output; private WebSocketCloseStatus? _closeStatus; private string _closeStatusDescription; private WebSocketState _state; private WebSocketMessage _internalBuffer = new WebSocketMessage(); public WebSocketChannel(ChannelReader input, ChannelWriter output) { _input = input; _output = output; } public override WebSocketCloseStatus? CloseStatus => _closeStatus; public override string CloseStatusDescription => _closeStatusDescription; public override WebSocketState State => _state; public override string SubProtocol => null; public override void Abort() { _output.TryComplete(new OperationCanceledException()); _state = WebSocketState.Aborted; } public void SendAbort() { _output.TryComplete(new WebSocketException(WebSocketError.ConnectionClosedPrematurely)); } public override async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { await SendMessageAsync(new WebSocketMessage { CloseStatus = closeStatus, CloseStatusDescription = statusDescription, MessageType = WebSocketMessageType.Close, }, cancellationToken); _state = WebSocketState.CloseSent; _output.TryComplete(); } public override async Task CloseOutputAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { await SendMessageAsync(new WebSocketMessage { CloseStatus = closeStatus, CloseStatusDescription = statusDescription, MessageType = WebSocketMessageType.Close, }, cancellationToken); _state = WebSocketState.CloseSent; _output.TryComplete(); } public override void Dispose() { _state = WebSocketState.Closed; _output.TryComplete(); } public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { try { if (_internalBuffer.Buffer == null || _internalBuffer.Buffer.Length == 0) { await _input.WaitToReadAsync(cancellationToken); if (_input.TryRead(out var message)) { if (message.MessageType == WebSocketMessageType.Close) { _state = WebSocketState.CloseReceived; _closeStatus = message.CloseStatus; _closeStatusDescription = message.CloseStatusDescription; return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); } _internalBuffer = message; } } var length = _internalBuffer.Buffer.Length; if (buffer.Count - buffer.Offset < _internalBuffer.Buffer.Length) { length = Math.Min(buffer.Count - buffer.Offset, _internalBuffer.Buffer.Length); Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); } else { Buffer.BlockCopy(_internalBuffer.Buffer, 0, buffer.Array, buffer.Offset, length); } var endOfMessage = _internalBuffer.EndOfMessage; if (length > 0) { // Remove the sent bytes from the remaining buffer _internalBuffer.Buffer = _internalBuffer.Buffer.AsMemory().Slice(length).ToArray(); endOfMessage = _internalBuffer.Buffer.Length == 0 && endOfMessage; } return new WebSocketReceiveResult(length, _internalBuffer.MessageType, endOfMessage); } catch (WebSocketException ex) { switch (ex.WebSocketErrorCode) { case WebSocketError.ConnectionClosedPrematurely: _state = WebSocketState.Aborted; break; } // Complete the client side if there's an error _output.TryComplete(); throw; } throw new InvalidOperationException("Unexpected close"); } public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) { var copy = new byte[buffer.Count]; Buffer.BlockCopy(buffer.Array, buffer.Offset, copy, 0, buffer.Count); return SendMessageAsync(new WebSocketMessage { Buffer = copy, MessageType = messageType, EndOfMessage = endOfMessage }, cancellationToken); } public async Task ExecuteAndCaptureFramesAsync() { var frames = new List(); while (await _input.WaitToReadAsync()) { while (_input.TryRead(out var message)) { if (message.MessageType == WebSocketMessageType.Close) { _state = WebSocketState.CloseReceived; _closeStatus = message.CloseStatus; _closeStatusDescription = message.CloseStatusDescription; return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, message.MessageType, message.EndOfMessage, message.CloseStatus, message.CloseStatusDescription)); } frames.Add(message); } } _state = WebSocketState.Closed; _closeStatus = WebSocketCloseStatus.InternalServerError; return new WebSocketConnectionSummary(frames, new WebSocketReceiveResult(0, WebSocketMessageType.Close, endOfMessage: true, closeStatus: WebSocketCloseStatus.InternalServerError, closeStatusDescription: "")); } private async Task SendMessageAsync(WebSocketMessage webSocketMessage, CancellationToken cancellationToken) { while (await _output.WaitToWriteAsync(cancellationToken)) { if (_output.TryWrite(webSocketMessage)) { break; } } } } public class WebSocketConnectionSummary { public IList Received { get; } public WebSocketReceiveResult CloseResult { get; } public WebSocketConnectionSummary(IList received, WebSocketReceiveResult closeResult) { Received = received; CloseResult = closeResult; } } public class WebSocketMessage { public byte[] Buffer { get; set; } public WebSocketMessageType MessageType { get; set; } public bool EndOfMessage { get; set; } public WebSocketCloseStatus? CloseStatus { get; set; } public string CloseStatusDescription { get; set; } } } }