add tests to WebSockets transport (#35)

* add tests to WebSockets transport
* adds some error handling
* make logger factory required
* allow frames to be received after the application closes the output
This commit is contained in:
Andrew Stanton-Nurse 2016-11-23 11:26:12 -08:00 committed by GitHub
parent 638b4b5fc4
commit 940ccf5c65
4 changed files with 273 additions and 46 deletions

View File

@ -2,13 +2,13 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.WebSockets.Internal;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Microsoft.Extensions.WebSockets.Internal;
namespace Microsoft.AspNetCore.Sockets
@ -24,10 +24,18 @@ namespace Microsoft.AspNetCore.Sockets
public WebSockets(Connection connection, Format format, ILoggerFactory loggerFactory)
{
if (connection == null)
{
throw new ArgumentNullException(nameof(connection));
}
if (loggerFactory == null)
{
throw new ArgumentNullException(nameof(loggerFactory));
}
_channel = (HttpConnection)connection.Channel;
_opcode = format == Format.Binary ? WebSocketOpcode.Binary : WebSocketOpcode.Text;
_logger = (ILogger)loggerFactory?.CreateLogger<WebSockets>() ?? NullLogger.Instance;
_logger = loggerFactory.CreateLogger<WebSockets>();
}
public async Task ProcessRequestAsync(HttpContext context)
@ -43,46 +51,64 @@ namespace Microsoft.AspNetCore.Sockets
{
_logger.LogInformation("Socket opened.");
// Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync.
var receiving = ws.ExecuteAsync((frame, state) => ((WebSockets)state).HandleFrame(frame), this);
var sending = StartSending(ws);
// Wait for something to shut down.
var trigger = await Task.WhenAny(
receiving,
sending);
// What happened?
if (trigger == receiving)
{
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description);
_channel.Input.CompleteWriter();
// Wait for the application to finish sending.
_logger.LogDebug("Waiting for the application to finish sending data");
await sending;
// Send the server's close frame
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure);
}
else
{
// The application finished sending. We're not going to keep the connection open,
// so close it and wait for the client to ack the close
_channel.Input.CompleteWriter();
_logger.LogDebug("Application finished sending. Sending close frame.");
await ws.CloseAsync(WebSocketCloseStatus.NormalClosure);
_logger.LogDebug("Waiting for the client to close the socket");
// TODO: Timeout.
await receiving;
}
await ProcessSocketAsync(ws);
}
_logger.LogInformation("Socket closed.");
}
public async Task ProcessSocketAsync(IWebSocketConnection socket)
{
// Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync.
var receiving = socket.ExecuteAsync((frame, state) => ((WebSockets)state).HandleFrame(frame), this);
var sending = StartSending(socket);
// Wait for something to shut down.
var trigger = await Task.WhenAny(
receiving,
sending);
// What happened?
if (trigger == receiving)
{
if (receiving.IsCanceled || receiving.IsFaulted)
{
// The receiver faulted or cancelled. This means the client is probably broken. Just propagate the exception and exit
receiving.GetAwaiter().GetResult();
// Should never get here because GetResult above will throw
Debug.Fail("GetResult didn't throw?");
return;
}
// Shutting down because we received a close frame from the client.
// Complete the input writer so that the application knows there won't be any more input.
_logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description);
_channel.Input.CompleteWriter();
// Wait for the application to finish sending.
_logger.LogDebug("Waiting for the application to finish sending data");
await sending;
// Send the server's close frame
await socket.CloseAsync(WebSocketCloseStatus.NormalClosure);
}
else
{
var failed = sending.IsFaulted || sending.IsCompleted;
// The application finished sending. Close our end of the connection
_logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame");
await socket.CloseAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError);
_logger.LogDebug("Waiting for the client to close the socket");
// Wait for the client to close.
// TODO: Consider timing out here and cancelling the receive loop.
await receiving;
_channel.Input.CompleteWriter();
}
}
private Task HandleFrame(WebSocketFrame frame)
{
// Is this a frame we care about?

View File

@ -0,0 +1,185 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Text;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.WebSockets.Internal;
using Microsoft.Extensions.WebSockets.Internal.Tests;
using Xunit;
namespace Microsoft.AspNetCore.Sockets.Tests
{
public class WebSocketsTests
{
[Fact]
public async Task ReceivedFramesAreWrittenToPipeline()
{
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Text, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Send a frame, then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
// Capture everything out of the input channel and then complete the writer (to do our end of the close)
var buffer = (await connection.Channel.Input.ReadToEndAsync()).ToArray();
httpConnection.Output.CompleteWriter();
// The transport should finish now
await transport;
// The connection should close after this, which means the client will get a close frame.
var clientSummary = await client;
// Read from the connection pipeline
Assert.Equal("Hello", Encoding.UTF8.GetString(buffer));
Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status);
}
}
[Theory]
[InlineData(Format.Text, WebSocketOpcode.Text)]
[InlineData(Format.Binary, WebSocketOpcode.Binary)]
public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode expectedOpcode)
{
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, format, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Write to the output channel, and then complete it
await httpConnection.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello"));
httpConnection.Output.CompleteWriter();
// The client should finish now, as should the server
var clientSummary = await client;
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
await transport;
Assert.Equal(1, clientSummary.Received.Count);
Assert.True(clientSummary.Received[0].EndOfMessage);
Assert.Equal(expectedOpcode, clientSummary.Received[0].Opcode);
Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray()));
}
}
[Fact]
public async Task FrameReceivedAfterServerCloseSent()
{
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Close the output and wait for the close frame
httpConnection.Output.CompleteWriter();
await client;
// Send another frame. Then close
await pair.ClientSocket.SendAsync(new WebSocketFrame(
endOfMessage: true,
opcode: WebSocketOpcode.Text,
payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello"))));
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
// Read that frame from the input
var result = (await httpConnection.Input.ReadToEndAsync()).ToArray();
Assert.Equal("Hello", Encoding.UTF8.GetString(result));
await transport;
}
}
[Fact]
public async Task TransportFailsWhenClientDisconnectsAbnormally()
{
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Terminate the client to server channel with an exception
pair.TerminateFromClient(new InvalidOperationException());
// Wait for the transport
await Assert.ThrowsAsync<InvalidOperationException>(() => transport);
}
}
[Fact]
public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails()
{
using (var factory = new PipelineFactory())
using (var pair = WebSocketPair.Create(factory))
{
var connection = new Connection();
connection.ConnectionId = Guid.NewGuid().ToString();
var httpConnection = new HttpConnection(factory);
connection.Channel = httpConnection;
var ws = new WebSockets(connection, Format.Binary, new LoggerFactory());
// Give the server socket to the transport and run it
var transport = ws.ProcessSocketAsync(pair.ServerSocket);
// Run the client socket
var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync();
// Fail in the app
httpConnection.Output.CompleteWriter(new InvalidOperationException());
var clientSummary = await client;
Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.Status);
// Close from the client
await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure);
}
}
}
}

View File

@ -1,6 +1,11 @@
{
"buildOptions": {
"warningsAsErrors": true
"warningsAsErrors": true,
"compile": [
"../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs",
"../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs",
"../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs"
]
},
"dependencies": {

View File

@ -8,7 +8,11 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
{
internal class WebSocketPair : IDisposable
{
private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking();
private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough();
private PipelineFactory _factory;
private readonly bool _ownFactory;
public PipelineReaderWriter ServerToClient { get; }
public PipelineReaderWriter ClientToServer { get; }
@ -16,8 +20,9 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
public IWebSocketConnection ClientSocket { get; }
public IWebSocketConnection ServerSocket { get; }
public WebSocketPair(PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket)
public WebSocketPair(bool ownFactory, PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket)
{
_ownFactory = ownFactory;
_factory = factory;
ServerToClient = serverToClient;
ClientToServer = clientToServer;
@ -25,26 +30,32 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests
ServerSocket = serverSocket;
}
public static WebSocketPair Create() => Create(new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(), new WebSocketOptions().WithAllFramesPassedThrough());
public static WebSocketPair Create() => Create(new PipelineFactory(), DefaultServerOptions, DefaultClientOptions, ownFactory: true);
public static WebSocketPair Create(PipelineFactory factory) => Create(factory, DefaultServerOptions, DefaultClientOptions, ownFactory: false);
public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(new PipelineFactory(), serverOptions, clientOptions, ownFactory: true);
public static WebSocketPair Create(PipelineFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(factory, serverOptions, clientOptions, ownFactory: false);
public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions)
private static WebSocketPair Create(PipelineFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions, bool ownFactory)
{
// Create channels
var factory = new PipelineFactory();
var serverToClient = factory.Create();
var clientToServer = factory.Create();
var serverSocket = new WebSocketConnection(clientToServer, serverToClient, options: serverOptions);
var clientSocket = new WebSocketConnection(serverToClient, clientToServer, options: clientOptions);
return new WebSocketPair(factory, serverToClient, clientToServer, clientSocket, serverSocket);
return new WebSocketPair(ownFactory, factory, serverToClient, clientToServer, clientSocket, serverSocket);
}
public void Dispose()
{
ServerSocket.Dispose();
ClientSocket.Dispose();
_factory.Dispose();
if (_ownFactory)
{
_factory.Dispose();
}
}
public void TerminateFromClient(Exception ex = null)