diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index 2e366cb2d1..b6a4c56ef6 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -40,7 +40,7 @@ namespace ClientSample try { var cts = new CancellationTokenSource(); - connection.Received += data => Console.WriteLine($"{Encoding.UTF8.GetString(data)}"); + connection.Received += data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}"); connection.Closed += e => cts.Cancel(); await connection.StartAsync(); diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs index afbc2ad32a..c555343736 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnection.cs @@ -70,7 +70,7 @@ namespace Microsoft.AspNetCore.SignalR.Client _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); - _connection.Received += OnDataReceived; + _connection.Received += OnDataReceivedAsync; _connection.Closed += Shutdown; } @@ -85,7 +85,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } // TODO: Client return values/tasks? - public void On(string methodName, Type[] parameterTypes, Action handler) + public void On(string methodName, Type[] parameterTypes, Func handler) { var invocationHandler = new InvocationHandler(parameterTypes, handler); _handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler); @@ -148,7 +148,7 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private void OnDataReceived(byte[] data) + private async Task OnDataReceivedAsync(byte[] data) { if (_protocol.TryParseMessages(data, _binder, out var messages)) { @@ -163,7 +163,7 @@ namespace Microsoft.AspNetCore.SignalR.Client var argsList = string.Join(", ", invocation.Arguments.Select(a => a.GetType().FullName)); _logger.LogTrace("Received Invocation '{invocationId}': {methodName}({args})", invocation.InvocationId, invocation.Target, argsList); } - DispatchInvocation(invocation, _connectionActive.Token); + await DispatchInvocationAsync(invocation, _connectionActive.Token); break; case CompletionMessage completion: if (!TryRemoveInvocation(completion.InvocationId, out irq)) @@ -218,18 +218,18 @@ namespace Microsoft.AspNetCore.SignalR.Client } } - private void DispatchInvocation(InvocationMessage invocation, CancellationToken cancellationToken) + private Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler)) { _logger.LogWarning("Failed to find handler for '{target}' method", invocation.Target); - return; + return Task.CompletedTask; } // TODO: Return values // TODO: Dispatch to a sync context to ensure we aren't blocking this loop. - handler.Handler(invocation.Arguments); + return handler.Handler(invocation.Arguments); } // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel @@ -355,10 +355,10 @@ namespace Microsoft.AspNetCore.SignalR.Client private struct InvocationHandler { - public Action Handler { get; } + public Func Handler { get; } public Type[] ParameterTypes { get; } - public InvocationHandler(Type[] parameterTypes, Action handler) + public InvocationHandler(Type[] parameterTypes, Func handler) { Handler = handler; ParameterTypes = parameterTypes; diff --git a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs index c26b4db924..04a29ead22 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client/HubConnectionExtensions.cs @@ -5,6 +5,7 @@ using System; using System.Threading; using System.Threading.Tasks; using System.Threading.Tasks.Channels; +using static Microsoft.AspNetCore.SignalR.Client.HubConnection; namespace Microsoft.AspNetCore.SignalR.Client { @@ -89,6 +90,15 @@ namespace Microsoft.AspNetCore.SignalR.Client return outputChannel.In; } + private static void On(this HubConnection hubConnetion, string methodName, Type[] parameterTypes, Action handler) + { + hubConnetion.On(methodName, parameterTypes, (parameters) => + { + handler(parameters); + return Task.CompletedTask; + }); + } + public static void On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index b2d62197f9..315ecc794c 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -14,7 +14,7 @@ namespace Microsoft.AspNetCore.Sockets.Client Task DisposeAsync(); event Action Connected; - event Action Received; + event Func Received; event Action Closed; } } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 1c8d4cde63..432c3a1597 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets.Client public Uri Url { get; } public event Action Connected; - public event Action Received; + public event Func Received; public event Action Closed; public HttpConnection(Uri url) @@ -284,11 +284,17 @@ namespace Microsoft.AspNetCore.Sockets.Client if (Input.TryRead(out var buffer)) { _logger.LogDebug("Scheduling raising Received event."); - var ignore = _eventQueue.Enqueue(() => + _ = _eventQueue.Enqueue(() => { _logger.LogDebug("Raising Received event."); - Received?.Invoke(buffer); + // Making a copy of the Received handler to ensure that its not null + // Can't use the ? operator because we specifically want to check if the handler is null + var receivedHandler = Received; + if (receivedHandler != null) + { + return receivedHandler(buffer); + } return Task.CompletedTask; }); diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs index 7b78d390ff..10cf204c88 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/TaskQueue.cs @@ -49,7 +49,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal { return t; } - return taskFunc(s1); + + return taskFunc(s1) ?? Task.CompletedTask; }, state).Unwrap(); _lastQueuedTask = newTask; diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs index 0dbfb9fbb7..e0a0313387 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/ConnectionTests.cs @@ -11,9 +11,11 @@ using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Client.Tests; using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Testing; using Moq; using Moq.Protected; using Xunit; +using Xunit.Abstractions; namespace Microsoft.AspNetCore.Sockets.Client.Tests { @@ -351,7 +353,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); var receivedInvoked = false; - connection.Received += (m) => receivedInvoked = true; + connection.Received += m => + { + receivedInvoked = true; + return Task.CompletedTask; + }; await connection.StartAsync(); await connection.DisposeAsync(); @@ -388,30 +394,30 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests }); + var callbackInvokedTcs = new TaskCompletionSource(); var closedTcs = new TaskCompletionSource(); - var allowDisposeTcs = new TaskCompletionSource(); - int receivedInvocationCount = 0; var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); connection.Received += - async (m) => + async m => { - if (Interlocked.Increment(ref receivedInvocationCount) == 2) - { - allowDisposeTcs.TrySetResult(null); - } + callbackInvokedTcs.SetResult(null); await closedTcs.Task; }; - connection.Closed += e => closedTcs.SetResult(null); await connection.StartAsync(); channel.Out.TryWrite(Array.Empty()); + + // Ensure that the Received callback has been called before attempting the second write + await callbackInvokedTcs.Task.OrTimeout(); channel.Out.TryWrite(Array.Empty()); - await allowDisposeTcs.Task.OrTimeout(); + + // Ensure that SignalR isn't blocked by the receive callback + Assert.False(channel.In.TryRead(out var message)); + + closedTcs.SetResult(null); + await connection.DisposeAsync(); - Assert.Equal(2, receivedInvocationCount); - // if the events were running on the main loop they would deadlock - await closedTcs.Task.OrTimeout(); } [Fact] @@ -593,7 +599,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests try { var receiveTcs = new TaskCompletionSource(); - connection.Received += (data) => receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + connection.Received += data => + { + receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; + }; + connection.Closed += e => { if (e != null) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs index 8a55e55fb9..caed80bc89 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionExtensionsTests.cs @@ -148,7 +148,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests await connection.ReceiveJsonMessage( new { - invocationId = "1", + invocationId = "1", type = 1, target = "Foo", arguments = new object[] { 42, "42" } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index a2ab964f01..535f605fc3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -27,7 +27,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private Task _receiveLoop; public event Action Connected; - public event Action Received; + public event Func Received; public event Action Closed; public Task Started => _started.Task; @@ -102,7 +102,7 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { while (_receivedMessages.In.TryRead(out var message)) { - Received?.Invoke(message); + await Received?.Invoke(message); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 09bece49a4..2419908a3b 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -96,6 +96,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { logger.LogInformation("Received {length} byte message", data.Length); receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + return Task.CompletedTask; }; connection.Closed += e => { @@ -163,6 +164,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { logger.LogInformation("Received {length} byte message", data.Length); receiveTcs.TrySetResult(data); + return Task.CompletedTask; }; logger.LogInformation("Starting connection to {url}", url);