diff --git a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs index 8e23595710..fdfdf3a94b 100644 --- a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs +++ b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs @@ -2,13 +2,13 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.WebSockets.Internal; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.WebSockets.Internal; namespace Microsoft.AspNetCore.Sockets @@ -24,10 +24,18 @@ namespace Microsoft.AspNetCore.Sockets public WebSockets(Connection connection, Format format, ILoggerFactory loggerFactory) { + if (connection == null) + { + throw new ArgumentNullException(nameof(connection)); + } + if (loggerFactory == null) + { + throw new ArgumentNullException(nameof(loggerFactory)); + } + _channel = (HttpConnection)connection.Channel; _opcode = format == Format.Binary ? WebSocketOpcode.Binary : WebSocketOpcode.Text; - - _logger = (ILogger)loggerFactory?.CreateLogger() ?? NullLogger.Instance; + _logger = loggerFactory.CreateLogger(); } public async Task ProcessRequestAsync(HttpContext context) @@ -43,46 +51,64 @@ namespace Microsoft.AspNetCore.Sockets { _logger.LogInformation("Socket opened."); - // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. - var receiving = ws.ExecuteAsync((frame, state) => ((WebSockets)state).HandleFrame(frame), this); - var sending = StartSending(ws); - - // Wait for something to shut down. - var trigger = await Task.WhenAny( - receiving, - sending); - - // What happened? - if (trigger == receiving) - { - // Shutting down because we received a close frame from the client. - // Complete the input writer so that the application knows there won't be any more input. - _logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description); - _channel.Input.CompleteWriter(); - - // Wait for the application to finish sending. - _logger.LogDebug("Waiting for the application to finish sending data"); - await sending; - - // Send the server's close frame - await ws.CloseAsync(WebSocketCloseStatus.NormalClosure); - } - else - { - // The application finished sending. We're not going to keep the connection open, - // so close it and wait for the client to ack the close - _channel.Input.CompleteWriter(); - _logger.LogDebug("Application finished sending. Sending close frame."); - await ws.CloseAsync(WebSocketCloseStatus.NormalClosure); - - _logger.LogDebug("Waiting for the client to close the socket"); - // TODO: Timeout. - await receiving; - } + await ProcessSocketAsync(ws); } _logger.LogInformation("Socket closed."); } + public async Task ProcessSocketAsync(IWebSocketConnection socket) + { + // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + var receiving = socket.ExecuteAsync((frame, state) => ((WebSockets)state).HandleFrame(frame), this); + var sending = StartSending(socket); + + // Wait for something to shut down. + var trigger = await Task.WhenAny( + receiving, + sending); + + // What happened? + if (trigger == receiving) + { + if (receiving.IsCanceled || receiving.IsFaulted) + { + // The receiver faulted or cancelled. This means the client is probably broken. Just propagate the exception and exit + receiving.GetAwaiter().GetResult(); + + // Should never get here because GetResult above will throw + Debug.Fail("GetResult didn't throw?"); + return; + } + + // Shutting down because we received a close frame from the client. + // Complete the input writer so that the application knows there won't be any more input. + _logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description); + _channel.Input.CompleteWriter(); + + // Wait for the application to finish sending. + _logger.LogDebug("Waiting for the application to finish sending data"); + await sending; + + // Send the server's close frame + await socket.CloseAsync(WebSocketCloseStatus.NormalClosure); + } + else + { + var failed = sending.IsFaulted || sending.IsCompleted; + + // The application finished sending. Close our end of the connection + _logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame"); + await socket.CloseAsync(!failed ? WebSocketCloseStatus.NormalClosure : WebSocketCloseStatus.InternalServerError); + + _logger.LogDebug("Waiting for the client to close the socket"); + + // Wait for the client to close. + // TODO: Consider timing out here and cancelling the receive loop. + await receiving; + _channel.Input.CompleteWriter(); + } + } + private Task HandleFrame(WebSocketFrame frame) { // Is this a frame we care about? diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs new file mode 100644 index 0000000000..79b7e3ebc7 --- /dev/null +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -0,0 +1,185 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Text; +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.WebSockets.Internal; +using Microsoft.Extensions.WebSockets.Internal.Tests; +using Xunit; + +namespace Microsoft.AspNetCore.Sockets.Tests +{ + public class WebSocketsTests + { + [Fact] + public async Task ReceivedFramesAreWrittenToPipeline() + { + using (var factory = new PipelineFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var connection = new Connection(); + connection.ConnectionId = Guid.NewGuid().ToString(); + var httpConnection = new HttpConnection(factory); + connection.Channel = httpConnection; + var ws = new WebSockets(connection, Format.Text, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // Run the client socket + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + // Send a frame, then close + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")))); + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + + // Capture everything out of the input channel and then complete the writer (to do our end of the close) + var buffer = (await connection.Channel.Input.ReadToEndAsync()).ToArray(); + httpConnection.Output.CompleteWriter(); + + // The transport should finish now + await transport; + + // The connection should close after this, which means the client will get a close frame. + var clientSummary = await client; + + // Read from the connection pipeline + Assert.Equal("Hello", Encoding.UTF8.GetString(buffer)); + Assert.Equal(WebSocketCloseStatus.NormalClosure, clientSummary.CloseResult.Status); + } + } + + [Theory] + [InlineData(Format.Text, WebSocketOpcode.Text)] + [InlineData(Format.Binary, WebSocketOpcode.Binary)] + public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode expectedOpcode) + { + using (var factory = new PipelineFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var connection = new Connection(); + connection.ConnectionId = Guid.NewGuid().ToString(); + var httpConnection = new HttpConnection(factory); + connection.Channel = httpConnection; + var ws = new WebSockets(connection, format, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // Run the client socket + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + // Write to the output channel, and then complete it + await httpConnection.Output.WriteAsync(Encoding.UTF8.GetBytes("Hello")); + httpConnection.Output.CompleteWriter(); + + // The client should finish now, as should the server + var clientSummary = await client; + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + await transport; + + Assert.Equal(1, clientSummary.Received.Count); + Assert.True(clientSummary.Received[0].EndOfMessage); + Assert.Equal(expectedOpcode, clientSummary.Received[0].Opcode); + Assert.Equal("Hello", Encoding.UTF8.GetString(clientSummary.Received[0].Payload.ToArray())); + } + } + + [Fact] + public async Task FrameReceivedAfterServerCloseSent() + { + using (var factory = new PipelineFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var connection = new Connection(); + connection.ConnectionId = Guid.NewGuid().ToString(); + var httpConnection = new HttpConnection(factory); + connection.Channel = httpConnection; + var ws = new WebSockets(connection, Format.Binary, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // Run the client socket + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + // Close the output and wait for the close frame + httpConnection.Output.CompleteWriter(); + await client; + + // Send another frame. Then close + await pair.ClientSocket.SendAsync(new WebSocketFrame( + endOfMessage: true, + opcode: WebSocketOpcode.Text, + payload: ReadableBuffer.Create(Encoding.UTF8.GetBytes("Hello")))); + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + + // Read that frame from the input + var result = (await httpConnection.Input.ReadToEndAsync()).ToArray(); + Assert.Equal("Hello", Encoding.UTF8.GetString(result)); + + await transport; + } + } + + [Fact] + public async Task TransportFailsWhenClientDisconnectsAbnormally() + { + using (var factory = new PipelineFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var connection = new Connection(); + connection.ConnectionId = Guid.NewGuid().ToString(); + var httpConnection = new HttpConnection(factory); + connection.Channel = httpConnection; + var ws = new WebSockets(connection, Format.Binary, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // Run the client socket + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + // Terminate the client to server channel with an exception + pair.TerminateFromClient(new InvalidOperationException()); + + // Wait for the transport + await Assert.ThrowsAsync(() => transport); + } + } + + [Fact] + public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails() + { + using (var factory = new PipelineFactory()) + using (var pair = WebSocketPair.Create(factory)) + { + var connection = new Connection(); + connection.ConnectionId = Guid.NewGuid().ToString(); + var httpConnection = new HttpConnection(factory); + connection.Channel = httpConnection; + var ws = new WebSockets(connection, Format.Binary, new LoggerFactory()); + + // Give the server socket to the transport and run it + var transport = ws.ProcessSocketAsync(pair.ServerSocket); + + // Run the client socket + var client = pair.ClientSocket.ExecuteAndCaptureFramesAsync(); + + // Fail in the app + httpConnection.Output.CompleteWriter(new InvalidOperationException()); + var clientSummary = await client; + Assert.Equal(WebSocketCloseStatus.InternalServerError, clientSummary.CloseResult.Status); + + // Close from the client + await pair.ClientSocket.CloseAsync(WebSocketCloseStatus.NormalClosure); + } + } + } +} diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/project.json b/test/Microsoft.AspNetCore.Sockets.Tests/project.json index dc5d40fd3f..9fac3b002f 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/project.json +++ b/test/Microsoft.AspNetCore.Sockets.Tests/project.json @@ -1,6 +1,11 @@ { "buildOptions": { - "warningsAsErrors": true + "warningsAsErrors": true, + "compile": [ + "../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionExtensions.cs", + "../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionSummary.cs", + "../Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs" + ] }, "dependencies": { diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs index 1a7237b953..4d09cccad7 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs @@ -8,7 +8,11 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests { internal class WebSocketPair : IDisposable { + private static readonly WebSocketOptions DefaultServerOptions = new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(); + private static readonly WebSocketOptions DefaultClientOptions = new WebSocketOptions().WithAllFramesPassedThrough(); + private PipelineFactory _factory; + private readonly bool _ownFactory; public PipelineReaderWriter ServerToClient { get; } public PipelineReaderWriter ClientToServer { get; } @@ -16,8 +20,9 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests public IWebSocketConnection ClientSocket { get; } public IWebSocketConnection ServerSocket { get; } - public WebSocketPair(PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) + public WebSocketPair(bool ownFactory, PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) { + _ownFactory = ownFactory; _factory = factory; ServerToClient = serverToClient; ClientToServer = clientToServer; @@ -25,26 +30,32 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests ServerSocket = serverSocket; } - public static WebSocketPair Create() => Create(new WebSocketOptions().WithAllFramesPassedThrough().WithRandomMasking(), new WebSocketOptions().WithAllFramesPassedThrough()); + public static WebSocketPair Create() => Create(new PipelineFactory(), DefaultServerOptions, DefaultClientOptions, ownFactory: true); + public static WebSocketPair Create(PipelineFactory factory) => Create(factory, DefaultServerOptions, DefaultClientOptions, ownFactory: false); + public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(new PipelineFactory(), serverOptions, clientOptions, ownFactory: true); + public static WebSocketPair Create(PipelineFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions) => Create(factory, serverOptions, clientOptions, ownFactory: false); - public static WebSocketPair Create(WebSocketOptions serverOptions, WebSocketOptions clientOptions) + private static WebSocketPair Create(PipelineFactory factory, WebSocketOptions serverOptions, WebSocketOptions clientOptions, bool ownFactory) { // Create channels - var factory = new PipelineFactory(); var serverToClient = factory.Create(); var clientToServer = factory.Create(); var serverSocket = new WebSocketConnection(clientToServer, serverToClient, options: serverOptions); var clientSocket = new WebSocketConnection(serverToClient, clientToServer, options: clientOptions); - return new WebSocketPair(factory, serverToClient, clientToServer, clientSocket, serverSocket); + return new WebSocketPair(ownFactory, factory, serverToClient, clientToServer, clientSocket, serverSocket); } public void Dispose() { ServerSocket.Dispose(); ClientSocket.Dispose(); - _factory.Dispose(); + + if (_ownFactory) + { + _factory.Dispose(); + } } public void TerminateFromClient(Exception ex = null)