diff --git a/samples/SignalRSamples/Hubs/Streaming.cs b/samples/SignalRSamples/Hubs/Streaming.cs index e772d0c38d..ee5401b7c1 100644 --- a/samples/SignalRSamples/Hubs/Streaming.cs +++ b/samples/SignalRSamples/Hubs/Streaming.cs @@ -17,7 +17,7 @@ namespace SignalRSamples.Hubs .Select((_, index) => index) .Take(count); - return observable.AsChannelReader(); + return observable.AsChannelReader(Context.ConnectionAborted); } public ChannelReader ChannelCounter(int count, int delay) diff --git a/samples/SignalRSamples/ObservableExtensions.cs b/samples/SignalRSamples/ObservableExtensions.cs index 8cd849b92b..875a769732 100644 --- a/samples/SignalRSamples/ObservableExtensions.cs +++ b/samples/SignalRSamples/ObservableExtensions.cs @@ -3,13 +3,18 @@ using System; using System.Reactive.Linq; +using System.Threading; using System.Threading.Channels; namespace SignalRSamples { public static class ObservableExtensions { - public static ChannelReader AsChannelReader(this IObservable observable, int? maxBufferSize = null) + public static ChannelReader AsChannelReader( + this IObservable observable, + CancellationToken connectionAborted, + int? maxBufferSize = null + ) { // This sample shows adapting an observable to a ChannelReader without // back pressure, if the connection is slower than the producer, memory will @@ -26,9 +31,14 @@ namespace SignalRSamples value => channel.Writer.TryWrite(value), error => channel.Writer.TryComplete(error), () => channel.Writer.TryComplete()); + var abortRegistration = connectionAborted.Register(() => channel.Writer.TryComplete()); // Complete the subscription on the reader completing - channel.Reader.Completion.ContinueWith(task => disposable.Dispose()); + channel.Reader.Completion.ContinueWith(task => + { + disposable.Dispose(); + abortRegistration.Dispose(); + }); return channel.Reader; }