diff --git a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs index 16aca5ecf7..16cf6b93d3 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs @@ -14,9 +14,9 @@ namespace Microsoft.AspNetCore.Sockets.Client { public class WebSocketsTransport : ITransport { - private ClientWebSocket _webSocket = new ClientWebSocket(); + private readonly ClientWebSocket _webSocket = new ClientWebSocket(); private IChannelConnection _application; - private CancellationToken _cancellationToken = new CancellationToken(); + private readonly CancellationTokenSource _transportCts = new CancellationTokenSource(); private readonly ILogger _logger; public WebSocketsTransport() @@ -48,8 +48,8 @@ namespace Microsoft.AspNetCore.Sockets.Client _application = application; await Connect(url); - var sendTask = SendMessages(url, _cancellationToken); - var receiveTask = ReceiveMessages(url, _cancellationToken); + var sendTask = SendMessages(url); + var receiveTask = ReceiveMessages(url); // TODO: Handle TCP connection errors // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 @@ -62,112 +62,131 @@ namespace Microsoft.AspNetCore.Sockets.Client }).Unwrap(); } - private async Task ReceiveMessages(Uri pollUrl, CancellationToken cancellationToken) + private async Task ReceiveMessages(Uri pollUrl) { _logger.LogInformation("Starting receive loop"); - while (!cancellationToken.IsCancellationRequested) + try { - const int bufferSize = 4096; - var totalBytes = 0; - var incomingMessage = new List>(); - WebSocketReceiveResult receiveResult; - do + while (!_transportCts.Token.IsCancellationRequested) { - var buffer = new ArraySegment(new byte[bufferSize]); - - //Exceptions are handled above where the send and receive tasks are being run. - receiveResult = await _webSocket.ReceiveAsync(buffer, cancellationToken); - if (receiveResult.MessageType == WebSocketMessageType.Close) + const int bufferSize = 4096; + var totalBytes = 0; + var incomingMessage = new List>(); + WebSocketReceiveResult receiveResult; + do { - _logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus); + var buffer = new ArraySegment(new byte[bufferSize]); - _application.Output.Complete(); - return; + //Exceptions are handled above where the send and receive tasks are being run. + receiveResult = await _webSocket.ReceiveAsync(buffer, _transportCts.Token); + if (receiveResult.MessageType == WebSocketMessageType.Close) + { + _logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus); + + _application.Output.Complete(); + return; + } + + _logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}", + receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage); + + var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); + incomingMessage.Add(truncBuffer); + totalBytes += receiveResult.Count; + } while (!receiveResult.EndOfMessage); + + //Making sure the message type is either text or binary + Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type"); + + Message message; + var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text; + if (incomingMessage.Count > 1) + { + var messageBuffer = new byte[totalBytes]; + var offset = 0; + for (var i = 0; i < incomingMessage.Count; i++) + { + Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count); + offset += incomingMessage[i].Count; + } + + message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage); + } + else + { + var buffer = new byte[incomingMessage[0].Count]; + Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count); + message = new Message(buffer, messageType, receiveResult.EndOfMessage); } - _logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}", - receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage); - - var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); - incomingMessage.Add(truncBuffer); - totalBytes += receiveResult.Count; - } while (!receiveResult.EndOfMessage); - - //Making sure the message type is either text or binary - Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type"); - - Message message; - var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text; - if (incomingMessage.Count > 1) - { - var messageBuffer = new byte[totalBytes]; - var offset = 0; - for (var i = 0; i < incomingMessage.Count; i++) + _logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length); + while (await _application.Output.WaitToWriteAsync(_transportCts.Token)) { - Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count); - offset += incomingMessage[i].Count; - } - - message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage); - } - else - { - var buffer = new byte[incomingMessage[0].Count]; - Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count); - message = new Message(buffer, messageType, receiveResult.EndOfMessage); - } - - _logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length); - while (await _application.Output.WaitToWriteAsync(cancellationToken)) - { - if (_application.Output.TryWrite(message)) - { - incomingMessage.Clear(); - break; + if (_application.Output.TryWrite(message)) + { + incomingMessage.Clear(); + break; + } } } } - - _logger.LogInformation("Receive loop stopped"); + catch (OperationCanceledException) + { + } + finally + { + _transportCts.Cancel(); + _logger.LogInformation("Receive loop stopped"); + } } - private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken) + private async Task SendMessages(Uri sendUrl) { _logger.LogInformation("Starting the send loop"); - while (await _application.Input.WaitToReadAsync(cancellationToken)) + try { - while (_application.Input.TryRead(out SendMessage message)) + while (await _application.Input.WaitToReadAsync(_transportCts.Token)) { - try + while (_application.Input.TryRead(out SendMessage message)) { - _logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}", - message.Type, message.Payload.Length); + try + { + _logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}", + message.Type, message.Payload.Length); - await _webSocket.SendAsync(new ArraySegment(message.Payload), - message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, - true, cancellationToken); + await _webSocket.SendAsync(new ArraySegment(message.Payload), + message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, + true, _transportCts.Token); - message.SendResult.SetResult(null); - } - catch (OperationCanceledException) - { - message.SendResult.SetCanceled(); - await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken); - break; - } - catch (Exception ex) - { - _logger.LogError(ex.Message); - message.SendResult.SetException(ex); - await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken); - throw; + message.SendResult.SetResult(null); + } + catch (OperationCanceledException) + { + _logger.LogInformation("Sending a message canceled."); + message.SendResult.SetCanceled(); + await CloseWebSocket(); + break; + } + catch (Exception ex) + { + _logger.LogError("Error while sending a message {0}", ex.Message); + message.SendResult.SetException(ex); + await CloseWebSocket(); + throw; + } } } } - - _logger.LogInformation("Send loop stopped"); + catch (OperationCanceledException) + { + } + finally + { + _transportCts.Cancel(); + _logger.LogInformation("Send loop stopped"); + } } private async Task Connect(Uri url) @@ -182,14 +201,14 @@ namespace Microsoft.AspNetCore.Sockets.Client uriBuilder.Scheme = "wss"; } - await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken); + await _webSocket.ConnectAsync(uriBuilder.Uri, CancellationToken.None); } public async Task StopAsync() { _logger.LogInformation("Transport {0} is stopping", nameof(WebSocketsTransport)); - await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken); + await CloseWebSocket(); _webSocket.Dispose(); try @@ -203,5 +222,25 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger.LogInformation("Transport {0} stopped", nameof(WebSocketsTransport)); } + + private async Task CloseWebSocket() + { + try + { + // Best effort - it's still possible (but not likely) that the transport is being closed via StopAsync + // while the webSocket is being closed due to an error. + if (_webSocket.State != WebSocketState.Closed) + { + _logger.LogInformation("Closing webSocket"); + await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None); + } + } + catch (Exception ex) + { + // This is benign - the exception can happen due to the race described above because we would + // try closing the webSocket twice. + _logger.LogInformation("Closing webSocket failed with {0}", ex); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index ff369ff32e..e652360531 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests const string message = "Hello, World!"; using (var ws = new ClientWebSocket()) { - string socketUrl = _serverFixture.WebSocketsUrl + "/echo"; + var socketUrl = _serverFixture.WebSocketsUrl + "/echo"; logger.LogInformation("Connecting WebSocket to {socketUrl}", socketUrl); await ws.ConnectAsync(new Uri(socketUrl), CancellationToken.None).OrTimeout(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs new file mode 100644 index 0000000000..260e55cc92 --- /dev/null +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -0,0 +1,80 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; +using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Internal; +using Microsoft.AspNetCore.Testing.xunit; +using Xunit; + +namespace Microsoft.AspNetCore.SignalR.Tests +{ + [Collection(EndToEndTestsCollection.Name)] + public class WebSocketsTransportTests + { + private readonly ServerFixture _serverFixture; + + public WebSocketsTransportTests(ServerFixture serverFixture) + { + if (serverFixture == null) + { + throw new ArgumentNullException(nameof(serverFixture)); + } + + _serverFixture = serverFixture; + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStopped() + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var webSocketsTransport = new WebSocketsTransport(); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); + await webSocketsTransport.StopAsync(); + await webSocketsTransport.Running.OrTimeout(); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + public async Task WebSocketsTransportStopsWhenConnectionChannelClosed() + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var webSocketsTransport = new WebSocketsTransport(); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); + connectionToTransport.Out.TryComplete(); + await webSocketsTransport.Running.OrTimeout(); + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer() + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var webSocketsTransport = new WebSocketsTransport(); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection); + + var sendTcs = new TaskCompletionSource(); + connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, MessageType.Binary, sendTcs)); + await sendTcs.Task; + // The echo endpoint close the connection immediately after sending response which should stop the transport + await webSocketsTransport.Running.OrTimeout(); + + Assert.True(transportToConnection.In.TryRead(out var message)); + Assert.Equal(new byte[] { 0x42 }, message.Payload); + } + } +}