Client to Server Streaming with IAsyncEnumerable (#9310)
This commit is contained in:
parent
6074daacae
commit
ebb9ad20db
|
|
@ -42,7 +42,9 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
private readonly SemaphoreSlim _connectionLock = new SemaphoreSlim(1, 1);
|
||||
|
||||
private static readonly MethodInfo _sendStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendStreamItems"));
|
||||
|
||||
#if NETCOREAPP3_0
|
||||
private static readonly MethodInfo _sendIAsyncStreamItemsMethod = typeof(HubConnection).GetMethods(BindingFlags.NonPublic | BindingFlags.Instance).Single(m => m.Name.Equals("SendIAsyncEnumerableStreamItems"));
|
||||
#endif
|
||||
// Persistent across all connections
|
||||
private readonly ILoggerFactory _loggerFactory;
|
||||
private readonly ILogger _logger;
|
||||
|
|
@ -533,13 +535,11 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
}
|
||||
|
||||
LaunchStreams(readers, cancellationToken);
|
||||
|
||||
return channel;
|
||||
}
|
||||
|
||||
private Dictionary<string, object> PackageStreamingParams(ref object[] args, out List<string> streamIds)
|
||||
{
|
||||
// lazy initialized, to avoid allocating unecessary dictionaries
|
||||
Dictionary<string, object> readers = null;
|
||||
streamIds = null;
|
||||
var newArgs = new List<object>(args.Length);
|
||||
|
|
@ -572,7 +572,6 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
}
|
||||
|
||||
args = newArgs.ToArray();
|
||||
|
||||
return readers;
|
||||
}
|
||||
|
||||
|
|
@ -590,6 +589,15 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
// For each stream that needs to be sent, run a "send items" task in the background.
|
||||
// This reads from the channel, attaches streamId, and sends to server.
|
||||
// A single background thread here quickly gets messy.
|
||||
#if NETCOREAPP3_0
|
||||
if (ReflectionHelper.IsIAsyncEnumerable(reader.GetType()))
|
||||
{
|
||||
_ = _sendIAsyncStreamItemsMethod
|
||||
.MakeGenericMethod(reader.GetType().GetInterface("IAsyncEnumerable`1").GetGenericArguments())
|
||||
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
|
||||
continue;
|
||||
}
|
||||
#endif
|
||||
_ = _sendStreamItemsMethod
|
||||
.MakeGenericMethod(reader.GetType().GetGenericArguments())
|
||||
.Invoke(this, new object[] { kvp.Key.ToString(), reader, cancellationToken });
|
||||
|
|
@ -597,24 +605,52 @@ namespace Microsoft.AspNetCore.SignalR.Client
|
|||
}
|
||||
|
||||
// this is called via reflection using the `_sendStreamItems` field
|
||||
private async Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
|
||||
private Task SendStreamItems<T>(string streamId, ChannelReader<T> reader, CancellationToken token)
|
||||
{
|
||||
Log.StartingStream(_logger, streamId);
|
||||
|
||||
var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token).Token;
|
||||
|
||||
string responseError = null;
|
||||
try
|
||||
async Task ReadChannelStream(CancellationTokenSource tokenSource)
|
||||
{
|
||||
while (await reader.WaitToReadAsync(combinedToken))
|
||||
while (await reader.WaitToReadAsync(tokenSource.Token))
|
||||
{
|
||||
while (!combinedToken.IsCancellationRequested && reader.TryRead(out var item))
|
||||
while (!tokenSource.Token.IsCancellationRequested && reader.TryRead(out var item))
|
||||
{
|
||||
await SendWithLock(new StreamItemMessage(streamId, item));
|
||||
Log.SendingStreamItem(_logger, streamId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return CommonStreaming(streamId, token, ReadChannelStream);
|
||||
}
|
||||
|
||||
#if NETCOREAPP3_0
|
||||
// this is called via reflection using the `_sendIAsyncStreamItemsMethod` field
|
||||
private Task SendIAsyncEnumerableStreamItems<T>(string streamId, IAsyncEnumerable<T> stream, CancellationToken token)
|
||||
{
|
||||
async Task ReadAsyncEnumerableStream(CancellationTokenSource tokenSource)
|
||||
{
|
||||
var streamValues = AsyncEnumerableAdapters.MakeCancelableTypedAsyncEnumerable(stream, tokenSource);
|
||||
|
||||
await foreach (var streamValue in streamValues)
|
||||
{
|
||||
await SendWithLock(new StreamItemMessage(streamId, streamValue));
|
||||
Log.SendingStreamItem(_logger, streamId);
|
||||
}
|
||||
}
|
||||
|
||||
return CommonStreaming(streamId, token, ReadAsyncEnumerableStream);
|
||||
}
|
||||
#endif
|
||||
|
||||
private async Task CommonStreaming(string streamId, CancellationToken token, Func<CancellationTokenSource, Task> createAndConsumeStream)
|
||||
{
|
||||
var cts = CancellationTokenSource.CreateLinkedTokenSource(_uploadStreamToken, token);
|
||||
|
||||
Log.StartingStream(_logger, streamId);
|
||||
string responseError = null;
|
||||
try
|
||||
{
|
||||
await createAndConsumeStream(cts);
|
||||
}
|
||||
catch (OperationCanceledException)
|
||||
{
|
||||
Log.CancelingStream(_logger, streamId);
|
||||
|
|
|
|||
|
|
@ -661,6 +661,106 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
|
||||
[LogLevel(LogLevel.Trace)]
|
||||
public async Task CanStreamToServerWithIAsyncEnumerable(string protocolName, HttpTransportType transportType, string path)
|
||||
{
|
||||
var protocol = HubProtocols[protocolName];
|
||||
using (StartServer<Startup>(out var server))
|
||||
{
|
||||
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
|
||||
try
|
||||
{
|
||||
async IAsyncEnumerable<string> clientStreamData()
|
||||
{
|
||||
var items = new string[] { "A", "B", "C", "D" };
|
||||
foreach (var item in items)
|
||||
{
|
||||
await Task.Delay(10);
|
||||
yield return item;
|
||||
}
|
||||
}
|
||||
|
||||
await connection.StartAsync().OrTimeout();
|
||||
|
||||
var stream = clientStreamData();
|
||||
|
||||
var channel = await connection.StreamAsChannelAsync<string>("StreamEcho", stream).OrTimeout();
|
||||
|
||||
Assert.Equal("A", await channel.ReadAsync().AsTask().OrTimeout());
|
||||
Assert.Equal("B", await channel.ReadAsync().AsTask().OrTimeout());
|
||||
Assert.Equal("C", await channel.ReadAsync().AsTask().OrTimeout());
|
||||
Assert.Equal("D", await channel.ReadAsync().AsTask().OrTimeout());
|
||||
|
||||
var results = await channel.ReadAndCollectAllAsync().OrTimeout();
|
||||
Assert.Empty(results);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
|
||||
[LogLevel(LogLevel.Trace)]
|
||||
public async Task CanCancelIAsyncEnumerableClientToServerUpload(string protocolName, HttpTransportType transportType, string path)
|
||||
{
|
||||
var protocol = HubProtocols[protocolName];
|
||||
using (StartServer<Startup>(out var server))
|
||||
{
|
||||
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
|
||||
try
|
||||
{
|
||||
async IAsyncEnumerable<int> clientStreamData()
|
||||
{
|
||||
for (var i = 0; i < 1000; i++)
|
||||
{
|
||||
yield return i;
|
||||
await Task.Delay(10);
|
||||
}
|
||||
}
|
||||
|
||||
await connection.StartAsync().OrTimeout();
|
||||
var results = new List<int>();
|
||||
var stream = clientStreamData();
|
||||
var cts = new CancellationTokenSource();
|
||||
var ex = await Assert.ThrowsAsync<OperationCanceledException>(async () =>
|
||||
{
|
||||
var channel = await connection.StreamAsChannelAsync<int>("StreamEchoInt", stream, cts.Token).OrTimeout();
|
||||
|
||||
while (await channel.WaitToReadAsync())
|
||||
{
|
||||
while (channel.TryRead(out var item))
|
||||
{
|
||||
results.Add(item);
|
||||
cts.Cancel();
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Assert.True(results.Count > 0 && results.Count < 1000);
|
||||
Assert.True(cts.IsCancellationRequested);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
|
||||
throw;
|
||||
}
|
||||
finally
|
||||
{
|
||||
await connection.DisposeAsync().OrTimeout();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[Theory]
|
||||
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
|
||||
[LogLevel(LogLevel.Trace)]
|
||||
|
|
@ -673,7 +773,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
try
|
||||
{
|
||||
await connection.StartAsync().OrTimeout();
|
||||
var stream = connection.StreamAsync<int>("Stream", 1000 );
|
||||
var stream = connection.StreamAsync<int>("Stream", 1000);
|
||||
var results = new List<int>();
|
||||
|
||||
var cts = new CancellationTokenSource();
|
||||
|
|
|
|||
|
|
@ -43,6 +43,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
|
||||
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
|
||||
|
||||
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
|
||||
|
||||
public string GetUserIdentifier()
|
||||
{
|
||||
return Context.UserIdentifier;
|
||||
|
|
@ -121,6 +123,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
}
|
||||
|
||||
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
|
||||
|
||||
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
|
||||
}
|
||||
|
||||
public class TestHubT : Hub<ITestHub>
|
||||
|
|
@ -151,6 +155,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
}
|
||||
|
||||
public ChannelReader<string> StreamEcho(ChannelReader<string> source) => TestHubMethodsImpl.StreamEcho(source);
|
||||
|
||||
public ChannelReader<int> StreamEchoInt(ChannelReader<int> source) => TestHubMethodsImpl.StreamEchoInt(source);
|
||||
}
|
||||
|
||||
internal static class TestHubMethodsImpl
|
||||
|
|
@ -212,6 +218,30 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
|
|||
|
||||
return output.Reader;
|
||||
}
|
||||
|
||||
public static ChannelReader<int> StreamEchoInt(ChannelReader<int> source)
|
||||
{
|
||||
var output = Channel.CreateUnbounded<int>();
|
||||
_ = Task.Run(async () =>
|
||||
{
|
||||
try
|
||||
{
|
||||
while (await source.WaitToReadAsync())
|
||||
{
|
||||
while (source.TryRead(out var item))
|
||||
{
|
||||
await output.Writer.WriteAsync(item);
|
||||
}
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
output.Writer.TryComplete();
|
||||
}
|
||||
});
|
||||
|
||||
return output.Reader;
|
||||
}
|
||||
}
|
||||
|
||||
public interface ITestHub
|
||||
|
|
|
|||
|
|
@ -210,12 +210,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol
|
|||
return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams));
|
||||
}
|
||||
|
||||
private static StreamItemMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
|
||||
private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver)
|
||||
{
|
||||
var headers = ReadHeaders(input, ref offset);
|
||||
var invocationId = ReadInvocationId(input, ref offset);
|
||||
var itemType = binder.GetStreamItemType(invocationId);
|
||||
var value = DeserializeObject(input, ref offset, itemType, "item", resolver);
|
||||
object value;
|
||||
try
|
||||
{
|
||||
var itemType = binder.GetStreamItemType(invocationId);
|
||||
value = DeserializeObject(input, ref offset, itemType, "item", resolver);
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
return new StreamBindingFailureMessage(invocationId, ExceptionDispatchInfo.Capture(ex));
|
||||
}
|
||||
|
||||
return ApplyHeaders(headers, new StreamItemMessage(invocationId, value));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
|
||||
|
||||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Channels;
|
||||
|
||||
namespace Microsoft.AspNetCore.SignalR
|
||||
|
|
@ -13,6 +15,13 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
public static bool IsStreamingType(Type type, bool mustBeDirectType = false)
|
||||
{
|
||||
// TODO #2594 - add Streams here, to make sending files easy
|
||||
|
||||
#if NETCOREAPP3_0
|
||||
if (IsIAsyncEnumerable(type))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
do
|
||||
{
|
||||
if (type.IsGenericType && type.GetGenericTypeDefinition() == typeof(ChannelReader<>))
|
||||
|
|
@ -25,5 +34,22 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
return false;
|
||||
}
|
||||
|
||||
#if NETCOREAPP3_0
|
||||
public static bool IsIAsyncEnumerable(Type type)
|
||||
{
|
||||
return type.GetInterfaces().Any(t =>
|
||||
{
|
||||
if (t.IsGenericType)
|
||||
{
|
||||
return t.GetGenericTypeDefinition() == typeof(IAsyncEnumerable<>);
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
});
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -277,8 +277,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol
|
|||
// StreamItemMessage
|
||||
new InvalidMessageData("StreamItemMissingId", new byte[] { 0x92, 2, 0x80 }, "Reading 'invocationId' as String failed."),
|
||||
new InvalidMessageData("StreamItemInvocationIdBoolean", new byte[] { 0x93, 2, 0x80, 0xc2 }, "Reading 'invocationId' as String failed."),
|
||||
new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
|
||||
new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
|
||||
|
||||
// These now trigger StreamBindingInvocationFailureMessages
|
||||
//new InvalidMessageData("StreamItemMissing", new byte[] { 0x93, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z' }, "Deserializing object of the `String` type for 'item' failed."),
|
||||
//new InvalidMessageData("StreamItemTypeMismatch", new byte[] { 0x94, 2, 0x80, 0xa3, (byte)'x', (byte)'y', (byte)'z', 42 }, "Deserializing object of the `String` type for 'item' failed."),
|
||||
|
||||
// CompletionMessage
|
||||
new InvalidMessageData("CompletionMissingId", new byte[] { 0x92, 3, 0x80 }, "Reading 'invocationId' as String failed."),
|
||||
|
|
|
|||
Loading…
Reference in New Issue