diff --git a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs index 6ebd1dcb25..93556e8d94 100644 --- a/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs +++ b/src/SignalR/common/Shared/AsyncEnumerableAdapters.cs @@ -51,37 +51,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal { ((CancellationTokenSource)ctsState).Cancel(); }, _cts); - - return new CancelableEnumerator(_asyncEnumerable.GetAsyncEnumerator(), registration); } return enumerator; } - - private class CancelableEnumerator : IAsyncEnumerator - { - private IAsyncEnumerator _asyncEnumerator; - private readonly CancellationTokenRegistration _cancellationTokenRegistration; - - public T Current => (T)_asyncEnumerator.Current; - - public CancelableEnumerator(IAsyncEnumerator asyncEnumerator, CancellationTokenRegistration registration) - { - _asyncEnumerator = asyncEnumerator; - _cancellationTokenRegistration = registration; - } - - public ValueTask MoveNextAsync() - { - return _asyncEnumerator.MoveNextAsync(); - } - - public ValueTask DisposeAsync() - { - _cancellationTokenRegistration.Dispose(); - return _asyncEnumerator.DisposeAsync(); - } - } } /// Converts an IAsyncEnumerable of T to an IAsyncEnumerable of object. @@ -98,10 +71,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal public IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken = default) { - // Assume that this will be iterated through with await foreach which always passes a default token. - // Instead use the token from the ctor. - Debug.Assert(cancellationToken == default); - var enumeratorOfT = _asyncEnumerable.GetAsyncEnumerator(_cancellationToken); return enumeratorOfT as IAsyncEnumerator ?? new BoxedAsyncEnumerator(enumeratorOfT); } diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index ee39dcad9d..6d097d731d 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -403,13 +403,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; - try { - await foreach (var streamItem in enumerable) + await foreach(var item in enumerable.WithCancellation(streamCts.Token)) { // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, streamItem)); + await connection.WriteAsync(new StreamItemMessage(invocationId, item)); } } catch (ChannelClosedException ex)