diff --git a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj index 0b41b30530..ee35f9d6bf 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj +++ b/src/Microsoft.AspNetCore.Sockets.Client/Microsoft.AspNetCore.Sockets.Client.csproj @@ -1,7 +1,5 @@  - - Client for ASP.NET Core SignalR netstandard1.3 @@ -9,13 +7,12 @@ true aspnetcore;signalr - - + + - - + \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs new file mode 100644 index 0000000000..c72201f5ca --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client/WebSocketsTransport.cs @@ -0,0 +1,167 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.IO.Pipelines; +using System.Net.WebSockets; +using System.Threading; +using System.Threading.Tasks; +using System.Collections.Generic; +using Microsoft.Extensions.Logging; +using System.Diagnostics; + +namespace Microsoft.AspNetCore.Sockets.Client +{ + public class WebSocketsTransport : ITransport + { + private ClientWebSocket _webSocket = new ClientWebSocket(); + private IChannelConnection _application; + private CancellationToken _cancellationToken = new CancellationToken(); + private readonly ILogger _logger; + + public WebSocketsTransport() + : this(null) + { + } + + public WebSocketsTransport(ILoggerFactory loggerFactory) + { + _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger("WebSocketsTransport"); + } + + public Task Running { get; private set; } + + public async Task StartAsync(Uri url, IChannelConnection application) + { + if (url == null) + { + throw new ArgumentNullException(nameof(url)); + } + + if (application == null) + { + throw new ArgumentNullException(nameof(application)); + } + + _application = application; + await Connect(url); + var sendTask = SendMessages(url, _cancellationToken); + var receiveTask = ReceiveMessages(url, _cancellationToken); + + // TODO: Handle TCP connection errors + // https://github.com/SignalR/SignalR/blob/1fba14fa3437e24c204dfaf8a18db3fce8acad3c/src/Microsoft.AspNet.SignalR.Core/Owin/WebSockets/WebSocketHandler.cs#L248-L251 + Running = Task.WhenAll(sendTask, receiveTask).ContinueWith(t => { + _application.Output.TryComplete(t.IsFaulted ? t.Exception.InnerException : null); + return t; + }).Unwrap(); + } + + private async Task ReceiveMessages(Uri pollUrl, CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested) + { + const int bufferSize = 1024; + var totalBytes = 0; + var incomingMessage = new List>(); + WebSocketReceiveResult receiveResult; + do + { + var buffer = new ArraySegment(new byte[bufferSize]); + + //Exceptions are handled above where the send and receive tasks are being run. + receiveResult = await _webSocket.ReceiveAsync(buffer, cancellationToken); + if(receiveResult.MessageType == WebSocketMessageType.Close) + { + _application.Output.Complete(); + return; + } + var truncBuffer = new ArraySegment(buffer.Array, 0, receiveResult.Count); + incomingMessage.Add(truncBuffer); + totalBytes += receiveResult.Count; + } while (!receiveResult.EndOfMessage); + + //Making sure the message type is either text or binary + Debug.Assert((receiveResult.MessageType == WebSocketMessageType.Binary || receiveResult.MessageType == WebSocketMessageType.Text ), "Unexpected message type"); + + Message message; + var messageType = receiveResult.MessageType == WebSocketMessageType.Binary ? Format.Binary : Format.Text; + if (incomingMessage.Count > 1) + { + var messageBuffer = new byte[totalBytes]; + var offset = 0; + for (var i = 0 ; i < incomingMessage.Count; i++) + { + Buffer.BlockCopy(incomingMessage[i].Array, 0, messageBuffer, offset, incomingMessage[i].Count); + offset += incomingMessage[i].Count; + } + + message = new Message(ReadableBuffer.Create(messageBuffer).Preserve(), messageType, receiveResult.EndOfMessage); + } + else + { + message = new Message(ReadableBuffer.Create(incomingMessage[0].Array, incomingMessage[0].Offset, incomingMessage[0].Count).Preserve(), messageType, receiveResult.EndOfMessage); + } + + while (await _application.Output.WaitToWriteAsync(cancellationToken)) + { + if (_application.Output.TryWrite(message)) + { + incomingMessage.Clear(); + break; + } + } + } + } + + private async Task SendMessages(Uri sendUrl, CancellationToken cancellationToken) + { + while (await _application.Input.WaitToReadAsync(cancellationToken)) + { + Message message; + while (_application.Input.TryRead(out message)) + { + using (message) + { + try + { + await _webSocket.SendAsync(new ArraySegment(message.Payload.Buffer.ToArray()), + message.MessageFormat == Format.Text ? WebSocketMessageType.Text : WebSocketMessageType.Binary, true, + cancellationToken); + } + catch (OperationCanceledException ex) + { + _logger?.LogError(ex.Message); + await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken); + break; + } + } + } + } + } + + private async Task Connect(Uri url) + { + var uriBuilder = new UriBuilder(url); + if (url.Scheme == "http") + { + uriBuilder.Scheme = "ws"; + } + else if (url.Scheme == "https") + { + uriBuilder.Scheme = "wss"; + } + + await _webSocket.ConnectAsync(uriBuilder.Uri, _cancellationToken); + } + + public void Dispose() + { + _webSocket.Dispose(); + } + + public async Task StopAsync() + { + await _webSocket.CloseAsync(WebSocketCloseStatus.Empty, null, _cancellationToken); + } + } +} diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index ab5024505f..e271f0260f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -2,7 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.IO.Pipelines; +using System.Collections.Generic; using System.Net.Http; using System.Net.WebSockets; using System.Text; @@ -75,5 +75,54 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + public async Task ConnectionCanSendAndReceiveSmallMessagesWebSocketsTransport() + { + const string message = "Major Key"; + var baseUrl = _serverFixture.BaseUrl; + var loggerFactory = new LoggerFactory(); + + var transport = new WebSocketsTransport(); + using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory)) + { + await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text); + + var receiveData = new ReceiveData(); + + Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); + Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); + } + } + + public static IEnumerable MessageSizesData + { + get + { + yield return new object[] { new string('A', 5 * 1024)}; + yield return new object[] { new string('A', 5 * 1024 * 1024 + 32)}; + } + } + + [ConditionalTheory] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + [MemberData(nameof(MessageSizesData))] + public async Task ConnectionCanSendAndReceiveDifferentMessageSizesWebSocketsTransport(string message) + { + var baseUrl = _serverFixture.BaseUrl; + var loggerFactory = new LoggerFactory(); + + var transport = new WebSocketsTransport(); + using (var connection = await ClientConnection.ConnectAsync(new Uri(baseUrl + "/echo/ws"), transport, loggerFactory)) + { + await connection.SendAsync(Encoding.UTF8.GetBytes(message), Format.Text); + + var receiveData = new ReceiveData(); + + Assert.True(await connection.ReceiveAsync(receiveData).OrTimeout()); + Assert.Equal(message, Encoding.UTF8.GetString(receiveData.Data)); + } + } } }