Enabling binary protocols over text transports

This commit is contained in:
Pawel Kadluczka 2017-07-18 18:14:16 -07:00 committed by Pawel Kadluczka
parent f9ee7911a5
commit a0e490e549
11 changed files with 201 additions and 13 deletions

View File

@ -4,16 +4,18 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Newtonsoft.Json;
@ -27,6 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly IConnection _connection;
private readonly IHubProtocol _protocol;
private readonly HubBinder _binder;
private IDataEncoder _encoder;
private readonly object _pendingCallsLock = new object();
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
@ -86,12 +89,27 @@ namespace Microsoft.AspNetCore.SignalR.Client
_connection.Features.Set(transferModeFeature);
}
transferModeFeature.TransferMode =
(_protocol.Type == ProtocolType.Binary)
var requestedTransferMode =
_protocol.Type == ProtocolType.Binary
? TransferMode.Binary
: TransferMode.Text;
transferModeFeature.TransferMode = requestedTransferMode;
await _connection.StartAsync();
var actualTransferMode = transferModeFeature.TransferMode;
if (requestedTransferMode == TransferMode.Binary && actualTransferMode == TransferMode.Text)
{
// This is for instance for SSE which is a Text protocol and the user wants to use a binary
// protocol so we need to encode messages.
_encoder = new Base64Encoder();
}
else
{
Debug.Assert(requestedTransferMode == actualTransferMode, "All transports besides SSE are expected to support binary mode.");
_encoder = new PassThroughEncoder();
}
using (var memoryStream = new MemoryStream())
{
@ -171,7 +189,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
try
{
var payload = _protocol.WriteToArray(invocationMessage);
var payload = _encoder.Encode(_protocol.WriteToArray(invocationMessage));
_logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId);
@ -188,6 +206,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
private async Task OnDataReceivedAsync(byte[] data)
{
data = _encoder.Decode(data);
if (_protocol.TryParseMessages(data, _binder, out var messages))
{
foreach (var message in messages)

View File

@ -0,0 +1,21 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Text;
namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
{
public class Base64Encoder : IDataEncoder
{
public byte[] Decode(byte[] payload)
{
return Convert.FromBase64String(Encoding.UTF8.GetString(payload));
}
public byte[] Encode(byte[] payload)
{
return Encoding.UTF8.GetBytes(Convert.ToBase64String(payload));
}
}
}

View File

@ -0,0 +1,11 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
{
public interface IDataEncoder
{
byte[] Encode(byte[] payload);
byte[] Decode(byte[] payload);
}
}

View File

@ -0,0 +1,18 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Internal.Encoders
{
public class PassThroughEncoder : IDataEncoder
{
public byte[] Decode(byte[] payload)
{
return payload;
}
public byte[] Encode(byte[] payload)
{
return payload;
}
}
}

View File

@ -0,0 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
namespace Microsoft.AspNetCore.SignalR.Features
{
public interface IDataEncoderFeature
{
IDataEncoder DataEncoder { get; set; }
}
public class DataEncoderFeature : IDataEncoderFeature
{
public IDataEncoder DataEncoder { get; set; }
}
}

View File

@ -1,9 +1,6 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Features

View File

@ -1,12 +1,12 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
@ -26,6 +26,8 @@ namespace Microsoft.AspNetCore.SignalR
private IHubFeature HubFeature => Features.Get<IHubFeature>();
private IDataEncoderFeature DataEncoderFeature => Features.Get<IDataEncoderFeature>();
// Used by the HubEndPoint only
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
@ -43,6 +45,12 @@ namespace Microsoft.AspNetCore.SignalR
set => HubFeature.Protocol = value;
}
public IDataEncoder DataEncoder
{
get => DataEncoderFeature.DataEncoder;
set => DataEncoderFeature.DataEncoder = value;
}
public virtual WritableChannel<byte[]> Output => _output;
}
}

View File

@ -15,14 +15,19 @@ using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.Logging;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
namespace Microsoft.AspNetCore.SignalR
{
public class HubEndPoint<THub> : IInvocationBinder where THub : Hub
{
private static readonly Base64Encoder Base64Encoder = new Base64Encoder();
private static readonly PassThroughEncoder PassThroughEncoder = new PassThroughEncoder();
private readonly Dictionary<string, HubMethodDescriptor> _methods = new Dictionary<string, HubMethodDescriptor>(StringComparer.OrdinalIgnoreCase);
private readonly HubLifetimeManager<THub> _lifetimeManager;
@ -51,13 +56,16 @@ namespace Microsoft.AspNetCore.SignalR
var output = Channel.CreateUnbounded<byte[]>();
// Set the hub feature before doing anything else. This stores
// all the relevant state for a SignalR Hub connection
// all the relevant state for a SignalR Hub connection.
connection.Features.Set<IHubFeature>(new HubFeature());
connection.Features.Set<IDataEncoderFeature>(new DataEncoderFeature());
var connectionContext = new HubConnectionContext(output, connection);
await ProcessNegotiate(connectionContext);
var encoder = connectionContext.DataEncoder;
// Hubs support multiple producers so we set up this loop to copy
// data written to the HubConnectionContext's channel to the transport channel
async Task WriteToTransport()
@ -66,6 +74,8 @@ namespace Microsoft.AspNetCore.SignalR
{
while (output.In.TryRead(out var buffer))
{
buffer = encoder.Encode(buffer);
while (await connection.Transport.Out.WaitToWriteAsync())
{
if (connection.Transport.Out.TryWrite(buffer))
@ -107,7 +117,20 @@ namespace Microsoft.AspNetCore.SignalR
// Resolve the Hub Protocol for the connection and store it in metadata
// Other components, outside the Hub, may need to know what protocol is in use
// for a particular connection, so we store it here.
connection.Protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection);
var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection);
connection.Protocol = protocol;
var transportCapabilities = connection.Features.Get<IConnectionTransportFeature>()?.TransportCapabilities
?? throw new InvalidOperationException("Unable to read transport capabilities.");
if (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0)
{
connection.DataEncoder = Base64Encoder;
}
else
{
connection.DataEncoder = PassThroughEncoder;
}
return;
}
@ -201,6 +224,8 @@ namespace Microsoft.AspNetCore.SignalR
{
while (connection.Input.TryRead(out var buffer))
{
buffer = connection.DataEncoder.Decode(buffer);
if (protocol.TryParseMessages(buffer, this, out var hubMessages))
{
foreach (var hubMessage in hubMessages)

View File

@ -3,7 +3,6 @@
using System;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Newtonsoft.Json;
namespace Microsoft.AspNetCore.SignalR.Internal

View File

@ -2,11 +2,15 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.IO;
using System.Text;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Logging;
using Moq;
using Newtonsoft.Json;
using Xunit;
@ -320,5 +324,57 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task MessagesEncodedWhenUsingBinaryProtocolOverTextTransport()
{
var connection = new TestConnection(TransferMode.Text);
var hubConnection = new HubConnection(connection, new MessagePackHubProtocol(), new LoggerFactory());
try
{
await hubConnection.StartAsync().OrTimeout();
await hubConnection.SendAsync("MyMethod", 42).OrTimeout();
await connection.ReadSentTextMessageAsync().OrTimeout();
var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout();
// this throws if the message is not a valid base64 string
Convert.FromBase64String(invokeMessage);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
[Fact]
public async Task MessagesDecodedWhenUsingBinaryProtocolOverTextTransport()
{
var connection = new TestConnection(TransferMode.Text);
var hubConnection = new HubConnection(connection, new MessagePackHubProtocol(), new LoggerFactory());
var invocationTcs = new TaskCompletionSource<int>();
try
{
await hubConnection.StartAsync().OrTimeout();
hubConnection.On<int>("MyMethod", result => invocationTcs.SetResult(result));
using (var ms = new MemoryStream())
{
new MessagePackHubProtocol().WriteMessage(new InvocationMessage("1", true, "MyMethod", 42), ms);
var invokeMessage = Convert.ToBase64String(ms.ToArray());
connection.ReceivedMessages.TryWrite(Encoding.UTF8.GetBytes(invokeMessage));
}
Assert.Equal(42, await invocationTcs.Task);
}
finally
{
await hubConnection.DisposeAsync().OrTimeout();
await connection.DisposeAsync().OrTimeout();
}
}
}
}

View File

@ -8,10 +8,11 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.AspNetCore.Sockets.Features;
using Microsoft.AspNetCore.Sockets.Internal.Formatters;
using Newtonsoft.Json;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
@ -26,6 +27,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
private CancellationTokenSource _receiveShutdownToken = new CancellationTokenSource();
private Task _receiveLoop;
private TransferMode? _transferMode;
public event Func<Task> Connected;
public event Func<byte[], Task> Received;
public event Func<Exception, Task> Closed;
@ -37,8 +40,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public IFeatureCollection Features { get; } = new FeatureCollection();
public TestConnection()
public TestConnection(TransferMode? transferMode = null)
{
_transferMode = transferMode;
_receiveLoop = ReceiveLoopAsync(_receiveShutdownToken.Token);
}
@ -68,6 +72,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public Task StartAsync()
{
if (_transferMode.HasValue)
{
var transferModeFeature = Features.Get<ITransferModeFeature>();
if (transferModeFeature == null)
{
transferModeFeature = new TransferModeFeature();
Features.Set(transferModeFeature);
}
transferModeFeature.TransferMode = _transferMode.Value;
}
_started.TrySetResult(null);
Connected?.Invoke();
return Task.CompletedTask;