diff --git a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts index 0b7fb0f391..6f6bcc8a4d 100644 --- a/clients/ts/FunctionalTests/ts/HubConnectionTests.ts +++ b/clients/ts/FunctionalTests/ts/HubConnectionTests.ts @@ -166,7 +166,7 @@ describe("hubConnection", () => { // exception expected but none thrown fail(); }).catch((e) => { - expect(e.message).toBe("The client attempted to invoke the streaming 'EmptyStream' method in a non-streaming fashion."); + expect(e.message).toBe("The client attempted to invoke the streaming 'EmptyStream' method with a non-streaming invocation."); }).then(() => { return hubConnection.stop(); }).then(() => { @@ -190,7 +190,7 @@ describe("hubConnection", () => { // exception expected but none thrown fail(); }).catch((e) => { - expect(e.message).toBe("The client attempted to invoke the streaming 'Stream' method in a non-streaming fashion."); + expect(e.message).toBe("The client attempted to invoke the streaming 'Stream' method with a non-streaming invocation."); }).then(() => { return hubConnection.stop(); }).then(() => { @@ -246,7 +246,7 @@ describe("hubConnection", () => { fail(); }, error: function error(err) { - expect(err.message).toEqual("The client attempted to invoke the non-streaming 'Echo' method in a streaming fashion."); + expect(err.message).toEqual("The client attempted to invoke the non-streaming 'Echo' method with a streaming invocation."); hubConnection.stop(); done(); }, diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs index 0ef0958d5a..966431e4df 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/AsyncEnumeratorAdapters.cs @@ -2,11 +2,9 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Linq; -using System.Reflection; using System.Threading; -using System.Threading.Tasks; using System.Threading.Channels; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.SignalR.Internal { @@ -23,7 +21,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal // Dispose the subscription when the token is cancelled cancellationToken.Register(state => ((IDisposable)state).Dispose(), subscription); - return GetAsyncEnumerator(channel.Reader, cancellationToken); + // Make sure the subscription is disposed when enumeration is completed. + return new AsyncEnumerator(channel.Reader, cancellationToken, subscription); } private class ChannelObserver : IObserver @@ -75,11 +74,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal public static IAsyncEnumerator GetAsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken = default(CancellationToken)) { - return new AsyncEnumerator(channel, cancellationToken); + // Nothing to dispose when we finish enumerating in this case. + return new AsyncEnumerator(channel, cancellationToken, disposable: null); } /// Provides an async enumerator for the data in a channel. - internal class AsyncEnumerator : IAsyncEnumerator + internal class AsyncEnumerator : IAsyncEnumerator, IDisposable { /// The channel being enumerated. private readonly ChannelReader _channel; @@ -88,10 +88,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal /// The current element of the enumeration. private object _current; - internal AsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken) + private readonly IDisposable _disposable; + + internal AsyncEnumerator(ChannelReader channel, CancellationToken cancellationToken, IDisposable disposable) { _channel = channel; _cancellationToken = cancellationToken; + _disposable = disposable; } public object Current => _current; @@ -117,6 +120,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; }, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default); } + + public void Dispose() + { + _disposable?.Dispose(); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs index c4aa4ec16c..18ededd7d1 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.Log.cs @@ -49,10 +49,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal LoggerMessage.Define(LogLevel.Debug, new EventId(12, "ReceivedStreamHubInvocation"), "Received stream hub invocation: {InvocationMessage}."); private static readonly Action _streamingMethodCalledWithInvoke = - LoggerMessage.Define(LogLevel.Error, new EventId(13, "StreamingMethodCalledWithInvoke"), "A streaming method was invoked in the non-streaming fashion : {InvocationMessage}."); + LoggerMessage.Define(LogLevel.Error, new EventId(13, "StreamingMethodCalledWithInvoke"), "A streaming method was invoked with a non-streaming invocation : {InvocationMessage}."); private static readonly Action _nonStreamingMethodCalledWithStream = - LoggerMessage.Define(LogLevel.Error, new EventId(14, "NonStreamingMethodCalledWithStream"), "A non-streaming method was invoked in the streaming fashion : {InvocationMessage}."); + LoggerMessage.Define(LogLevel.Error, new EventId(14, "NonStreamingMethodCalledWithStream"), "A non-streaming method was invoked with a streaming invocation : {InvocationMessage}."); private static readonly Action _invalidReturnValueFromStreamingMethod = LoggerMessage.Define(LogLevel.Error, new EventId(15, "InvalidReturnValueFromStreamingMethod"), "A streaming method returned a value that cannot be used to build enumerator {HubMethod}."); diff --git a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs index 1330ca1837..67c293420b 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/Internal/DefaultHubDispatcher.cs @@ -193,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal if (isStreamedInvocation) { - if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator)) + if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts)) { Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name); @@ -203,7 +203,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal } Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator); + await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, streamCts); } // Non-empty/null InvocationId ==> Blocking invocation that needs a response else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -231,7 +231,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal } } - private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator) + private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, CancellationTokenSource streamCts) { string error = null; @@ -259,7 +259,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal } finally { - await connection.WriteAsync(new CompletionMessage(invocationId, error: error, result: null, hasResult: false)); + (enumerator as IDisposable)?.Dispose(); + + // Dispose the linked CTS for the stream. + streamCts.Dispose(); + + await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) { @@ -337,7 +342,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, - $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method in a non-streaming fashion.")); + $"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation.")); } return false; @@ -347,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal { Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage); await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId, - $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method in a streaming fashion.")); + $"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation.")); return false; } @@ -355,31 +360,35 @@ namespace Microsoft.AspNetCore.SignalR.Internal return true; } - private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator) + private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator enumerator, out CancellationTokenSource streamCts) { if (result != null) { if (hubMethodDescriptor.IsObservable) { - enumerator = hubMethodDescriptor.FromObservable(result, CreateCancellation()); + streamCts = CreateCancellation(); + enumerator = hubMethodDescriptor.FromObservable(result, streamCts.Token); return true; } if (hubMethodDescriptor.IsChannel) { - enumerator = hubMethodDescriptor.FromChannel(result, CreateCancellation()); + streamCts = CreateCancellation(); + enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token); return true; } } + streamCts = null; enumerator = null; return false; - CancellationToken CreateCancellation() + CancellationTokenSource CreateCancellation() { - var streamCts = new CancellationTokenSource(); - connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts); - return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, streamCts.Token).Token; + var userCts = new CancellationTokenSource(); + connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts); + + return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index ee5208252f..dd31e19714 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -618,7 +618,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await connection.StartAsync().OrTimeout(); var channel = await connection.StreamAsChannelAsync("HelloWorld").OrTimeout(); var ex = await Assert.ThrowsAsync(() => channel.ReadAllAsync()).OrTimeout(); - Assert.Equal("The client attempted to invoke the non-streaming 'HelloWorld' method in a streaming fashion.", ex.Message); + Assert.Equal("The client attempted to invoke the non-streaming 'HelloWorld' method with a streaming invocation.", ex.Message); } catch (Exception ex) { @@ -645,7 +645,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await connection.StartAsync().OrTimeout(); var ex = await Assert.ThrowsAsync(() => connection.InvokeAsync("Stream", 3)).OrTimeout(); - Assert.Equal("The client attempted to invoke the streaming 'Stream' method in a non-streaming fashion.", ex.Message); + Assert.Equal("The client attempted to invoke the streaming 'Stream' method with a non-streaming invocation.", ex.Message); } catch (Exception ex) {