diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index ee0059f6f8..c981542c22 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -688,7 +688,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState) + private async Task<(bool close, Exception exception)> ProcessMessagesAsync(HubMessage message, ConnectionState connectionState, ChannelWriter invocationMessageWriter) { Log.ResettingKeepAliveTimer(_logger); ResetTimeout(); @@ -703,7 +703,7 @@ namespace Microsoft.AspNetCore.SignalR.Client break; case InvocationMessage invocation: Log.ReceivedInvocation(_logger, invocation.InvocationId, invocation.Target, invocation.Arguments); - await DispatchInvocationAsync(invocation); + await invocationMessageWriter.WriteAsync(invocation); break; case CompletionMessage completion: if (!connectionState.TryRemoveInvocation(completion.InvocationId, out irq)) @@ -903,6 +903,19 @@ namespace Microsoft.AspNetCore.SignalR.Client var uploadStreamSource = new CancellationTokenSource(); _uploadStreamToken = uploadStreamSource.Token; + var invocationMessageChannel = Channel.CreateUnbounded(); + var invocationMessageReceiveTask = StartProcessingInvocationMessages(invocationMessageChannel.Reader); + + async Task StartProcessingInvocationMessages(ChannelReader invocationMessageChannelReader) + { + while (await invocationMessageChannelReader.WaitToReadAsync()) + { + while (invocationMessageChannelReader.TryRead(out var invocationMessage)) + { + await DispatchInvocationAsync(invocationMessage); + } + } + } try { @@ -929,7 +942,7 @@ namespace Microsoft.AspNetCore.SignalR.Client Exception exception; // We have data, process it - (close, exception) = await ProcessMessagesAsync(message, connectionState); + (close, exception) = await ProcessMessagesAsync(message, connectionState, invocationMessageChannel.Writer); if (close) { // Closing because we got a close frame, possibly with an error in it. @@ -970,6 +983,8 @@ namespace Microsoft.AspNetCore.SignalR.Client } finally { + invocationMessageChannel.Writer.TryComplete(); + await invocationMessageReceiveTask; timer.Stop(); uploadStreamSource.Cancel(); } @@ -1233,8 +1248,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - // Represents all the transient state about a connection - // This includes binding information because return type binding depends upon _pendingCalls + //TODO: Refactor all transient state about the connection into the ConnectionState class. private class ConnectionState : IInvocationBinder { private volatile bool _stopping; diff --git a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs index 5e811aa36a..1d403d4609 100644 --- a/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/FunctionalTests/HubConnectionTests.cs @@ -247,6 +247,46 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + [LogLevel(LogLevel.Trace)] + public async Task CanInvokeFromOnHandler(string protocolName, HttpTransportType transportType, string path) + { + var protocol = HubProtocols[protocolName]; + using (StartServer(out var server)) + { + const string originalMessage = "SignalR"; + + var connection = CreateHubConnection(server.Url, path, transportType, protocol, LoggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + + var helloWorldTcs = new TaskCompletionSource(); + var echoTcs = new TaskCompletionSource(); + connection.On("Echo", async (message) => + { + echoTcs.SetResult(message); + helloWorldTcs.SetResult(await connection.InvokeAsync(nameof(TestHub.HelloWorld)).OrTimeout()); + }); + + await connection.InvokeAsync("CallEcho", originalMessage).OrTimeout(); + + Assert.Equal(originalMessage, await echoTcs.Task.OrTimeout()); + Assert.Equal("Hello World!", await helloWorldTcs.Task.OrTimeout()); + } + catch (Exception ex) + { + LoggerFactory.CreateLogger().LogError(ex, "{ExceptionType} from test", ex.GetType().FullName); + throw; + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] [LogLevel(LogLevel.Trace)]