WebSockets Transport (#185)

WebSockets transport
This commit is contained in:
Mikael Mengistu 2017-02-15 09:24:41 -08:00 committed by Pawel Kadluczka
parent a728e1da41
commit cb7692d16e
3 changed files with 220 additions and 7 deletions

View File

@ -1,7 +1,5 @@
<Project Sdk="Microsoft.NET.Sdk">
<Import Project="..\..\build\common.props" />
<PropertyGroup>
<Description>Client for ASP.NET Core SignalR</Description>
<TargetFramework>netstandard1.3</TargetFramework>
@ -9,13 +7,12 @@
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<PackageTags>aspnetcore;signalr</PackageTags>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.Sockets.Common\Microsoft.AspNetCore.Sockets.Common.csproj" />
<PackageReference Include="System.IO.Pipelines" Version="0.1.0-*" />
<PackageReference Include="System.Threading.Tasks.Channels" Version="0.1.0-*" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="1.2.0-*" />
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="1.2.0-*" PrivateAssets="All"/>
<PackageReference Include="Microsoft.Extensions.TaskCache.Sources" Version="1.2.0-*" PrivateAssets="All" />
<PackageReference Include="System.Net.WebSockets.Client" Version="4.4.0-*" />
</ItemGroup>
</Project>
</Project>

View File

@ -0,0 +1,167 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
using System.Threading.Tasks;
using System.Collections.Generic;
using Microsoft.Extensions.Logging;
using System.Diagnostics;
namespace Microsoft.AspNetCore.Sockets.Client
{
public class WebSocketsTransport : ITransport
{
private ClientWebSocket _webSocket = new ClientWebSocket();
private IChannelConnection<Message> _application;
private CancellationToken _cancellationToken = new CancellationToken();
private readonly ILogger _logger;
public WebSocketsTransport()
: this(null)
{
}
public WebSocketsTransport(ILoggerFactory loggerFactory)
{
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("WebSocketsTransport");
}
public Task Running { get; private set; }
public async Task StartAsync(Uri url, IChannelConnection<Message> application)
{
if (url == null)
{
throw new ArgumentNullException(nameof(url));
}
if (application == null)
{
throw new ArgumentNullException(nameof(application));
}
_application = application;
await Connect(url);
var sendTask = SendMessages(url, _cancellationToken);
var receiveTask = ReceiveMessages(url, _cancellationToken);
// TODO: Handle TCP connection errors
// https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251
Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => {
_application.Output.TryComplete(t.IsFaulted ? t.Exception.InnerException : null);
return t;
}).Unwrap();
}
private async Task ReceiveMessages(Uri pollUrl, CancellationToken cancellationToken)
{
while (!cancellationToken.IsCancellationRequested)
{
const int bufferSize = 1024;
var totalBytes = 0;
var incomingMessage = new List<ArraySegment<byte>>();
WebSocketReceiveResult receiveResult;
do
{
var buffer = new ArraySegment<byte>(new byte[bufferSize]);
//Exceptions are handled above where the send and receive tasks are being run.
receiveResult = await _webSocket.ReceiveAsync(buffer, cancellationToken);
if(receiveResult.MessageType == WebSocketMessageType.Close)
{
_application.Output.Complete();
return;
}
var truncBuffer = new ArraySegment<byte>(buffer.Array, 0, receiveResult.Count);
incomingMessage.Add(truncBuffer);
totalBytes += receiveResult.Count;
} while (!receiveResult.EndOfMessage);
//Making sure the message type is either text or binary
Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text ), "Unexpected message type");
Message message;
var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? Format.Binary : Format.Text;
if (incomingMessage.Count > 1)
{
var messageBuffer = new byte[totalBytes];
var offset = 0;
for (var i = 0 ; i < incomingMessage.Count; i++)
{
Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count);
offset += incomingMessage[i].Count;
}
message = new Message(ReadableBuffer.Create(messageBuffer).Preserve(), messageType, receiveResult.EndOfMessage);
}
else
{
message = new Message(ReadableBuffer.Create(incomingMessage[0].Array, incomingMessage[0].Offset, incomingMessage[0].Count).Preserve(), messageType, receiveResult.EndOfMessage);
}
while (await _application.Output.WaitToWriteAsync(cancellationToken))
{
if (_application.Output.TryWrite(message))
{
incomingMessage.Clear();
break;
}
}
}
}
private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken)
{
while (await _application.Input.WaitToReadAsync(cancellationToken))
{
Message message;
while (_application.Input.TryRead(out message))
{
using (message)
{
try
{
await _webSocket.SendAsync(new ArraySegment<byte>(message.Payload.Buffer.ToArray()),
message.MessageFormat == Format.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, true,
cancellationToken);
}
catch (OperationCanceledException ex)
{
_logger?.LogError(ex.Message);
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
break;
}
}
}
}
}
private async Task Connect(Uri url)
{
var uriBuilder = new UriBuilder(url);
if (url.Scheme == "http")
{
uriBuilder.Scheme = "ws";
}
else if (url.Scheme == "https")
{
uriBuilder.Scheme = "wss";
}
await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken);
}
public void Dispose()
{
_webSocket.Dispose();
}
public async Task StopAsync()
{
await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken);
}
}
}

View File

@ -2,7 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO.Pipelines;
using System.Collections.Generic;
using System.Net.Http;
using System.Net.WebSockets;
using System.Text;
@ -75,5 +75,54 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
}
[ConditionalFact]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
public async Task ConnectionCanSendAndReceiveSmallMessagesWebSocketsTransport()
{
const string message = "Major Key";
var baseUrl = _serverFixture.BaseUrl;
var loggerFactory = new LoggerFactory();
var transport = new WebSocketsTransport();
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory))
{
await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
}
}
public static IEnumerable<object[]> MessageSizesData
{
get
{
yield return new object[] { new string('A', 5 * 1024)};
yield return new object[] { new string('A', 5 * 1024 * 1024 + 32)};
}
}
[ConditionalTheory]
[OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")]
[MemberData(nameof(MessageSizesData))]
public async Task ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport(string message)
{
var baseUrl = _serverFixture.BaseUrl;
var loggerFactory = new LoggerFactory();
var transport = new WebSocketsTransport();
using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory))
{
await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text);
var receiveData = new ReceiveData();
Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout());
Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data));
}
}
}
}