diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 568a508749..cf041618c0 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -11,6 +11,8 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; +using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; @@ -77,6 +79,18 @@ namespace Microsoft.AspNetCore.SignalR.Client public async Task StartAsync() { + var transferModeFeature = _connection.Features.Get(); + if (transferModeFeature == null) + { + transferModeFeature = new TransferModeFeature(); + _connection.Features.Set(transferModeFeature); + } + + transferModeFeature.TransferMode = + (_protocol.Type == ProtocolType.Binary) + ? TransferMode.Binary + : TransferMode.Text; + await _connection.StartAsync(); using (var memoryStream = new MemoryStream()) @@ -389,5 +403,10 @@ namespace Microsoft.AspNetCore.SignalR.Client ParameterTypes = parameterTypes; } } + + private class TransferModeFeature : ITransferModeFeature + { + public TransferMode TransferMode { get; set; } + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs index 57c9124f6f..4ccb27e0c2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/IHubProtocol.cs @@ -11,6 +11,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol { string Name { get; } + ProtocolType Type { get; } + bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages); void WriteMessage(HubMessage message, Stream output); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs index df908e323f..d98f2968f2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/JsonHubProtocol.cs @@ -46,6 +46,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public string Name => "json"; + public ProtocolType Type => ProtocolType.Text; + public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs index 34e0ee6bec..12b9269153 100644 --- a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/MessagePackHubProtocol.cs @@ -18,6 +18,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal.Protocol public string Name => "messagepack"; + public ProtocolType Type => ProtocolType.Binary; + public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) { messages = new List(); diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/ProtocolType.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/ProtocolType.cs new file mode 100644 index 0000000000..6b97082e2e --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Protocol/ProtocolType.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Protocol +{ + public enum ProtocolType + { + Binary, + Text + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs index 8eeb8879e0..3bdce8877e 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/IConnectionTransportFeature.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.Text; using System.Threading.Tasks.Channels; namespace Microsoft.AspNetCore.Sockets.Features diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ITransferModeFeature.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ITransferModeFeature.cs new file mode 100644 index 0000000000..2962559a19 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/Features/ITransferModeFeature.cs @@ -0,0 +1,10 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.Sockets.Features +{ + public interface ITransferModeFeature + { + TransferMode TransferMode { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index 09c9954198..56643ed655 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -4,6 +4,7 @@ using System; using System.Threading; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; namespace Microsoft.AspNetCore.Sockets.Client { @@ -16,5 +17,7 @@ namespace Microsoft.AspNetCore.Sockets.Client event Func Connected; event Func Received; event Func Closed; + + IFeatureCollection Features { get; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/TransferMode.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/TransferMode.cs new file mode 100644 index 0000000000..bd9adc3838 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/TransferMode.cs @@ -0,0 +1,14 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; + +namespace Microsoft.AspNetCore.Sockets +{ + [Flags] + public enum TransferMode + { + Binary = 0x01, + Text = 0x02 + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 53938ac4f9..5ddf582865 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -2,11 +2,14 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Diagnostics; using System.IO; using System.Net.Http; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Client.Internal; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.Logging; @@ -35,6 +38,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public Uri Url { get; } + public IFeatureCollection Features { get; } = new FeatureCollection(); + public event Func Connected; public event Func Received; public event Func Closed; @@ -48,7 +53,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { } public HttpConnection(Uri url, TransportType transportType) - : this(url, transportType, loggerFactory: null) + : this(url, transportType, loggerFactory: null) { } @@ -262,7 +267,11 @@ namespace Microsoft.AspNetCore.Sockets.Client // Start the transport, giving it one end of the pipeline try { - await _transport.StartAsync(connectUrl, applicationSide, _connectionId); + await _transport.StartAsync(connectUrl, applicationSide, requestedTransferMode: GetTransferMode(), connectionId: _connectionId); + + // actual transfer mode can differ from the one that was requested so set it on the feature + Debug.Assert(_transport.Mode.HasValue, "transfer mode not set after transport started"); + SetTransferMode(_transport.Mode.Value); } catch (Exception ex) { @@ -271,6 +280,29 @@ namespace Microsoft.AspNetCore.Sockets.Client } } + private TransferMode GetTransferMode() + { + var transferModeFeature = Features.Get(); + if (transferModeFeature == null) + { + return TransferMode.Text; + } + + return transferModeFeature.TransferMode; + } + + private void SetTransferMode(TransferMode transferMode) + { + var transferModeFeature = Features.Get(); + if (transferModeFeature == null) + { + transferModeFeature = new TransferModeFeature(); + Features.Set(transferModeFeature); + } + + transferModeFeature.TransferMode = transferMode; + } + private async Task ReceiveAsync() { try diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs index 81d442d636..7591183e33 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ITransport.cs @@ -9,7 +9,8 @@ namespace Microsoft.AspNetCore.Sockets.Client { public interface ITransport { - Task StartAsync(Uri url, Channel application, string connectionId); + Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId); Task StopAsync(); + TransferMode? Mode { get; } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index 296d87a591..4b277f1dc7 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -10,8 +10,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal internal static class SocketClientLoggerExtensions { // Category: Shared with LongPollingTransport, WebSocketsTransport and ServerSentEventsTransport - private static readonly Action _startTransport = - LoggerMessage.Define(LogLevel.Information, 0, "{time}: Connection Id {connectionId}: Starting transport."); + private static readonly Action _startTransport = + LoggerMessage.Define(LogLevel.Information, 0, "{time}: Connection Id {connectionId}: Starting transport. Transfer mode: {transferMode}."); private static readonly Action _transportStopped = LoggerMessage.Define(LogLevel.Debug, 1, "{time}: Connection Id {connectionId}: Transport stopped."); @@ -147,11 +147,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal private static readonly Action _stoppingClient = LoggerMessage.Define(LogLevel.Information, 18, "{time}: Connection Id {connectionId}: Stopping client."); - public static void StartTransport(this ILogger logger, string connectionId) + public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode) { if (logger.IsEnabled(LogLevel.Information)) { - _startTransport(logger, DateTime.Now, connectionId, null); + _startTransport(logger, DateTime.Now, connectionId, transferMode, null); } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs index c705438970..666c37d87b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/LongPollingTransport.cs @@ -26,6 +26,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public Task Running { get; private set; } = Task.CompletedTask; + public TransferMode? Mode { get; private set; } + public LongPollingTransport(HttpClient httpClient) : this(httpClient, null) { } @@ -36,12 +38,18 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, string connectionId) + public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) { - _connectionId = connectionId; - _logger.StartTransport(_connectionId); + if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) + { + throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode)); + } _application = application; + Mode = requestedTransferMode; + _connectionId = connectionId; + + _logger.StartTransport(_connectionId, Mode.Value); // Start sending and polling (ask for binary if the server supports it) _poller = Poll(url, _transportCts.Token); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs index fb62f25604..a9048f3725 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/ServerSentEventsTransport.cs @@ -27,6 +27,8 @@ namespace Microsoft.AspNetCore.Sockets.Client public Task Running { get; private set; } = Task.CompletedTask; + public TransferMode? Mode { get; private set; } + public ServerSentEventsTransport(HttpClient httpClient) : this(httpClient, null) { } @@ -42,12 +44,19 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task StartAsync(Uri url, Channel application, string connectionId) + public Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) { - _connectionId = connectionId; - _logger.StartTransport(_connectionId); + if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) + { + throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode)); + } _application = application; + Mode = TransferMode.Text; // Server Sent Events is a text only transport + _connectionId = connectionId; + + _logger.StartTransport(_connectionId, Mode.Value); + var sendTask = SendUtils.SendMessages(url, _application, _httpClient, _transportCts, _logger, _connectionId); var receiveTask = OpenConnection(_application, url, _transportCts.Token); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/TransferModeFeature.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/TransferModeFeature.cs new file mode 100644 index 0000000000..77f5697a01 --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/TransferModeFeature.cs @@ -0,0 +1,12 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.Sockets.Features; + +namespace Microsoft.AspNetCore.Sockets.Client +{ + public class TransferModeFeature : ITransferModeFeature + { + public TransferMode TransferMode { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs index f9cb848bf4..5f7345162f 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/WebSocketsTransport.cs @@ -22,6 +22,10 @@ namespace Microsoft.AspNetCore.Sockets.Client private readonly ILogger _logger; private string _connectionId; + public Task Running { get; private set; } = Task.CompletedTask; + + public TransferMode? Mode { get; private set; } + public WebSocketsTransport() : this(null) { @@ -32,9 +36,7 @@ namespace Microsoft.AspNetCore.Sockets.Client _logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger(); } - public Task Running { get; private set; } = Task.CompletedTask; - - public async Task StartAsync(Uri url, Channel application, string connectionId) + public async Task StartAsync(Uri url, Channel application, TransferMode requestedTransferMode, string connectionId) { if (url == null) { @@ -46,10 +48,17 @@ namespace Microsoft.AspNetCore.Sockets.Client throw new ArgumentNullException(nameof(application)); } + if (requestedTransferMode != TransferMode.Binary && requestedTransferMode != TransferMode.Text) + { + throw new ArgumentException("Invalid transfer mode.", nameof(requestedTransferMode)); + } + _application = application; + Mode = requestedTransferMode; _connectionId = connectionId; - _logger.StartTransport(_connectionId); + _logger.StartTransport(_connectionId, Mode.Value); + await Connect(url); var sendTask = SendMessages(url); var receiveTask = ReceiveMessages(url); @@ -145,6 +154,11 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.SendStarted(_connectionId); + var webSocketMessageType = + Mode == TransferMode.Binary + ? WebSocketMessageType.Binary + : WebSocketMessageType.Text; + try { while (await _application.In.WaitToReadAsync(_transportCts.Token)) @@ -155,7 +169,7 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.ReceivedFromApp(_connectionId, message.Payload.Length); - await _webSocket.SendAsync(new ArraySegment(message.Payload), WebSocketMessageType.Text, true, _transportCts.Token); + await _webSocket.SendAsync(new ArraySegment(message.Payload), webSocketMessageType, true, _transportCts.Token); message.SendResult.SetResult(null); } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs similarity index 91% rename from test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs rename to test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 0c82f7273f..97a8722cda 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -10,16 +10,15 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Testing; using Moq; using Moq.Protected; using Xunit; -using Xunit.Abstractions; namespace Microsoft.AspNetCore.Sockets.Client.Tests { - public class ConnectionTests + public class HttpConnectionTests { [Fact] public void CannotCreateConnectionWithNullUrl() @@ -138,7 +137,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var transport = new Mock(); transport.Setup(t => t.StopAsync()).Returns(async () => { await releaseDisposeTcs.Task; }); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(transport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); @@ -154,7 +152,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests releaseDisposeTcs.SetResult(null); await disposeTask.OrTimeout(); - transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny()), Times.Never); + transport.Verify(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny()), Times.Never); } [Fact] @@ -180,7 +178,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); await connection.StartAsync(); @@ -205,8 +202,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + try { var connectedEventRaisedTcs = new TaskCompletionSource(); @@ -241,10 +238,9 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests }); var mockTransport = new Mock(); - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny())) + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) .Returns(Task.FromException(new InvalidOperationException("Transport failed to start"))); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); var connectedEventRaised = false; @@ -281,7 +277,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests : ResponseUtils.CreateResponse(HttpStatusCode.OK); }); - var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); var closedEventTcs = new TaskCompletionSource(); @@ -350,8 +345,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny())) - .Returns, string>((url, c, id) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string>((url, c, transferMode, connectionId) => { channel = c; return Task.CompletedTask; @@ -365,9 +360,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests channel.Out.TryComplete(); return Task.CompletedTask; }); - + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + var receivedInvoked = false; connection.Received += m => { @@ -396,8 +392,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var mockTransport = new Mock(); Channel channel = null; - mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny())) - .Returns, string>((url, c, id) => + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string>((url, c, transferMode, connectionId) => { channel = c; return Task.CompletedTask; @@ -408,12 +404,13 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests channel.Out.TryComplete(); return Task.CompletedTask; }); - + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Text); var callbackInvokedTcs = new TaskCompletionSource(); var closedTcs = new TaskCompletionSource(); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + connection.Received += async m => { @@ -584,7 +581,6 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests }); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), TransportType.LongPolling, loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); - await connection.StartAsync(); var exception = await Assert.ThrowsAsync( @@ -773,5 +769,48 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests Assert.Equal("No requested transports available on the server.", exception.Message); } + + [Fact] + public async Task CanStartConnectionWithoutSettingTransferModeFeature() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return request.Method == HttpMethod.Options + ? ResponseUtils.CreateResponse(HttpStatusCode.OK, ResponseUtils.CreateNegotiationResponse()) + : ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + var mockTransport = new Mock(); + Channel channel = null; + mockTransport.Setup(t => t.StartAsync(It.IsAny(), It.IsAny>(), It.IsAny(), It.IsAny())) + .Returns, TransferMode, string>((url, c, transferMode, connectionId) => + { + channel = c; + return Task.CompletedTask; + }); + mockTransport.Setup(t => t.StopAsync()) + .Returns(() => + { + channel.Out.TryComplete(); + return Task.CompletedTask; + }); + mockTransport.SetupGet(t => t.Mode).Returns(TransferMode.Binary); + + var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), + loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); + + await connection.StartAsync().OrTimeout(); + var transferModeFeature = connection.Features.Get(); + await connection.DisposeAsync().OrTimeout(); + + mockTransport.Verify(t => t.StartAsync( + It.IsAny(), It.IsAny>(), TransferMode.Text, It.IsAny()), Times.Once); + Assert.NotNull(transferModeFeature); + Assert.Equal(TransferMode.Binary, transferModeFeature.TransferMode); + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index 4d96a7fb28..d35eaa4f3c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Threading.Tasks; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; @@ -21,6 +22,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task StartAsyncCallsConnectionStart() { var connection = new Mock(); + connection.SetupGet(p => p.Features).Returns(new FeatureCollection()); connection.Setup(m => m.StartAsync()).Returns(Task.CompletedTask).Verifiable(); var hubConnection = new HubConnection(connection.Object); await hubConnection.StartAsync(); @@ -125,6 +127,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { var exception = new InvalidOperationException(); var mockConnection = new Mock(); + mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection()); mockConnection .Setup(m => m.DisposeAsync()) .Callback(() => mockConnection.Raise(c => c.Closed += null, exception)) @@ -144,6 +147,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public async Task DoesNotThrowWhenClientMethodCalledButNoInvocationHandlerHasBeenSetUp() { var mockConnection = new Mock(); + mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection()); var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 }); @@ -181,7 +185,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests }; } - public string Name { get => "MockHubProtocol"; } + public string Name => "MockHubProtocol"; + + public ProtocolType Type => ProtocolType.Binary; public bool TryParseMessages(ReadOnlyBuffer input, IInvocationBinder binder, out IList messages) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs index 360157aedc..3a3f760ab4 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/LongPollingTransportTests.cs @@ -10,6 +10,7 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; using Moq; @@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); transportActiveTask = longPollingTransport.Running; @@ -79,7 +80,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = ChannelConnection.Create(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); await longPollingTransport.Running.OrTimeout(); Assert.True(transportToConnection.In.Completion.IsCompleted); @@ -132,7 +133,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); var data = await transportToConnection.In.ReadAllAsync().OrTimeout(); await longPollingTransport.Running.OrTimeout(); @@ -168,7 +169,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); var exception = await Assert.ThrowsAsync(async () => await transportToConnection.In.Completion.OrTimeout()); @@ -204,7 +205,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); await connectionToTransport.Out.WriteAsync(new SendMessage()); @@ -245,7 +246,7 @@ namespace Microsoft.AspNetCore.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); connectionToTransport.Out.Complete(); @@ -296,7 +297,7 @@ namespace Microsoft.AspNetCore.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); // Wait for the transport to finish await longPollingTransport.Running.OrTimeout(); @@ -361,7 +362,7 @@ namespace Microsoft.AspNetCore.Client.Tests await connectionToTransport.Out.WriteAsync(new SendMessage(Encoding.UTF8.GetBytes("World"), tcs2)).OrTimeout(); // Start the transport - await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Binary, connectionId: string.Empty); connectionToTransport.Out.Complete(); @@ -378,5 +379,63 @@ namespace Microsoft.AspNetCore.Client.Tests } } } + + [Theory] + [InlineData(TransferMode.Binary)] + [InlineData(TransferMode.Text)] + public async Task LongPollingTransportSetsTransferMode(TransferMode transferMode) + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + + try + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + Assert.Null(longPollingTransport.Mode); + await longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty); + Assert.Equal(transferMode, longPollingTransport.Mode); + } + finally + { + await longPollingTransport.StopAsync(); + } + } + } + + [Fact] + public async Task LongPollingTransportThrowsForInvalidTransferMode() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return ResponseUtils.CreateResponse(HttpStatusCode.OK); + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var longPollingTransport = new LongPollingTransport(httpClient); + var exception = await Assert.ThrowsAsync(() => + longPollingTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + + Assert.Contains("Invalid transfer mode.", exception.Message); + Assert.Equal("requestedTransferMode", exception.ParamName); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs index 4682181e70..ed5a2d0054 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ServerSentEventsTransportTests.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; using Moq; @@ -52,7 +53,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var connectionToTransport = Channel.CreateUnbounded(); var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); - await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + await sseTransport.StartAsync( + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); await sseTransport.StopAsync().OrTimeout(); @@ -103,7 +105,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); transportActiveTask = sseTransport.Running; Assert.False(transportActiveTask.IsCompleted); @@ -150,7 +152,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); var exception = await Assert.ThrowsAsync(() => sseTransport.Running.OrTimeout()); Assert.Equal("Incomplete message.", exception.Message); @@ -195,7 +197,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); await eventStreamTcs.Task; var sendTcs = new TaskCompletionSource(); @@ -241,7 +243,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); await eventStreamTcs.Task.OrTimeout(); connectionToTransport.Out.TryComplete(null); @@ -270,7 +272,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests var transportToConnection = Channel.CreateUnbounded(); var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); await sseTransport.StartAsync( - new Uri("http://fakeuri.org"), channelConnection, connectionId: string.Empty).OrTimeout(); + new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text, connectionId: string.Empty).OrTimeout(); var message = await transportToConnection.In.ReadAsync().AsTask().OrTimeout(); Assert.Equal("3:abc", Encoding.ASCII.GetString(message)); @@ -278,5 +280,58 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await sseTransport.Running.OrTimeout(); } } + + [Theory] + [InlineData(TransferMode.Text)] + [InlineData(TransferMode.Binary)] + public async Task SSETransportSetsTransferMode(TransferMode transferMode) + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var sseTransport = new ServerSentEventsTransport(httpClient); + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + Assert.Null(sseTransport.Mode); + await sseTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, transferMode, connectionId: string.Empty).OrTimeout(); + Assert.Equal(TransferMode.Text, sseTransport.Mode); + await sseTransport.StopAsync().OrTimeout(); + } + } + + [Fact] + public async Task SSETransportThrowsForInvalidTransferMode() + { + var mockHttpHandler = new Mock(); + mockHttpHandler.Protected() + .Setup>("SendAsync", ItExpr.IsAny(), ItExpr.IsAny()) + .Returns(async (request, cancellationToken) => + { + await Task.Yield(); + return new HttpResponseMessage { Content = new StringContent(string.Empty) }; + }); + + using (var httpClient = new HttpClient(mockHttpHandler.Object)) + { + var sseTransport = new ServerSentEventsTransport(httpClient); + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + var exception = await Assert.ThrowsAsync(() => + sseTransport.StartAsync(new Uri("http://fakeuri.org"), null, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + + Assert.Contains("Invalid transfer mode.", exception.Message); + Assert.Equal("requestedTransferMode", exception.ParamName); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 2126f7d4a6..ffd89f1a1b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -7,6 +7,7 @@ using System.Text; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Newtonsoft.Json; @@ -34,6 +35,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public ReadableChannel SentMessages => _sentMessages.In; public WritableChannel ReceivedMessages => _receivedMessages.Out; + public IFeatureCollection Features { get; } = new FeatureCollection(); + public TestConnection() { _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index aa8a5226ed..b67b3d1442 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -11,6 +11,7 @@ using System.Threading.Tasks; using Microsoft.AspNetCore.SignalR.Client; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Testing.xunit; using Microsoft.Extensions.Logging; @@ -78,6 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests [ConditionalTheory] [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] [MemberData(nameof(TransportTypes))] + // TODO: transfer types public async Task ConnectionCanSendAndReceiveMessages(TransportType transportType) { using (StartLog(out var loggerFactory, testName: $"ConnectionCanSendAndReceiveMessages_{transportType.ToString()}")) @@ -88,6 +90,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests var url = _serverFixture.BaseUrl + "/echo"; var connection = new HttpConnection(new Uri(url), transportType, loggerFactory); + + connection.Features.Set( + new TransferModeFeature { TransferMode = TransferMode.Text }); try { var receiveTcs = new TaskCompletionSource(); @@ -163,6 +168,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests var url = _serverFixture.BaseUrl + "/echo"; var connection = new HttpConnection(new Uri(url), loggerFactory); + connection.Features.Set( + new TransferModeFeature { TransferMode = TransferMode.Binary }); + try { var receiveTcs = new TaskCompletionSource(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs index 07b64d9026..2ea9c73e36 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/WebSocketsTransportTests.cs @@ -5,6 +5,7 @@ using System; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.AspNetCore.Testing.xunit; @@ -40,7 +41,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var webSocketsTransport = new WebSocketsTransport(loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, connectionId: string.Empty).OrTimeout(); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + TransferMode.Binary, connectionId: string.Empty).OrTimeout(); await webSocketsTransport.StopAsync().OrTimeout(); await webSocketsTransport.Running.OrTimeout(); } @@ -57,15 +59,18 @@ namespace Microsoft.AspNetCore.SignalR.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var webSocketsTransport = new WebSocketsTransport(loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, connectionId: string.Empty); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + TransferMode.Binary, connectionId: string.Empty); connectionToTransport.Out.TryComplete(); await webSocketsTransport.Running.OrTimeout(); } } - [ConditionalFact] + [ConditionalTheory] [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] - public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer() + [InlineData(TransferMode.Text)] + [InlineData(TransferMode.Binary)] + public async Task WebSocketsTransportStopsWhenConnectionClosedByTheServer(TransferMode transferMode) { using (StartLog(out var loggerFactory)) { @@ -74,7 +79,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); var webSocketsTransport = new WebSocketsTransport(loggerFactory); - await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, connectionId: string.Empty); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, transferMode, connectionId: string.Empty); var sendTcs = new TaskCompletionSource(); connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs)); @@ -86,5 +91,48 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.Equal(new byte[] { 0x42 }, buffer); } } + + [ConditionalTheory] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + [InlineData(TransferMode.Text)] + [InlineData(TransferMode.Binary)] + public async Task WebSocketsTransportSetsTransferMode(TransferMode transferMode) + { + using (StartLog(out var loggerFactory)) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var webSocketsTransport = new WebSocketsTransport(loggerFactory); + + Assert.Null(webSocketsTransport.Mode); + await webSocketsTransport.StartAsync(new Uri(_serverFixture.WebSocketsUrl + "/echo"), channelConnection, + transferMode, connectionId: string.Empty).OrTimeout(); + Assert.Equal(transferMode, webSocketsTransport.Mode); + + await webSocketsTransport.StopAsync().OrTimeout(); + await webSocketsTransport.Running.OrTimeout(); + } + } + + [ConditionalFact] + [OSSkipCondition(OperatingSystems.Windows, WindowsVersions.Win7, WindowsVersions.Win2008R2, SkipReason = "No WebSockets Client for this platform")] + public async Task WebSocketsTransportThrowsForInvalidTransferMode() + { + using (StartLog(out var loggerFactory)) + { + var connectionToTransport = Channel.CreateUnbounded(); + var transportToConnection = Channel.CreateUnbounded(); + var channelConnection = new ChannelConnection(connectionToTransport, transportToConnection); + + var webSocketsTransport = new WebSocketsTransport(loggerFactory); + var exception = await Assert.ThrowsAsync(() => + webSocketsTransport.StartAsync(new Uri("http://fakeuri.org"), channelConnection, TransferMode.Text | TransferMode.Binary, connectionId: string.Empty)); + + Assert.Contains("Invalid transfer mode.", exception.Message); + Assert.Equal("requestedTransferMode", exception.ParamName); + } + } } }