Fixing a bug where cancellation could result in HubException
When the client cancels a streaming method the server would send an error completion. This was not correct because cancellation is not an error. We did not see this because our client ignores any messages for a given streaming invocation after sending a CancelInvokationMessage but other clients may want to drain messages before considering a streaming method canceled.
This commit is contained in:
parent
8c446fc02d
commit
cc42b0eaef
|
|
@ -46,7 +46,7 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
IHubContext<THub> hubContext,
|
||||
IOptions<HubOptions> hubOptions,
|
||||
ILogger<HubEndPoint<THub>> logger,
|
||||
IServiceScopeFactory serviceScopeFactory,
|
||||
IServiceScopeFactory serviceScopeFactory,
|
||||
IUserIdProvider userIdProvider)
|
||||
{
|
||||
_protocolResolver = protocolResolver;
|
||||
|
|
@ -282,7 +282,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
case CancelInvocationMessage cancelInvocationMessage:
|
||||
// Check if there is an associated active stream and cancel it if it exists.
|
||||
if (connection.ActiveRequestCancellationSources.TryRemove(cancelInvocationMessage.InvocationId, out var cts))
|
||||
// The cts will be removed when the streaming method completes executing
|
||||
if (connection.ActiveRequestCancellationSources.TryGetValue(cancelInvocationMessage.InvocationId, out var cts))
|
||||
{
|
||||
_logger.CancelStream(cancelInvocationMessage.InvocationId);
|
||||
cts.Cancel();
|
||||
|
|
@ -464,6 +465,8 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
|
||||
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator)
|
||||
{
|
||||
string error = null;
|
||||
|
||||
try
|
||||
{
|
||||
while (await enumerator.MoveNextAsync())
|
||||
|
|
@ -471,15 +474,20 @@ namespace Microsoft.AspNetCore.SignalR
|
|||
// Send the stream item
|
||||
await SendMessageAsync(connection, new StreamItemMessage(invocationId, enumerator.Current));
|
||||
}
|
||||
|
||||
await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: null));
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: ex.Message));
|
||||
// If the streaming method was canceled we don't want to send a HubException message - this is not an error case
|
||||
if (!(ex is OperationCanceledException && connection.ActiveRequestCancellationSources.TryGetValue(invocationId, out var cts)
|
||||
&& cts.IsCancellationRequested))
|
||||
{
|
||||
error = ex.Message;
|
||||
}
|
||||
}
|
||||
finally
|
||||
{
|
||||
await SendMessageAsync(connection, new StreamCompletionMessage(invocationId, error: error));
|
||||
|
||||
if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts))
|
||||
{
|
||||
cts.Dispose();
|
||||
|
|
|
|||
|
|
@ -132,7 +132,6 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
|
||||
client.Dispose();
|
||||
|
||||
|
||||
// We don't care if this throws, we just expect it to complete
|
||||
try
|
||||
{
|
||||
|
|
@ -1024,6 +1023,32 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
}
|
||||
}
|
||||
|
||||
[Fact]
|
||||
public async Task NonErrorCompletionSentWhenStreamCanceledFromClient()
|
||||
{
|
||||
var serviceProvider = CreateServiceProvider();
|
||||
var endPoint = serviceProvider.GetService<HubEndPoint<StreamingHub>>();
|
||||
|
||||
using (var client = new TestClient())
|
||||
{
|
||||
var endPointLifetime = endPoint.OnConnectedAsync(client.Connection);
|
||||
|
||||
await client.Connected.OrTimeout();
|
||||
|
||||
var invocationId = await client.SendInvocationAsync(nameof(StreamingHub.BlockingStream)).OrTimeout();
|
||||
// cancel the Streaming method
|
||||
await client.SendHubMessageAsync(new CancelInvocationMessage(invocationId)).OrTimeout();
|
||||
|
||||
var hubMessage = Assert.IsType<StreamCompletionMessage>(await client.ReadAsync().OrTimeout());
|
||||
Assert.Equal(invocationId, hubMessage.InvocationId);
|
||||
Assert.Null(hubMessage.Error);
|
||||
|
||||
client.Dispose();
|
||||
|
||||
await endPointLifetime.OrTimeout();
|
||||
}
|
||||
}
|
||||
|
||||
public static IEnumerable<object[]> StreamingMethodAndHubProtocols
|
||||
{
|
||||
get
|
||||
|
|
@ -1587,6 +1612,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
|
|||
return channel.In;
|
||||
}
|
||||
|
||||
public ReadableChannel<string> BlockingStream()
|
||||
{
|
||||
return Channel.CreateUnbounded<string>().In;
|
||||
}
|
||||
|
||||
private class CountingObservable : IObservable<string>
|
||||
{
|
||||
private int _count;
|
||||
|
|
|
|||
Loading…
Reference in New Issue