From 32baa655b93b8d64a8d25aa6ad2af288ed9abc54 Mon Sep 17 00:00:00 2001 From: David Fowler Date: Thu, 22 Feb 2018 15:19:25 -0800 Subject: [PATCH] Do over the websocket transport (#1481) * Do over the websocket transport - Unify client and server logic (no code sharing yet) - Removed use of cancellation tokens to communicate shutdown and instead used the pipe reader and socket abort. - Added CloseTimeout to HttpOptions --- .../HttpOptions.cs | 1 + .../WebSocketsTransport.cs | 210 +++++++++++------- .../Transports/WebSocketsTransport.cs | 193 +++++++++++----- .../TestWebSocketConnectionFeature.cs | 3 + .../WebSocketsTests.cs | 30 ++- 5 files changed, 291 insertions(+), 146 deletions(-) diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs index 196ab681cf..81786b5084 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpOptions.cs @@ -13,6 +13,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Http public HttpMessageHandler HttpMessageHandler { get; set; } public IReadOnlyCollection> Headers { get; set; } public Func AccessTokenFactory { get; set; } + public TimeSpan CloseTimeout { get; set; } = TimeSpan.FromSeconds(5); /// /// Gets or sets a delegate that will be invoked with the object used diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index 519b3ec88e..047117a847 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -5,7 +5,6 @@ using System; using System.Diagnostics; using System.IO.Pipelines; using System.Net.WebSockets; -using System.Runtime.InteropServices; using System.Threading; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Client.Http; @@ -19,9 +18,9 @@ namespace Microsoft.AspNetCore.Sockets.Client { private readonly ClientWebSocket _webSocket; private IDuplexPipe _application; - private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); - private readonly CancellationTokenSource _receiveCts = new CancellationTokenSource(); private readonly ILogger _logger; + private readonly TimeSpan _closeTimeout; + private volatile bool _aborted; public Task Running { get; private set; } = Task.CompletedTask; @@ -51,6 +50,7 @@ namespace Microsoft.AspNetCore.Sockets.Client httpOptions?.WebSocketOptions?.Invoke(_webSocket.Options); + _closeTimeout = httpOptions?.CloseTimeout ?? TimeSpan.FromSeconds(5); _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } @@ -77,27 +77,69 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.StartTransport(Mode.Value); await Connect(url); - var sendTask = SendMessages(); - var receiveTask = ReceiveMessages(); // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 - Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => - { - _webSocket.Dispose(); - _logger.TransportStopped(t.Exception?.InnerException); - - _application.Output.Complete(t.Exception?.InnerException); - _application.Input.Complete(); - - return t; - }).Unwrap(); + Running = ProcessSocketAsync(_webSocket); } - private async Task ReceiveMessages() + private async Task ProcessSocketAsync(WebSocket socket) { - _logger.StartReceive(); + using (socket) + { + // Begin sending and receiving. Receiving must be started first because ExecuteAsync enables SendAsync. + var receiving = StartReceiving(socket); + var sending = StartSending(socket); + // Wait for send or receive to complete + var trigger = await Task.WhenAny(receiving, sending); + + if (trigger == receiving) + { + // We're waiting for the application to finish and there are 2 things it could be doing + // 1. Waiting for application data + // 2. Waiting for a websocket send to complete + + // Cancel the application so that ReadAsync yields + _application.Input.CancelPendingRead(); + + using (var delayCts = new CancellationTokenSource()) + { + var resultTask = await Task.WhenAny(sending, Task.Delay(_closeTimeout, delayCts.Token)); + + if (resultTask != sending) + { + _aborted = true; + + // Abort the websocket if we're stuck in a pending send to the client + socket.Abort(); + } + else + { + // Cancel the timeout + delayCts.Cancel(); + } + } + } + else + { + // We're waiting on the websocket to close and there are 2 things it could be doing + // 1. Waiting for websocket data + // 2. Waiting on a flush to complete (backpressure being applied) + + _aborted = true; + + // Abort the websocket if we're stuck in a pending receive from the client + socket.Abort(); + + // Cancel any pending flush so that we can quit + _application.Output.CancelPendingFlush(); + } + } + } + + private async Task StartReceiving(WebSocket socket) + { try { while (true) @@ -105,15 +147,14 @@ namespace Microsoft.AspNetCore.Sockets.Client var memory = _application.Output.GetMemory(); #if NETCOREAPP2_1 - var receiveResult = await _webSocket.ReceiveAsync(memory, _receiveCts.Token); + var receiveResult = await socket.ReceiveAsync(memory, CancellationToken.None); #else var isArray = memory.TryGetArray(out var arraySegment); Debug.Assert(isArray); // Exceptions are handled above where the send and receive tasks are being run. - var receiveResult = await _webSocket.ReceiveAsync(arraySegment, _receiveCts.Token); + var receiveResult = await socket.ReceiveAsync(arraySegment, CancellationToken.None); #endif - if (receiveResult.MessageType == WebSocketMessageType.Close) { _logger.WebSocketClosed(_webSocket.CloseStatus); @@ -132,7 +173,14 @@ namespace Microsoft.AspNetCore.Sockets.Client if (receiveResult.EndOfMessage) { - await _application.Output.FlushAsync(_transportCts.Token); + var flushResult = await _application.Output.FlushAsync(); + + // We canceled in the middle of applying back pressure + // or if the consumer is done + if (flushResult.IsCancelled || flushResult.IsCompleted) + { + break; + } } } } @@ -140,71 +188,111 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.ReceiveCanceled(); } + catch (Exception ex) + { + if (!_aborted) + { + _application.Output.Complete(ex); + + // We re-throw here so we can communicate that there was an error when sending + // the close frame + throw; + } + } finally { // We're done writing + _application.Output.Complete(); + _logger.ReceiveStopped(); - _transportCts.Cancel(); } } - private async Task SendMessages() + private async Task StartSending(WebSocket socket) { - _logger.SendStarted(); - var webSocketMessageType = Mode == TransferMode.Binary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; + Exception error = null; + try { while (true) { - var result = await _application.Input.ReadAsync(_transportCts.Token); + var result = await _application.Input.ReadAsync(); var buffer = result.Buffer; + + // Get a frame from the application + try { + if (result.IsCancelled) + { + break; + } + if (!buffer.IsEmpty) { - _logger.ReceivedFromApp(buffer.Length); + try + { + _logger.ReceivedFromApp(buffer.Length); - await _webSocket.SendAsync(buffer, webSocketMessageType, _transportCts.Token); + if (WebSocketCanSend(socket)) + { + await socket.SendAsync(buffer, webSocketMessageType); + } + else + { + break; + } + } + catch (Exception ex) + { + if (!_aborted) + { + _logger.ErrorSendingMessage(ex); + } + break; + } } else if (result.IsCompleted) { break; } } - catch (OperationCanceledException) - { - _logger.SendMessageCanceled(); - await CloseWebSocket(); - break; - } - catch (Exception ex) - { - _logger.ErrorSendingMessage(ex); - await CloseWebSocket(); - throw; - } finally { _application.Input.AdvanceTo(buffer.End); } } } - catch (OperationCanceledException) + catch (Exception ex) { - _logger.SendCanceled(); + error = ex; } finally { + if (WebSocketCanSend(socket)) + { + // We're done sending, send the close frame to the client if the websocket is still open + await socket.CloseOutputAsync(error != null ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + + _application.Input.Complete(); + _logger.SendStopped(); - TriggerCancel(); } } + private static bool WebSocketCanSend(WebSocket ws) + { + return !(ws.State == WebSocketState.Aborted || + ws.State == WebSocketState.Closed || + ws.State == WebSocketState.CloseSent); + } + private async Task Connect(Uri url) { var uriBuilder = new UriBuilder(url); @@ -224,7 +312,8 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.TransportStopping(); - await CloseWebSocket(); + // Cancel any pending reads from the application, this should start the entire shutdown process + _application.Input.CancelPendingRead(); try { @@ -235,38 +324,5 @@ namespace Microsoft.AspNetCore.Sockets.Client // exceptions have been handled in the Running task continuation by closing the channel with the exception } } - - private async Task CloseWebSocket() - { - try - { - // Best effort - it's still possible (but not likely) that the transport is being closed via StopAsync - // while the webSocket is being closed due to an error. - if (_webSocket.State != WebSocketState.Closed) - { - _logger.ClosingWebSocket(); - - // We intentionally don't pass _transportCts.Token to CloseOutputAsync. The token can be cancelled - // for reasons not related to webSocket in which case we would not close the websocket gracefully. - await _webSocket.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, null, CancellationToken.None); - - // shutdown the transport after a timeout in case the server does not send close frame - TriggerCancel(); - } - } - catch (Exception ex) - { - // This is benign - the exception can happen due to the race described above because we would - // try closing the webSocket twice. - _logger.ClosingWebSocketFailed(ex); - } - } - - private void TriggerCancel() - { - // Give server 5 seconds to respond with a close frame for graceful close. - _receiveCts.CancelAfter(TimeSpan.FromSeconds(5)); - _transportCts.Cancel(); - } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs index a6edb00588..a5ae47d6ba 100644 --- a/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Http/Internal/Transports/WebSocketsTransport.cs @@ -19,6 +19,7 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports private readonly ILogger _logger; private readonly IDuplexPipe _application; private readonly DefaultConnectionContext _connection; + private volatile bool _aborted; public WebSocketsTransport(WebSocketOptions options, IDuplexPipe application, DefaultConnectionContext connection, ILoggerFactory loggerFactory) { @@ -68,41 +69,68 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports var receiving = StartReceiving(socket); var sending = StartSending(socket); - // Wait for something to shut down. - var trigger = await Task.WhenAny( - receiving, - sending); + // Wait for send or receive to complete + var trigger = await Task.WhenAny(receiving, sending); - var failed = trigger.IsCanceled || trigger.IsFaulted; - var task = Task.CompletedTask; if (trigger == receiving) { - task = sending; _logger.WaitingForSend(); + + // We're waiting for the application to finish and there are 2 things it could be doing + // 1. Waiting for application data + // 2. Waiting for a websocket send to complete + + // Cancel the application so that ReadAsync yields + _application.Input.CancelPendingRead(); + + using (var delayCts = new CancellationTokenSource()) + { + var resultTask = await Task.WhenAny(sending, Task.Delay(_options.CloseTimeout, delayCts.Token)); + + if (resultTask != sending) + { + // We timed out so now we're in ungraceful shutdown mode + _logger.CloseTimedOut(); + + // Abort the websocket if we're stuck in a pending send to the client + _aborted = true; + + socket.Abort(); + } + else + { + delayCts.Cancel(); + } + } } else { - task = receiving; _logger.WaitingForClose(); + + // We're waiting on the websocket to close and there are 2 things it could be doing + // 1. Waiting for websocket data + // 2. Waiting on a flush to complete (backpressure being applied) + + using (var delayCts = new CancellationTokenSource()) + { + var resultTask = await Task.WhenAny(receiving, Task.Delay(_options.CloseTimeout, delayCts.Token)); + + if (resultTask != receiving) + { + // Abort the websocket if we're stuck in a pending receive from the client + _aborted = true; + + socket.Abort(); + + // Cancel any pending flush so that we can quit + _application.Output.CancelPendingFlush(); + } + else + { + delayCts.Cancel(); + } + } } - - await socket.CloseOutputAsync(failed ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - - var resultTask = await Task.WhenAny(task, Task.Delay(_options.CloseTimeout)); - - if (resultTask != task) - { - _logger.CloseTimedOut(); - socket.Abort(); - } - else - { - // Observe any exceptions from second completed task - task.GetAwaiter().GetResult(); - } - - // Observe any exceptions from original completed task - trigger.GetAwaiter().GetResult(); } private async Task StartReceiving(WebSocket socket) @@ -133,10 +161,32 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports if (receiveResult.EndOfMessage) { - await _application.Output.FlushAsync(); + var flushResult = await _application.Output.FlushAsync(); + + // We canceled in the middle of applying back pressure + // or if the consumer is done + if (flushResult.IsCancelled || flushResult.IsCompleted) + { + break; + } } } } + catch (OperationCanceledException) + { + // Ignore aborts, don't treat them like transport errors + } + catch (Exception ex) + { + if (!_aborted) + { + _application.Output.Complete(ex); + + // We re-throw here so we can communicate that there was an error when sending + // the close frame + throw; + } + } finally { // We're done writing @@ -144,54 +194,81 @@ namespace Microsoft.AspNetCore.Sockets.Internal.Transports } } - private async Task StartSending(WebSocket ws) + private async Task StartSending(WebSocket socket) { - while (true) + Exception error = null; + + try { - var result = await _application.Input.ReadAsync(); - var buffer = result.Buffer; - - // Get a frame from the application - - try + while (true) { - if (!buffer.IsEmpty) + var result = await _application.Input.ReadAsync(); + var buffer = result.Buffer; + + // Get a frame from the application + + try { - try + if (result.IsCancelled) { - _logger.SendPayload(buffer.Length); + break; + } - var webSocketMessageType = (_connection.TransferMode == TransferMode.Binary - ? WebSocketMessageType.Binary - : WebSocketMessageType.Text); - - if (WebSocketCanSend(ws)) + if (!buffer.IsEmpty) + { + try { - await ws.SendAsync(buffer, webSocketMessageType); + _logger.SendPayload(buffer.Length); + + var webSocketMessageType = (_connection.TransferMode == TransferMode.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text); + + if (WebSocketCanSend(socket)) + { + await socket.SendAsync(buffer, webSocketMessageType); + } + else + { + break; + } + } + catch (Exception ex) + { + if (!_aborted) + { + _logger.ErrorWritingFrame(ex); + } + break; } } - catch (WebSocketException socketException) when (!WebSocketCanSend(ws)) + else if (result.IsCompleted) { - // this can happen when we send the CloseFrame to the client and try to write afterwards - _logger.SendFailed(socketException); - break; - } - catch (Exception ex) - { - _logger.ErrorWritingFrame(ex); break; } } - else if (result.IsCompleted) + finally { - break; + _application.Input.AdvanceTo(buffer.End); } } - finally - { - _application.Input.AdvanceTo(buffer.End); - } } + catch (Exception ex) + { + error = ex; + } + finally + { + // Send the close frame before calling into user code + if (WebSocketCanSend(socket)) + { + // We're done sending, send the close frame to the client if the websocket is still open + await socket.CloseOutputAsync(error != null ? WebSocketCloseStatus.InternalServerError : WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); + } + + _application.Input.Complete(); + } + } private static bool WebSocketCanSend(WebSocket ws) diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs index ea085b939f..80b1bfa114 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/TestWebSocketConnectionFeature.cs @@ -134,6 +134,9 @@ namespace Microsoft.AspNetCore.Sockets.Tests break; } + // Complete the client side if there's an error + _output.TryComplete(); + throw; } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index ded1a6bd09..c0e328dff9 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -108,7 +108,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests } [Fact] - public async Task TransportFailsWhenClientDisconnectsAbnormally() + public async Task TransportCommunicatesErrorToApplicationWhenClientDisconnectsAbnormally() { using (StartLog(out var loggerFactory, LogLevel.Debug)) { @@ -119,12 +119,21 @@ namespace Microsoft.AspNetCore.Sockets.Tests { async Task CompleteApplicationAfterTransportCompletes() { - // Wait until the transport completes so that we can end the application - var result = await connection.Transport.Input.ReadAsync(); - connection.Transport.Input.AdvanceTo(result.Buffer.End); - - // Complete the application so that the connection unwinds without aborting - connection.Transport.Output.Complete(); + try + { + // Wait until the transport completes so that we can end the application + var result = await connection.Transport.Input.ReadAsync(); + connection.Transport.Input.AdvanceTo(result.Buffer.End); + } + catch (Exception ex) + { + Assert.IsType(ex); + } + finally + { + // Complete the application so that the connection unwinds without aborting + connection.Transport.Output.Complete(); + } } var connectionContext = new DefaultConnectionContext(string.Empty, null, null); @@ -144,7 +153,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests feature.Client.SendAbort(); // Wait for the transport - await Assert.ThrowsAsync(() => transport).OrTimeout(); + await transport.OrTimeout(); await client.OrTimeout(); } @@ -178,8 +187,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests // Close from the client await feature.Client.CloseAsync(WebSocketCloseStatus.NormalClosure, "", CancellationToken.None); - var ex = await Assert.ThrowsAsync(() => transport.OrTimeout()); - Assert.Equal("Catastrophic failure.", ex.Message); + await transport.OrTimeout(); } } } @@ -247,7 +255,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests // fail the client to server channel connection.Transport.Output.Complete(new Exception()); - await Assert.ThrowsAsync(() => transport).OrTimeout(); + await transport.OrTimeout(); Assert.Equal(WebSocketState.Aborted, serverSocket.State); }