diff --git a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs index 7f8706bdd1..fe65bf13c6 100644 --- a/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs +++ b/src/SignalR/server/Core/src/Internal/DefaultHubDispatcher.cs @@ -272,7 +272,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal } Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor); - _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts); + _ = StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, scope, hubActivator, hub, cts, hubMethodInvocationMessage); } else if (string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId)) @@ -304,17 +304,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal // And normal invocations handle cleanup below in the finally if (isStreamCall) { - hubActivator?.Release(hub); - scope.Dispose(); - foreach (var stream in hubMethodInvocationMessage.StreamIds) - { - try - { - connection.StreamTracker.Complete(CompletionMessage.Empty(stream)); - } - // ignore failures, it means the client already completed the streams - catch { } - } + CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); } } @@ -352,56 +342,72 @@ namespace Microsoft.AspNetCore.SignalR.Internal { if (disposeScope) { - hubActivator?.Release(hub); - scope.Dispose(); + CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); + } + } + } + + private void CleanupInvocation(HubConnectionContext connection, HubMethodInvocationMessage hubMessage, IHubActivator hubActivator, + THub hub, IServiceScope scope) + { + hubActivator?.Release(hub); + scope.Dispose(); + + if (hubMessage.StreamIds != null) + { + foreach (var stream in hubMessage.StreamIds) + { + try + { + connection.StreamTracker.Complete(CompletionMessage.Empty(stream)); + } + // ignore failures, it means the client already completed the streams + catch (KeyNotFoundException) { } } } } private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator, IServiceScope scope, - IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts) + IHubActivator hubActivator, THub hub, CancellationTokenSource streamCts, HubMethodInvocationMessage hubMethodInvocationMessage) { string error = null; - using (scope) + try { - try + while (await enumerator.MoveNextAsync()) { - while (await enumerator.MoveNextAsync()) - { - // Send the stream item - await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); - } + // Send the stream item + await connection.WriteAsync(new StreamItemMessage(invocationId, enumerator.Current)); } - catch (ChannelClosedException ex) + } + catch (ChannelClosedException ex) + { + // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method + error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors); + } + catch (Exception ex) + { + // If the streaming method was canceled we don't want to send a HubException message - this is not an error case + if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) + && cts.IsCancellationRequested)) { - // If the channel closes from an exception in the streaming method, grab the innerException for the error from the streaming method - error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex.InnerException ?? ex, _enableDetailedErrors); + error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); } - catch (Exception ex) + } + finally + { + (enumerator as IDisposable)?.Dispose(); + + CleanupInvocation(connection, hubMethodInvocationMessage, hubActivator, hub, scope); + + // Dispose the linked CTS for the stream. + streamCts.Dispose(); + + await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); + + if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) { - // If the streaming method was canceled we don't want to send a HubException message - this is not an error case - if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) - && cts.IsCancellationRequested)) - { - error = ErrorMessageHelper.BuildErrorMessage("An error occurred on the server while streaming results.", ex, _enableDetailedErrors); - } - } - finally - { - (enumerator as IDisposable)?.Dispose(); - - hubActivator.Release(hub); - - // Dispose the linked CTS for the stream. - streamCts.Dispose(); - - await connection.WriteAsync(CompletionMessage.WithError(invocationId, error)); - - if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) - { - cts.Dispose(); - } + cts.Dispose(); } } } diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs index fb0c3ee00c..c4988c378d 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTestUtils/Hubs.cs @@ -258,6 +258,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests // Wait for an item to appear first then return from the hub method to end the invocation return source.WaitToReadAsync().AsTask(); } + + public ChannelReader StreamAndUploadIgnoreItems(ChannelReader source) + { + var channel = Channel.CreateUnbounded(); + _ = ChannelFunc(channel.Writer, source); + + return channel.Reader; + + async Task ChannelFunc(ChannelWriter output, ChannelReader input) + { + // Wait for an item to appear first then return from the hub method to end the invocation + await input.WaitToReadAsync(); + output.Complete(); + } + } } public abstract class TestHub : Hub diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index 35741b02ca..14b2583314 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -3101,6 +3101,53 @@ namespace Microsoft.AspNetCore.SignalR.Tests Assert.True(errorLogged); } + [Fact] + public async Task UploadStreamAndStreamingMethodClosesStreamsOnServerWhenMethodCompletes() + { + bool errorLogged = false; + bool ExpectedErrors(WriteContext writeContext) + { + if (writeContext.LoggerName == "Microsoft.AspNetCore.SignalR.HubConnectionHandler" && + writeContext.EventId.Name == "ErrorProcessingRequest") + { + errorLogged = true; + return true; + } + + return false; + } + + using (StartVerifiableLog(ExpectedErrors)) + { + var serviceProvider = HubConnectionHandlerTestUtils.CreateServiceProvider(loggerFactory: LoggerFactory); + var connectionHandler = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var connectionHandlerTask = await client.ConnectAsync(connectionHandler).OrTimeout(); + + await client.SendStreamInvocationAsync(nameof(MethodHub.StreamAndUploadIgnoreItems), streamIds: new[] { "id" }, args: Array.Empty()).OrTimeout(); + + await client.SendHubMessageAsync(new StreamItemMessage("id", "ignored")).OrTimeout(); + var result = await client.ReadAsync().OrTimeout(); + + var simpleCompletion = Assert.IsType(result); + Assert.Null(simpleCompletion.Result); + + // This will log an error on the server as the hub method has completed and will complete all associated streams + await client.SendHubMessageAsync(new StreamItemMessage("id", "error!")).OrTimeout(); + + // Shut down + client.Dispose(); + + await connectionHandlerTask.OrTimeout(); + } + } + + // Check that the stream has been completed by noting the existance of an error + Assert.True(errorLogged); + } + [Theory] [InlineData(nameof(LongRunningHub.CancelableStream))] [InlineData(nameof(LongRunningHub.CancelableStream2), 1, 2)]