Fixing WebSocketsTransport to handle exceptions correctly
Exceptions thrown when sending or receiving messages would leave the WebSockets transport in a half-closed state when one of the loops is closed but the other one is still running preventing from the Connection.Closed event to be fired. Fixes: #412
This commit is contained in:
parent
fbf7e1fb72
commit
cb9f44ddf6
|
|
@ -14,9 +14,9 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
{
|
||||
public class WebSocketsTransport : ITransport
|
||||
{
|
||||
private ClientWebSocket _webSocket = new ClientWebSocket();
|
||||
private readonly ClientWebSocket _webSocket = new ClientWebSocket();
|
||||
private IChannelConnection<SendMessage, Message> _application;
|
||||
private CancellationToken _cancellationToken = new CancellationToken();
|
||||
private readonly CancellationTokenSource _transportCts = new CancellationTokenSource();
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public WebSocketsTransport()
|
||||
|
|
@ -48,8 +48,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
_application = application;
|
||||
|
||||
await Connect(url);
|
||||
var sendTask = SendMessages(url, _cancellationToken);
|
||||
var receiveTask = ReceiveMessages(url, _cancellationToken);
|
||||
var sendTask = SendMessages(url);
|
||||
var receiveTask = ReceiveMessages(url);
|
||||
|
||||
// TODO: Handle TCP connection errors
|
||||
// https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251
|
||||
|
|
@ -62,112 +62,131 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
}).Unwrap();
|
||||
}
|
||||
|
||||
private async Task ReceiveMessages(Uri pollUrl, CancellationToken cancellationToken)
|
||||
private async Task ReceiveMessages(Uri pollUrl)
|
||||
{
|
||||
_logger.LogInformation("Starting receive loop");
|
||||
|
||||
while (!cancellationToken.IsCancellationRequested)
|
||||
try
|
||||
{
|
||||
const int bufferSize = 4096;
|
||||
var totalBytes = 0;
|
||||
var incomingMessage = new List<ArraySegment<byte>>();
|
||||
WebSocketReceiveResult receiveResult;
|
||||
do
|
||||
while (!_transportCts.Token.IsCancellationRequested)
|
||||
{
|
||||
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
|
||||
|
||||
//Exceptions are handled above where the send and receive tasks are being run.
|
||||
receiveResult = await _webSocket.ReceiveAsync(buffer, cancellationToken);
|
||||
if (receiveResult.MessageType == WebSocketMessageType.Close)
|
||||
const int bufferSize = 4096;
|
||||
var totalBytes = 0;
|
||||
var incomingMessage = new List<ArraySegment<byte>>();
|
||||
WebSocketReceiveResult receiveResult;
|
||||
do
|
||||
{
|
||||
_logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus);
|
||||
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
|
||||
|
||||
_application.Output.Complete();
|
||||
return;
|
||||
//Exceptions are handled above where the send and receive tasks are being run.
|
||||
receiveResult = await _webSocket.ReceiveAsync(buffer, _transportCts.Token);
|
||||
if (receiveResult.MessageType == WebSocketMessageType.Close)
|
||||
{
|
||||
_logger.LogInformation("Websocket closed by the server. Close status {0}", receiveResult.CloseStatus);
|
||||
|
||||
_application.Output.Complete();
|
||||
return;
|
||||
}
|
||||
|
||||
_logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}",
|
||||
receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage);
|
||||
|
||||
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
|
||||
incomingMessage.Add(truncBuffer);
|
||||
totalBytes += receiveResult.Count;
|
||||
} while (!receiveResult.EndOfMessage);
|
||||
|
||||
//Making sure the message type is either text or binary
|
||||
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
|
||||
|
||||
Message message;
|
||||
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text;
|
||||
if (incomingMessage.Count > 1)
|
||||
{
|
||||
var messageBuffer = new byte[totalBytes];
|
||||
var offset = 0;
|
||||
for (var i = 0; i < incomingMessage.Count; i++)
|
||||
{
|
||||
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
|
||||
offset += incomingMessage[i].Count;
|
||||
}
|
||||
|
||||
message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage);
|
||||
}
|
||||
else
|
||||
{
|
||||
var buffer = new byte[incomingMessage[0].Count];
|
||||
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count);
|
||||
message = new Message(buffer, messageType, receiveResult.EndOfMessage);
|
||||
}
|
||||
|
||||
_logger.LogDebug("Message received. Type: {0}, size: {1}, EndOfMessage: {2}",
|
||||
receiveResult.MessageType.ToString(), receiveResult.Count, receiveResult.EndOfMessage);
|
||||
|
||||
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
|
||||
incomingMessage.Add(truncBuffer);
|
||||
totalBytes += receiveResult.Count;
|
||||
} while (!receiveResult.EndOfMessage);
|
||||
|
||||
//Making sure the message type is either text or binary
|
||||
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text), "Unexpected message type");
|
||||
|
||||
Message message;
|
||||
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? MessageType.Binary : MessageType.Text;
|
||||
if (incomingMessage.Count > 1)
|
||||
{
|
||||
var messageBuffer = new byte[totalBytes];
|
||||
var offset = 0;
|
||||
for (var i = 0; i < incomingMessage.Count; i++)
|
||||
_logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length);
|
||||
while (await _application.Output.WaitToWriteAsync(_transportCts.Token))
|
||||
{
|
||||
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
|
||||
offset += incomingMessage[i].Count;
|
||||
}
|
||||
|
||||
message = new Message(messageBuffer, messageType, receiveResult.EndOfMessage);
|
||||
}
|
||||
else
|
||||
{
|
||||
var buffer = new byte[incomingMessage[0].Count];
|
||||
Buffer.BlockCopy(incomingMessage[0].Array, incomingMessage[0].Offset, buffer, 0, incomingMessage[0].Count);
|
||||
message = new Message(buffer, messageType, receiveResult.EndOfMessage);
|
||||
}
|
||||
|
||||
_logger.LogInformation("Passing message to application. Payload size: {0}", message.Payload.Length);
|
||||
while (await _application.Output.WaitToWriteAsync(cancellationToken))
|
||||
{
|
||||
if (_application.Output.TryWrite(message))
|
||||
{
|
||||
incomingMessage.Clear();
|
||||
break;
|
||||
if (_application.Output.TryWrite(message))
|
||||
{
|
||||
incomingMessage.Clear();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Receive loop stopped");
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
}
|
||||
finally
|
||||
{
|
||||
_transportCts.Cancel();
|
||||
_logger.LogInformation("Receive loop stopped");
|
||||
}
|
||||
}
|
||||
|
||||
private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken)
|
||||
private async Task SendMessages(Uri sendUrl)
|
||||
{
|
||||
_logger.LogInformation("Starting the send loop");
|
||||
|
||||
while (await _application.Input.WaitToReadAsync(cancellationToken))
|
||||
try
|
||||
{
|
||||
while (_application.Input.TryRead(out SendMessage message))
|
||||
while (await _application.Input.WaitToReadAsync(_transportCts.Token))
|
||||
{
|
||||
try
|
||||
while (_application.Input.TryRead(out SendMessage message))
|
||||
{
|
||||
_logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}",
|
||||
message.Type, message.Payload.Length);
|
||||
try
|
||||
{
|
||||
_logger.LogDebug("Received message from application. Message type {0}. Payload size: {1}",
|
||||
message.Type, message.Payload.Length);
|
||||
|
||||
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload),
|
||||
message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
|
||||
true, cancellationToken);
|
||||
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload),
|
||||
message.Type == MessageType.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary,
|
||||
true, _transportCts.Token);
|
||||
|
||||
message.SendResult.SetResult(null);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
message.SendResult.SetCanceled();
|
||||
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(ex.Message);
|
||||
message.SendResult.SetException(ex);
|
||||
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
|
||||
throw;
|
||||
message.SendResult.SetResult(null);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
_logger.LogInformation("Sending a message canceled.");
|
||||
message.SendResult.SetCanceled();
|
||||
await CloseWebSocket();
|
||||
break;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError("Error while sending a message {0}", ex.Message);
|
||||
message.SendResult.SetException(ex);
|
||||
await CloseWebSocket();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
_logger.LogInformation("Send loop stopped");
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
}
|
||||
finally
|
||||
{
|
||||
_transportCts.Cancel();
|
||||
_logger.LogInformation("Send loop stopped");
|
||||
}
|
||||
}
|
||||
|
||||
private async Task Connect(Uri url)
|
||||
|
|
@ -182,14 +201,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
uriBuilder.Scheme = "wss";
|
||||
}
|
||||
|
||||
await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken);
|
||||
await _webSocket.ConnectAsync(uriBuilder.Uri, CancellationToken.None);
|
||||
}
|
||||
|
||||
public async Task StopAsync()
|
||||
{
|
||||
_logger.LogInformation("Transport {0} is stopping", nameof(WebSocketsTransport));
|
||||
|
||||
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
|
||||
await CloseWebSocket();
|
||||
_webSocket.Dispose();
|
||||
|
||||
try
|
||||
|
|
@ -203,5 +222,25 @@ namespace Microsoft.AspNetCore.Sockets.Client
|
|||
|
||||
_logger.LogInformation("Transport {0} stopped", nameof(WebSocketsTransport));
|
||||
}
|
||||
|
||||
private async Task CloseWebSocket()
|
||||
{
|
||||
try
|
||||
{
|
||||
// Best effort - it's still possible (but not likely) that the transport is being closed via StopAsync
|
||||
// while the webSocket is being closed due to an error.
|
||||
if (_webSocket.State != WebSocketState.Closed)
|
||||
{
|
||||
_logger.LogInformation("Closing webSocket");
|
||||
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, CancellationToken.None);
|
||||
}
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
// This is benign - the exception can happen due to the race described above because we would
|
||||
// try closing the webSocket twice.
|
||||
_logger.LogInformation("Closing webSocket failed with {0}", ex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
const string message = "Hello, World!";
|
||||
using (var ws = new ClientWebSocket())
|
||||
{
|
||||
string socketUrl = _serverFixture.WebSocketsUrl + "/echo";
|
||||
var socketUrl = _serverFixture.WebSocketsUrl + "/echo";
|
||||
|
||||
logger.LogInformation("Connecting WebSocket to {socketUrl}", socketUrl);
|
||||
await ws.ConnectAsync(new Uri(socketUrl), CancellationToken.None).OrTimeout();
|
||||
|
|
|
|||
|
|
@ -0,0 +1,80 @@
|
|||
// Copyright (c) .NET Foundation. All rights reserved.
|
||||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks.Channels;
|
||||
using Microsoft.AspNetCore.SignalR.Tests.Common;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.AspNetCore.Sockets.Client;
|
||||
using Microsoft.AspNetCore.Sockets.Internal;
|
||||
using Microsoft.AspNetCore.Testing.xunit;
|
||||
using Xunit;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR.Tests
|
||||
{
|
||||
[Collection(EndToEndTestsCollection.Name)]
|
||||
public class WebSocketsTransportTests
|
||||
{
|
||||
private readonly ServerFixture _serverFixture;
|
||||
|
||||
public WebSocketsTransportTests(ServerFixture serverFixture)
|
||||
{
|
||||
if (serverFixture == null)
|
||||
{
|
||||
throw new ArgumentNullException(nameof(serverFixture));
|
||||
}
|
||||
|
||||
_serverFixture = serverFixture;
|
||||
}
|
||||
|
||||
[ConditionalFact]
|
||||
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
|
||||
public async Task WebSocketsTransportStopsSendAndReceiveLoopsWhenTransportIsStopped()
|
||||
{
|
||||
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
|
||||
var transportToConnection = Channel.CreateUnbounded<Message>();
|
||||
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
|
||||
|
||||
var webSocketsTransport = new WebSocketsTransport();
|
||||
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
|
||||
await webSocketsTransport.StopAsync();
|
||||
await webSocketsTransport.Running.OrTimeout();
|
||||
}
|
||||
|
||||
[ConditionalFact]
|
||||
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
|
||||
public async Task WebSocketsTransportStopsWhenConnectionChannelClosed()
|
||||
{
|
||||
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
|
||||
var transportToConnection = Channel.CreateUnbounded<Message>();
|
||||
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
|
||||
|
||||
var webSocketsTransport = new WebSocketsTransport();
|
||||
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
|
||||
connectionToTransport.Out.TryComplete();
|
||||
await webSocketsTransport.Running.OrTimeout();
|
||||
}
|
||||
|
||||
[ConditionalFact]
|
||||
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
|
||||
public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer()
|
||||
{
|
||||
var connectionToTransport = Channel.CreateUnbounded<SendMessage>();
|
||||
var transportToConnection = Channel.CreateUnbounded<Message>();
|
||||
var channelConnection = new ChannelConnection<SendMessage, Message>(connectionToTransport, transportToConnection);
|
||||
|
||||
var webSocketsTransport = new WebSocketsTransport();
|
||||
await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection);
|
||||
|
||||
var sendTcs = new TaskCompletionSource<object>();
|
||||
connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, MessageType.Binary, sendTcs));
|
||||
await sendTcs.Task;
|
||||
// The echo endpoint close the connection immediately after sending response which should stop the transport
|
||||
await webSocketsTransport.Running.OrTimeout();
|
||||
|
||||
Assert.True(transportToConnection.In.TryRead(out var message));
|
||||
Assert.Equal(new byte[] { 0x42 }, message.Payload);
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue