From cc42b0eaef121b12bf19bb7504e547dd6bbb2983 Mon Sep 17 00:00:00 2001 From: Pawel Kadluczka Date: Wed, 25 Oct 2017 13:51:43 -0700 Subject: [PATCH] Fixing a bug where cancellation could result in HubException When the client cancels a streaming method the server would send an error completion. This was not correct because cancellation is not an error. We did not see this because our client ignores any messages for a given streaming invocation after sending a CancelInvokationMessage but other clients may want to drain messages before considering a streaming method canceled. --- .../HubEndPoint.cs | 18 ++++++++--- .../HubEndpointTests.cs | 32 ++++++++++++++++++- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs index e55d31dc57..e3ce3f1b7d 100644 --- a/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR.Core/HubEndPoint.cs @@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR IHubContext hubContext, IOptions hubOptions, ILogger> logger, - IServiceScopeFactory serviceScopeFactory, + IServiceScopeFactory serviceScopeFactory, IUserIdProvider userIdProvider) { _protocolResolver = protocolResolver; @@ -282,7 +282,8 @@ namespace Microsoft.AspNetCore.SignalR case CancelInvocationMessage cancelInvocationMessage: // Check if there is an associated active stream and cancel it if it exists. - if (connection.ActiveRequestCancellationSources.TryRemove(cancelInvocationMessage.InvocationId, out var cts)) + // The cts will be removed when the streaming method completes executing + if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts)) { _logger.CancelStream(cancelInvocationMessage.InvocationId); cts.Cancel(); @@ -464,6 +465,8 @@ namespace Microsoft.AspNetCore.SignalR private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator enumerator) { + string error = null; + try { while (await enumerator.MoveNextAsync()) @@ -471,15 +474,20 @@ namespace Microsoft.AspNetCore.SignalR // Send the stream item await SendMessageAsync(connection, new StreamItemMessage(invocationId, enumerator.Current)); } - - await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: null)); } catch (Exception ex) { - await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: ex.Message)); + // If the streaming method was canceled we don't want to send a HubException message - this is not an error case + if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts) + && cts.IsCancellationRequested)) + { + error = ex.Message; + } } finally { + await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: error)); + if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts)) { cts.Dispose(); diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index 02972620c1..188a9511d0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -132,7 +132,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests client.Dispose(); - // We don't care if this throws, we just expect it to complete try { @@ -1024,6 +1023,32 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Fact] + public async Task NonErrorCompletionSentWhenStreamCanceledFromClient() + { + var serviceProvider = CreateServiceProvider(); + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient()) + { + var endPointLifetime = endPoint.OnConnectedAsync(client.Connection); + + await client.Connected.OrTimeout(); + + var invocationId = await client.SendInvocationAsync(nameof(StreamingHub.BlockingStream)).OrTimeout(); + // cancel the Streaming method + await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout(); + + var hubMessage = Assert.IsType(await client.ReadAsync().OrTimeout()); + Assert.Equal(invocationId, hubMessage.InvocationId); + Assert.Null(hubMessage.Error); + + client.Dispose(); + + await endPointLifetime.OrTimeout(); + } + } + public static IEnumerable StreamingMethodAndHubProtocols { get @@ -1587,6 +1612,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests return channel.In; } + public ReadableChannel BlockingStream() + { + return Channel.CreateUnbounded().In; + } + private class CountingObservable : IObservable { private int _count;