diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs index d637b66313..29b5cac70f 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs @@ -62,6 +62,11 @@ namespace Microsoft.AspNetCore.Sockets.Tests _state = WebSocketState.Aborted; } + public void SendAbort() + { + _output.TryComplete(new WebSocketException(WebSocketError.ConnectionClosedPrematurely)); + } + public override async Task CloseAsync(WebSocketCloseStatus closeStatus, string statusDescription, CancellationToken cancellationToken) { await SendMessageAsync(new WebSocketMessage @@ -100,20 +105,39 @@ namespace Microsoft.AspNetCore.Sockets.Tests public override async Task ReceiveAsync(ArraySegment buffer, CancellationToken cancellationToken) { - var message = await _input.ReadAsync(); - - if (message.MessageType == WebSocketMessageType.Close) + try { - _state = WebSocketState.CloseReceived; - _closeStatus = message.CloseStatus; - _closeStatusDescription = message.CloseStatusDescription; - return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); + await _input.WaitToReadAsync(); + + if (_input.TryRead(out var message)) + { + if (message.MessageType == WebSocketMessageType.Close) + { + _state = WebSocketState.CloseReceived; + _closeStatus = message.CloseStatus; + _closeStatusDescription = message.CloseStatusDescription; + return new WebSocketReceiveResult(0, WebSocketMessageType.Close, true, message.CloseStatus, message.CloseStatusDescription); + } + + // REVIEW: This assumes the buffer passed in is > the buffer received + Buffer.BlockCopy(message.Buffer, 0, buffer.Array, buffer.Offset, message.Buffer.Length); + + return new WebSocketReceiveResult(message.Buffer.Length, message.MessageType, message.EndOfMessage); + } + } + catch (WebSocketException ex) + { + switch (ex.WebSocketErrorCode) + { + case WebSocketError.ConnectionClosedPrematurely: + _state = WebSocketState.Aborted; + break; + } + + throw; } - // REVIEW: This assumes the buffer passed in is > the buffer received - Buffer.BlockCopy(message.Buffer, 0, buffer.Array, buffer.Offset, message.Buffer.Length); - - return new WebSocketReceiveResult(message.Buffer.Length, message.MessageType, message.EndOfMessage); + throw new InvalidOperationException("Unexpected close"); } public override Task SendAsync(ArraySegment buffer, WebSocketMessageType messageType, bool endOfMessage, CancellationToken cancellationToken) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 40ca21e0a9..26809f5665 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -106,12 +106,16 @@ namespace Microsoft.AspNetCore.Sockets.Tests using (var applicationSide = ChannelConnection.Create(transportToApplication, applicationToTransport)) using (var feature = new TestWebSocketConnectionFeature()) { - var options = new WebSocketOptions() + async Task CompleteApplicationAfterTransportCompletes() { - CloseTimeout = TimeSpan.FromMilliseconds(100) - }; + // Wait until the transport completes so that we can end the application + await applicationSide.In.WaitToReadAsync(); - var ws = new WebSocketsTransport(options, transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); + // Complete the application so that the connection unwinds without aborting + applicationSide.Out.TryComplete(); + } + + var ws = new WebSocketsTransport(new WebSocketOptions(), transportSide, connectionId: string.Empty, loggerFactory: new LoggerFactory()); // Give the server socket to the transport and run it var transport = ws.ProcessSocketAsync(await feature.AcceptAsync()); @@ -119,11 +123,15 @@ namespace Microsoft.AspNetCore.Sockets.Tests // Run the client socket var client = feature.Client.ExecuteAndCaptureFramesAsync(); + // When the close frame is received, we complete the application so the send + // loop unwinds + _ = CompleteApplicationAfterTransportCompletes(); + // Terminate the client to server channel with an exception - feature.Client.Abort(); + feature.Client.SendAbort(); // Wait for the transport - await Assert.ThrowsAsync(() => transport).OrTimeout(); + await Assert.ThrowsAsync(() => transport).OrTimeout(); var summary = await client.OrTimeout(); Assert.Equal(WebSocketCloseStatus.InternalServerError, summary.CloseResult.CloseStatus);