From 86083c0302d6eecb11808f9bb6406be62bb0075b Mon Sep 17 00:00:00 2001 From: David Fowler Date: Sat, 7 Apr 2018 15:10:39 -0700 Subject: [PATCH] Removing native support for IObservable (#1890) - There are too many issues and questions with respect to back pressure and the buffering policy we should use when the client being streamed to can't support the data being pushed via OnNext. As a result, we're dropping support for IObservable but keeping ChannelReader and we'll eventually support IAsyncEnumerable when that makes it into the BCL. - Add sample showing Observable -> ChannelReader adaption --- .../DefaultHubDispatcherBenchmark.cs | 58 ----- clients/ts/FunctionalTests/TestHub.cs | 18 +- clients/ts/FunctionalTests/ts/Utils.ts | 2 - samples/SignalRSamples/Hubs/Streaming.cs | 6 +- .../SignalRSamples/ObservableExtensions.cs | 36 +++ .../CastObservable.cs | 54 ----- .../Internal/AsyncEnumeratorAdapters.cs | 65 ------ .../Internal/DefaultHubDispatcher.cs | 9 +- .../Internal/HubMethodDescriptor.cs | 48 +--- .../HubConnectionTests.cs | 2 +- .../Hubs.cs | 25 ++- .../HubConnectionHandlerTests.cs | 211 +----------------- .../HubEndpointTestUtils/Hubs.cs | 124 +--------- 13 files changed, 81 insertions(+), 577 deletions(-) create mode 100644 samples/SignalRSamples/ObservableExtensions.cs delete mode 100644 src/Microsoft.AspNetCore.SignalR.Client.Core/CastObservable.cs diff --git a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs index 5575712075..cc524f714b 100644 --- a/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs +++ b/benchmarks/Microsoft.AspNetCore.SignalR.Microbenchmarks/DefaultHubDispatcherBenchmark.cs @@ -93,8 +93,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks public class TestHub : Hub { - private static readonly IObservable ObservableInstance = Observable.Empty(); - public void Invocation() { } @@ -119,21 +117,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return new ValueTask(1); } - public IObservable StreamObservable() - { - return ObservableInstance; - } - - public Task> StreamObservableAsync() - { - return Task.FromResult(ObservableInstance); - } - - public ValueTask> StreamObservableValueTaskAsync() - { - return new ValueTask>(ObservableInstance); - } - public ChannelReader StreamChannelReader() { var channel = Channel.CreateUnbounded(); @@ -173,11 +156,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return channel.Reader; } - - public IObservable StreamObservableCount(int count) - { - return Observable.Range(0, count); - } } [Benchmark] @@ -210,24 +188,6 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks return _dispatcher.DispatchMessageAsync(_connectionContext, new InvocationMessage("123", "InvocationValueTaskAsync", null)); } - [Benchmark] - public Task StreamObservable() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservable", null)); - } - - [Benchmark] - public Task StreamObservableAsync() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableAsync", null)); - } - - [Benchmark] - public Task StreamObservableValueTaskAsync() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableValueTaskAsync", null)); - } - [Benchmark] public Task StreamChannelReader() { @@ -263,23 +223,5 @@ namespace Microsoft.AspNetCore.SignalR.Microbenchmarks { return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamChannelReaderCount", argumentBindingException: null, new object[] { 1000 })); } - - [Benchmark] - public Task StreamObservableCount_Zero() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableCount", argumentBindingException: null, new object[] { 0 })); - } - - [Benchmark] - public Task StreamObservableCount_One() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableCount", argumentBindingException: null, new object[] { 1 })); - } - - [Benchmark] - public Task StreamObservableCount_Thousand() - { - return _dispatcher.DispatchMessageAsync(_connectionContext, new StreamInvocationMessage("123", "StreamObservableCount", argumentBindingException: null, new object[] { 1000 })); - } } } diff --git a/clients/ts/FunctionalTests/TestHub.cs b/clients/ts/FunctionalTests/TestHub.cs index 8cb5b3c90f..e555408983 100644 --- a/clients/ts/FunctionalTests/TestHub.cs +++ b/clients/ts/FunctionalTests/TestHub.cs @@ -3,6 +3,7 @@ using System; using System.Reactive.Linq; +using System.Threading.Channels; using System.Threading.Tasks; using Microsoft.AspNetCore.Http.Connections; using Microsoft.AspNetCore.SignalR; @@ -38,17 +39,24 @@ namespace FunctionalTests return Clients.Client(Context.ConnectionId).SendAsync("CustomObject", customObject); } - public IObservable Stream() + public ChannelReader Stream() { - return new string[] { "a", "b", "c" }.ToObservable(); + var channel = Channel.CreateUnbounded(); + channel.Writer.TryWrite("a"); + channel.Writer.TryWrite("b"); + channel.Writer.TryWrite("c"); + channel.Writer.Complete(); + return channel.Reader; } - public IObservable EmptyStream() + public ChannelReader EmptyStream() { - return Array.Empty().ToObservable(); + var channel = Channel.CreateUnbounded(); + channel.Writer.Complete(); + return channel.Reader; } - public IObservable StreamThrowException(string message) + public ChannelReader StreamThrowException(string message) { throw new InvalidOperationException(message); } diff --git a/clients/ts/FunctionalTests/ts/Utils.ts b/clients/ts/FunctionalTests/ts/Utils.ts index 2379cfdc57..2f0abcb00f 100644 --- a/clients/ts/FunctionalTests/ts/Utils.ts +++ b/clients/ts/FunctionalTests/ts/Utils.ts @@ -1,8 +1,6 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -jasmine.DEFAULT_TIMEOUT_INTERVAL = 20000; - export function getParameterByName(name: string) { const url = window.location.href; name = name.replace(/[\[\]]/g, "\\$&"); diff --git a/samples/SignalRSamples/Hubs/Streaming.cs b/samples/SignalRSamples/Hubs/Streaming.cs index 00d81fba3d..e772d0c38d 100644 --- a/samples/SignalRSamples/Hubs/Streaming.cs +++ b/samples/SignalRSamples/Hubs/Streaming.cs @@ -11,11 +11,13 @@ namespace SignalRSamples.Hubs { public class Streaming : Hub { - public IObservable ObservableCounter(int count, int delay) + public ChannelReader ObservableCounter(int count, int delay) { - return Observable.Interval(TimeSpan.FromMilliseconds(delay)) + var observable = Observable.Interval(TimeSpan.FromMilliseconds(delay)) .Select((_, index) => index) .Take(count); + + return observable.AsChannelReader(); } public ChannelReader ChannelCounter(int count, int delay) diff --git a/samples/SignalRSamples/ObservableExtensions.cs b/samples/SignalRSamples/ObservableExtensions.cs new file mode 100644 index 0000000000..04f5a5b549 --- /dev/null +++ b/samples/SignalRSamples/ObservableExtensions.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Reactive.Linq; +using System.Threading.Channels; + +namespace SignalRSamples +{ + public static class ObservableExtensions + { + public static ChannelReader AsChannelReader(this IObservable observable) + { + // This sample shows adapting an observable to a ChannelReader without + // back pressure, if the connection is slower than the producer, memory will + // start to increase. + + // If the channel is unbounded, TryWrite will return false and effectively + // drop items. + + // The other alternative is to use a bounded channel, and when the limit is reached + // block on WaitToWriteAsync. This will block a thread pool thread and isn't recommended + var channel = Channel.CreateUnbounded(); + + var disposable = observable.Subscribe( + value => channel.Writer.TryWrite(value), + error => channel.Writer.TryComplete(error), + () => channel.Writer.TryComplete()); + + // Complete the subscription on the reader completing + channel.Reader.Completion.ContinueWith(task => disposable.Dispose()); + + return channel.Reader; + } + } +} \ No newline at end of file diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/CastObservable.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/CastObservable.cs deleted file mode 100644 index 629cdf24bd..0000000000 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/CastObservable.cs +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; - -namespace Microsoft.AspNetCore.SignalR.Client -{ - internal class CastObservable : IObservable - { - private IObservable _innerObservable; - - public CastObservable(IObservable innerObservable) - { - _innerObservable = innerObservable; - } - - public IDisposable Subscribe(IObserver observer) - { - return _innerObservable.Subscribe(new CastObserver(observer)); - } - - private class CastObserver : IObserver - { - private IObserver _innerObserver; - - public CastObserver(IObserver innerObserver) - { - _innerObserver = innerObserver; - } - - public void OnCompleted() - { - _innerObserver.OnCompleted(); - } - - public void OnError(Exception error) - { - _innerObserver.OnError(error); - } - - public void OnNext(object value) - { - try - { - _innerObserver.OnNext((TResult)value); - } - catch(Exception ex) - { - _innerObserver.OnError(ex); - } - } - } - } -} diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs index 2839a974f0..0e50b5c433 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs @@ -11,71 +11,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal // True-internal because this is a weird and tricky class to use :) internal static class AsyncEnumeratorAdapters { - public static IAsyncEnumerator FromObservable(IObservable observable, CancellationToken cancellationToken) - { - // TODO: Allow bounding and optimizations? - var channel = Channel.CreateUnbounded(); - var observer = new ObserverState(); - var channelObserver = new ChannelObserver(channel.Writer); - observer.Subscription = observable.Subscribe(channelObserver); - observer.TokenRegistration = cancellationToken.Register(obs => ((ChannelObserver)obs).OnCompleted(), channelObserver); - - // Make sure the subscription and token registration is disposed when enumeration is completed. - return new AsyncEnumerator(channel.Reader, cancellationToken, observer); - } - - // To track and dispose of the Subscription and the cancellation token registration. - private class ObserverState : IDisposable - { - public CancellationTokenRegistration TokenRegistration; - public IDisposable Subscription; - - public void Dispose() - { - TokenRegistration.Dispose(); - Subscription.Dispose(); - } - } - - private class ChannelObserver : IObserver - { - private readonly ChannelWriter _output; - - public ChannelObserver(ChannelWriter output) - { - _output = output; - } - - public void OnCompleted() - { - _output.TryComplete(); - } - - public void OnError(Exception error) - { - _output.TryComplete(error); - } - - public void OnNext(T value) - { - // This will block the thread emitting the object if the channel is bounded and full - // I think this is OK, since we want to push the backpressure up. However, we may need - // to find a way to force the entire subscription off to a dedicated thread in order to - // ensure we don't block other tasks - - // Right now however, we use unbounded channels, so all of the above is moot because TryWrite will always succeed - while (!_output.TryWrite(value)) - { - // Wait for a spot - if (!_output.WaitToWriteAsync().Result) - { - // Channel was closed so we just no-op. The observer shouldn't throw. - return; - } - } - } - } - public static IAsyncEnumerator GetAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) { // Nothing to dispose when we finish enumerating in this case. diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index 67c293420b..d52ae5bbfa 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -198,7 +198,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); await SendInvocationError(hubMethodInvocationMessage, connection, - $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is null, does not implement the IObservable<> interface or is not a ReadableChannel<>."); + $"The value returned by the streaming method '{methodExecutor.MethodInfo.Name}' is not a ChannelReader<>."); return; } @@ -364,13 +364,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { if (result != null) { - if (hubMethodDescriptor.IsObservable) - { - streamCts = CreateCancellation(); - enumerator = hubMethodDescriptor.FromObservable(result, streamCts.Token); - return true; - } - if (hubMethodDescriptor.IsChannel) { streamCts = CreateCancellation(); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs index 55171bcc82..a15dce772e 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/HubMethodDescriptor.cs @@ -15,10 +15,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { internal class HubMethodDescriptor { - private static readonly MethodInfo FromObservableMethod = typeof(AsyncEnumeratorAdapters) - .GetRuntimeMethods() - .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.FromObservable)) && m.IsGenericMethod); - private static readonly MethodInfo GetAsyncEnumeratorMethod = typeof(AsyncEnumeratorAdapters) .GetRuntimeMethods() .Single(m => m.Name.Equals(nameof(AsyncEnumeratorAdapters.GetAsyncEnumerator)) && m.IsGenericMethod); @@ -33,12 +29,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal ? MethodExecutor.AsyncResultType : MethodExecutor.MethodReturnType; - if (IsObservableType(NonAsyncReturnType, out var observableItemType)) - { - IsObservable = true; - StreamReturnType = observableItemType; - } - else if (IsChannelType(NonAsyncReturnType, out var channelItemType)) + if (IsChannelType(NonAsyncReturnType, out var channelItemType)) { IsChannel = true; StreamReturnType = channelItemType; @@ -53,11 +44,9 @@ namespace Microsoft.AspNetCore.SignalR.Internal public Type NonAsyncReturnType { get; } - public bool IsObservable { get; } - public bool IsChannel { get; } - public bool IsStreamable => IsObservable || IsChannel; + public bool IsStreamable => IsChannel; public Type StreamReturnType { get; } @@ -76,35 +65,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } - private static bool IsObservableType(Type type, out Type payloadType) - { - var observableInterface = IsIObservable(type) ? type : type.GetInterfaces().FirstOrDefault(IsIObservable); - if (observableInterface == null) - { - payloadType = null; - return false; - } - - payloadType = observableInterface.GetGenericArguments()[0]; - return true; - - bool IsIObservable(Type iface) - { - return iface.IsGenericType && iface.GetGenericTypeDefinition() == typeof(IObservable<>); - } - } - - public IAsyncEnumerator FromObservable(object observable, CancellationToken cancellationToken) - { - // there is the potential for compile to be called times but this has no harmful effect other than perf - if (_convertToEnumerator == null) - { - _convertToEnumerator = CompileConvertToEnumerator(FromObservableMethod, StreamReturnType); - } - - return _convertToEnumerator.Invoke(observable, cancellationToken); - } - public IAsyncEnumerator FromChannel(object channel, CancellationToken cancellationToken) { // there is the potential for compile to be called times but this has no harmful effect other than perf @@ -120,10 +80,6 @@ namespace Microsoft.AspNetCore.SignalR.Internal { // This will call one of two adapter methods to wrap the passed in streamable value // and cancellation token to an IAsyncEnumerator - // - // IObservable: - // AsyncEnumeratorAdapters.FromObservable(observable, cancellationToken); - // // ChannelReader // AsyncEnumeratorAdapters.GetAsyncEnumerator(channelReader, cancellationToken); diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 79c5cf1f8d..61f98651e0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -681,7 +681,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await connection.StartAsync().OrTimeout(); var channel = await connection.StreamAsChannelAsync("StreamBroken").OrTimeout(); var ex = await Assert.ThrowsAsync(() => channel.ReadAllAsync()).OrTimeout(); - Assert.Equal("The value returned by the streaming method 'StreamBroken' is null, does not implement the IObservable<> interface or is not a ReadableChannel<>.", ex.Message); + Assert.Equal("The value returned by the streaming method 'StreamBroken' is not a ChannelReader<>.", ex.Message); } catch (Exception ex) { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index b8579ff3bc..ee7c962a45 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public string Echo(string message) => TestHubMethodsImpl.Echo(message); - public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -87,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public string Echo(string message) => TestHubMethodsImpl.Echo(message); - public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -110,7 +110,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests public string Echo(string message) => TestHubMethodsImpl.Echo(message); - public IObservable Stream(int count) => TestHubMethodsImpl.Stream(count); + public ChannelReader Stream(int count) => TestHubMethodsImpl.Stream(count); public ChannelReader StreamException() => TestHubMethodsImpl.StreamException(); @@ -139,11 +139,22 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests return message; } - public static IObservable Stream(int count) + public static ChannelReader Stream(int count) { - return Observable.Interval(TimeSpan.FromMilliseconds(1)) - .Select((_, index) => index) - .Take(count); + var channel = Channel.CreateUnbounded(); + + Task.Run(async () => + { + for (var i = 0; i < count; i++) + { + await channel.Writer.WriteAsync(i); + await Task.Delay(100); + } + + channel.Writer.TryComplete(); + }); + + return channel.Reader; } public static ChannelReader StreamException() diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs index 1bb27f1637..a94b8d0619 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubConnectionHandlerTests.cs @@ -85,212 +85,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - [Fact] - public async Task ObservableHubRemovesSubscriptionsWithInfiniteStreams() - { - var observable = new Observable(); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(s => s.AddSingleton(observable)); - var connectionHandler = serviceProvider.GetService>(); - - var waitForSubscribe = new TaskCompletionSource(); - observable.OnSubscribe = o => - { - waitForSubscribe.TrySetResult(null); - }; - - var waitForDispose = new TaskCompletionSource(); - observable.OnDispose = o => - { - waitForDispose.TrySetResult(null); - }; - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - - async Task Produce() - { - var i = 0; - while (true) - { - observable.OnNext(i++); - await Task.Delay(100); - } - } - - _ = Produce(); - - Assert.Empty(observable.Observers); - - var subscribeTask = client.StreamAsync(nameof(ObservableHub.Subscribe)); - - await waitForSubscribe.Task.OrTimeout(); - - Assert.Single(observable.Observers); - - client.Dispose(); - - // We don't care if this throws, we just expect it to complete - try - { - await subscribeTask.OrTimeout(); - } - catch - { - - } - - await waitForDispose.Task.OrTimeout(); - - Assert.Empty(observable.Observers); - - await connectionHandlerTask.OrTimeout(); - } - } - - [Fact] - public async Task OberserverDoesntThrowWhenOnNextIsCalledAfterChannelIsCompleted() - { - var observable = new Observable(); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(s => s.AddSingleton(observable)); - var connectionHandler = serviceProvider.GetService>(); - - var waitForSubscribe = new TaskCompletionSource(); - observable.OnSubscribe = o => - { - waitForSubscribe.TrySetResult(null); - }; - - var waitForDispose = new TaskCompletionSource(); - observable.OnDispose = o => - { - waitForDispose.TrySetResult(null); - }; - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - - var subscribeTask = client.StreamAsync(nameof(ObservableHub.Subscribe)); - - await waitForSubscribe.Task.OrTimeout(); - - Assert.Single(observable.Observers); - - // Disposing the client to complete the observer. Further calls to OnNext should no-op - client.Dispose(); - - // Calling OnNext after the client has disconnected shouldn't throw. - observable.OnNext(1); - - await waitForDispose.Task.OrTimeout(); - - Assert.Empty(observable.Observers); - - await connectionHandlerTask.OrTimeout(); - } - } - - [Fact] - public async Task ObservableHubRemovesSubscriptions() - { - var observable = new Observable(); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(s => s.AddSingleton(observable)); - var connectionHandler = serviceProvider.GetService>(); - - var waitForSubscribe = new TaskCompletionSource(); - observable.OnSubscribe = o => - { - waitForSubscribe.TrySetResult(null); - }; - - var waitForDispose = new TaskCompletionSource(); - observable.OnDispose = o => - { - waitForDispose.TrySetResult(null); - }; - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - - async Task Subscribe() - { - var results = await client.StreamAsync(nameof(ObservableHub.Subscribe)); - - var items = results.OfType().ToList(); - - Assert.Single(items); - Assert.Equal(2, (long)items[0].Item); - } - - observable.OnNext(1); - - Assert.Empty(observable.Observers); - - var subscribeTask = Subscribe(); - - await waitForSubscribe.Task.OrTimeout(); - - Assert.Single(observable.Observers); - - observable.OnNext(2); - - observable.Complete(); - - await subscribeTask.OrTimeout(); - - client.Dispose(); - - await waitForDispose.Task.OrTimeout(); - - Assert.Empty(observable.Observers); - - await connectionHandlerTask.OrTimeout(); - } - } - - [Fact] - public async Task ObservableHubRemovesSubscriptionWhenCanceledFromClient() - { - var observable = new Observable(); - var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(s => s.AddSingleton(observable)); - var connectionHandler = serviceProvider.GetService>(); - - var waitForSubscribe = new TaskCompletionSource(); - observable.OnSubscribe = o => - { - waitForSubscribe.TrySetResult(null); - }; - - var waitForDispose = new TaskCompletionSource(); - observable.OnDispose = o => - { - waitForDispose.TrySetResult(null); - }; - - using (var client = new TestClient()) - { - var connectionHandlerTask = await client.ConnectAsync(connectionHandler); - - var invocationId = await client.SendStreamInvocationAsync(nameof(ObservableHub.Subscribe)).OrTimeout(); - - await waitForSubscribe.Task.OrTimeout(); - - await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout(); - - await waitForDispose.Task.OrTimeout(); - - var message = await client.ReadAsync().OrTimeout(); - - Assert.IsType(message); - - client.Dispose(); - - await connectionHandlerTask.OrTimeout(); - } - } - [Fact] public async Task MissingHandshakeAndMessageSentFromHubConnectionCanBeDisposedCleanly() { @@ -1677,7 +1471,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.NotNull(completion); if (detailedErrors) { - Assert.Equal("An error occurred on the server while streaming results. Exception: Exception from observable", completion.Error); + Assert.Equal("An error occurred on the server while streaming results. Exception: Exception from channel", completion.Error); } else { @@ -1733,8 +1527,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { foreach (var method in new[] { - nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync), - nameof(StreamingHub.CounterObservable), nameof(StreamingHub.CounterObservableAsync), nameof(StreamingHub.CounterObservableValueTaskAsync) + nameof(StreamingHub.CounterChannel), nameof(StreamingHub.CounterChannelAsync), nameof(StreamingHub.CounterChannelValueTaskAsync) }) { foreach (var protocolName in HubProtocolHelpers.AllProtocolNames) diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs index dadea91d70..8782293e8c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTestUtils/Hubs.cs @@ -421,18 +421,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - public class ObservableHub : Hub - { - private readonly Observable _numbers; - - public ObservableHub(Observable numbers) - { - _numbers = numbers; - } - - public IObservable Subscribe() => _numbers; - } - public class AbortHub : Hub { public void Kill() @@ -441,89 +429,8 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - public class Observable : IObservable - { - public List> Observers = new List>(); - - public Action> OnSubscribe; - - public Action> OnDispose; - - public IDisposable Subscribe(IObserver observer) - { - lock (Observers) - { - Observers.Add(observer); - } - - OnSubscribe?.Invoke(observer); - - return new DisposableAction(() => - { - lock (Observers) - { - Observers.Remove(observer); - } - - OnDispose?.Invoke(observer); - }); - } - - public void OnNext(T value) - { - lock (Observers) - { - foreach (var observer in Observers) - { - observer.OnNext(value); - } - } - } - - public void Complete() - { - lock (Observers) - { - foreach (var observer in Observers) - { - observer.OnCompleted(); - } - } - } - - private class DisposableAction : IDisposable - { - private readonly Action _action; - public DisposableAction(Action action) - { - _action = action; - } - - public void Dispose() - { - _action(); - } - } - } - public class StreamingHub : TestHub { - public IObservable CounterObservable(int count) - { - return new CountingObservable(count); - } - - public async Task> CounterObservableAsync(int count) - { - await Task.Yield(); - return CounterObservable(count); - } - - public async ValueTask> CounterObservableValueTaskAsync(int count) - { - await Task.Yield(); - return CounterObservable(count); - } public ChannelReader CounterChannel(int count) { @@ -558,34 +465,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests return Channel.CreateUnbounded().Reader; } - public IObservable ThrowStream() + public ChannelReader ThrowStream() { - return Observable.Throw(new Exception("Exception from observable")); - } - - private class CountingObservable : IObservable - { - private int _count; - - public CountingObservable(int count) - { - _count = count; - } - - public IDisposable Subscribe(IObserver observer) - { - var cts = new CancellationTokenSource(); - Task.Run(() => - { - for (int i = 0; !cts.Token.IsCancellationRequested && i < _count; i++) - { - observer.OnNext(i.ToString()); - } - observer.OnCompleted(); - }); - - return new CancellationDisposable(cts); - } + var channel = Channel.CreateUnbounded(); + channel.Writer.TryComplete(new Exception("Exception from channel")); + return channel.Reader; } }