diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index cf041618c0..a0e296bf01 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -4,16 +4,18 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Diagnostics; using System.IO; using System.Linq; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; -using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Newtonsoft.Json; @@ -27,6 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Client private readonly IConnection _connection; private readonly IHubProtocol _protocol; private readonly HubBinder _binder; + private IDataEncoder _encoder; private readonly object _pendingCallsLock = new object(); private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); @@ -86,12 +89,27 @@ namespace Microsoft.AspNetCore.SignalR.Client _connection.Features.Set(transferModeFeature); } - transferModeFeature.TransferMode = - (_protocol.Type == ProtocolType.Binary) + var requestedTransferMode = + _protocol.Type == ProtocolType.Binary ? TransferMode.Binary : TransferMode.Text; + transferModeFeature.TransferMode = requestedTransferMode; await _connection.StartAsync(); + var actualTransferMode = transferModeFeature.TransferMode; + + if (requestedTransferMode == TransferMode.Binary && actualTransferMode == TransferMode.Text) + { + // This is for instance for SSE which is a Text protocol and the user wants to use a binary + // protocol so we need to encode messages. + _encoder = new Base64Encoder(); + } + else + { + Debug.Assert(requestedTransferMode == actualTransferMode, "All transports besides SSE are expected to support binary mode."); + + _encoder = new PassThroughEncoder(); + } using (var memoryStream = new MemoryStream()) { @@ -171,7 +189,7 @@ namespace Microsoft.AspNetCore.SignalR.Client { try { - var payload = _protocol.WriteToArray(invocationMessage); + var payload = _encoder.Encode(_protocol.WriteToArray(invocationMessage)); _logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId); @@ -188,6 +206,8 @@ namespace Microsoft.AspNetCore.SignalR.Client private async Task OnDataReceivedAsync(byte[] data) { + data = _encoder.Decode(data); + if (_protocol.TryParseMessages(data, _binder, out var messages)) { foreach (var message in messages) diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs new file mode 100644 index 0000000000..f35965243e --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/Base64Encoder.cs @@ -0,0 +1,21 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Text; + +namespace Microsoft.AspNetCore.SignalR.Internal.Encoders +{ + public class Base64Encoder : IDataEncoder + { + public byte[] Decode(byte[] payload) + { + return Convert.FromBase64String(Encoding.UTF8.GetString(payload)); + } + + public byte[] Encode(byte[] payload) + { + return Encoding.UTF8.GetBytes(Convert.ToBase64String(payload)); + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs new file mode 100644 index 0000000000..07b724dfcf --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/IDataEncoder.cs @@ -0,0 +1,11 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Encoders +{ + public interface IDataEncoder + { + byte[] Encode(byte[] payload); + byte[] Decode(byte[] payload); + } +} diff --git a/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs new file mode 100644 index 0000000000..a3e5016495 --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR.Common/Internal/Encoders/PassThroughEncoder.cs @@ -0,0 +1,18 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +namespace Microsoft.AspNetCore.SignalR.Internal.Encoders +{ + public class PassThroughEncoder : IDataEncoder + { + public byte[] Decode(byte[] payload) + { + return payload; + } + + public byte[] Encode(byte[] payload) + { + return payload; + } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/Features/IDataEncoderFeature.cs b/src/Microsoft.AspNetCore.SignalR/Features/IDataEncoderFeature.cs new file mode 100644 index 0000000000..ff36291a5c --- /dev/null +++ b/src/Microsoft.AspNetCore.SignalR/Features/IDataEncoderFeature.cs @@ -0,0 +1,17 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using Microsoft.AspNetCore.SignalR.Internal.Encoders; + +namespace Microsoft.AspNetCore.SignalR.Features +{ + public interface IDataEncoderFeature + { + IDataEncoder DataEncoder { get; set; } + } + + public class DataEncoderFeature : IDataEncoderFeature + { + public IDataEncoder DataEncoder { get; set; } + } +} diff --git a/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs index 0269c4458a..9755cb0616 100644 --- a/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs +++ b/src/Microsoft.AspNetCore.SignalR/Features/IHubFeature.cs @@ -1,9 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; -using System.Collections.Generic; -using System.Text; using Microsoft.AspNetCore.SignalR.Internal.Protocol; namespace Microsoft.AspNetCore.SignalR.Features diff --git a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs index 9159d25c49..b84efb3b81 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubConnectionContext.cs @@ -1,12 +1,12 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System; using System.Collections.Generic; using System.Security.Claims; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.SignalR.Features; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Features; @@ -26,6 +26,8 @@ namespace Microsoft.AspNetCore.SignalR private IHubFeature HubFeature => Features.Get(); + private IDataEncoderFeature DataEncoderFeature => Features.Get(); + // Used by the HubEndPoint only internal ReadableChannel Input => _connectionContext.Transport; @@ -43,6 +45,12 @@ namespace Microsoft.AspNetCore.SignalR set => HubFeature.Protocol = value; } + public IDataEncoder DataEncoder + { + get => DataEncoderFeature.DataEncoder; + set => DataEncoderFeature.DataEncoder = value; + } + public virtual WritableChannel Output => _output; } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index bafc117c11..26db9b46e4 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -15,14 +15,19 @@ using Microsoft.AspNetCore.SignalR.Features; using Microsoft.AspNetCore.SignalR.Internal; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.Sockets; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Internal; using Microsoft.Extensions.Logging; +using Microsoft.AspNetCore.SignalR.Internal.Encoders; namespace Microsoft.AspNetCore.SignalR { public class HubEndPoint : IInvocationBinder where THub : Hub { + private static readonly Base64Encoder Base64Encoder = new Base64Encoder(); + private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder(); + private readonly Dictionary _methods = new Dictionary(StringComparer.OrdinalIgnoreCase); private readonly HubLifetimeManager _lifetimeManager; @@ -51,13 +56,16 @@ namespace Microsoft.AspNetCore.SignalR var output = Channel.CreateUnbounded(); // Set the hub feature before doing anything else. This stores - // all the relevant state for a SignalR Hub connection + // all the relevant state for a SignalR Hub connection. connection.Features.Set(new HubFeature()); + connection.Features.Set(new DataEncoderFeature()); var connectionContext = new HubConnectionContext(output, connection); await ProcessNegotiate(connectionContext); + var encoder = connectionContext.DataEncoder; + // Hubs support multiple producers so we set up this loop to copy // data written to the HubConnectionContext's channel to the transport channel async Task WriteToTransport() @@ -66,6 +74,8 @@ namespace Microsoft.AspNetCore.SignalR { while (output.In.TryRead(out var buffer)) { + buffer = encoder.Encode(buffer); + while (await connection.Transport.Out.WaitToWriteAsync()) { if (connection.Transport.Out.TryWrite(buffer)) @@ -107,7 +117,20 @@ namespace Microsoft.AspNetCore.SignalR // Resolve the Hub Protocol for the connection and store it in metadata // Other components, outside the Hub, may need to know what protocol is in use // for a particular connection, so we store it here. - connection.Protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection); + connection.Protocol = protocol; + + var transportCapabilities = connection.Features.Get()?.TransportCapabilities + ?? throw new InvalidOperationException("Unable to read transport capabilities."); + + if (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0) + { + connection.DataEncoder = Base64Encoder; + } + else + { + connection.DataEncoder = PassThroughEncoder; + } return; } @@ -201,6 +224,8 @@ namespace Microsoft.AspNetCore.SignalR { while (connection.Input.TryRead(out var buffer)) { + buffer = connection.DataEncoder.Decode(buffer); + if (protocol.TryParseMessages(buffer, this, out var hubMessages)) { foreach (var hubMessage in hubMessages) diff --git a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs index 1dec7f4cfc..e6da126131 100644 --- a/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs +++ b/src/Microsoft.AspNetCore.SignalR/Internal/DefaultHubProtocolResolver.cs @@ -3,7 +3,6 @@ using System; using Microsoft.AspNetCore.SignalR.Internal.Protocol; -using Microsoft.AspNetCore.Sockets; using Newtonsoft.Json; namespace Microsoft.AspNetCore.SignalR.Internal diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index f2cca4d845..6b57c4473f 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -2,11 +2,15 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.IO; +using System.Text; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.SignalR.Internal.Protocol; using Microsoft.AspNetCore.SignalR.Tests.Common; +using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.Logging; +using Moq; using Newtonsoft.Json; using Xunit; @@ -320,5 +324,57 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.DisposeAsync().OrTimeout(); } } + + [Fact] + public async Task MessagesEncodedWhenUsingBinaryProtocolOverTextTransport() + { + var connection = new TestConnection(TransferMode.Text); + + var hubConnection = new HubConnection(connection, new MessagePackHubProtocol(), new LoggerFactory()); + try + { + await hubConnection.StartAsync().OrTimeout(); + await hubConnection.SendAsync("MyMethod", 42).OrTimeout(); + + await connection.ReadSentTextMessageAsync().OrTimeout(); + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + // this throws if the message is not a valid base64 string + Convert.FromBase64String(invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + + [Fact] + public async Task MessagesDecodedWhenUsingBinaryProtocolOverTextTransport() + { + var connection = new TestConnection(TransferMode.Text); + var hubConnection = new HubConnection(connection, new MessagePackHubProtocol(), new LoggerFactory()); + + var invocationTcs = new TaskCompletionSource(); + try + { + await hubConnection.StartAsync().OrTimeout(); + hubConnection.On("MyMethod", result => invocationTcs.SetResult(result)); + + using (var ms = new MemoryStream()) + { + new MessagePackHubProtocol().WriteMessage(new InvocationMessage("1", true, "MyMethod", 42), ms); + var invokeMessage = Convert.ToBase64String(ms.ToArray()); + connection.ReceivedMessages.TryWrite(Encoding.UTF8.GetBytes(invokeMessage)); + } + + Assert.Equal(42, await invocationTcs.Task); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index ffd89f1a1b..ffe1b7f314 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -8,10 +8,11 @@ using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Http.Features; +using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; +using Microsoft.AspNetCore.Sockets.Features; using Microsoft.AspNetCore.Sockets.Internal.Formatters; using Newtonsoft.Json; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests { @@ -26,6 +27,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource(); private Task _receiveLoop; + private TransferMode? _transferMode; + public event Func Connected; public event Func Received; public event Func Closed; @@ -37,8 +40,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public IFeatureCollection Features { get; } = new FeatureCollection(); - public TestConnection() + public TestConnection(TransferMode? transferMode = null) { + _transferMode = transferMode; _receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token); } @@ -68,6 +72,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public Task StartAsync() { + if (_transferMode.HasValue) + { + var transferModeFeature = Features.Get(); + if (transferModeFeature == null) + { + transferModeFeature = new TransferModeFeature(); + Features.Set(transferModeFeature); + } + + transferModeFeature.TransferMode = _transferMode.Value; + } + _started.TrySetResult(null); Connected?.Invoke(); return Task.CompletedTask;