diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts index 9cf4e54c6e..e83e6a2cf0 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS.Tests/HubConnection.spec.ts @@ -21,6 +21,32 @@ describe("HubConnection", () => { }); }); + describe("send", () => { + it("sends a non blocking invocation", async () => { + let connection = new TestConnection(); + + let hubConnection = new HubConnection(connection); + var invokePromise = hubConnection.send("testMethod", "arg", 42) + .catch((_) => { }); // Suppress exception and unhandled promise rejection warning. + + // Verify the message is sent + expect(connection.sentData.length).toBe(1); + expect(JSON.parse(connection.sentData[0])).toEqual({ + type: 1, + invocationId: connection.lastInvocationId, + target: "testMethod", + nonblocking: true, + arguments: [ + "arg", + 42 + ] + }); + + // Close the connection + hubConnection.stop(); + }); + }); + describe("invoke", () => { it("sends an invocation", async () => { let connection = new TestConnection(); diff --git a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts index 19178aba34..06061cbfff 100644 --- a/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts +++ b/client-ts/Microsoft.AspNetCore.SignalR.Client.TS/HubConnection.ts @@ -102,7 +102,7 @@ export class HubConnection { } stream(methodName: string, ...args: any[]): Observable { - let invocationDescriptor = this.createInvocation(methodName, args); + let invocationDescriptor = this.createInvocation(methodName, args, false); let subject = new Subject(); @@ -136,8 +136,16 @@ export class HubConnection { return subject; } + send(methodName: string, ...args: any[]): Promise { + let invocationDescriptor = this.createInvocation(methodName, args, true); + + let message = this.protocol.writeMessage(invocationDescriptor); + + return this.connection.send(message); + } + invoke(methodName: string, ...args: any[]): Promise { - let invocationDescriptor = this.createInvocation(methodName, args); + let invocationDescriptor = this.createInvocation(methodName, args, false); let p = new Promise((resolve, reject) => { this.callbacks.set(invocationDescriptor.invocationId, (invocationEvent: CompletionMessage | ResultMessage) => { @@ -175,7 +183,7 @@ export class HubConnection { this.connectionClosedCallback = callback; } - private createInvocation(methodName: string, args: any[]): InvocationMessage { + private createInvocation(methodName: string, args: any[], nonblocking: boolean): InvocationMessage { let id = this.id; this.id++; @@ -184,7 +192,7 @@ export class HubConnection { invocationId: id.toString(), target: methodName, arguments: args, - nonblocking: false + nonblocking: nonblocking }; } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index 6d6624aef3..bcfd1a3863 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -101,29 +101,46 @@ namespace Microsoft.AspNetCore.SignalR.Client public ReadableChannel Stream(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Stream(cancellationToken, returnType, GetNextId(), _loggerFactory, out var channel); - InvokeCore(methodName, irq, args); + _ = InvokeCore(methodName, irq, args, nonBlocking: false); return channel; } - public Task Invoke(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) + public async Task InvokeAsync(string methodName, Type returnType, CancellationToken cancellationToken, params object[] args) { var irq = InvocationRequest.Invoke(cancellationToken, returnType, GetNextId(), _loggerFactory, out var task); - InvokeCore(methodName, irq, args); - return task; + await InvokeCore(methodName, irq, args, nonBlocking: false); + return await task; } - private void InvokeCore(string methodName, InvocationRequest irq, object[] args) + public Task SendAsync(string methodName, CancellationToken cancellationToken, params object[] args) + { + var irq = InvocationRequest.Invoke(cancellationToken, typeof(void), GetNextId(), _loggerFactory, out _); + return InvokeCore(methodName, irq, args, nonBlocking: true); + } + + private Task InvokeCore(string methodName, InvocationRequest irq, object[] args, bool nonBlocking) { ThrowIfConnectionTerminated(); - _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, irq.ResultType.AssemblyQualifiedName, args.Length); + if (nonBlocking) + { + _logger.LogTrace("Preparing invocation of '{target}' and {argumentCount} args", methodName, irq.ResultType.AssemblyQualifiedName, args.Length); + } + else + { + _logger.LogTrace("Preparing invocation of '{target}', with return type '{returnType}' and {argumentCount} args", methodName, irq.ResultType.AssemblyQualifiedName, args.Length); + } // Create an invocation descriptor. Client invocations are always blocking - var invocationMessage = new InvocationMessage(irq.InvocationId, nonBlocking: false, target: methodName, arguments: args); + var invocationMessage = new InvocationMessage(irq.InvocationId, nonBlocking, methodName, args); - // I just want an excuse to use 'irq' as a variable name... - _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); + // We don't need to track invocations for fire an forget calls + if (!nonBlocking) + { + // I just want an excuse to use 'irq' as a variable name... + _logger.LogDebug("Registering Invocation ID '{invocationId}' for tracking", invocationMessage.InvocationId); - AddInvocation(irq); + AddInvocation(irq); + } // Trace the full invocation, but only if that logging level is enabled (because building the args list is a bit slow) if (_logger.IsEnabled(LogLevel.Trace)) @@ -133,7 +150,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } // We don't need to wait for this to complete. It will signal back to the invocation request. - _ = SendInvocation(invocationMessage, irq); + return SendInvocation(invocationMessage, irq); } private async Task SendInvocation(InvocationMessage invocationMessage, InvocationRequest irq) diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs index 04a29ead22..7ef2ff80cd 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs @@ -5,7 +5,6 @@ using System; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; -using static Microsoft.AspNetCore.SignalR.Client.HubConnection; namespace Microsoft.AspNetCore.SignalR.Client { @@ -21,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Client throw new ArgumentNullException(nameof(hubConnection)); } - return hubConnection.Invoke(methodName, typeof(object), cancellationToken, args); + return hubConnection.InvokeAsync(methodName, typeof(object), cancellationToken, args); } public static Task Invoke(this HubConnection hubConnection, string methodName, params object[] args) => @@ -34,7 +33,12 @@ namespace Microsoft.AspNetCore.SignalR.Client throw new ArgumentNullException(nameof(hubConnection)); } - return (TResult)await hubConnection.Invoke(methodName, typeof(TResult), cancellationToken, args); + return (TResult)await hubConnection.InvokeAsync(methodName, typeof(TResult), cancellationToken, args); + } + + public static Task SendAsync(this HubConnection hubConnection, string methodName, params object[] args) + { + return hubConnection.SendAsync(methodName, CancellationToken.None, args); } public static ReadableChannel Stream(this HubConnection hubConnection, string methodName, params object[] args) => diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index 9e91c1b47e..d1e818aed2 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -274,7 +274,10 @@ namespace Microsoft.AspNetCore.SignalR if (!await IsHubMethodAuthorized(scope.ServiceProvider, connection.User, descriptor.Policies)) { _logger.LogDebug("Failed to invoke {hubMethod} because user is unauthorized", invocationMessage.Target); - await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized")); + if (!invocationMessage.NonBlocking) + { + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, $"Failed to invoke '{invocationMessage.Target}' because user is unauthorized")); + } return; } @@ -309,7 +312,7 @@ namespace Microsoft.AspNetCore.SignalR _logger.LogTrace("[{connectionId}/{invocationId}] Streaming result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); await StreamResultsAsync(invocationMessage.InvocationId, connection, protocol, enumerator); } - else + else if (!invocationMessage.NonBlocking) { _logger.LogTrace("[{connectionId}/{invocationId}] Sending result of type {resultType}", connection.ConnectionId, invocationMessage.InvocationId, methodExecutor.MethodReturnType.FullName); await SendMessageAsync(connection, protocol, CompletionMessage.WithResult(invocationMessage.InvocationId, result)); @@ -318,12 +321,18 @@ namespace Microsoft.AspNetCore.SignalR catch (TargetInvocationException ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message)); + if (!invocationMessage.NonBlocking) + { + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.InnerException.Message)); + } } catch (Exception ex) { _logger.LogError(0, ex, "Failed to invoke hub method"); - await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message)); + if (!invocationMessage.NonBlocking) + { + await SendMessageAsync(connection, protocol, CompletionMessage.WithError(invocationMessage.InvocationId, ex.Message)); + } } finally { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs index 8f344a424c..0818fdc422 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionProtocolTests.cs @@ -17,6 +17,30 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests // don't cause problems. public class HubConnectionProtocolTests { + [Fact] + public async Task SendAsyncSendsANonBlockingInvocationMessage() + { + var connection = new TestConnection(); + var hubConnection = new HubConnection(connection, new JsonHubProtocol(new JsonSerializer()), new LoggerFactory()); + try + { + await hubConnection.StartAsync(); + + var invokeTask = hubConnection.SendAsync("Foo"); + + // skip negotiation + await connection.ReadSentTextMessageAsync().OrTimeout(); + var invokeMessage = await connection.ReadSentTextMessageAsync().OrTimeout(); + + Assert.Equal("78:{\"invocationId\":\"1\",\"type\":1,\"target\":\"Foo\",\"nonBlocking\":true,\"arguments\":[]};", invokeMessage); + } + finally + { + await hubConnection.DisposeAsync().OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + } + } + [Fact] public async Task ClientSendsNegotationMessageWhenStartingConnection() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index dccf6de789..e9cf7c6cdb 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -234,6 +234,32 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + [Theory] + [InlineData(nameof(MethodHub.VoidMethod))] + [InlineData(nameof(MethodHub.MethodThatThrows))] + public async Task NonBlockingInvocationDoesNotSendCompletion(string methodName) + { + var serviceProvider = CreateServiceProvider(); + + var endPoint = serviceProvider.GetService>(); + + using (var client = new TestClient(synchronousCallbacks: true)) + { + var endPointTask = endPoint.OnConnectedAsync(client.Connection); + + // This invocation should be completely synchronous + await client.SendInvocationAsync(methodName, nonBlocking: true).OrTimeout(); + + // Nothing should have been written + Assert.False(client.Application.In.TryRead(out var buffer)); + + // kill the connection + client.Dispose(); + + await endPointTask.OrTimeout(); + } + } + [Fact] public async Task HubMethodWithMultiParam() { diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs index b6fb1bcc2b..b2ec6e1135 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/TestClient.cs @@ -27,10 +27,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests public Channel Application { get; } public Task Connected => Connection.Metadata.Get>("ConnectedTask").Task; - public TestClient() + public TestClient(bool synchronousCallbacks = false) { - var transportToApplication = Channel.CreateUnbounded(); - var applicationToTransport = Channel.CreateUnbounded(); + var options = new ChannelOptimizations { AllowSynchronousContinuations = synchronousCallbacks }; + var transportToApplication = Channel.CreateUnbounded(options); + var applicationToTransport = Channel.CreateUnbounded(options); Application = ChannelConnection.Create(input: applicationToTransport, output: transportToApplication); _transport = ChannelConnection.Create(input: transportToApplication, output: applicationToTransport); @@ -52,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task> StreamAsync(string methodName, params object[] args) { - var invocationId = await SendInvocationAsync(methodName, args); + var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args); var messages = new List(); while (true) @@ -85,7 +86,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests public async Task InvokeAsync(string methodName, params object[] args) { - var invocationId = await SendInvocationAsync(methodName, args); + var invocationId = await SendInvocationAsync(methodName, nonBlocking: false, args: args); while (true) { @@ -113,10 +114,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - public async Task SendInvocationAsync(string methodName, params object[] args) + public Task SendInvocationAsync(string methodName, params object[] args) + { + return SendInvocationAsync(methodName, nonBlocking: false, args: args); + } + + public async Task SendInvocationAsync(string methodName, bool nonBlocking, params object[] args) { var invocationId = GetInvocationId(); - var payload = _protocol.WriteToArray(new InvocationMessage(invocationId, nonBlocking: false, target: methodName, arguments: args)); + var payload = _protocol.WriteToArray(new InvocationMessage(invocationId, nonBlocking, methodName, args)); await Application.Out.WriteAsync(payload);