Revert "Remove CancelableEnumerator (#10099)" (#10129)

This commit is contained in:
Mikael Mengistu 2019-05-09 21:08:20 -07:00 committed by GitHub
parent 3cd84a9b03
commit 0adbfc6d25
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 2 deletions

View File

@ -51,10 +51,37 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
((CancellationTokenSource)ctsState).Cancel();
}, _cts);
return new CancelableEnumerator<TResult>(_asyncEnumerable.GetAsyncEnumerator(), registration);
}
return enumerator;
}
private class CancelableEnumerator<T> : IAsyncEnumerator<T>
{
private IAsyncEnumerator<T> _asyncEnumerator;
private readonly CancellationTokenRegistration _cancellationTokenRegistration;
public T Current => (T)_asyncEnumerator.Current;
public CancelableEnumerator(IAsyncEnumerator<T> asyncEnumerator, CancellationTokenRegistration registration)
{
_asyncEnumerator = asyncEnumerator;
_cancellationTokenRegistration = registration;
}
public ValueTask<bool> MoveNextAsync()
{
return _asyncEnumerator.MoveNextAsync();
}
public ValueTask DisposeAsync()
{
_cancellationTokenRegistration.Dispose();
return _asyncEnumerator.DisposeAsync();
}
}
}
/// <summary>Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object.</summary>
@ -71,6 +98,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public IAsyncEnumerator<object> GetAsyncEnumerator(CancellationToken cancellationToken = default)
{
// Assume that this will be iterated through with await foreach which always passes a default token.
// Instead use the token from the ctor.
Debug.Assert(cancellationToken == default);
var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken);
return enumeratorOfT as IAsyncEnumerator<object> ?? new BoxedAsyncEnumerator(enumeratorOfT);
}

View File

@ -403,12 +403,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal
IHubActivator<THub> hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage)
{
string error = null;
try
{
await foreach(var item in enumerable.WithCancellation(streamCts.Token))
await foreach (var streamItem in enumerable)
{
// Send the stream item
await connection.WriteAsync(new StreamItemMessage(invocationId, item));
await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem));
}
}
catch (ChannelClosedException ex)