Fixing WebSocketsTransport to handle exceptions correctly

Exceptions thrown when sending or receiving messages would leave the
WebSockets transport in a half-closed state when one of the loops is
closed but the other one is still running preventing from the
Connection.Closed event to be fired.

Fixes: #412
This commit is contained in:
Pawel Kadluczka 2017-05-18 09:50:36 -07:00
parent fbf7e1fb72
commit cb9f44ddf6
3 changed files with 206 additions and 87 deletions

View File

@ -14,9 +14,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
public class WebSocketsTransport : ITransport
{
private ClientWebSocket _webSocket = new ClientWebSocket();
private readonly ClientWebSocket _webSocket = new ClientWebSocket();
private IChannelConnection<SendMessage, Message> _application;
private CancellationToken _cancellationToken = new CancellationToken();
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
private readonly ILogger _logger;
public WebSocketsTransport()
@ -48,8 +48,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
_application = application;
await Connect(url);
var sendTask = SendMessages(url, _cancellationToken);
var receiveTask = ReceiveMessages(url, _cancellationToken);
var sendTask = SendMessages(url);
var receiveTask = ReceiveMessages(url);
// TODO: Handle TCP connection errors
// https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251
@ -62,112 +62,131 @@ namespace Microsoft.AspNetCore.Sockets.Client
}).Unwrap();
}
private async Task ReceiveMessages(Uri pollUrl, CancellationToken cancellationToken)
private async Task ReceiveMessages(Uri pollUrl)
{
_logger.LogInformation("Starting receive loop");
while (!cancellationToken.IsCancellationRequested)
try
{
const int bufferSize = 4096;
var totalBytes = 0;
var incomingMessage = new List<ArraySegment<byte>>();
WebSocketReceiveResult receiveResult;
do
while (!_transportCts.Token.IsCancellationRequested)
{
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
//Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await _webSocket.ReceiveAsync(buffer, cancellationToken);
if (receiveResult.MessageType == WebSocketMessageType.Close)
const int bufferSize = 4096;
var totalBytes = 0;
var incomingMessage = new List<ArraySegment<byte>>();
WebSocketReceiveResult receiveResult;
do
{
_logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus);
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
_application.Output.Complete();
return;
//Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await _webSocket.ReceiveAsync(buffer, _transportCts.Token);
if (receiveResult.MessageType == WebSocketMessageType.Close)
{
_logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus);
_application.Output.Complete();
return;
}
_logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}",
receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
//Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
Message message;
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text;
if (incomingMessage.Count > 1)
{
var messageBuffer = new byte[totalBytes];
var offset = 0;
for (var i = 0; i < incomingMessage.Count; i++)
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage);
}
else
{
var buffer = new byte[incomingMessage[0].Count];
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count);
message = new Message(buffer, messageType, receiveResult.EndOfMessage);
}
_logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}",
receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage);
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
//Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
Message message;
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text;
if (incomingMessage.Count > 1)
{
var messageBuffer = new byte[totalBytes];
var offset = 0;
for (var i = 0; i < incomingMessage.Count; i++)
_logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length);
while (await _application.Output.WaitToWriteAsync(_transportCts.Token))
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage);
}
else
{
var buffer = new byte[incomingMessage[0].Count];
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count);
message = new Message(buffer, messageType, receiveResult.EndOfMessage);
}
_logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length);
while (await _application.Output.WaitToWriteAsync(cancellationToken))
{
if (_application.Output.TryWrite(message))
{
incomingMessage.Clear();
break;
if (_application.Output.TryWrite(message))
{
incomingMessage.Clear();
break;
}
}
}
}
_logger.LogInformation("Receive loop stopped");
catch (OperationCanceledException)
{
}
finally
{
_transportCts.Cancel();
_logger.LogInformation("Receive loop stopped");
}
}
private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken)
private async Task SendMessages(Uri sendUrl)
{
_logger.LogInformation("Starting the send loop");
while (await _application.Input.WaitToReadAsync(cancellationToken))
try
{
while (_application.Input.TryRead(out SendMessage message))
while (await _application.Input.WaitToReadAsync(_transportCts.Token))
{
try
while (_application.Input.TryRead(out SendMessage message))
{
_logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}",
message.Type, message.Payload.Length);
try
{
_logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}",
message.Type, message.Payload.Length);
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload),
message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
true, cancellationToken);
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload),
message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
true, _transportCts.Token);
message.SendResult.SetResult(null);
}
catch (OperationCanceledException)
{
message.SendResult.SetCanceled();
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
break;
}
catch (Exception ex)
{
_logger.LogError(ex.Message);
message.SendResult.SetException(ex);
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
throw;
message.SendResult.SetResult(null);
}
catch (OperationCanceledException)
{
_logger.LogInformation("Sending a message canceled.");
message.SendResult.SetCanceled();
await CloseWebSocket();
break;
}
catch (Exception ex)
{
_logger.LogError("Error while sending a message {0}", ex.Message);
message.SendResult.SetException(ex);
await CloseWebSocket();
throw;
}
}
}
}
_logger.LogInformation("Send loop stopped");
catch (OperationCanceledException)
{
}
finally
{
_transportCts.Cancel();
_logger.LogInformation("Send loop stopped");
}
}
private async Task Connect(Uri url)
@ -182,14 +201,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
uriBuilder.Scheme = "wss";
}
await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken);
await _webSocket.ConnectAsync(uriBuilder.Uri, CancellationToken.None);
}
public async Task StopAsync()
{
_logger.LogInformation("Transport {0} is stopping", nameof(WebSocketsTransport));
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
await CloseWebSocket();
_webSocket.Dispose();
try
@ -203,5 +222,25 @@ namespace Microsoft.AspNetCore.Sockets.Client
_logger.LogInformation("Transport {0} stopped", nameof(WebSocketsTransport));
}
private async Task CloseWebSocket()
{
try
{
// Best effort - it's still possible (but not likely) that the transport is being closed via StopAsync
// while the webSocket is being closed due to an error.
if (_webSocket.State != WebSocketState.Closed)
{
_logger.LogInformation("Closing webSocket");
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None);
}
}
catch (Exception ex)
{
// This is benign - the exception can happen due to the race described above because we would
// try closing the webSocket twice.
_logger.LogInformation("Closing webSocket failed with {0}", ex);
}
}
}
}

View File

@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
const string message = "Hello, World!";
using (var ws = new ClientWebSocket())
{
string socketUrl = _serverFixture.WebSocketsUrl + "/echo";
var socketUrl = _serverFixture.WebSocketsUrl + "/echo";
logger.LogInformation("Connecting WebSocket to {socketUrl}", socketUrl);
await ws.ConnectAsync(new Uri(socketUrl), CancellationToken.None).OrTimeout();

View File

@ -0,0 +1,80 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Internal;
using Microsoft.AspNetCore.Testing.xunit;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Tests
{
[Collection(EndToEndTestsCollection.Name)]
public class WebSocketsTransportTests
{
private readonly ServerFixture _serverFixture;
public WebSocketsTransportTests(ServerFixture serverFixture)
{
if (serverFixture == null)
{
throw new ArgumentNullException(nameof(serverFixture));
}
_serverFixture = serverFixture;
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStopped()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
await webSocketsTransport.StopAsync();
await webSocketsTransport.Running.OrTimeout();
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsWhenConnectionChannelClosed()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
connectionToTransport.Out.TryComplete();
await webSocketsTransport.Running.OrTimeout();
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer()
{
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
var transportToConnection = Channel.CreateUnbounded<Message>();
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
var webSocketsTransport = new WebSocketsTransport();
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
var sendTcs = new TaskCompletionSource<object>();
connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, MessageType.Binary, sendTcs));
await sendTcs.Task;
// The echo endpoint close the connection immediately after sending response which should stop the transport
await webSocketsTransport.Running.OrTimeout();
Assert.True(transportToConnection.In.TryRead(out var message));
Assert.Equal(new byte[] { 0x42 }, message.Payload);
}
}
}