diff --git a/samples/ClientSample/RawSample.cs b/samples/ClientSample/RawSample.cs index d21da7c658..19ba4a07a6 100644 --- a/samples/ClientSample/RawSample.cs +++ b/samples/ClientSample/RawSample.cs @@ -40,7 +40,7 @@ namespace ClientSample try { var cts = new CancellationTokenSource(); - connection.Received += data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}"); + connection.OnReceived(data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}")); connection.Closed += e => { cts.Cancel(); diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs index d20df8ac2c..9642ceaa97 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnection.cs @@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Client private readonly object _pendingCallsLock = new object(); private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource(); private readonly Dictionary _pendingCalls = new Dictionary(); - private readonly ConcurrentDictionary _handlers = new ConcurrentDictionary(); + private readonly ConcurrentDictionary> _handlers = new ConcurrentDictionary>(); private int _nextId = 0; @@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Client _protocol = protocol; _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; _logger = _loggerFactory.CreateLogger(); - _connection.Received += OnDataReceivedAsync; + _connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this); _connection.Closed += Shutdown; } @@ -119,10 +119,20 @@ namespace Microsoft.AspNetCore.SignalR.Client } // TODO: Client return values/tasks? - public void On(string methodName, Type[] parameterTypes, Func handler) + public IDisposable On(string methodName, Type[] parameterTypes, Func handler, object state) { - var invocationHandler = new InvocationHandler(parameterTypes, handler); - _handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler); + var invocationHandler = new InvocationHandler(parameterTypes, handler, state); + var invocationList = _handlers.AddOrUpdate(methodName, _ => new List { invocationHandler }, + (_, invocations) => + { + lock (invocations) + { + invocations.Add(invocationHandler); + } + return invocations; + }); + + return new Subscription(invocationHandler, invocationList); } public async Task> StreamAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default(CancellationToken)) @@ -299,18 +309,35 @@ namespace Microsoft.AspNetCore.SignalR.Client return Task.CompletedTask; } - private Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) + private async Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken) { // Find the handler - if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler)) + if (!_handlers.TryGetValue(invocation.Target, out var handlers)) { _logger.MissingHandler(invocation.Target); - return Task.CompletedTask; + return; } - // TODO: Return values - // TODO: Dispatch to a sync context to ensure we aren't blocking this loop. - return handler.Handler(invocation.Arguments); + //TODO: Optimize this! + // Copying the callbacks to avoid concurrency issues + InvocationHandler[] copiedHandlers; + lock (handlers) + { + copiedHandlers = new InvocationHandler[handlers.Count]; + handlers.CopyTo(copiedHandlers); + } + + foreach (var handler in copiedHandlers) + { + try + { + await handler.InvokeAsync(invocation.Arguments); + } + catch (Exception ex) + { + _logger.ExceptionThrownFromCallback(nameof(On), ex); + } + } } // This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel @@ -404,6 +431,26 @@ namespace Microsoft.AspNetCore.SignalR.Client } } + private class Subscription : IDisposable + { + private readonly InvocationHandler _handler; + private readonly List _handlerList; + + public Subscription(InvocationHandler handler, List handlerList) + { + _handler = handler; + _handlerList = handlerList; + } + + public void Dispose() + { + lock (_handlerList) + { + _handlerList.Remove(_handler); + } + } + } + private class HubBinder : IInvocationBinder { private HubConnection _connection; @@ -415,7 +462,7 @@ namespace Microsoft.AspNetCore.SignalR.Client public Type GetReturnType(string invocationId) { - if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq)) + if (!_connection._pendingCalls.TryGetValue(invocationId, out var irq)) { _connection._logger.ReceivedUnexpectedResponse(invocationId); return null; @@ -425,24 +472,40 @@ namespace Microsoft.AspNetCore.SignalR.Client public Type[] GetParameterTypes(string methodName) { - if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler)) + if (!_connection._handlers.TryGetValue(methodName, out var handlers)) { _connection._logger.MissingHandler(methodName); return Type.EmptyTypes; } - return handler.ParameterTypes; + + // We use the parameter types of the first handler + lock (handlers) + { + if (handlers.Count > 0) + { + return handlers[0].ParameterTypes; + } + throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'"); + } } } private struct InvocationHandler { - public Func Handler { get; } public Type[] ParameterTypes { get; } + private readonly Func _callback; + private readonly object _state; - public InvocationHandler(Type[] parameterTypes, Func handler) + public InvocationHandler(Type[] parameterTypes, Func callback, object state) { - Handler = handler; + _callback = callback; ParameterTypes = parameterTypes; + _state = state; + } + + public Task InvokeAsync(object[] parameters) + { + return _callback(parameters, _state); } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs index 1886aaa52d..edf00f1f83 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/HubConnectionExtensions.cs @@ -10,69 +10,70 @@ namespace Microsoft.AspNetCore.SignalR.Client { public static partial class HubConnectionExtensions { - private static void On(this HubConnection hubConnetion, string methodName, Type[] parameterTypes, Action handler) + private static IDisposable On(this HubConnection hubConnetion, string methodName, Type[] parameterTypes, Action handler) { - hubConnetion.On(methodName, parameterTypes, (parameters) => + return hubConnetion.On(methodName, parameterTypes, (parameters, state) => { - handler(parameters); + var currentHandler = (Action)state; + currentHandler(parameters); return Task.CompletedTask; - }); + }, handler); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, Type.EmptyTypes, args => handler()); + return hubConnection.On(methodName, Type.EmptyTypes, args => handler()); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1) }, args => handler((T1)args[0])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2) }, args => handler((T1)args[0], (T2)args[1])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2), typeof(T3) }, args => handler((T1)args[0], (T2)args[1], (T3)args[2])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) }, args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3])); } @@ -89,41 +90,50 @@ namespace Microsoft.AspNetCore.SignalR.Client args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) }, args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) }, args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5], (T7)args[6])); } - public static void On(this HubConnection hubConnection, string methodName, Action handler) + public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler) { if (hubConnection == null) { throw new ArgumentNullException(nameof(hubConnection)); } - hubConnection.On(methodName, + return hubConnection.On(methodName, new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) }, args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5], (T7)args[6], (T8)args[7])); } + + public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func handler) + { + return hubConnection.On(methodName, parameterTypes, (parameters, state) => + { + var currentHandler = (Func)state; + return currentHandler(parameters); + }, handler); + } } } diff --git a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs index 037840b0cf..af32121b5f 100644 --- a/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.SignalR.Client.Core/Internal/SignalRClientLoggerExtensions.cs @@ -103,6 +103,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal private static readonly Action _streamItemOnNonStreamInvocation = LoggerMessage.Define(LogLevel.Error, new EventId(4, nameof(StreamItemOnNonStreamInvocation)), "Invocation {invocationId} received stream item but was invoked as a non-streamed invocation."); + private static readonly Action _exceptionThrownFromCallback = + LoggerMessage.Define(LogLevel.Error, new EventId(5, nameof(ExceptionThrownFromCallback)), "An exception was thrown from the '{callback}' callback"); + public static void PreparingNonBlockingInvocation(this ILogger logger, string invocationId, string target, int count) { _preparingNonBlockingInvocation(logger, invocationId, target, count, null); @@ -260,5 +263,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal { _streamItemOnNonStreamInvocation(logger, invocationId, null); } + + public static void ExceptionThrownFromCallback(this ILogger logger, string callbackName, Exception exception) + { + _exceptionThrownFromCallback(logger, callbackName, exception); + } } } diff --git a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs index c01c8cdf05..5efbd777b3 100644 --- a/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Abstractions/IConnection.cs @@ -14,7 +14,8 @@ namespace Microsoft.AspNetCore.Sockets.Client Task SendAsync(byte[] data, CancellationToken cancellationToken); Task DisposeAsync(); - event Func Received; + IDisposable OnReceived(Func callback, object state); + event Func Closed; IFeatureCollection Features { get; } diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs index 3d01f3e69d..aa4fcf47fe 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnection.cs @@ -2,6 +2,8 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Concurrent; +using System.Collections.Generic; using System.Diagnostics; using System.IO; using System.Net.Http; @@ -37,15 +39,14 @@ namespace Microsoft.AspNetCore.Sockets.Client private readonly ITransportFactory _transportFactory; private string _connectionId; private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5); - private ReadableChannel Input => _transportChannel.In; private WritableChannel Output => _transportChannel.Out; + private readonly List _callbacks = new List(); public Uri Url { get; } public IFeatureCollection Features { get; } = new FeatureCollection(); - public event Func Received; public event Func Closed; public HttpConnection(Uri url) @@ -186,7 +187,7 @@ namespace Microsoft.AspNetCore.Sockets.Client } catch (Exception ex) { - _logger.ExceptionThrownFromHandler(_connectionId, nameof(Closed), ex); + _logger.ExceptionThrownFromCallback(_connectionId, nameof(Closed), ex); } } }); @@ -338,16 +339,23 @@ namespace Microsoft.AspNetCore.Sockets.Client { _logger.RaiseReceiveEvent(_connectionId); - var receivedHandler = Received; - if (receivedHandler != null) + // Copying the callbacks to avoid concurrency issues + ReceiveCallback[] callbackCopies; + lock (_callbacks) + { + callbackCopies = new ReceiveCallback[_callbacks.Count]; + _callbacks.CopyTo(callbackCopies); + } + + foreach (var callbackObject in callbackCopies) { try { - await receivedHandler(buffer); + await callbackObject.InvokeAsync(buffer); } catch (Exception ex) { - _logger.ExceptionThrownFromHandler(_connectionId, nameof(Received), ex); + _logger.ExceptionThrownFromCallback(_connectionId, nameof(OnReceived), ex); } } }); @@ -444,6 +452,52 @@ namespace Microsoft.AspNetCore.Sockets.Client _httpClient.Dispose(); } + public IDisposable OnReceived(Func callback, object state) + { + var receiveCallback = new ReceiveCallback(callback, state); + lock (_callbacks) + { + _callbacks.Add(receiveCallback); + } + return new Subscription(receiveCallback, _callbacks); + } + + private class ReceiveCallback + { + private readonly Func _callback; + private readonly object _state; + + public ReceiveCallback(Func callback, object state) + { + _callback = callback; + _state = state; + } + + public Task InvokeAsync(byte[] data) + { + return _callback(data, _state); + } + } + + private class Subscription : IDisposable + { + private readonly ReceiveCallback _receiveCallback; + private readonly List _callbacks; + public Subscription(ReceiveCallback callback, List callbacks) + { + _receiveCallback = callback; + _callbacks = callbacks; + } + + public void Dispose() + { + lock (_callbacks) + { + _callbacks.Remove(_receiveCallback); + } + } + } + private class ConnectionState { public const int Initial = 0; diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs new file mode 100644 index 0000000000..490b6dc3de --- /dev/null +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/HttpConnectionExtensions.cs @@ -0,0 +1,20 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. + +using System; +using System.Threading.Tasks; + +namespace Microsoft.AspNetCore.Sockets.Client +{ + public static partial class HttpConnectionExtensions + { + public static IDisposable OnReceived(this HttpConnection connection, Func callback) + { + return connection.OnReceived((data, state) => + { + var currentCallback = (Func)state; + return currentCallback(data); + }, callback); + } + } +} diff --git a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs index 153fed65ef..1ba8b1e82b 100644 --- a/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs +++ b/src/Microsoft.AspNetCore.Sockets.Client.Http/Internal/SocketClientLoggerExtensions.cs @@ -150,8 +150,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal private static readonly Action _stoppingClient = LoggerMessage.Define(LogLevel.Information, new EventId(18, nameof(StoppingClient)), "{time}: Connection Id {connectionId}: Stopping client."); - private static readonly Action _exceptionThrownFromHandler = - LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromHandler)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{eventHandlerName}' event handler."); + private static readonly Action _exceptionThrownFromCallback = + LoggerMessage.Define(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback"); public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode) @@ -514,11 +514,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal } } - public static void ExceptionThrownFromHandler(this ILogger logger, string connectionId, string eventHandlerName, Exception exception) + public static void ExceptionThrownFromCallback(this ILogger logger, string connectionId, string callbackName, Exception exception) { if (logger.IsEnabled(LogLevel.Error)) { - _exceptionThrownFromHandler(logger, DateTime.Now, connectionId, eventHandlerName, exception); + _exceptionThrownFromCallback(logger, DateTime.Now, connectionId, callbackName, exception); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 26ca8eefa0..c995f36f3c 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -12,7 +12,6 @@ using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; -using Newtonsoft.Json; using Xunit; using Xunit.Abstractions; @@ -164,6 +163,41 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests } } + [Theory] + [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] + public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) + { + using (StartLog(out var loggerFactory)) + { + var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory); + var connection = new HubConnection(httpConnection, protocol, loggerFactory); + try + { + await connection.StartAsync().OrTimeout(); + var closeTcs = new TaskCompletionSource(); + connection.Closed += ex => + { + if (ex != null) + { + closeTcs.SetException(ex); + } + else + { + closeTcs.SetResult(null); + } + return Task.CompletedTask; + }; + await connection.InvokeAsync("CallHandlerThatDoesntExist").OrTimeout(); + await connection.DisposeAsync().OrTimeout(); + await closeTcs.Task.OrTimeout(); + } + finally + { + await connection.DisposeAsync().OrTimeout(); + } + } + } + [Theory] [MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))] public async Task CanStreamClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs index 2c02e941fe..3246d9c825 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/Hubs.cs @@ -24,6 +24,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests await Clients.Client(Context.ConnectionId).InvokeAsync("Echo", message); } + public async Task CallHandlerThatDoesntExist() + { + await Clients.Client(Context.ConnectionId).InvokeAsync("NoClientHandler"); + } + public IObservable Stream(int count) { return Observable.Interval(TimeSpan.FromMilliseconds(1)) @@ -60,6 +65,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { return Clients.All.Send(message); } + + public async Task CallHandlerThatDoesntExist() + { + await Clients.Client(Context.ConnectionId).NoClientHandler(); + } + } public class TestHubT : Hub @@ -90,12 +101,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests { return Clients.All.Send(message); } + + public async Task CallHandlerThatDoesntExist() + { + await Clients.Client(Context.ConnectionId).NoClientHandler(); + } } public interface ITestHub { Task Echo(string message); Task Send(string message); + Task NoClientHandler(); } } diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs index 20e99ffd0c..3baddd6175 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HttpConnectionTests.cs @@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests } [Fact] - public async Task ReceivedEventNotRaisedAfterConnectionIsDisposed() + public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed() { var mockHttpHandler = new Mock(); mockHttpHandler.Protected() @@ -289,16 +289,16 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); - var receivedInvoked = false; - connection.Received += m => + var onReceivedInvoked = false; + connection.OnReceived( _ => { - receivedInvoked = true; + onReceivedInvoked = true; return Task.CompletedTask; - }; + }); await connection.StartAsync(); await connection.DisposeAsync(); - Assert.False(receivedInvoked); + Assert.False(onReceivedInvoked); } [Fact] @@ -336,12 +336,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); - connection.Received += - async m => + connection.OnReceived(_ => { callbackInvokedTcs.SetResult(null); - await closedTcs.Task; - }; + return closedTcs.Task; + }); await connection.StartAsync(); channel.Out.TryWrite(Array.Empty()); @@ -392,11 +391,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var closedTcs = new TaskCompletionSource(); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); - connection.Received += - async m => - { - await blockReceiveCallbackTcs.Task; - }; + connection.OnReceived(_ => blockReceiveCallbackTcs.Task); + connection.Closed += _ => { closedTcs.SetResult(null); return Task.CompletedTask; @@ -445,11 +441,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var closedTcs = new TaskCompletionSource(); var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object); - connection.Received += - m => + connection.OnReceived( _ => { throw new OperationCanceledException(); - }; + }); await connection.StartAsync(); channel.Out.TryWrite(Array.Empty()); @@ -642,11 +637,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests try { var receiveTcs = new TaskCompletionSource(); - connection.Received += data => + connection.OnReceived((data, state) => { - receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + var tcs = ((TaskCompletionSource)state); + tcs.TrySetResult(Encoding.UTF8.GetString(data)); return Task.CompletedTask; - }; + }, receiveTcs); connection.Closed += e => { @@ -699,7 +695,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var receiveTcs = new TaskCompletionSource(); var receivedRaised = false; - connection.Received += data => + connection.OnReceived(data => { if (!receivedRaised) { @@ -709,7 +705,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); return Task.CompletedTask; - }; + }); connection.Closed += e => { @@ -762,7 +758,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests var receiveTcs = new TaskCompletionSource(); var receivedRaised = false; - connection.Received += data => + connection.OnReceived((data) => { if (!receivedRaised) { @@ -772,7 +768,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); return Task.CompletedTask; - }; + }); connection.Closed += e => { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs index e3b0f554e2..cf9a7a88ec 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/HubConnectionTests.cs @@ -12,7 +12,6 @@ using Microsoft.AspNetCore.SignalR.Tests.Common; using Microsoft.AspNetCore.Sockets.Client; using Microsoft.Extensions.Logging; using Moq; -using Newtonsoft.Json; using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests @@ -120,23 +119,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Assert.Same(exception, thrown); } - [Fact] - public async Task DoesNotThrowWhenClientMethodCalledButNoInvocationHandlerHasBeenSetUp() - { - var mockConnection = new Mock(); - mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection()); - - var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 }); - - var mockProtocol = MockHubProtocol.ReturnOnParse(invocation); - - var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null); - await hubConnection.StartAsync(); - - mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { } }); - Assert.Equal(1, mockProtocol.ParseCalls); - } - // Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse private class MockHubProtocol : IHubProtocol { diff --git a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs index 2e57f53cf4..bc234ebb69 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.Tests/TestConnection.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Collections.Generic; using System.IO; using System.Text; using System.Threading; @@ -30,7 +31,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests private TransferMode? _transferMode; public event Func Connected; - public event Func Received; public event Func Closed; public Task Started => _started.Task; @@ -38,6 +38,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests public ReadableChannel SentMessages => _sentMessages.In; public WritableChannel ReceivedMessages => _receivedMessages.Out; + private readonly List _callbacks = new List(); + public IFeatureCollection Features { get; } = new FeatureCollection(); public TestConnection(TransferMode? transferMode = null) @@ -120,7 +122,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests { while (_receivedMessages.In.TryRead(out var message)) { - await Received?.Invoke(message); + ReceiveCallback[] callbackCopies; + lock (_callbacks) + { + callbackCopies = _callbacks.ToArray(); + } + + foreach (var callback in callbackCopies) + { + await callback.InvokeAsync(message); + } } } } @@ -136,5 +147,51 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests Closed?.Invoke(ex); } } + + public IDisposable OnReceived(Func callback, object state) + { + var receiveCallBack = new ReceiveCallback(callback, state); + lock (_callbacks) + { + _callbacks.Add(receiveCallBack); + } + return new Subscription(receiveCallBack, _callbacks); + } + + private class ReceiveCallback + { + private readonly Func _callback; + private readonly object _state; + + public ReceiveCallback(Func callback, object state) + { + _callback = callback; + _state = state; + } + + public Task InvokeAsync(byte[] data) + { + return _callback(data, _state); + } + } + + private class Subscription : IDisposable + { + private readonly ReceiveCallback _callback; + private readonly List _callbacks; + public Subscription(ReceiveCallback callback, List callbacks) + { + _callback = callback; + _callbacks = callbacks; + } + + public void Dispose() + { + lock (_callbacks) + { + _callbacks.Remove(_callback); + } + } + } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs index 4c3b6a0067..bd8f0f7020 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/EndToEndTests.cs @@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests { var receiveTcs = new TaskCompletionSource(); var closeTcs = new TaskCompletionSource(); - connection.Received += data => + connection.OnReceived((data, state) => { logger.LogInformation("Received {length} byte message", data.Length); @@ -124,10 +124,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests { data = Convert.FromBase64String(Encoding.UTF8.GetString(data)); } - - receiveTcs.TrySetResult(Encoding.UTF8.GetString(data)); + var tcs = (TaskCompletionSource)state; + tcs.TrySetResult(Encoding.UTF8.GetString(data)); return Task.CompletedTask; - }; + }, receiveTcs); + connection.Closed += e => { logger.LogInformation("Connection closed"); @@ -224,12 +225,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests try { var receiveTcs = new TaskCompletionSource(); - connection.Received += data => + connection.OnReceived((data, state) => { logger.LogInformation("Received {length} byte message", data.Length); - receiveTcs.TrySetResult(data); + var tcs = (TaskCompletionSource)state; + tcs.TrySetResult(data); return Task.CompletedTask; - }; + }, receiveTcs); logger.LogInformation("Starting connection to {url}", url); await connection.StartAsync().OrTimeout();