Introducing HubProtocolReaderWriter

This commit is contained in:
Pawel Kadluczka 2017-08-04 16:34:42 -07:00 committed by Pawel Kadluczka
parent ad4784dbd2
commit 3a1d4c5dd6
11 changed files with 99 additions and 122 deletions

View File

@ -29,7 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly IConnection _connection;
private readonly IHubProtocol _protocol;
private readonly HubBinder _binder;
private IDataEncoder _encoder;
private HubProtocolReaderWriter _protocolReaderWriter;
private readonly object _pendingCallsLock = new object();
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
@ -98,18 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
await _connection.StartAsync();
var actualTransferMode = transferModeFeature.TransferMode;
if (requestedTransferMode == TransferMode.Binary && actualTransferMode == TransferMode.Text)
{
// This is for instance for SSE which is a Text protocol and the user wants to use a binary
// protocol so we need to encode messages.
_encoder = new Base64Encoder();
}
else
{
Debug.Assert(requestedTransferMode == actualTransferMode, "All transports besides SSE are expected to support binary mode.");
_encoder = new PassThroughEncoder();
}
_protocolReaderWriter = new HubProtocolReaderWriter(_protocol, GetDataEncoder(requestedTransferMode, actualTransferMode));
using (var memoryStream = new MemoryStream())
{
@ -118,6 +107,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private IDataEncoder GetDataEncoder(TransferMode requestedTransferMode, TransferMode actualTransferMode)
{
if (requestedTransferMode == TransferMode.Binary && actualTransferMode == TransferMode.Text)
{
// This is for instance for SSE which is a Text protocol and the user wants to use a binary
// protocol so we need to encode messages.
return new Base64Encoder();
}
Debug.Assert(requestedTransferMode == actualTransferMode, "All transports besides SSE are expected to support binary mode.");
return new PassThroughEncoder();
}
public async Task DisposeAsync()
{
await _connection.DisposeAsync();
@ -189,8 +192,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
try
{
var payload = _encoder.Encode(_protocol.WriteToArray(invocationMessage));
var payload = _protocolReaderWriter.WriteMessage(invocationMessage);
_logger.LogInformation("Sending Invocation '{invocationId}'", invocationMessage.InvocationId);
await _connection.SendAsync(payload, irq.CancellationToken);
@ -206,9 +208,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
private async Task OnDataReceivedAsync(byte[] data)
{
data = _encoder.Decode(data);
if (_protocol.TryParseMessages(data, _binder, out var messages))
if (_protocolReaderWriter.ReadMessages(data, _binder, out var messages))
{
foreach (var message in messages)
{

View File

@ -0,0 +1,37 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.IO;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public class HubProtocolReaderWriter
{
private readonly IHubProtocol _hubProtocol;
private readonly IDataEncoder _dataEncoder;
public HubProtocolReaderWriter(IHubProtocol hubProtocol, IDataEncoder dataEncoder)
{
_hubProtocol = hubProtocol;
_dataEncoder = dataEncoder;
}
public bool ReadMessages(byte[] input, IInvocationBinder binder, out IList<HubMessage> messages)
{
var buffer = _dataEncoder.Decode(input);
return _hubProtocol.TryParseMessages(buffer, binder, out messages);
}
public byte[] WriteMessage(HubMessage hubMessage)
{
using (var ms = new MemoryStream())
{
_hubProtocol.WriteMessage(hubMessage, ms);
return _dataEncoder.Encode(ms.ToArray());
}
}
}
}

View File

@ -1,21 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
namespace Microsoft.AspNetCore.SignalR.Internal.Protocol
{
public static class HubProtocolWriteMessageExtensions
{
public static byte[] WriteToArray(this IHubProtocol protocol, HubMessage message)
{
using (var output = new MemoryStream())
{
// Encode the message
protocol.WriteMessage(message, output);
return output.ToArray();
}
}
}
}

View File

@ -1,17 +0,0 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
namespace Microsoft.AspNetCore.SignalR.Features
{
public interface IDataEncoderFeature
{
IDataEncoder DataEncoder { get; set; }
}
public class DataEncoderFeature : IDataEncoderFeature
{
public IDataEncoder DataEncoder { get; set; }
}
}

View File

@ -1,17 +1,17 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.AspNetCore.SignalR.Features
{
public interface IHubFeature
{
IHubProtocol Protocol { get; set; }
HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
}
public class HubFeature : IHubFeature
{
public IHubProtocol Protocol { get; set; }
public HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
}
}

View File

@ -6,7 +6,7 @@ using System.Security.Claims;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.SignalR.Features;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Features;
@ -26,8 +26,6 @@ namespace Microsoft.AspNetCore.SignalR
private IHubFeature HubFeature => Features.Get<IHubFeature>();
private IDataEncoderFeature DataEncoderFeature => Features.Get<IDataEncoderFeature>();
// Used by the HubEndPoint only
internal ReadableChannel<byte[]> Input => _connectionContext.Transport;
@ -39,17 +37,7 @@ namespace Microsoft.AspNetCore.SignalR
public virtual IDictionary<object, object> Metadata => _connectionContext.Metadata;
public virtual IHubProtocol Protocol
{
get => HubFeature.Protocol;
set => HubFeature.Protocol = value;
}
public IDataEncoder DataEncoder
{
get => DataEncoderFeature.DataEncoder;
set => DataEncoderFeature.DataEncoder = value;
}
public virtual HubProtocolReaderWriter ProtocolReaderWriter { get; set; }
public virtual WritableChannel<HubMessage> Output => _output;
}

View File

@ -58,25 +58,21 @@ namespace Microsoft.AspNetCore.SignalR
// Set the hub feature before doing anything else. This stores
// all the relevant state for a SignalR Hub connection.
connection.Features.Set<IHubFeature>(new HubFeature());
connection.Features.Set<IDataEncoderFeature>(new DataEncoderFeature());
var connectionContext = new HubConnectionContext(output, connection);
await ProcessNegotiate(connectionContext);
var encoder = connectionContext.DataEncoder;
// Hubs support multiple producers so we set up this loop to copy
// data written to the HubConnectionContext's channel to the transport channel
var protocolReaderWriter = connectionContext.ProtocolReaderWriter;
async Task WriteToTransport()
{
while (await output.In.WaitToReadAsync())
{
while (output.In.TryRead(out var hubMessage))
{
var buffer = connectionContext.Protocol.WriteToArray(hubMessage);
buffer = encoder.Encode(buffer);
var buffer = protocolReaderWriter.WriteMessage(hubMessage);
while (await connection.Transport.Out.WaitToWriteAsync())
{
if (connection.Transport.Out.TryWrite(buffer))
@ -115,23 +111,16 @@ namespace Microsoft.AspNetCore.SignalR
{
if (NegotiationProtocol.TryParseMessage(buffer, out var negotiationMessage))
{
// Resolve the Hub Protocol for the connection and store it in metadata
// Other components, outside the Hub, may need to know what protocol is in use
// for a particular connection, so we store it here.
var protocol = _protocolResolver.GetProtocol(negotiationMessage.Protocol, connection);
connection.Protocol = protocol;
var transportCapabilities = connection.Features.Get<IConnectionTransportFeature>()?.TransportCapabilities
?? throw new InvalidOperationException("Unable to read transport capabilities.");
if (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0)
{
connection.DataEncoder = Base64Encoder;
}
else
{
connection.DataEncoder = PassThroughEncoder;
}
var dataEncoder = (protocol.Type == ProtocolType.Binary && (transportCapabilities & TransferMode.Binary) == 0)
? (IDataEncoder)Base64Encoder
: PassThroughEncoder;
connection.ProtocolReaderWriter = new HubProtocolReaderWriter(protocol, dataEncoder);
return;
}
@ -217,7 +206,6 @@ namespace Microsoft.AspNetCore.SignalR
// is used to get the exception so we can bubble it up the stack
var cts = new CancellationTokenSource();
var completion = new TaskCompletionSource<object>();
var protocol = connection.Protocol;
try
{
@ -225,9 +213,7 @@ namespace Microsoft.AspNetCore.SignalR
{
while (connection.Input.TryRead(out var buffer))
{
buffer = connection.DataEncoder.Decode(buffer);
if (protocol.TryParseMessages(buffer, this, out var hubMessages))
if (connection.ProtocolReaderWriter.ReadMessages(buffer, this, out var hubMessages))
{
foreach (var hubMessage in hubMessages)
{
@ -241,7 +227,7 @@ namespace Microsoft.AspNetCore.SignalR
// Don't wait on the result of execution, continue processing other
// incoming messages on this connection.
var ignore = ProcessInvocation(connection, protocol, invocationMessage, cts, completion);
var ignore = ProcessInvocation(connection, invocationMessage, cts, completion);
break;
// Other kind of message we weren't expecting
@ -262,7 +248,6 @@ namespace Microsoft.AspNetCore.SignalR
}
private async Task ProcessInvocation(HubConnectionContext connection,
IHubProtocol protocol,
InvocationMessage invocationMessage,
CancellationTokenSource dispatcherCancellation,
TaskCompletionSource<object> dispatcherCompletion)
@ -271,7 +256,7 @@ namespace Microsoft.AspNetCore.SignalR
{
// If an unexpected exception occurs then we want to kill the entire connection
// by ending the processing loop
await Execute(connection, protocol, invocationMessage);
await Execute(connection, invocationMessage);
}
catch (Exception ex)
{
@ -283,21 +268,21 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task Execute(HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
private async Task Execute(HubConnectionContext connection, InvocationMessage invocationMessage)
{
if (!_methods.TryGetValue(invocationMessage.Target, out var descriptor))
{
// Send an error to the client. Then let the normal completion process occur
_logger.LogError("Unknown hub method '{method}'", invocationMessage.Target);
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, $"Unknown hub method '{invocationMessage.Target}'"));
}
else
{
await Invoke(descriptor, connection, protocol, invocationMessage);
await Invoke(descriptor, connection, invocationMessage);
}
}
private async Task SendMessageAsync(HubConnectionContext connection, IHubProtocol protocol, HubMessage hubMessage)
private async Task SendMessageAsync(HubConnectionContext connection, HubMessage hubMessage)
{
while (await connection.Output.WaitToWriteAsync())
{
@ -312,7 +297,7 @@ namespace Microsoft.AspNetCore.SignalR
throw new OperationCanceledException("Outbound channel was closed while trying to write hub message");
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, IHubProtocol protocol, InvocationMessage invocationMessage)
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection, InvocationMessage invocationMessage)
{
var methodExecutor = descriptor.MethodExecutor;
@ -323,7 +308,7 @@ namespace Microsoft.AspNetCore.SignalR
_logger.LogDebug("Failed to invoke {hubMethod} because user is unauthorized", invocationMessage.Target);
if (!invocationMessage.NonBlocking)
{
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized"));
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized"));
}
return;
}
@ -357,12 +342,12 @@ namespace Microsoft.AspNetCore.SignalR
if (IsStreamed(methodExecutor, result, methodExecutor.MethodReturnType, out var enumerator))
{
_logger.LogTrace("[{connectionId}/{invocationId}] Streaming result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await StreamResultsAsync(invocationMessage.InvocationId, connection, protocol, enumerator);
await StreamResultsAsync(invocationMessage.InvocationId, connection, enumerator);
}
else if (!invocationMessage.NonBlocking)
{
_logger.LogTrace("[{connectionId}/{invocationId}] Sending result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName);
await SendMessageAsync(connection, protocol, CompletionMessage.WithResult(invocationMessage.InvocationId, result));
await SendMessageAsync(connection, CompletionMessage.WithResult(invocationMessage.InvocationId, result));
}
}
catch (TargetInvocationException ex)
@ -370,7 +355,7 @@ namespace Microsoft.AspNetCore.SignalR
_logger.LogError(0, ex, "Failed to invoke hub method");
if (!invocationMessage.NonBlocking)
{
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message));
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message));
}
}
catch (Exception ex)
@ -378,7 +363,7 @@ namespace Microsoft.AspNetCore.SignalR
_logger.LogError(0, ex, "Failed to invoke hub method");
if (!invocationMessage.NonBlocking)
{
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message));
await SendMessageAsync(connection, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message));
}
}
finally
@ -410,7 +395,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IHubProtocol protocol, IAsyncEnumerator<object> enumerator)
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection,IAsyncEnumerator<object> enumerator)
{
// TODO: Cancellation? See https://github.com/aspnet/SignalR/issues/481
try
@ -418,14 +403,14 @@ namespace Microsoft.AspNetCore.SignalR
while (await enumerator.MoveNextAsync())
{
// Send the stream item
await SendMessageAsync(connection, protocol, new StreamItemMessage(invocationId, enumerator.Current));
await SendMessageAsync(connection, new StreamItemMessage(invocationId, enumerator.Current));
}
await SendMessageAsync(connection, protocol, CompletionMessage.Empty(invocationId));
await SendMessageAsync(connection, CompletionMessage.Empty(invocationId));
}
catch (Exception ex)
{
await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationId, ex.Message));
await SendMessageAsync(connection, CompletionMessage.WithError(invocationId, ex.Message));
}
}

View File

@ -44,7 +44,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
using (StartLog(out var loggerFactory))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, protocol, loggerFactory);
try

View File

@ -61,10 +61,14 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
};
var protocol = new JsonHubProtocol(jsonSerializer);
var encoded = protocol.WriteToArray(message);
var json = Encoding.UTF8.GetString(encoded);
Assert.Equal(expectedOutput, json);
using (var ms = new MemoryStream())
{
protocol.WriteMessage(message, ms);
var json = Encoding.UTF8.GetString(ms.ToArray());
Assert.Equal(expectedOutput, json);
}
}
[Theory]

View File

@ -9,6 +9,7 @@ using System.Threading;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.SignalR.Internal;
using Microsoft.AspNetCore.SignalR.Internal.Encoders;
using Microsoft.AspNetCore.SignalR.Internal.Protocol;
using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Internal;
@ -19,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public class TestClient : IDisposable, IInvocationBinder
{
private static int _id;
private IHubProtocol _protocol;
private readonly HubProtocolReaderWriter _protocolReaderWriter;
private CancellationTokenSource _cts;
private ChannelConnection<byte[]> _transport;
@ -40,13 +41,14 @@ namespace Microsoft.AspNetCore.SignalR.Tests
Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) }));
Connection.Metadata["ConnectedTask"] = new TaskCompletionSource<bool>();
_protocol = new JsonHubProtocol(new JsonSerializer());
var protocol = new JsonHubProtocol(new JsonSerializer());
_protocolReaderWriter = new HubProtocolReaderWriter(protocol, new PassThroughEncoder());
_cts = new CancellationTokenSource();
using (var memoryStream = new MemoryStream())
{
NegotiationProtocol.WriteMessage(new NegotiationMessage(_protocol.Name), memoryStream);
NegotiationProtocol.WriteMessage(new NegotiationMessage(protocol.Name), memoryStream);
Application.Out.TryWrite(memoryStream.ToArray());
}
}
@ -122,8 +124,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public async Task<string> SendInvocationAsync(string methodName, bool nonBlocking, params object[] args)
{
var invocationId = GetInvocationId();
var payload = _protocol.WriteToArray(new InvocationMessage(invocationId, nonBlocking, methodName, args));
var payload = _protocolReaderWriter.WriteMessage(new InvocationMessage(invocationId, nonBlocking, methodName, args));
await Application.Out.WriteAsync(payload);
return invocationId;
@ -152,7 +154,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public HubMessage TryRead()
{
if (Application.In.TryRead(out var buffer) &&
_protocol.TryParseMessages(buffer, this, out var messages))
_protocolReaderWriter.ReadMessages(buffer, this, out var messages))
{
return messages[0];
}

View File

@ -84,7 +84,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var sendTcs = new TaskCompletionSource<object>();
connectionToTransport.Out.TryWrite(new SendMessage(new byte[] { 0x42 }, sendTcs));
await sendTcs.Task;
// The echo endpoint close the connection immediately after sending response which should stop the transport
// The echo endpoint closes the connection immediately after sending response which should stop the transport
await webSocketsTransport.Running.OrTimeout();
Assert.True(transportToConnection.In.TryRead(out var buffer));