Add support for streaming arguments to clients using Channel (#2441)

This commit is contained in:
Dylan Dmitri Gray 2018-08-01 09:31:43 -07:00 committed by Andrew Stanton-Nurse
parent e6b5c1c0cc
commit 76706144a5
41 changed files with 1797 additions and 222 deletions

View File

@ -2,6 +2,6 @@
<configuration>
<packageSources>
<clear />
<!-- Restore sources should be defined in build/sources.props. -->
<!-- Restore sources should be defined in build/sources.props -->
</packageSources>
</configuration>

View File

@ -59,5 +59,10 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks
}
throw new InvalidOperationException("Unexpected binder call");
}
public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -26,6 +26,8 @@ namespace ClientSample
RawSample.Register(app);
HubSample.Register(app);
StreamingSample.Register(app);
UploadSample.Register(app);
app.Command("help", cmd =>
{

View File

@ -0,0 +1,46 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.Extensions.CommandLineUtils;
namespace ClientSample
{
internal class StreamingSample
{
internal static void Register(CommandLineApplication app)
{
app.Command("streaming", cmd =>
{
cmd.Description = "Tests a streaming connection to a hub";
var baseUrlArgument = cmd.Argument("<BASEURL>", "The URL to the Chat Hub to test");
cmd.OnExecute(() => ExecuteAsync(baseUrlArgument.Value));
});
}
public static async Task<int> ExecuteAsync(string baseUrl)
{
var connection = new HubConnectionBuilder()
.WithUrl(baseUrl)
.Build();
await connection.StartAsync();
var reader = await connection.StreamAsChannelAsync<int>("ChannelCounter", 10, 2000);
while (await reader.WaitToReadAsync())
{
while (reader.TryRead(out var item))
{
Console.WriteLine($"received: {item}");
}
}
return 0;
}
}
}

View File

@ -0,0 +1,90 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.Collections.Generic;
using System.Diagnostics;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Client;
using Microsoft.Extensions.CommandLineUtils;
namespace ClientSample
{
internal class UploadSample
{
internal static void Register(CommandLineApplication app)
{
app.Command("uploading", cmd =>
{
cmd.Description = "Tests a streaming invocation from client to hub";
var baseUrlArgument = cmd.Argument("<BASEURL>", "The URL to the Chat Hub to test");
cmd.OnExecute(() => ExecuteAsync(baseUrlArgument.Value));
});
}
public static async Task<int> ExecuteAsync(string baseUrl)
{
var connection = new HubConnectionBuilder()
.WithUrl(baseUrl)
.Build();
await connection.StartAsync();
await BasicInvoke(connection);
//await MultiParamInvoke(connection);
//await AdditionalArgs(connection);
return 0;
}
public static async Task BasicInvoke(HubConnection connection)
{
var channel = Channel.CreateUnbounded<string>();
var invokeTask = connection.InvokeAsync<string>("UploadWord", channel.Reader);
foreach (var c in "hello")
{
await channel.Writer.WriteAsync(c.ToString());
}
channel.Writer.TryComplete();
var result = await invokeTask;
Debug.WriteLine($"You message was: {result}");
}
private static async Task WriteStreamAsync<T>(IEnumerable<T> sequence, ChannelWriter<T> writer)
{
foreach (T element in sequence)
{
await writer.WriteAsync(element);
await Task.Delay(100);
}
writer.TryComplete();
}
public static async Task MultiParamInvoke(HubConnection connection)
{
var letters = Channel.CreateUnbounded<string>();
var numbers = Channel.CreateUnbounded<int>();
_ = WriteStreamAsync(new[] { "h", "i", "!" }, letters.Writer);
_ = WriteStreamAsync(new[] { 1, 2, 3, 4, 5 }, numbers.Writer);
var result = await connection.InvokeAsync<string>("DoubleStreamUpload", letters.Reader, numbers.Reader);
Debug.WriteLine(result);
}
public static async Task AdditionalArgs(HubConnection connection)
{
var channel = Channel.CreateUnbounded<char>();
_ = WriteStreamAsync<char>("main message".ToCharArray(), channel.Writer);
var result = await connection.InvokeAsync<string>("UploadWithSuffix", channel.Reader, " + wooh I'm a suffix");
Debug.WriteLine($"Your message was: {result}");
}
}
}

View File

@ -0,0 +1,110 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR;
namespace SignalRSamples.Hubs
{
public class UploadHub : Hub
{
public async Task<string> DoubleStreamUpload(ChannelReader<string> letters, ChannelReader<int> numbers)
{
var total = await Sum(numbers);
var word = await UploadWord(letters);
return string.Format("You sent over <{0}> <{1}s>", total, word);
}
public async Task<int> Sum(ChannelReader<int> source)
{
var total = 0;
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
total += item;
}
}
return total;
}
public async Task LocalSum(ChannelReader<int> source)
{
var total = 0;
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
total += item;
}
}
Debug.WriteLine(String.Format("Complete, your total is <{0}>.", total));
}
public async Task<string> UploadWord(ChannelReader<string> source)
{
var sb = new StringBuilder();
// receiving a StreamCompleteMessage should cause this WaitToRead to return false
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
Debug.WriteLine($"received: {item}");
Console.WriteLine($"received: {item}");
sb.Append(item);
}
}
// method returns, somewhere else returns a CompletionMessage with any errors
return sb.ToString();
}
public async Task<string> UploadWithSuffix(ChannelReader<string> source, string suffix)
{
var sb = new StringBuilder();
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
await Task.Delay(50);
Debug.WriteLine($"received: {item}");
sb.Append(item);
}
}
sb.Append(suffix);
return sb.ToString();
}
public async Task<string> UploadFile(ChannelReader<byte[]> source, string filepath)
{
var result = Enumerable.Empty<byte>();
int chunk = 1;
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
Debug.WriteLine($"received chunk #{chunk++}");
result = result.Concat(item); // atrocious
await Task.Delay(50);
}
}
File.WriteAllBytes(filepath, result.ToArray());
Debug.WriteLine("returning status code");
return $"file written to '{filepath}'";
}
}
}

View File

@ -60,6 +60,7 @@ namespace SignalRSamples
routes.MapHub<DynamicChat>("/dynamic");
routes.MapHub<Chat>("/default");
routes.MapHub<Streaming>("/streaming");
routes.MapHub<UploadHub>("/uploading");
routes.MapHub<HubTChat>("/hubT");
});

View File

@ -0,0 +1,39 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Channels;
namespace Microsoft.AspNetCore.SignalR
{
internal static class ReflectionHelper
{
public static bool IsStreamingType(Type type)
{
// IMPORTANT !!
// All valid types must be generic
// because HubConnectionContext gets the generic argument and uses it to determine the expected item type of the stream
// The long-term solution is making a (streaming type => expected item type) method.
if (!type.IsGenericType)
{
return false;
}
// walk up inheritance chain, until parent is either null or a ChannelReader<T>
// TODO #2594 - add Streams here, to make sending files easy
while (type != null)
{
if (type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
{
return true;
}
type = type.BaseType;
}
return false;
}
}
}

View File

@ -186,6 +186,18 @@ namespace Microsoft.AspNetCore.SignalR.Client
private static readonly Action<ILogger, Exception> _unableToAcquireConnectionLockForPing =
LoggerMessage.Define(LogLevel.Trace, new EventId(62, "UnableToAcquireConnectionLockForPing"), "Skipping ping because a send is already in progress.");
private static readonly Action<ILogger, string, Exception> _startingStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(63, "StartingStream"), "Initiating stream '{StreamId}'.");
private static readonly Action<ILogger, string, Exception> _sendingStreamItem =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(64, "StreamItemSent"), "Sending item for stream '{StreamId}'.");
private static readonly Action<ILogger, string, Exception> _cancelingStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(65, "CancelingStream"), "Stream '{StreamId}' has been canceled by client.");
private static readonly Action<ILogger, string, Exception> _completingStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(66, "CompletingStream"), "Sending completion message for stream '{StreamId}'.");
public static void PreparingNonBlockingInvocation(ILogger logger, string target, int count)
{
_preparingNonBlockingInvocation(logger, target, count, null);
@ -496,6 +508,26 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
_unableToAcquireConnectionLockForPing(logger, null);
}
public static void StartingStream(ILogger logger, string streamId)
{
_startingStream(logger, streamId, null);
}
public static void SendingStreamItem(ILogger logger, string streamId)
{
_sendingStreamItem(logger, streamId, null);
}
public static void CancelingStream(ILogger logger, string streamId)
{
_cancelingStream(logger, streamId, null);
}
public static void CompletingStream(ILogger logger, string streamId)
{
_completingStream(logger, streamId, null);
}
}
}
}

View File

@ -7,6 +7,8 @@ using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Channels;
@ -37,6 +39,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
// This lock protects the connection state.
private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);
private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));
// Persistent across all connections
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
@ -44,10 +48,13 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly IServiceProvider _serviceProvider;
private readonly IConnectionFactory _connectionFactory;
private readonly ConcurrentDictionary<string, InvocationHandlerList> _handlers = new ConcurrentDictionary<string, InvocationHandlerList>(StringComparer.Ordinal);
private long _nextActivationServerTimeout;
private long _nextActivationSendPing;
private bool _disposed;
private CancellationToken _uploadStreamToken;
private readonly ConnectionLogScope _logScope;
// Transient state to a connection
@ -419,6 +426,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
CheckDisposed();
CheckConnectionActive(nameof(StreamAsChannelCoreAsync));
// I just want an excuse to use 'irq' as a variable name...
var irq = InvocationRequest.Stream(cancellationToken, returnType, _connectionState.GetNextId(), _loggerFactory, this, out channel);
await InvokeStreamCore(methodName, irq, args, cancellationToken);
@ -435,9 +443,84 @@ namespace Microsoft.AspNetCore.SignalR.Client
return channel;
}
private Dictionary<string, object> PackageStreamingParams(object[] args)
{
// lazy initialized, to avoid allocation unecessary dictionaries
Dictionary<string, object> readers = null;
for (var i = 0; i < args.Length; i++)
{
if (ReflectionHelper.IsStreamingType(args[i].GetType()))
{
if (readers == null)
{
readers = new Dictionary<string, object>();
}
var id = _connectionState.GetNextStreamId();
readers[id] = args[i];
args[i] = new StreamPlaceholder(id);
Log.StartingStream(_logger, id);
}
}
return readers;
}
private void LaunchStreams(Dictionary<string, object> readers, CancellationToken cancellationToken)
{
if (readers == null)
{
// if there were no streaming parameters then readers is never initialized
return;
}
foreach (var kvp in readers)
{
var reader = kvp.Value;
// For each stream that needs to be sent, run a "send items" task in the background.
// This reads from the channel, attaches streamId, and sends to server.
// A single background thread here quickly gets messy.
_ = _sendStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetGenericArguments())
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
}
}
// this is called via reflection using the `_sendStreamItems` field
private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
{
Log.StartingStream(_logger, streamId);
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;
string responseError = null;
try
{
while (await reader.WaitToReadAsync(combinedToken))
{
while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
{
await SendWithLock(new StreamDataMessage(streamId, item));
Log.SendingStreamItem(_logger, streamId);
}
}
}
catch (OperationCanceledException)
{
Log.CancelingStream(_logger, streamId);
responseError = $"Stream canceled by client.";
}
Log.CompletingStream(_logger, streamId);
await SendWithLock(new StreamCompleteMessage(streamId, responseError));
}
private async Task<object> InvokeCoreAsyncCore(string methodName, Type returnType, object[] args, CancellationToken cancellationToken)
{
var readers = PackageStreamingParams(args);
CheckDisposed();
await WaitConnectionLockAsync();
@ -455,21 +538,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
ReleaseConnectionLock();
}
// Wait for this outside the lock, because it won't complete until the server responds.
LaunchStreams(readers, cancellationToken);
// Wait for this outside the lock, because it won't complete until the server responds
return await invocationTask;
}
private async Task InvokeCore(string methodName, InvocationRequest irq, object[] args, CancellationToken cancellationToken)
{
AssertConnectionValid();
Log.PreparingBlockingInvocation(_logger, irq.InvocationId, methodName, irq.ResultType.FullName, args.Length);
// Client invocations are always blocking
var invocationMessage = new InvocationMessage(irq.InvocationId, methodName, args);
Log.RegisteringInvocation(_logger, invocationMessage.InvocationId);
_connectionState.AddInvocation(irq);
// Trace the full invocation
@ -495,7 +577,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
var invocationMessage = new StreamInvocationMessage(irq.InvocationId, methodName, args);
// I just want an excuse to use 'irq' as a variable name...
Log.RegisteringInvocation(_logger, invocationMessage.InvocationId);
_connectionState.AddInvocation(irq);
@ -525,28 +606,33 @@ namespace Microsoft.AspNetCore.SignalR.Client
// REVIEW: If a token is passed in and is canceled during FlushAsync it seems to break .Complete()...
await _connectionState.Connection.Transport.Output.FlushAsync();
Log.MessageSent(_logger, hubMessage);
// We've sent a message, so don't ping for a while
ResetSendPing();
Log.MessageSent(_logger, hubMessage);
}
private async Task SendCoreAsyncCore(string methodName, object[] args, CancellationToken cancellationToken)
{
CheckDisposed();
var readers = PackageStreamingParams(args);
Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length);
var invocationMessage = new InvocationMessage(null, methodName, args);
await SendWithLock(invocationMessage, callerName: nameof(SendCoreAsync));
LaunchStreams(readers, cancellationToken);
}
private async Task SendWithLock(HubMessage message, CancellationToken cancellationToken = default, [CallerMemberName] string callerName = "")
{
CheckDisposed();
await WaitConnectionLockAsync();
try
{
CheckConnectionActive(callerName);
CheckDisposed();
CheckConnectionActive(nameof(SendCoreAsync));
Log.PreparingNonBlockingInvocation(_logger, methodName, args.Length);
var invocationMessage = new InvocationMessage(null, methodName, args);
await SendHubMessage(invocationMessage, cancellationToken);
await SendHubMessage(message, cancellationToken);
}
finally
{
@ -575,15 +661,15 @@ namespace Microsoft.AspNetCore.SignalR.Client
if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq))
{
Log.DroppedCompletionMessage(_logger, completion.InvocationId);
break;
}
else
{
DispatchInvocationCompletion(completion, irq);
irq.Dispose();
}
DispatchInvocationCompletion(completion, irq);
irq.Dispose();
break;
case StreamItemMessage streamItem:
// Complete the invocation with an error, we don't support streaming (yet)
// if there's no open StreamInvocation with the given id, then complete with an error
if (!connectionState.TryGetInvocation(streamItem.InvocationId, out irq))
{
Log.DroppedStreamMessage(_logger, streamItem.InvocationId);
@ -767,6 +853,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
var timer = new TimerAwaitable(TickRate, TickRate);
_ = TimerLoop(timer);
var uploadStreamSource = new CancellationTokenSource();
_uploadStreamToken = uploadStreamSource.Token;
try
{
while (true)
@ -834,6 +923,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
finally
{
timer.Stop();
uploadStreamSource.Cancel();
}
// Clear the connectionState field
@ -916,11 +1006,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
private void OnServerTimeout()
{
if (Debugger.IsAttached)
{
return;
}
_connectionState.CloseException = new TimeoutException(
$"Server timeout ({ServerTimeout.TotalMilliseconds:0.00}ms) elapsed without receiving a message from the server.");
_connectionState.Connection.Transport.Input.CancelPendingRead();
@ -1104,7 +1189,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
private TaskCompletionSource<object> _stopTcs;
private readonly object _lock = new object();
private readonly Dictionary<string, InvocationRequest> _pendingCalls = new Dictionary<string, InvocationRequest>(StringComparer.Ordinal);
private int _nextId;
private int _nextInvocationId;
private int _nextStreamId;
public ConnectionContext Connection { get; }
public Task ReceiveTask { get; set; }
@ -1125,7 +1211,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
Connection = connection;
}
public string GetNextId() => Interlocked.Increment(ref _nextId).ToString(CultureInfo.InvariantCulture);
public string GetNextId() => Interlocked.Increment(ref _nextInvocationId).ToString(CultureInfo.InvariantCulture);
public string GetNextStreamId() => Interlocked.Increment(ref _nextStreamId).ToString(CultureInfo.InvariantCulture);
public void AddInvocation(InvocationRequest irq)
{
@ -1232,6 +1319,18 @@ namespace Microsoft.AspNetCore.SignalR.Client
return irq.ResultType;
}
Type IInvocationBinder.GetStreamItemType(string invocationId)
{
// previously, streaming was only server->client, and used GetReturnType for StreamItems
// literally the same code as the above method
if (!TryGetInvocation(invocationId, out var irq))
{
Log.ReceivedUnexpectedResponse(_hubConnection._logger, invocationId);
return null;
}
return irq.ResultType;
}
IReadOnlyList<Type> IInvocationBinder.GetParameterTypes(string methodName)
{
if (!_hubConnection._handlers.TryGetValue(methodName, out var invocationHandlerList))

View File

@ -10,6 +10,7 @@
<Compile Include="..\Common\AwaitableThreadPool.cs" Link="AwaitableThreadPool.cs" />
<Compile Include="..\Common\ForceAsyncAwaiter.cs" Link="ForceAsyncAwaiter.cs" />
<Compile Include="..\Common\PipeWriterStream.cs" Link="PipeWriterStream.cs" />
<Compile Include="..\Common\ReflectionHelper.cs" Link="ReflectionHelper.cs" />
<Compile Include="..\Common\TimerAwaitable.cs" Link="Internal\TimerAwaitable.cs" />
</ItemGroup>

View File

@ -10,5 +10,6 @@ namespace Microsoft.AspNetCore.SignalR
{
Type GetReturnType(string invocationId);
IReadOnlyList<Type> GetParameterTypes(string methodName);
Type GetStreamItemType(string streamId);
}
}

View File

@ -42,5 +42,15 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
/// Represents the close message type.
/// </summary>
public const int CloseMessageType = 7;
/// <summary>
/// Represents the stream complete message type.
/// </summary>
public const int StreamCompleteMessageType = 8;
/// <summary>
/// Same as StreamItemMessage, except
/// </summary>
public const int StreamDataMessageType = 9;
}
}

View File

@ -0,0 +1,37 @@
using System;
using System.Collections.Generic;
using System.Runtime.ExceptionServices;
using System.Text;
namespace Microsoft.AspNetCore.SignalR.Protocol
{
/// <summary>
/// Represents a failure to bind arguments for a StreamDataMessage. This does not represent an actual
/// message that is sent on the wire, it is returned by <see cref="IHubProtocol.TryParseMessage"/>
/// to indicate that a binding failure occurred when parsing a StreamDataMessage. The stream ID is associated
/// so that the error can be sent to the relevant hub method.
/// </summary>
public class StreamBindingFailureMessage : HubMessage
{
/// <summary>
/// Gets the id of the relevant stream
/// </summary>
public string Id { get; }
/// <summary>
/// Gets the exception thrown during binding.
/// </summary>
public ExceptionDispatchInfo BindingFailure { get; }
/// <summary>
/// Initializes a new instance of the <see cref="InvocationBindingFailureMessage"/> class.
/// </summary>
/// <param name="id">The stream ID.</param>
/// <param name="bindingFailure">The exception thrown during binding.</param>
public StreamBindingFailureMessage(string id, ExceptionDispatchInfo bindingFailure)
{
Id = id;
BindingFailure = bindingFailure;
}
}
}

View File

@ -0,0 +1,41 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.SignalR.Protocol
{
/// <summary>
/// A message for indicating that a particular stream has ended.
/// </summary>
public class StreamCompleteMessage : HubMessage
{
/// <summary>
/// Gets the stream id.
/// </summary>
public string StreamId { get; }
/// <summary>
/// Gets the error. Will be null if there is no error.
/// </summary>
public string Error { get; }
/// <summary>
/// Whether the message has an error.
/// </summary>
public bool HasError { get => Error != null; }
/// <summary>
/// Initializes a new instance of <see cref="StreamCompleteMessage"/>
/// </summary>
/// <param name="streamId">The streamId of the stream to complete.</param>
/// <param name="error">An optional error field.</param>
public StreamCompleteMessage(string streamId, string error = null)
{
StreamId = streamId;
Error = error;
}
}
}

View File

@ -0,0 +1,33 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.SignalR.Protocol
{
/// <summary>
/// Sent to parameter streams.
/// Similar to <see cref="StreamItemMessage"/>, except the data is sent to a parameter stream, rather than in response to an invocation.
/// </summary>
public class StreamDataMessage : HubMessage
{
/// <summary>
/// The piece of data this message carries.
/// </summary>
public object Item { get; }
/// <summary>
/// The stream to which to deliver data.
/// </summary>
public string StreamId { get; }
public StreamDataMessage(string streamId, object item)
{
StreamId = streamId;
Item = item;
}
public override string ToString()
{
return $"StreamDataMessage {{ {nameof(StreamId)}: \"{StreamId}\", {nameof(Item)}: {Item ?? "<<null>>"} }}";
}
}
}

View File

@ -0,0 +1,25 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Text;
namespace Microsoft.AspNetCore.SignalR.Protocol
{
/// <summary>
/// Used by protocol serializers/deserializers to transfer information about streaming parameters.
/// Is packed as an argument in the form `{"streamId": "42"}`, and sent over wire.
/// Is then unpacked on the other side, and a new channel is created and saved under the streamId.
/// Then, each <see cref="StreamDataMessage"/> is routed to the appropiate channel based on streamId.
/// </summary>
public class StreamPlaceholder
{
public string StreamId { get; private set; }
public StreamPlaceholder(string streamId)
{
StreamId = streamId;
}
}
}

View File

@ -8,5 +8,10 @@
"TypeId": "public interface Microsoft.AspNetCore.SignalR.Protocol.IHubProtocol",
"MemberId": "System.Int32 get_MinorVersion()",
"Kind": "Addition"
},
{
"TypeId": "public interface Microsoft.AspNetCore.SignalR.IInvocationBinder",
"MemberId": "System.Type GetStreamItemType(System.String streamId)",
"Kind": "Addition"
}
]

View File

@ -21,6 +21,7 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubConnectionContext
{
private StreamTracker _streamTracker;
private static readonly WaitCallback _abortedCallback = AbortConnection;
private readonly ConnectionContext _connectionContext;
@ -54,6 +55,18 @@ namespace Microsoft.AspNetCore.SignalR
_clientTimeoutInterval = clientTimeoutInterval.Ticks;
}
internal StreamTracker StreamTracker
{
get
{
// lazy for performance reasons
if (_streamTracker == null)
{
_streamTracker = new StreamTracker();
}
return _streamTracker;
}
}
/// <summary>
/// Initializes a new instance of the <see cref="HubConnectionContext"/> class.
/// </summary>

View File

@ -186,6 +186,9 @@ namespace Microsoft.AspNetCore.SignalR
{
var input = connection.Input;
var protocol = connection.Protocol;
var binder = new HubConnectionBinder<THub>(_dispatcher, connection);
while (true)
{
var result = await input.ReadAsync();
@ -202,7 +205,7 @@ namespace Microsoft.AspNetCore.SignalR
{
connection.ResetClientTimeout();
while (protocol.TryParseMessage(ref buffer, _dispatcher, out var message))
while (protocol.TryParseMessage(ref buffer, binder, out var message))
{
await _dispatcher.DispatchMessageAsync(connection, message);
}

View File

@ -57,6 +57,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal
private static readonly Action<ILogger, string, Exception> _invalidReturnValueFromStreamingMethod =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(15, "InvalidReturnValueFromStreamingMethod"), "A streaming method returned a value that cannot be used to build enumerator {HubMethod}.");
private static readonly Action<ILogger, string, Exception> _receivedStreamItem =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(16, "ReceivedStreamItem"), "Received item for stream '{StreamId}'.");
private static readonly Action<ILogger, string, Exception> _startingParameterStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(17, "StartingParameterStream"), "Creating streaming parameter channel '{StreamId}'.");
private static readonly Action<ILogger, string, Exception> _completingStream =
LoggerMessage.Define<string>(LogLevel.Trace, new EventId(18, "CompletingStream"), "Stream '{StreamId}' has been completed by client.");
private static readonly Action<ILogger, string, string, Exception> _closingStreamWithBindingError =
LoggerMessage.Define<string, string>(LogLevel.Warning, new EventId(19, "ClosingStreamWithBindingError"), "Stream '{StreamId}' closed with error '{Error}'.");
public static void ReceivedHubInvocation(ILogger logger, InvocationMessage invocationMessage)
{
_receivedHubInvocation(logger, invocationMessage, null);
@ -133,6 +145,26 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
_invalidReturnValueFromStreamingMethod(logger, hubMethod, null);
}
public static void ReceivedStreamItem(ILogger logger, StreamDataMessage message)
{
_receivedStreamItem(logger, message.StreamId, null);
}
public static void StartingParameterStream(ILogger logger, string streamId)
{
_startingParameterStream(logger, streamId, null);
}
public static void CompletingStream(ILogger logger, StreamCompleteMessage message)
{
_completingStream(logger, message.StreamId, null);
}
public static void ClosingStreamWithBindingError(ILogger logger, StreamCompleteMessage message)
{
_closingStreamWithBindingError(logger, message.StreamId, message.Error, null);
}
}
}
}

View File

@ -81,15 +81,18 @@ namespace Microsoft.AspNetCore.SignalR.Internal
switch (hubMessage)
{
case InvocationBindingFailureMessage bindingFailureMessage:
return ProcessBindingFailure(connection, bindingFailureMessage);
return ProcessInvocationBindingFailure(connection, bindingFailureMessage);
case StreamBindingFailureMessage bindingFailureMessage:
return ProcessStreamBindingFailure(connection, bindingFailureMessage);
case InvocationMessage invocationMessage:
Log.ReceivedHubInvocation(_logger, invocationMessage);
return ProcessInvocation(connection, invocationMessage, isStreamedInvocation: false);
return ProcessInvocation(connection, invocationMessage, isStreamResponse: false);
case StreamInvocationMessage streamInvocationMessage:
Log.ReceivedStreamHubInvocation(_logger, streamInvocationMessage);
return ProcessInvocation(connection, streamInvocationMessage, isStreamedInvocation: true);
return ProcessInvocation(connection, streamInvocationMessage, isStreamResponse: true);
case CancelInvocationMessage cancelInvocationMessage:
// Check if there is an associated active stream and cancel it if it exists.
@ -110,6 +113,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
connection.StartClientTimeout();
break;
case StreamDataMessage streamItem:
Log.ReceivedStreamItem(_logger, streamItem);
return ProcessStreamItem(connection, streamItem);
case StreamCompleteMessage streamCompleteMessage:
// closes channels, removes from Lookup dict
// user's method can see the channel is complete and begin wrapping up
Log.CompletingStream(_logger, streamCompleteMessage);
connection.StreamTracker.Complete(streamCompleteMessage);
break;
// Other kind of message we weren't expecting
default:
Log.UnsupportedMessageReceived(_logger, hubMessage.GetType().FullName);
@ -119,30 +133,37 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return Task.CompletedTask;
}
private Task ProcessBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage)
private Task ProcessInvocationBindingFailure(HubConnectionContext connection, InvocationBindingFailureMessage bindingFailureMessage)
{
Log.FailedInvokingHubMethod(_logger, bindingFailureMessage.Target, bindingFailureMessage.BindingFailure.SourceException);
var errorMessage = ErrorMessageHelper.BuildErrorMessage($"Failed to invoke '{bindingFailureMessage.Target}' due to an error on the server.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
return SendInvocationError(bindingFailureMessage.InvocationId, connection, errorMessage);
}
public override Type GetReturnType(string invocationId)
private Task ProcessStreamBindingFailure(HubConnectionContext connection, StreamBindingFailureMessage bindingFailureMessage)
{
return typeof(object);
var errorString = ErrorMessageHelper.BuildErrorMessage(
$"Failed to bind Stream Item arguments to proper type.",
bindingFailureMessage.BindingFailure.SourceException, _enableDetailedErrors);
var message = new StreamCompleteMessage(bindingFailureMessage.Id, errorString);
Log.ClosingStreamWithBindingError(_logger, message);
connection.StreamTracker.Complete(message);
return Task.CompletedTask;
}
public override IReadOnlyList<Type> GetParameterTypes(string methodName)
private Task ProcessStreamItem(HubConnectionContext connection, StreamDataMessage message)
{
if (!_methods.TryGetValue(methodName, out var descriptor))
{
return Type.EmptyTypes;
}
return descriptor.ParameterTypes;
Log.ReceivedStreamItem(_logger, message);
return connection.StreamTracker.ProcessItem(message);
}
private Task ProcessInvocation(HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation)
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse)
{
if (!_methods.TryGetValue(hubMethodInvocationMessage.Target, out var descriptor))
{
@ -153,12 +174,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
else
{
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamedInvocation);
bool isStreamCall = descriptor.HasStreamingParameters;
if (isStreamResponse && isStreamCall)
{
throw new NotSupportedException("Streaming responses for streaming uploads are not supported.");
}
return Invoke(descriptor, connection, hubMethodInvocationMessage, isStreamResponse, isStreamCall);
}
}
private async Task Invoke(HubMethodDescriptor descriptor, HubConnectionContext connection,
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamedInvocation)
HubMethodInvocationMessage hubMethodInvocationMessage, bool isStreamResponse, bool isStreamCall)
{
var methodExecutor = descriptor.MethodExecutor;
@ -176,7 +202,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return;
}
if (!await ValidateInvocationMode(descriptor, isStreamedInvocation, hubMethodInvocationMessage, connection))
if (!await ValidateInvocationMode(descriptor, isStreamResponse, hubMethodInvocationMessage, connection))
{
return;
}
@ -184,33 +210,73 @@ namespace Microsoft.AspNetCore.SignalR.Internal
hubActivator = scope.ServiceProvider.GetRequiredService<IHubActivator<THub>>();
hub = hubActivator.Create();
if (isStreamCall)
{
// swap out placeholders for channels
var args = hubMethodInvocationMessage.Arguments;
for (int i = 0; i < args.Length; i++)
{
var placeholder = args[i] as StreamPlaceholder;
if (placeholder == null)
{
continue;
}
Log.StartingParameterStream(_logger, placeholder.StreamId);
var itemType = methodExecutor.MethodParameters[i].ParameterType.GetGenericArguments()[0];
args[i] = connection.StreamTracker.AddStream(placeholder.StreamId, itemType);
}
}
try
{
InitializeHub(hub, connection);
Task invocation = null;
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
if (isStreamedInvocation)
if (isStreamResponse)
{
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
await SendInvocationError(hubMethodInvocationMessage.InvocationId, connection,
$"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>.");
return;
}
disposeScope = false;
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
// Fire-and-forget stream invocations, otherwise they would block other hub invocations from being able to run
_ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, streamCts);
}
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
{
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
// Send Async, no response expected
invocation = ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
}
else
{
// Invoke Async, one reponse expected
async Task ExecuteInvocation()
{
var result = await ExecuteHubMethod(methodExecutor, hub, hubMethodInvocationMessage.Arguments);
Log.SendingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
await connection.WriteAsync(CompletionMessage.WithResult(hubMethodInvocationMessage.InvocationId, result));
}
invocation = ExecuteInvocation();
}
if (isStreamCall || isStreamResponse)
{
// don't await streaming invocations
// leave them running in the background, allowing dispatcher to process other messages between streaming items
disposeScope = false;
}
else
{
// complete the non-streaming calls now
await invocation;
}
}
catch (TargetInvocationException ex)
@ -236,7 +302,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope, IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts)
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, IServiceScope scope,
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts)
{
string error = null;
@ -424,5 +491,15 @@ namespace Microsoft.AspNetCore.SignalR.Internal
Log.HubMethodBound(_logger, hubName, methodName);
}
}
public override IReadOnlyList<Type> GetParameterTypes(string methodName)
{
if (!_methods.TryGetValue(methodName, out var descriptor))
{
return Type.EmptyTypes;
}
return descriptor.ParameterTypes;
}
}
}

View File

@ -0,0 +1,36 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.SignalR.Internal;
namespace Microsoft.AspNetCore.SignalR.Internal
{
internal class HubConnectionBinder<THub> : IInvocationBinder where THub : Hub
{
private HubDispatcher<THub> _dispatcher;
private HubConnectionContext _connection;
public HubConnectionBinder(HubDispatcher<THub> dispatcher, HubConnectionContext connection)
{
_dispatcher = dispatcher;
_connection = connection;
}
public IReadOnlyList<Type> GetParameterTypes(string methodName)
{
return _dispatcher.GetParameterTypes(methodName);
}
public Type GetReturnType(string invocationId)
{
return typeof(object);
}
public Type GetStreamItemType(string streamId)
{
return _connection.StreamTracker.GetStreamItemType(streamId);
}
}
}

View File

@ -8,12 +8,11 @@ using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR.Internal
{
public abstract class HubDispatcher<THub> : IInvocationBinder where THub : Hub
public abstract class HubDispatcher<THub> where THub : Hub
{
public abstract Task OnConnectedAsync(HubConnectionContext connection);
public abstract Task OnDisconnectedAsync(HubConnectionContext connection, Exception exception);
public abstract Task DispatchMessageAsync(HubConnectionContext connection, HubMessage hubMessage);
public abstract IReadOnlyList<Type> GetParameterTypes(string methodName);
public abstract Type GetReturnType(string invocationId);
public abstract IReadOnlyList<Type> GetParameterTypes(string name);
}
}

View File

@ -9,6 +9,7 @@ using System.Reflection;
using System.Threading;
using System.Threading.Channels;
using Microsoft.AspNetCore.Authorization;
using Microsoft.AspNetCore.SignalR.Protocol;
using Microsoft.Extensions.Internal;
namespace Microsoft.AspNetCore.SignalR.Internal
@ -22,7 +23,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public HubMethodDescriptor(ObjectMethodExecutor methodExecutor, IEnumerable<IAuthorizeData> policies)
{
MethodExecutor = methodExecutor;
ParameterTypes = methodExecutor.MethodParameters.Select(p => p.ParameterType).ToArray();
ParameterTypes = methodExecutor.MethodParameters.Select(GetParameterType).ToArray();
Policies = policies.ToArray();
NonAsyncReturnType = (MethodExecutor.IsMethodAsync)
@ -36,6 +37,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
public bool HasStreamingParameters { get; private set; }
private Func<object, CancellationToken, IAsyncEnumerator<object>> _convertToEnumerator;
public ObjectMethodExecutor MethodExecutor { get; }
@ -52,6 +55,17 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IList<IAuthorizeData> Policies { get; }
private Type GetParameterType(ParameterInfo p)
{
var type = p.ParameterType;
if (ReflectionHelper.IsStreamingType(type))
{
HasStreamingParameters = true;
return typeof(StreamPlaceholder);
}
return type;
}
private static bool IsChannelType(Type type, out Type payloadType)
{
var channelType = type.AllBaseTypes().FirstOrDefault(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(ChannelReader<>));

View File

@ -6,6 +6,10 @@
<RootNamespace>Microsoft.AspNetCore.SignalR</RootNamespace>
</PropertyGroup>
<ItemGroup>
<Compile Include="..\Common\ReflectionHelper.cs" Link="ReflectionHelper.cs" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Common\Microsoft.AspNetCore.SignalR.Common.csproj" />
<ProjectReference Include="..\Microsoft.AspNetCore.SignalR.Protocols.Json\Microsoft.AspNetCore.SignalR.Protocols.Json.csproj" />

View File

@ -0,0 +1,105 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.SignalR.Protocol;
namespace Microsoft.AspNetCore.SignalR
{
internal class StreamTracker
{
private static readonly MethodInfo _buildConverterMethod = typeof(StreamTracker).GetMethods(BindingFlags.NonPublic | BindingFlags.Static).Single(m => m.Name.Equals("BuildStream"));
private ConcurrentDictionary<string, IStreamConverter> _lookup = new ConcurrentDictionary<string, IStreamConverter>();
/// <summary>
/// Creates a new stream and returns the ChannelReader for it as an object.
/// </summary>
public object AddStream(string streamId, Type itemType)
{
var newConverter = (IStreamConverter)_buildConverterMethod.MakeGenericMethod(itemType).Invoke(null, Array.Empty<object>());
_lookup[streamId] = newConverter;
return newConverter.GetReaderAsObject();
}
private IStreamConverter TryGetConverter(string streamId)
{
if (_lookup.TryGetValue(streamId, out var converter))
{
return converter;
}
else
{
throw new KeyNotFoundException($"No stream with id '{streamId}' could be found.");
}
}
public Task ProcessItem(StreamDataMessage message)
{
return TryGetConverter(message.StreamId).WriteToStream(message.Item);
}
public Type GetStreamItemType(string streamId)
{
return TryGetConverter(streamId).GetItemType();
}
public void Complete(StreamCompleteMessage message)
{
_lookup.TryRemove(message.StreamId, out var converter);
if (converter == null)
{
throw new KeyNotFoundException($"No stream with id '{message.StreamId}' could be found.");
}
converter.TryComplete(message.HasError ? new Exception(message.Error) : null);
}
private static IStreamConverter BuildStream<T>()
{
return new ChannelConverter<T>();
}
private interface IStreamConverter
{
Type GetItemType();
object GetReaderAsObject();
Task WriteToStream(object item);
void TryComplete(Exception ex);
}
private class ChannelConverter<T> : IStreamConverter
{
private Channel<T> _channel;
public ChannelConverter()
{
_channel = Channel.CreateUnbounded<T>();
}
public Type GetItemType()
{
return typeof(T);
}
public object GetReaderAsObject()
{
return _channel.Reader;
}
public Task WriteToStream(object o)
{
return _channel.Writer.WriteAsync((T)o).AsTask();
}
public void TryComplete(Exception ex)
{
_channel.Writer.TryComplete(ex);
}
}
}
}

View File

@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<Description>Implements the SignalR Hub Protocol over JSON.</Description>

View File

@ -24,6 +24,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
private const string ResultPropertyName = "result";
private const string ItemPropertyName = "item";
private const string InvocationIdPropertyName = "invocationId";
private const string StreamIdPropertyName = "streamId";
private const string TypePropertyName = "type";
private const string ErrorPropertyName = "error";
private const string TargetPropertyName = "target";
@ -119,6 +120,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
int? type = null;
string invocationId = null;
string streamId = null;
string target = null;
string error = null;
var hasItem = false;
@ -165,6 +167,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
case InvocationIdPropertyName:
invocationId = JsonUtils.ReadAsString(reader, InvocationIdPropertyName);
break;
case StreamIdPropertyName:
streamId = JsonUtils.ReadAsString(reader, StreamIdPropertyName);
break;
case TargetPropertyName:
target = JsonUtils.ReadAsString(reader, TargetPropertyName);
break;
@ -199,15 +204,32 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
hasItem = true;
if (string.IsNullOrEmpty(invocationId))
string id = null;
if (!string.IsNullOrEmpty(invocationId))
{
// If we don't have an invocation id then we need to store it as a JToken so we can parse it later
itemToken = JToken.Load(reader);
id = invocationId;
}
else if (!string.IsNullOrEmpty(streamId))
{
id = streamId;
}
else
{
var returnType = binder.GetReturnType(invocationId);
item = PayloadSerializer.Deserialize(reader, returnType);
// If we don't have an id yetmthen we need to store it as a JToken to parse later
itemToken = JToken.Load(reader);
break;
}
Type itemType = binder.GetStreamItemType(id);
try
{
item = PayloadSerializer.Deserialize(reader, itemType);
}
catch (JsonSerializationException ex)
{
return new StreamBindingFailureMessage(id, ExceptionDispatchInfo.Capture(ex));
}
break;
case ArgumentsPropertyName:
@ -313,11 +335,33 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
: BindStreamInvocationMessage(invocationId, target, arguments, hasArguments, binder);
}
break;
case HubProtocolConstants.StreamDataMessageType:
if (itemToken != null)
{
var itemType = binder.GetStreamItemType(streamId);
try
{
item = itemToken.ToObject(itemType, PayloadSerializer);
}
catch (JsonSerializationException ex)
{
return new StreamBindingFailureMessage(streamId, ExceptionDispatchInfo.Capture(ex));
}
}
message = BindParamStreamMessage(streamId, item, hasItem, binder);
break;
case HubProtocolConstants.StreamItemMessageType:
if (itemToken != null)
{
var returnType = binder.GetReturnType(invocationId);
item = itemToken.ToObject(returnType, PayloadSerializer);
var returnType = binder.GetStreamItemType(invocationId);
try
{
item = itemToken.ToObject(returnType, PayloadSerializer);
}
catch (JsonSerializationException ex)
{
return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
};
}
message = BindStreamItemMessage(invocationId, item, hasItem, binder);
@ -338,6 +382,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return PingMessage.Instance;
case HubProtocolConstants.CloseMessageType:
return BindCloseMessage(error);
case HubProtocolConstants.StreamCompleteMessageType:
message = BindStreamCompleteMessage(streamId, error);
break;
case null:
throw new InvalidDataException($"Missing required property '{TypePropertyName}'.");
default:
@ -408,6 +455,10 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
WriteHeaders(writer, m);
WriteStreamInvocationMessage(m, writer);
break;
case StreamDataMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamDataMessageType);
WriteStreamDataMessage(m, writer);
break;
case StreamItemMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamItemMessageType);
WriteHeaders(writer, m);
@ -430,6 +481,10 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
WriteMessageType(writer, HubProtocolConstants.CloseMessageType);
WriteCloseMessage(m, writer);
break;
case StreamCompleteMessage m:
WriteMessageType(writer, HubProtocolConstants.StreamCompleteMessageType);
WriteStreamCompleteMessage(m, writer);
break;
default:
throw new InvalidOperationException($"Unsupported message type: {message.GetType().FullName}");
}
@ -478,6 +533,18 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
WriteInvocationId(message, writer);
}
private void WriteStreamCompleteMessage(StreamCompleteMessage message, JsonTextWriter writer)
{
writer.WritePropertyName(StreamIdPropertyName);
writer.WriteValue(message.StreamId);
if (message.Error != null)
{
writer.WritePropertyName(ErrorPropertyName);
writer.WriteValue(message.Error);
}
}
private void WriteStreamItemMessage(StreamItemMessage message, JsonTextWriter writer)
{
WriteInvocationId(message, writer);
@ -485,6 +552,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
PayloadSerializer.Serialize(writer, message.Item);
}
private void WriteStreamDataMessage(StreamDataMessage message, JsonTextWriter writer)
{
writer.WritePropertyName(StreamIdPropertyName);
writer.WriteValue(message.StreamId);
writer.WritePropertyName(ItemPropertyName);
PayloadSerializer.Serialize(writer, message.Item);
}
private void WriteInvocationMessage(InvocationMessage message, JsonTextWriter writer)
{
WriteInvocationId(message, writer);
@ -548,6 +623,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return new CancelInvocationMessage(invocationId);
}
private HubMessage BindStreamCompleteMessage(string streamId, string error)
{
if (string.IsNullOrEmpty(streamId))
{
throw new InvalidDataException($"Missing required property '{StreamIdPropertyName}'.");
}
// note : if the stream completes normally, the error should be `null`
return new StreamCompleteMessage(streamId, error);
}
private HubMessage BindCompletionMessage(string invocationId, string error, object result, bool hasResult, IInvocationBinder binder)
{
if (string.IsNullOrEmpty(invocationId))
@ -568,6 +654,20 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return new CompletionMessage(invocationId, error, result: null, hasResult: false);
}
private HubMessage BindParamStreamMessage(string streamId, object item, bool hasItem, IInvocationBinder binder)
{
if (string.IsNullOrEmpty(streamId))
{
throw new InvalidDataException($"Missing required property '{StreamIdPropertyName}");
}
if (!hasItem)
{
throw new InvalidDataException($"Missing required property '{ItemPropertyName}");
}
return new StreamDataMessage(streamId, item);
}
private HubMessage BindStreamItemMessage(string invocationId, object item, bool hasItem, IInvocationBinder binder)
{
if (string.IsNullOrEmpty(invocationId))
@ -658,7 +758,6 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
{
if (paramIndex < paramCount)
{
// Set all known arguments
arguments[paramIndex] = PayloadSerializer.Deserialize(reader, paramTypes[paramIndex]);
}
else

View File

@ -121,7 +121,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver)
{
_ = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize);
MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize);
startOffset += readSize;
var messageType = ReadInt32(input, ref startOffset, "messageType");
@ -142,6 +142,8 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return PingMessage.Instance;
case HubProtocolConstants.CloseMessageType:
return CreateCloseMessage(input, ref startOffset);
case HubProtocolConstants.StreamCompleteMessageType:
return CreateStreamCompleteMessage(input, ref startOffset);
default:
// Future protocol changes can add message types, old clients can ignore them
return null;
@ -179,6 +181,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var target = ReadString(input, ref offset, "target");
var parameterTypes = binder.GetParameterTypes(target);
try
@ -196,7 +199,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var itemType = binder.GetReturnType(invocationId);
var itemType = binder.GetStreamItemType(invocationId);
var value = DeserializeObject(input, ref offset, itemType, "item", resolver);
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
}
@ -244,6 +247,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return new CloseMessage(error);
}
private static StreamCompleteMessage CreateStreamCompleteMessage(byte[] input, ref int offset)
{
var streamId = ReadString(input, ref offset, "streamId");
var error = ReadString(input, ref offset, "error");
if (string.IsNullOrEmpty(error))
{
error = null;
}
return new StreamCompleteMessage(streamId, error);
}
private static Dictionary<string, string> ReadHeaders(byte[] input, ref int offset)
{
var headerCount = ReadMapLength(input, ref offset, "headers");
@ -376,6 +390,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
case CloseMessage closeMessage:
WriteCloseMessage(closeMessage, packer);
break;
case StreamCompleteMessage m:
WriteStreamCompleteMessage(m, packer);
break;
default:
throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}");
}
@ -469,6 +486,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
MessagePackBinary.WriteString(packer, message.InvocationId);
}
private void WriteStreamCompleteMessage(StreamCompleteMessage message, Stream packer)
{
MessagePackBinary.WriteArrayHeader(packer, 3);
MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamCompleteMessageType);
MessagePackBinary.WriteString(packer, message.StreamId);
if (message.HasError)
{
MessagePackBinary.WriteString(packer, message.Error);
}
else
{
MessagePackBinary.WriteNil(packer);
}
}
private void WriteCloseMessage(CloseMessage message, Stream packer)
{
MessagePackBinary.WriteArrayHeader(packer, 2);

View File

@ -3,6 +3,9 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.SignalR.Protocol;
@ -145,6 +148,227 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
[Fact]
public async Task StreamIntsToServer()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<int>("SomeMethod", channel.Reader);
var invocation = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]);
Assert.Equal("SomeMethod", invocation["target"]);
var streamId = invocation["arguments"][0]["streamId"];
foreach (var number in new[] { 42, 43, 322, 3145, -1234 })
{
await channel.Writer.WriteAsync(number).AsTask().OrTimeout();
var item = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamDataMessageType, item["type"]);
Assert.Equal(number, item["item"]);
Assert.Equal(streamId, item["streamId"]);
}
channel.Writer.TryComplete();
var completion = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, completion["type"]);
await connection.ReceiveJsonMessage(
new { type = HubProtocolConstants.CompletionMessageType, invocationId = invocation["invocationId"], result = 42 }
).OrTimeout();
var result = await invokeTask.OrTimeout();
Assert.Equal(42, result);
}
}
[Fact]
public async Task StreamIntsToServerViaSend()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
var channel = Channel.CreateUnbounded<int>();
var sendTask = hubConnection.SendAsync("SomeMethod", channel.Reader);
var invocation = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]);
Assert.Equal("SomeMethod", invocation["target"]);
Assert.Null(invocation["invocationId"]);
var streamId = invocation["arguments"][0]["streamId"];
foreach (var item in new[] { 2, 3, 10, 5 })
{
await channel.Writer.WriteAsync(item);
var received = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamDataMessageType, received["type"]);
Assert.Equal(item, received["item"]);
Assert.Equal(streamId, received["streamId"]);
}
}
}
[Fact]
public async Task StreamsObjectsToServer()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
var channel = Channel.CreateUnbounded<object>();
var invokeTask = hubConnection.InvokeAsync<SampleObject>("UploadMethod", channel.Reader);
var invocation = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]);
Assert.Equal("UploadMethod", invocation["target"]);
var id = invocation["invocationId"];
var items = new[] { new SampleObject("ab", 12), new SampleObject("ef", 23) };
foreach (var item in items)
{
await channel.Writer.WriteAsync(item);
var received = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamDataMessageType, received["type"]);
Assert.Equal(item.Foo, received["item"]["foo"]);
Assert.Equal(item.Bar, received["item"]["bar"]);
}
channel.Writer.TryComplete();
var completion = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, completion["type"]);
var expected = new SampleObject("oof", 14);
await connection.ReceiveJsonMessage(
new { type = HubProtocolConstants.CompletionMessageType, invocationId = id, result = expected }
).OrTimeout();
var result = await invokeTask.OrTimeout();
Assert.Equal(expected.Foo, result.Foo);
Assert.Equal(expected.Bar, result.Bar);
}
}
[Fact]
public async Task UploadStreamCancelationSendsStreamComplete()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
var cts = new CancellationTokenSource();
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader, cts.Token);
var invokeMessage = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]);
cts.Cancel();
// after cancellation, don't send from the pipe
foreach (var number in new[] { 42, 43, 322, 3145, -1234 })
{
await channel.Writer.WriteAsync(number);
}
// the next sent message should be a completion message
var complete = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.StreamCompleteMessageType, complete["type"]);
Assert.EndsWith("canceled by client.", ((string)complete["error"]));
}
}
[Fact]
public async Task InvocationCanCompleteBeforeStreamCompletes()
{
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader);
var invocation = await connection.ReadSentJsonAsync().OrTimeout();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]);
var id = invocation["invocationId"];
await connection.ReceiveJsonMessage(new { type = HubProtocolConstants.CompletionMessageType, invocationId = id, result = 10 });
var result = await invokeTask.OrTimeout();
Assert.Equal(10L, result);
// after the server returns, with whatever response
// the client's behavior is undefined, and the server is responsible for ignoring stray messages
}
}
[Fact]
public async Task WrongTypeOnServerResponse()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HubConnection).FullName &&
(writeContext.EventId.Name == "ServerDisconnectedWithError"
|| writeContext.EventId.Name == "ShutdownWithError");
}
using (StartVerifiableLog(out var loggerFactory, LogLevel.Trace, expectedErrorsFilter: ExpectedErrors))
{
var connection = new TestConnection();
var hubConnection = CreateHubConnection(connection, loggerFactory: loggerFactory);
await hubConnection.StartAsync().OrTimeout();
// we expect to get sent ints, and receive an int back
var channel = Channel.CreateUnbounded<int>();
var invokeTask = hubConnection.InvokeAsync<int>("SumInts", channel.Reader);
var invocation = await connection.ReadSentJsonAsync();
Assert.Equal(HubProtocolConstants.InvocationMessageType, invocation["type"]);
var id = invocation["invocationId"];
await channel.Writer.WriteAsync(5);
await channel.Writer.WriteAsync(10);
await connection.ReceiveJsonMessage(new { type = HubProtocolConstants.CompletionMessageType, invocationId = id, result = "humbug" });
try
{
await invokeTask;
}
catch (Exception ex)
{
Assert.Equal(typeof(Newtonsoft.Json.JsonSerializationException), ex.GetType());
}
}
}
private class SampleObject
{
public SampleObject(string foo, int bar)
{
Foo = foo;
Bar = bar;
}
public string Foo { get; private set; }
public int Bar { get; private set; }
}
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
private class MockHubProtocol : IHubProtocol
{

View File

@ -163,6 +163,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
}
}
public async Task<JObject> ReadSentJsonAsync()
{
return JObject.Parse(await ReadSentTextMessageAsync());
}
public async Task<IList<string>> ReadAllSentMessagesAsync(bool ignorePings = true)
{
if (!Disposed.IsCompleted)

View File

@ -37,5 +37,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
arg is StreamItemMessage ||
arg is StreamInvocationMessage;
}
public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
}

View File

@ -40,6 +40,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueIgnore", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"byteArrProp\":\"AQID\"}]}"),
new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueIgnoreAndNoCamelCase", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), false, NullValueHandling.Include, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"StringProp\":\"SignalR!\",\"DoubleProp\":6.2831853071,\"IntProp\":42,\"DateTimeProp\":\"2017-04-11T00:00:00Z\",\"NullProp\":null,\"ByteArrProp\":\"AQID\"}]}"),
new JsonProtocolTestData("InvocationMessage_HasCustomArgumentWithNullValueInclude", new InvocationMessage(null, "Target", new object[] { new CustomObject() }), true, NullValueHandling.Include, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}]}"),
new JsonProtocolTestData("InvocationMessage_HasStreamPlaceholder", new InvocationMessage(null, "Target", new object[] { new StreamPlaceholder("__test_id__")}), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Target\",\"arguments\":[{\"streamId\":\"__test_id__\"}]}"),
new JsonProtocolTestData("InvocationMessage_HasHeaders", AddHeaders(TestHeaders, new InvocationMessage("123", "Target", new object[] { 1, "Foo", 2.0f })), true, NullValueHandling.Ignore, "{\"type\":1," + SerializedHeaders + ",\"invocationId\":\"123\",\"target\":\"Target\",\"arguments\":[1,\"Foo\",2.0]}"),
new JsonProtocolTestData("InvocationMessage_StringIsoDateArgument", new InvocationMessage("Method", new object[] { "2016-05-10T13:51:20+12:34" }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Method\",\"arguments\":[\"2016-05-10T13:51:20+12:34\"]}"),
new JsonProtocolTestData("InvocationMessage_DateTimeOffsetArgument", new InvocationMessage("Method", new object[] { DateTimeOffset.Parse("2016-05-10T13:51:20+12:34") }), true, NullValueHandling.Ignore, "{\"type\":1,\"target\":\"Method\",\"arguments\":[\"2016-05-10T13:51:20+12:34\"]}"),
@ -55,7 +56,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new JsonProtocolTestData("StreamItemMessage_HasCustomItemWithNullValueInclude", new StreamItemMessage("123", new CustomObject()), true, NullValueHandling.Include, "{\"type\":2,\"invocationId\":\"123\",\"item\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}}"),
new JsonProtocolTestData("StreamItemMessage_HasHeaders", AddHeaders(TestHeaders, new StreamItemMessage("123", new CustomObject())), true, NullValueHandling.Include, "{\"type\":2," + SerializedHeaders + ",\"invocationId\":\"123\",\"item\":{\"stringProp\":\"SignalR!\",\"doubleProp\":6.2831853071,\"intProp\":42,\"dateTimeProp\":\"2017-04-11T00:00:00Z\",\"nullProp\":null,\"byteArrProp\":\"AQID\"}}"),
new JsonProtocolTestData("CompletionMessage_HasIntergerResult", CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"type\":3,\"invocationId\":\"123\",\"result\":1}"),
new JsonProtocolTestData("CompletionMessage_HasIntegerResult", CompletionMessage.WithResult("123", 1), true, NullValueHandling.Ignore, "{\"type\":3,\"invocationId\":\"123\",\"result\":1}"),
new JsonProtocolTestData("CompletionMessage_HasStringResult", CompletionMessage.WithResult("123", "Foo"), true, NullValueHandling.Ignore, "{\"type\":3,\"invocationId\":\"123\",\"result\":\"Foo\"}"),
new JsonProtocolTestData("CompletionMessage_HasFloatResult", CompletionMessage.WithResult("123", 2.0f), true, NullValueHandling.Ignore, "{\"type\":3,\"invocationId\":\"123\",\"result\":2.0}"),
new JsonProtocolTestData("CompletionMessage_HasBoolResult", CompletionMessage.WithResult("123", true), true, NullValueHandling.Ignore, "{\"type\":3,\"invocationId\":\"123\",\"result\":true}"),
@ -88,7 +89,11 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new JsonProtocolTestData("CloseMessage", CloseMessage.Empty, false, NullValueHandling.Ignore, "{\"type\":7}"),
new JsonProtocolTestData("CloseMessage_HasError", new CloseMessage("Error!"), false, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"Error!\"}"),
new JsonProtocolTestData("CloseMessage_HasErrorWithCamelCase", new CloseMessage("Error!"), true, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"Error!\"}"),
new JsonProtocolTestData("CloseMessage_HasErrorEmptyString", new CloseMessage(""), false, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"\"}")
new JsonProtocolTestData("CloseMessage_HasErrorEmptyString", new CloseMessage(""), false, NullValueHandling.Ignore, "{\"type\":7,\"error\":\"\"}"),
new JsonProtocolTestData("StreamCompleteMessage", new StreamCompleteMessage("123"), true, NullValueHandling.Ignore, "{\"type\":8,\"streamId\":\"123\"}"),
new JsonProtocolTestData("StreamCompleteMessageWithError", new StreamCompleteMessage("123", "zoinks"), true, NullValueHandling.Ignore, "{\"type\":8,\"streamId\":\"123\",\"error\":\"zoinks\"}"),
}.ToDictionary(t => t.Name);
public static IEnumerable<object[]> ProtocolTestDataNames => ProtocolTestData.Keys.Select(name => new object[] { name });
@ -101,6 +106,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
new JsonProtocolTestData("StreamInvocationMessage_IntegerArrayArgumentFirst", new StreamInvocationMessage("3", "Method", new object[] { 1, 2 }), false, NullValueHandling.Ignore, "{ \"type\":4, \"arguments\": [1,2], \"target\": \"Method\", \"invocationId\": \"3\" }"),
new JsonProtocolTestData("CompletionMessage_ResultFirst", new CompletionMessage("15", null, 10, hasResult: true), false, NullValueHandling.Ignore, "{ \"type\":3, \"result\": 10, \"invocationId\": \"15\" }"),
new JsonProtocolTestData("StreamItemMessage_ItemFirst", new StreamItemMessage("1a", "foo"), false, NullValueHandling.Ignore, "{ \"item\": \"foo\", \"invocationId\": \"1a\", \"type\":2 }")
}.ToDictionary(t => t.Name);
public static IEnumerable<object[]> OutOfOrderJsonTestDataNames => OutOfOrderJsonTestData.Keys.Select(name => new object[] { name });

View File

@ -91,6 +91,11 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
name: "InvocationWithHeadersNoIdAndArrayOfCustomObjectArgs",
message: AddHeaders(TestHeaders, new InvocationMessage("method", new object[] { new CustomObject(), new CustomObject() })),
binary: "lQGDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmXApm1ldGhvZJKGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgOGqlN0cmluZ1Byb3CoU2lnbmFsUiGqRG91YmxlUHJvcMtAGSH7VELPEqdJbnRQcm9wKqxEYXRlVGltZVByb3DW/1jsHICoTnVsbFByb3DAq0J5dGVBcnJQcm9wxAMBAgM="),
new ProtocolTestData(
name: "InvocationWithStreamPlaceholderObject",
message: new InvocationMessage(null, "Target", new object[] { new StreamPlaceholder("__test_id__")}),
binary: "lQGAwKZUYXJnZXSRgahTdHJlYW1JZKtfX3Rlc3RfaWRfXw=="
),
// StreamItem Messages
new ProtocolTestData(
@ -228,6 +233,16 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
message: AddHeaders(TestHeaders, new CancelInvocationMessage("xyz")),
binary: "kwWDo0Zvb6NCYXKyS2V5V2l0aApOZXcNCkxpbmVzq1N0aWxsIFdvcmtzsVZhbHVlV2l0aE5ld0xpbmVzsEFsc28KV29ya3MNCkZpbmWjeHl6"),
// StreamComplete Messages
new ProtocolTestData(
name: "StreamComplete",
message: new StreamCompleteMessage("xyz"),
binary: "kwijeHl6wA=="),
new ProtocolTestData(
name: "StreamCompleteWithError",
message: new StreamCompleteMessage("xyz", "zoinks"),
binary: "kwijeHl6pnpvaW5rcw=="),
// Ping Messages
new ProtocolTestData(
name: "Ping",
@ -259,7 +274,14 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
var expectedMessage = new InvocationMessage("xyz", "method", Array.Empty<object>());
// Verify that the input binary string decodes to the expected MsgPack primitives
var bytes = new byte[] { ArrayBytes(6), 1, 0x80, StringBytes(3), (byte)'x', (byte)'y', (byte)'z', StringBytes(6), (byte)'m', (byte)'e', (byte)'t', (byte)'h', (byte)'o', (byte)'d', ArrayBytes(0), StringBytes(2), (byte)'e', (byte)'x' };
var bytes = new byte[] { ArrayBytes(8),
1,
0x80,
StringBytes(3), (byte)'x', (byte)'y', (byte)'z',
StringBytes(6), (byte)'m', (byte)'e', (byte)'t', (byte)'h', (byte)'o', (byte)'d',
ArrayBytes(0),
0xc3,
StringBytes(2), (byte)'e', (byte)'x' };
// Parse the input fully now.
bytes = Frame(bytes);

View File

@ -58,5 +58,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
}
throw new InvalidOperationException("Unexpected binder call");
}
public Type GetStreamItemType(string streamId)
{
// In v1, stream items were only sent from server -> client
// and so they had items typed based on what the hub method returned.
// We just forward here for backwards compatibility.
return GetReturnType(streamId);
}
}
}

View File

@ -34,11 +34,13 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
return StreamInvocationMessagesEqual(streamInvocationMessage, (StreamInvocationMessage)y);
case CancelInvocationMessage cancelItemMessage:
return string.Equals(cancelItemMessage.InvocationId, ((CancelInvocationMessage)y).InvocationId, StringComparison.Ordinal);
case PingMessage pingMessage:
case PingMessage _:
// If the types are equal (above), then we're done.
return true;
case CloseMessage closeMessage:
return string.Equals(closeMessage.Error, ((CloseMessage) y).Error);
case StreamCompleteMessage streamCompleteMessage:
return StreamCompleteMessagesEqual(streamCompleteMessage, (StreamCompleteMessage)y);
default:
throw new InvalidOperationException($"Unknown message type: {x.GetType().FullName}");
}
@ -46,34 +48,40 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
private bool CompletionMessagesEqual(CompletionMessage x, CompletionMessage y)
{
return SequenceEqual(x.Headers, y.Headers) &&
string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) &&
string.Equals(x.Error, y.Error, StringComparison.Ordinal) &&
x.HasResult == y.HasResult &&
(Equals(x.Result, y.Result) || SequenceEqual(x.Result, y.Result));
return SequenceEqual(x.Headers, y.Headers)
&& string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal)
&& string.Equals(x.Error, y.Error, StringComparison.Ordinal)
&& x.HasResult == y.HasResult
&& (Equals(x.Result, y.Result) || SequenceEqual(x.Result, y.Result));
}
private bool StreamItemMessagesEqual(StreamItemMessage x, StreamItemMessage y)
{
return SequenceEqual(x.Headers, y.Headers) &&
string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) &&
(Equals(x.Item, y.Item) || SequenceEqual(x.Item, y.Item));
return SequenceEqual(x.Headers, y.Headers)
&& string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal)
&& (Equals(x.Item, y.Item) || SequenceEqual(x.Item, y.Item));
}
private bool InvocationMessagesEqual(InvocationMessage x, InvocationMessage y)
{
return SequenceEqual(x.Headers, y.Headers) &&
string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) &&
string.Equals(x.Target, y.Target, StringComparison.Ordinal) &&
ArgumentListsEqual(x.Arguments, y.Arguments);
return SequenceEqual(x.Headers, y.Headers)
&& string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal)
&& string.Equals(x.Target, y.Target, StringComparison.Ordinal)
&& ArgumentListsEqual(x.Arguments, y.Arguments);
}
private bool StreamInvocationMessagesEqual(StreamInvocationMessage x, StreamInvocationMessage y)
{
return SequenceEqual(x.Headers, y.Headers) &&
string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal) &&
string.Equals(x.Target, y.Target, StringComparison.Ordinal) &&
ArgumentListsEqual(x.Arguments, y.Arguments);
return SequenceEqual(x.Headers, y.Headers)
&& string.Equals(x.InvocationId, y.InvocationId, StringComparison.Ordinal)
&& string.Equals(x.Target, y.Target, StringComparison.Ordinal)
&& ArgumentListsEqual(x.Arguments, y.Arguments);
}
private bool StreamCompleteMessagesEqual(StreamCompleteMessage x, StreamCompleteMessage y)
{
return x.StreamId == y.StreamId
&& y.Error == y.Error;
}
private bool ArgumentListsEqual(object[] left, object[] right)
@ -90,7 +98,7 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
for (var i = 0; i < left.Length; i++)
{
if (!(Equals(left[i], right[i]) || SequenceEqual(left[i], right[i])))
if (!(Equals(left[i], right[i]) || SequenceEqual(left[i], right[i]) || PlaceholdersEqual(left[i], right[i])))
{
return false;
}
@ -98,6 +106,21 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
return true;
}
private bool PlaceholdersEqual(object left, object right)
{
if (left.GetType() != right.GetType())
{
return false;
}
switch(left)
{
case StreamPlaceholder leftPlaceholder:
return leftPlaceholder.StreamId == (right as StreamPlaceholder).StreamId;
default:
return false;
}
}
private bool SequenceEqual(object left, object right)
{
if (left == null && right == null)

View File

@ -178,6 +178,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests
return SendHubMessageAsync(new StreamInvocationMessage(invocationId, methodName, args));
}
public Task<string> BeginUploadStreamAsync(string invocationId, string methodName, params object[] args)
{
var message = new InvocationMessage(invocationId, methodName, args);
return SendHubMessageAsync(message);
}
public async Task<string> SendHubMessageAsync(HubMessage message)
{
var payload = _protocol.GetMessageBytes(message);
@ -295,6 +301,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
}
private class DefaultInvocationBinder : IInvocationBinder
{
public IReadOnlyList<Type> GetParameterTypes(string methodName)
@ -307,6 +314,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
return typeof(object);
}
public Type GetStreamItemType(string streamId)
{
throw new NotImplementedException();
}
}
}
}

View File

@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.Text;
using System.Threading.Channels;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Authorization;
@ -174,6 +175,78 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public SelfRef Self;
}
public async Task<string> StreamingConcat(ChannelReader<string> source)
{
var sb = new StringBuilder();
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
sb.Append(item);
}
}
return sb.ToString();
}
public async Task<int> StreamingSum(ChannelReader<int> source)
{
var total = 0;
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
total += item;
}
}
return total;
}
public async Task<List<object>> UploadArray(ChannelReader<object> source)
{
var results = new List<object>();
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
results.Add(item);
}
}
return results;
}
public async Task<string> TestTypeCastingErrors(ChannelReader<int> source)
{
try
{
await source.WaitToReadAsync();
}
catch (Exception ex)
{
Console.WriteLine(ex.ToString());
return "error identified and caught";
}
return "wrong type accepted, this is bad";
}
public async Task<bool> TestCustomErrorPassing(ChannelReader<int> source)
{
try
{
await source.WaitToReadAsync();
}
catch (Exception ex)
{
return ex.Message == HubConnectionHandlerTests.CustomErrorMessage;
}
return false;
}
}
public abstract class TestHub : Hub

View File

@ -4,6 +4,7 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Security.Claims;
using System.Text;
@ -1489,7 +1490,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<StreamingHub>>();
var invocationBinder = new Mock<IInvocationBinder>();
invocationBinder.Setup(b => b.GetReturnType(It.IsAny<string>())).Returns(typeof(string));
invocationBinder.Setup(b => b.GetStreamItemType(It.IsAny<string>())).Returns(typeof(string));
using (var client = new TestClient(protocol: protocol, invocationBinder: invocationBinder.Object))
{
@ -2361,6 +2362,93 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task UploadStringsToConcat()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id"));
foreach (var letter in new[] { "B", "E", "A", "N", "E", "D" })
{
await client.SendHubMessageAsync(new StreamDataMessage("id", letter)).OrTimeout();
}
await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout();
var result = (CompletionMessage)await client.ReadAsync().OrTimeout();
Assert.Equal("BEANED", result.Result);
}
}
[Fact]
public async Task UploadStreamedObjects()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.UploadArray), new StreamPlaceholder("id"));
var objects = new[] { new SampleObject("solo", 322), new SampleObject("ggez", 3145) };
foreach (var thing in objects)
{
await client.SendHubMessageAsync(new StreamDataMessage("id", thing)).OrTimeout();
}
await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout();
var response = (CompletionMessage)await client.ReadAsync().OrTimeout();
var result = ((JArray)response.Result).ToArray<object>();
Assert.Equal(objects[0].Foo, ((JContainer)result[0])["foo"]);
Assert.Equal(objects[0].Bar, ((JContainer)result[0])["bar"]);
Assert.Equal(objects[1].Foo, ((JContainer)result[1])["foo"]);
Assert.Equal(objects[1].Bar, ((JContainer)result[1])["bar"]);
}
}
[Fact]
public async Task UploadManyStreams()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
var ids = new[] { "0", "1", "2" };
foreach (string id in ids)
{
await client.BeginUploadStreamAsync("invocation_"+id, nameof(MethodHub.StreamingConcat), new StreamPlaceholder(id));
}
var words = new[] { "zygapophyses", "qwerty", "abcd" };
var pos = new[] { 0, 0, 0 };
var order = new[] { 2, 2, 0, 2, 1, 0, 0, 0, 0, 0, 0, 2, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1 };
foreach (var spot in order)
{
await client.SendHubMessageAsync(new StreamDataMessage(spot.ToString(), words[spot][pos[spot]])).OrTimeout();
pos[spot] += 1;
}
foreach (string id in new[] { "0", "2", "1" })
{
await client.SendHubMessageAsync(new StreamCompleteMessage(id)).OrTimeout();
var response = await client.ReadAsync().OrTimeout();
Debug.Write(response);
Assert.Equal(words[int.Parse(id)], ((CompletionMessage)response).Result);
}
}
}
[Fact]
public async Task ConnectionAbortedIfSendFailsWithProtocolError()
{
@ -2381,6 +2469,30 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task UploadStreamItemInvalidTypeAutoCasts()
{
// NOTE -- json.net is flexible here, and casts for us
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.StreamingConcat), new StreamPlaceholder("id")).OrTimeout();
// send integers that are then cast to strings
await client.SendHubMessageAsync(new StreamDataMessage("id", 5)).OrTimeout();
await client.SendHubMessageAsync(new StreamDataMessage("id", 10)).OrTimeout();
await client.SendHubMessageAsync(new StreamCompleteMessage("id")).OrTimeout();
var response = (CompletionMessage)await client.ReadAsync().OrTimeout();
Assert.Equal("510", response.Result);
}
}
[Fact]
public async Task ServerReportsProtocolMinorVersion()
{
@ -2395,7 +2507,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
using (var client = new TestClient(protocol: testProtocol.Object))
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler);
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
Assert.NotNull(client.HandshakeResponseMessage);
Assert.Equal(112, client.HandshakeResponseMessage.MinorVersion);
@ -2405,6 +2517,89 @@ namespace Microsoft.AspNetCore.SignalR.Tests
}
}
[Fact]
public async Task UploadStreamItemInvalidType()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocationId", nameof(MethodHub.TestTypeCastingErrors), new StreamPlaceholder("channelId")).OrTimeout();
// client is running wild, sending strings not ints.
// this error should be propogated to the user's HubMethod code
await client.SendHubMessageAsync(new StreamItemMessage("channelId", "not a number")).OrTimeout();
var response = await client.ReadAsync().OrTimeout();
Assert.Equal(typeof(CompletionMessage), response.GetType());
Assert.Equal("error identified and caught", (string)((CompletionMessage)response).Result);
}
}
[Fact]
public async Task UploadStreamItemInvalidId()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options => options.EnableDetailedErrors = true);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.SendHubMessageAsync(new StreamItemMessage("fake_id", "not a number")).OrTimeout();
// Client is breaking protocol by sending an invalid id, and should be closed.
var message = client.TryRead();
Assert.IsType<CloseMessage>(message);
Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error);
}
}
[Fact]
public async Task UploadStreamCompleteInvalidId()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(services =>
{
services.AddSignalR(options => options.EnableDetailedErrors = true);
});
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.SendHubMessageAsync(new StreamCompleteMessage("fake_id")).OrTimeout();
// Client is breaking protocol by sending an invalid id, and should be closed.
var message = client.TryRead();
Assert.IsType<CloseMessage>(message);
Assert.Equal("Connection closed with an error. KeyNotFoundException: No stream with id 'fake_id' could be found.", ((CloseMessage)message).Error);
}
}
public static string CustomErrorMessage = "custom error for testing ::::)";
[Fact]
public async Task UploadStreamCompleteWithError()
{
var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider();
var connectionHandler = serviceProvider.GetService<HubConnectionHandler<MethodHub>>();
using (var client = new TestClient())
{
var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout();
await client.BeginUploadStreamAsync("invocation", nameof(MethodHub.TestCustomErrorPassing), new StreamPlaceholder("id")).OrTimeout();
await client.SendHubMessageAsync(new StreamCompleteMessage("id", CustomErrorMessage)).OrTimeout();
var response = (CompletionMessage)await client.ReadAsync().OrTimeout();
Assert.True((bool)response.Result);
}
}
private class CustomHubActivator<THub> : IHubActivator<THub> where THub : Hub
{
public int ReleaseCount;
@ -2450,5 +2645,16 @@ namespace Microsoft.AspNetCore.SignalR.Tests
public string GetUserId(HubConnectionContext connection) => _getUserId(connection);
}
private class SampleObject
{
public SampleObject(string foo, int bar)
{
Bar = bar;
Foo = foo;
}
public int Bar { get; }
public string Foo { get; }
}
}
}