Client to Server Streaming with IAsyncEnumerable (#9310)

This commit is contained in:
Mikael Mengistu 2019-04-18 13:20:39 -07:00 committed by GitHub
parent 6074daacae
commit ebb9ad20db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 222 additions and 19 deletions

View File

@ -42,7 +42,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);
private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));
#if NETCOREAPP3_0
private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems"));
#endif
// Persistent across all connections
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
@ -533,13 +535,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
LaunchStreams(readers, cancellationToken);
return channel;
}
private Dictionary<string, object> PackageStreamingParams(ref object[] args, out List<string> streamIds)
{
// lazy initialized, to avoid allocating unecessary dictionaries
Dictionary<string, object> readers = null;
streamIds = null;
var newArgs = new List<object>(args.Length);
@ -572,7 +572,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
args = newArgs.ToArray();
return readers;
}
@ -590,6 +589,15 @@ namespace Microsoft.AspNetCore.SignalR.Client
// For each stream that needs to be sent, run a "send items" task in the background.
// This reads from the channel, attaches streamId, and sends to server.
// A single background thread here quickly gets messy.
#if NETCOREAPP3_0
if (ReflectionHelper.IsIAsyncEnumerable(reader.GetType()))
{
_ = _sendIAsyncStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments())
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
continue;
}
#endif
_ = _sendStreamItemsMethod
.MakeGenericMethod(reader.GetType().GetGenericArguments())
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
@ -597,24 +605,52 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
// this is called via reflection using the `_sendStreamItems` field
private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
private Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
{
Log.StartingStream(_logger, streamId);
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;
string responseError = null;
try
async Task ReadChannelStream(CancellationTokenSource tokenSource)
{
while (await reader.WaitToReadAsync(combinedToken))
while (await reader.WaitToReadAsync(tokenSource.Token))
{
while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
{
await SendWithLock(new StreamItemMessage(streamId, item));
Log.SendingStreamItem(_logger, streamId);
}
}
}
return CommonStreaming(streamId, token, ReadChannelStream);
}
#if NETCOREAPP3_0
// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
private Task SendIAsyncEnumerableStreamItems<T>(string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
{
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
{
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
await foreach (var streamValue in streamValues)
{
await SendWithLock(new StreamItemMessage(streamId, streamValue));
Log.SendingStreamItem(_logger, streamId);
}
}
return CommonStreaming(streamId, token, ReadAsyncEnumerableStream);
}
#endif
private async Task CommonStreaming(string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
{
var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token);
Log.StartingStream(_logger, streamId);
string responseError = null;
try
{
await createAndConsumeStream(cts);
}
catch (OperationCanceledException)
{
Log.CancelingStream(_logger, streamId);

View File

@ -661,6 +661,106 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
public async Task CanStreamToServerWithIAsyncEnumerable(string protocolName, HttpTransportType transportType, string path)
{
var protocol = HubProtocols[protocolName];
using (StartServer<Startup>(out var server))
{
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
try
{
async IAsyncEnumerable<string> clientStreamData()
{
var items = new string[] { "A", "B", "C", "D" };
foreach (var item in items)
{
await Task.Delay(10);
yield return item;
}
}
await connection.StartAsync().OrTimeout();
var stream = clientStreamData();
var channel = await connection.StreamAsChannelAsync<string>("StreamEcho", stream).OrTimeout();
Assert.Equal("A", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("B", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("C", await channel.ReadAsync().AsTask().OrTimeout());
Assert.Equal("D", await channel.ReadAsync().AsTask().OrTimeout());
var results = await channel.ReadAndCollectAllAsync().OrTimeout();
Assert.Empty(results);
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
public async Task CanCancelIAsyncEnumerableClientToServerUpload(string protocolName, HttpTransportType transportType, string path)
{
var protocol = HubProtocols[protocolName];
using (StartServer<Startup>(out var server))
{
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
try
{
async IAsyncEnumerable<int> clientStreamData()
{
for (var i = 0; i < 1000; i++)
{
yield return i;
await Task.Delay(10);
}
}
await connection.StartAsync().OrTimeout();
var results = new List<int>();
var stream = clientStreamData();
var cts = new CancellationTokenSource();
var ex = await Assert.ThrowsAsync<OperationCanceledException>(async () =>
{
var channel = await connection.StreamAsChannelAsync<int>("StreamEchoInt", stream, cts.Token).OrTimeout();
while (await channel.WaitToReadAsync())
{
while (channel.TryRead(out var item))
{
results.Add(item);
cts.Cancel();
}
}
});
Assert.True(results.Count > 0 && results.Count < 1000);
Assert.True(cts.IsCancellationRequested);
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
@ -673,7 +773,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
try
{
await connection.StartAsync().OrTimeout();
var stream = connection.StreamAsync<int>("Stream", 1000 );
var stream = connection.StreamAsync<int>("Stream", 1000);
var results = new List<int>();
var cts = new CancellationTokenSource();

View File

@ -43,6 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
public string GetUserIdentifier()
{
return Context.UserIdentifier;
@ -121,6 +123,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
}
public class TestHubT : Hub<ITestHub>
@ -151,6 +155,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
}
internal static class TestHubMethodsImpl
@ -212,6 +218,30 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
return output.Reader;
}
public static ChannelReader<int> StreamEchoInt(ChannelReader<int> source)
{
var output = Channel.CreateUnbounded<int>();
_ = Task.Run(async () =>
{
try
{
while (await source.WaitToReadAsync())
{
while (source.TryRead(out var item))
{
await output.Writer.WriteAsync(item);
}
}
}
finally
{
output.Writer.TryComplete();
}
});
return output.Reader;
}
}
public interface ITestHub

View File

@ -210,12 +210,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
}
private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
{
var headers = ReadHeaders(input, ref offset);
var invocationId = ReadInvocationId(input, ref offset);
var itemType = binder.GetStreamItemType(invocationId);
var value = DeserializeObject(input, ref offset, itemType, "item", resolver);
object value;
try
{
var itemType = binder.GetStreamItemType(invocationId);
value = DeserializeObject(input, ref offset, itemType, "item", resolver);
}
catch (Exception ex)
{
return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
}
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
}

View File

@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Channels;
namespace Microsoft.AspNetCore.SignalR
@ -13,6 +15,13 @@ namespace Microsoft.AspNetCore.SignalR
public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
{
// TODO #2594 - add Streams here, to make sending files easy
#if NETCOREAPP3_0
if (IsIAsyncEnumerable(type))
{
return true;
}
#endif
do
{
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
@ -25,5 +34,22 @@ namespace Microsoft.AspNetCore.SignalR
return false;
}
#if NETCOREAPP3_0
public static bool IsIAsyncEnumerable(Type type)
{
return type.GetInterfaces().Any(t =>
{
if (t.IsGenericType)
{
return t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>);
}
else
{
return false;
}
});
}
#endif
}
}

View File

@ -277,8 +277,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
// StreamItemMessage
new InvalidMessageData("StreamItemMissingId", new byte[] { 0x92, 2, 0x80 }, "Reading 'invocationId' as String failed."),
new InvalidMessageData("StreamItemInvocationIdBoolean", new byte[] { 0x93, 2, 0x80, 0xc2 }, "Reading 'invocationId' as String failed."),
new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
// These now trigger StreamBindingInvocationFailureMessages
//new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
//new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
// CompletionMessage
new InvalidMessageData("CompletionMissingId", new byte[] { 0x92, 3, 0x80 }, "Reading 'invocationId' as String failed."),