fix #1815 by disposing linked cts (#1849)

This commit is contained in:
Andrew Stanton-Nurse 2018-04-04 21:12:21 -07:00 committed by GitHub
parent 6b76d1355e
commit fccc9d1b50
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 44 additions and 27 deletions

View File

@ -166,7 +166,7 @@ describe("hubConnection", () => {
// exception expected but none thrown
fail();
}).catch((e) => {
expect(e.message).toBe("The client attempted to invoke the streaming 'EmptyStream' method in a non-streaming fashion.");
expect(e.message).toBe("The client attempted to invoke the streaming 'EmptyStream' method with a non-streaming invocation.");
}).then(() => {
return hubConnection.stop();
}).then(() => {
@ -190,7 +190,7 @@ describe("hubConnection", () => {
// exception expected but none thrown
fail();
}).catch((e) => {
expect(e.message).toBe("The client attempted to invoke the streaming 'Stream' method in a non-streaming fashion.");
expect(e.message).toBe("The client attempted to invoke the streaming 'Stream' method with a non-streaming invocation.");
}).then(() => {
return hubConnection.stop();
}).then(() => {
@ -246,7 +246,7 @@ describe("hubConnection", () => {
fail();
},
error: function error(err) {
expect(err.message).toEqual("The client attempted to invoke the non-streaming 'Echo' method in a streaming fashion.");
expect(err.message).toEqual("The client attempted to invoke the non-streaming 'Echo' method with a streaming invocation.");
hubConnection.stop();
done();
},

View File

@ -2,11 +2,9 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using System.Threading.Channels;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.SignalR.Internal
{
@ -23,7 +21,8 @@ namespace Microsoft.AspNetCore.SignalR.Internal
// Dispose the subscription when the token is cancelled
cancellationToken.Register(state => ((IDisposable)state).Dispose(), subscription);
return GetAsyncEnumerator(channel.Reader, cancellationToken);
// Make sure the subscription is disposed when enumeration is completed.
return new AsyncEnumerator<object>(channel.Reader, cancellationToken, subscription);
}
private class ChannelObserver<T> : IObserver<T>
@ -75,11 +74,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal
public static IAsyncEnumerator<object> GetAsyncEnumerator<T>(ChannelReader<T> channel, CancellationToken cancellationToken = default(CancellationToken))
{
return new AsyncEnumerator<T>(channel, cancellationToken);
// Nothing to dispose when we finish enumerating in this case.
return new AsyncEnumerator<T>(channel, cancellationToken, disposable: null);
}
/// <summary>Provides an async enumerator for the data in a channel.</summary>
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>
internal class AsyncEnumerator<T> : IAsyncEnumerator<object>, IDisposable
{
/// <summary>The channel being enumerated.</summary>
private readonly ChannelReader<T> _channel;
@ -88,10 +88,13 @@ namespace Microsoft.AspNetCore.SignalR.Internal
/// <summary>The current element of the enumeration.</summary>
private object _current;
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken)
private readonly IDisposable _disposable;
internal AsyncEnumerator(ChannelReader<T> channel, CancellationToken cancellationToken, IDisposable disposable)
{
_channel = channel;
_cancellationToken = cancellationToken;
_disposable = disposable;
}
public object Current => _current;
@ -117,6 +120,11 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}, this, CancellationToken.None, TaskContinuationOptions.ExecuteSynchronously | TaskContinuationOptions.NotOnCanceled, TaskScheduler.Default);
}
public void Dispose()
{
_disposable?.Dispose();
}
}
}

View File

@ -49,10 +49,10 @@ namespace Microsoft.AspNetCore.SignalR.Internal
LoggerMessage.Define<StreamInvocationMessage>(LogLevel.Debug, new EventId(12, "ReceivedStreamHubInvocation"), "Received stream hub invocation: {InvocationMessage}.");
private static readonly Action<ILogger, HubMethodInvocationMessage, Exception> _streamingMethodCalledWithInvoke =
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(13, "StreamingMethodCalledWithInvoke"), "A streaming method was invoked in the non-streaming fashion : {InvocationMessage}.");
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(13, "StreamingMethodCalledWithInvoke"), "A streaming method was invoked with a non-streaming invocation : {InvocationMessage}.");
private static readonly Action<ILogger, HubMethodInvocationMessage, Exception> _nonStreamingMethodCalledWithStream =
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(14, "NonStreamingMethodCalledWithStream"), "A non-streaming method was invoked in the streaming fashion : {InvocationMessage}.");
LoggerMessage.Define<HubMethodInvocationMessage>(LogLevel.Error, new EventId(14, "NonStreamingMethodCalledWithStream"), "A non-streaming method was invoked with a streaming invocation : {InvocationMessage}.");
private static readonly Action<ILogger, string, Exception> _invalidReturnValueFromStreamingMethod =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(15, "InvalidReturnValueFromStreamingMethod"), "A streaming method returned a value that cannot be used to build enumerator {HubMethod}.");

View File

@ -193,7 +193,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
if (isStreamedInvocation)
{
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator))
if (!TryGetStreamingEnumerator(connection, hubMethodInvocationMessage.InvocationId, descriptor, result, out var enumerator, out var streamCts))
{
Log.InvalidReturnValueFromStreamingMethod(_logger, methodExecutor.MethodInfo.Name);
@ -203,7 +203,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
Log.StreamingResult(_logger, hubMethodInvocationMessage.InvocationId, methodExecutor);
await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator);
await StreamResultsAsync(hubMethodInvocationMessage.InvocationId, connection, enumerator, streamCts);
}
// Non-empty/null InvocationId ==> Blocking invocation that needs a response
else if (!string.IsNullOrEmpty(hubMethodInvocationMessage.InvocationId))
@ -231,7 +231,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
}
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator)
private async Task StreamResultsAsync(string invocationId, HubConnectionContext connection, IAsyncEnumerator<object> enumerator, CancellationTokenSource streamCts)
{
string error = null;
@ -259,7 +259,12 @@ namespace Microsoft.AspNetCore.SignalR.Internal
}
finally
{
await connection.WriteAsync(new CompletionMessage(invocationId, error: error, result: null, hasResult: false));
(enumerator as IDisposable)?.Dispose();
// Dispose the linked CTS for the stream.
streamCts.Dispose();
await connection.WriteAsync(CompletionMessage.WithError(invocationId, error));
if (connection.ActiveRequestCancellationSources.TryRemove(invocationId, out var cts))
{
@ -337,7 +342,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
Log.StreamingMethodCalledWithInvoke(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method in a non-streaming fashion."));
$"The client attempted to invoke the streaming '{hubMethodInvocationMessage.Target}' method with a non-streaming invocation."));
}
return false;
@ -347,7 +352,7 @@ namespace Microsoft.AspNetCore.SignalR.Internal
{
Log.NonStreamingMethodCalledWithStream(_logger, hubMethodInvocationMessage);
await connection.WriteAsync(CompletionMessage.WithError(hubMethodInvocationMessage.InvocationId,
$"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method in a streaming fashion."));
$"The client attempted to invoke the non-streaming '{hubMethodInvocationMessage.Target}' method with a streaming invocation."));
return false;
}
@ -355,31 +360,35 @@ namespace Microsoft.AspNetCore.SignalR.Internal
return true;
}
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator)
private bool TryGetStreamingEnumerator(HubConnectionContext connection, string invocationId, HubMethodDescriptor hubMethodDescriptor, object result, out IAsyncEnumerator<object> enumerator, out CancellationTokenSource streamCts)
{
if (result != null)
{
if (hubMethodDescriptor.IsObservable)
{
enumerator = hubMethodDescriptor.FromObservable(result, CreateCancellation());
streamCts = CreateCancellation();
enumerator = hubMethodDescriptor.FromObservable(result, streamCts.Token);
return true;
}
if (hubMethodDescriptor.IsChannel)
{
enumerator = hubMethodDescriptor.FromChannel(result, CreateCancellation());
streamCts = CreateCancellation();
enumerator = hubMethodDescriptor.FromChannel(result, streamCts.Token);
return true;
}
}
streamCts = null;
enumerator = null;
return false;
CancellationToken CreateCancellation()
CancellationTokenSource CreateCancellation()
{
var streamCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, streamCts);
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, streamCts.Token).Token;
var userCts = new CancellationTokenSource();
connection.ActiveRequestCancellationSources.TryAdd(invocationId, userCts);
return CancellationTokenSource.CreateLinkedTokenSource(connection.ConnectionAborted, userCts.Token);
}
}

View File

@ -618,7 +618,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await connection.StartAsync().OrTimeout();
var channel = await connection.StreamAsChannelAsync<int>("HelloWorld").OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => channel.ReadAllAsync()).OrTimeout();
Assert.Equal("The client attempted to invoke the non-streaming 'HelloWorld' method in a streaming fashion.", ex.Message);
Assert.Equal("The client attempted to invoke the non-streaming 'HelloWorld' method with a streaming invocation.", ex.Message);
}
catch (Exception ex)
{
@ -645,7 +645,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await connection.StartAsync().OrTimeout();
var ex = await Assert.ThrowsAsync<HubException>(() => connection.InvokeAsync("Stream", 3)).OrTimeout();
Assert.Equal("The client attempted to invoke the streaming 'Stream' method in a non-streaming fashion.", ex.Message);
Assert.Equal("The client attempted to invoke the streaming 'Stream' method with a non-streaming invocation.", ex.Message);
}
catch (Exception ex)
{