Resolve deadlock with `InvokeAsync` in `On` handler (#8334)

- Use a channel to unblock the dispatch loop
- Added tests
This commit is contained in:
Mikael Mengistu 2019-03-14 22:04:36 -07:00 committed by David Fowler
parent a673be3b9a
commit cfe0cc38ec
2 changed files with 59 additions and 5 deletions

View File

@ -688,7 +688,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState)
private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState, ChannelWriter<InvocationMessage> invocationMessageWriter)
{
Log.ResettingKeepAliveTimer(_logger);
ResetTimeout();
@ -703,7 +703,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
break;
case InvocationMessage invocation:
Log.ReceivedInvocation(_logger, invocation.InvocationId, invocation.Target, invocation.Arguments);
await DispatchInvocationAsync(invocation);
await invocationMessageWriter.WriteAsync(invocation);
break;
case CompletionMessage completion:
if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq))
@ -903,6 +903,19 @@ namespace Microsoft.AspNetCore.SignalR.Client
var uploadStreamSource = new CancellationTokenSource();
_uploadStreamToken = uploadStreamSource.Token;
var invocationMessageChannel = Channel.CreateUnbounded<InvocationMessage>();
var invocationMessageReceiveTask = StartProcessingInvocationMessages(invocationMessageChannel.Reader);
async Task StartProcessingInvocationMessages(ChannelReader<InvocationMessage> invocationMessageChannelReader)
{
while (await invocationMessageChannelReader.WaitToReadAsync())
{
while (invocationMessageChannelReader.TryRead(out var invocationMessage))
{
await DispatchInvocationAsync(invocationMessage);
}
}
}
try
{
@ -929,7 +942,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
Exception exception;
// We have data, process it
(close, exception) = await ProcessMessagesAsync(message, connectionState);
(close, exception) = await ProcessMessagesAsync(message, connectionState, invocationMessageChannel.Writer);
if (close)
{
// Closing because we got a close frame, possibly with an error in it.
@ -970,6 +983,8 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
finally
{
invocationMessageChannel.Writer.TryComplete();
await invocationMessageReceiveTask;
timer.Stop();
uploadStreamSource.Cancel();
}
@ -1233,8 +1248,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
// Represents all the transient state about a connection
// This includes binding information because return type binding depends upon _pendingCalls
//TODO: Refactor all transient state about the connection into the ConnectionState class.
private class ConnectionState : IInvocationBinder
{
private volatile bool _stopping;

View File

@ -247,6 +247,46 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]
public async Task CanInvokeFromOnHandler(string protocolName, HttpTransportType transportType, string path)
{
var protocol = HubProtocols[protocolName];
using (StartServer<Startup>(out var server))
{
const string originalMessage = "SignalR";
var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory);
try
{
await connection.StartAsync().OrTimeout();
var helloWorldTcs = new TaskCompletionSource<string>();
var echoTcs = new TaskCompletionSource<string>();
connection.On<string>("Echo", async (message) =>
{
echoTcs.SetResult(message);
helloWorldTcs.SetResult(await connection.InvokeAsync<string>(nameof(TestHub.HelloWorld)).OrTimeout());
});
await connection.InvokeAsync("CallEcho", originalMessage).OrTimeout();
Assert.Equal(originalMessage, await echoTcs.Task.OrTimeout());
Assert.Equal("Hello World!", await helloWorldTcs.Task.OrTimeout());
}
catch (Exception ex)
{
LoggerFactory.CreateLogger<HubConnectionTests>().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName);
throw;
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
[LogLevel(LogLevel.Trace)]