Replace Received Event with OnReceived (#1006)

This commit is contained in:
Mikael Mengistu 2017-10-18 17:10:51 -07:00 committed by GitHub
parent bb308ff72e
commit 04d4da2987
14 changed files with 347 additions and 103 deletions

View File

@ -40,7 +40,7 @@ namespace ClientSample
try
{
var cts = new CancellationTokenSource();
connection.Received += data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}");
connection.OnReceived(data => Console.Out.WriteLineAsync($"{Encoding.UTF8.GetString(data)}"));
connection.Closed += e =>
{
cts.Cancel();

View File

@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
private readonly object _pendingCallsLock = new object();
private readonly CancellationTokenSource _connectionActive = new CancellationTokenSource();
private readonly Dictionary<string, InvocationRequest> _pendingCalls = new Dictionary<string, InvocationRequest>();
private readonly ConcurrentDictionary<string, InvocationHandler> _handlers = new ConcurrentDictionary<string, InvocationHandler>();
private readonly ConcurrentDictionary<string, List<InvocationHandler>> _handlers = new ConcurrentDictionary<string, List<InvocationHandler>>();
private int _nextId = 0;
@ -62,7 +62,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
_protocol = protocol;
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<HubConnection>();
_connection.Received += OnDataReceivedAsync;
_connection.OnReceived((data, state) => ((HubConnection)state).OnDataReceivedAsync(data), this);
_connection.Closed += Shutdown;
}
@ -119,10 +119,20 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
// TODO: Client return values/tasks?
public void On(string methodName, Type[] parameterTypes, Func<object[], Task> handler)
public IDisposable On(string methodName, Type[] parameterTypes, Func<object[], object, Task> handler, object state)
{
var invocationHandler = new InvocationHandler(parameterTypes, handler);
_handlers.AddOrUpdate(methodName, invocationHandler, (_, __) => invocationHandler);
var invocationHandler = new InvocationHandler(parameterTypes, handler, state);
var invocationList = _handlers.AddOrUpdate(methodName, _ => new List<InvocationHandler> { invocationHandler },
(_, invocations) =>
{
lock (invocations)
{
invocations.Add(invocationHandler);
}
return invocations;
});
return new Subscription(invocationHandler, invocationList);
}
public async Task<ReadableChannel<object>> StreamAsync(string methodName, Type returnType, object[] args, CancellationToken cancellationToken = default(CancellationToken))
@ -299,18 +309,35 @@ namespace Microsoft.AspNetCore.SignalR.Client
return Task.CompletedTask;
}
private Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken)
private async Task DispatchInvocationAsync(InvocationMessage invocation, CancellationToken cancellationToken)
{
// Find the handler
if (!_handlers.TryGetValue(invocation.Target, out InvocationHandler handler))
if (!_handlers.TryGetValue(invocation.Target, out var handlers))
{
_logger.MissingHandler(invocation.Target);
return Task.CompletedTask;
return;
}
// TODO: Return values
// TODO: Dispatch to a sync context to ensure we aren't blocking this loop.
return handler.Handler(invocation.Arguments);
//TODO: Optimize this!
// Copying the callbacks to avoid concurrency issues
InvocationHandler[] copiedHandlers;
lock (handlers)
{
copiedHandlers = new InvocationHandler[handlers.Count];
handlers.CopyTo(copiedHandlers);
}
foreach (var handler in copiedHandlers)
{
try
{
await handler.InvokeAsync(invocation.Arguments);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromCallback(nameof(On), ex);
}
}
}
// This async void is GROSS but we need to dispatch asynchronously because we're writing to a Channel
@ -404,6 +431,26 @@ namespace Microsoft.AspNetCore.SignalR.Client
}
}
private class Subscription : IDisposable
{
private readonly InvocationHandler _handler;
private readonly List<InvocationHandler> _handlerList;
public Subscription(InvocationHandler handler, List<InvocationHandler> handlerList)
{
_handler = handler;
_handlerList = handlerList;
}
public void Dispose()
{
lock (_handlerList)
{
_handlerList.Remove(_handler);
}
}
}
private class HubBinder : IInvocationBinder
{
private HubConnection _connection;
@ -415,7 +462,7 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type GetReturnType(string invocationId)
{
if (!_connection._pendingCalls.TryGetValue(invocationId, out InvocationRequest irq))
if (!_connection._pendingCalls.TryGetValue(invocationId, out var irq))
{
_connection._logger.ReceivedUnexpectedResponse(invocationId);
return null;
@ -425,24 +472,40 @@ namespace Microsoft.AspNetCore.SignalR.Client
public Type[] GetParameterTypes(string methodName)
{
if (!_connection._handlers.TryGetValue(methodName, out InvocationHandler handler))
if (!_connection._handlers.TryGetValue(methodName, out var handlers))
{
_connection._logger.MissingHandler(methodName);
return Type.EmptyTypes;
}
return handler.ParameterTypes;
// We use the parameter types of the first handler
lock (handlers)
{
if (handlers.Count > 0)
{
return handlers[0].ParameterTypes;
}
throw new InvalidOperationException($"There are no callbacks registered for the method '{methodName}'");
}
}
}
private struct InvocationHandler
{
public Func<object[], Task> Handler { get; }
public Type[] ParameterTypes { get; }
private readonly Func<object[], object, Task> _callback;
private readonly object _state;
public InvocationHandler(Type[] parameterTypes, Func<object[], Task> handler)
public InvocationHandler(Type[] parameterTypes, Func<object[], object, Task> callback, object state)
{
Handler = handler;
_callback = callback;
ParameterTypes = parameterTypes;
_state = state;
}
public Task InvokeAsync(object[] parameters)
{
return _callback(parameters, _state);
}
}

View File

@ -10,69 +10,70 @@ namespace Microsoft.AspNetCore.SignalR.Client
{
public static partial class HubConnectionExtensions
{
private static void On(this HubConnection hubConnetion, string methodName, Type[] parameterTypes, Action<object[]> handler)
private static IDisposable On(this HubConnection hubConnetion, string methodName, Type[] parameterTypes, Action<object[]> handler)
{
hubConnetion.On(methodName, parameterTypes, (parameters) =>
return hubConnetion.On(methodName, parameterTypes, (parameters, state) =>
{
handler(parameters);
var currentHandler = (Action<object[]>)state;
currentHandler(parameters);
return Task.CompletedTask;
});
}, handler);
}
public static void On(this HubConnection hubConnection, string methodName, Action handler)
public static IDisposable On(this HubConnection hubConnection, string methodName, Action handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName, Type.EmptyTypes, args => handler());
return hubConnection.On(methodName, Type.EmptyTypes, args => handler());
}
public static void On<T1>(this HubConnection hubConnection, string methodName, Action<T1> handler)
public static IDisposable On<T1>(this HubConnection hubConnection, string methodName, Action<T1> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1) },
args => handler((T1)args[0]));
}
public static void On<T1, T2>(this HubConnection hubConnection, string methodName, Action<T1, T2> handler)
public static IDisposable On<T1, T2>(this HubConnection hubConnection, string methodName, Action<T1, T2> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2) },
args => handler((T1)args[0], (T2)args[1]));
}
public static void On<T1, T2, T3>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3> handler)
public static IDisposable On<T1, T2, T3>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2), typeof(T3) },
args => handler((T1)args[0], (T2)args[1], (T3)args[2]));
}
public static void On<T1, T2, T3, T4>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4> handler)
public static IDisposable On<T1, T2, T3, T4>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4) },
args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3]));
}
@ -89,41 +90,50 @@ namespace Microsoft.AspNetCore.SignalR.Client
args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4]));
}
public static void On<T1, T2, T3, T4, T5, T6>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6> handler)
public static IDisposable On<T1, T2, T3, T4, T5, T6>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6) },
args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5]));
}
public static void On<T1, T2, T3, T4, T5, T6, T7>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6, T7> handler)
public static IDisposable On<T1, T2, T3, T4, T5, T6, T7>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6, T7> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7) },
args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5], (T7)args[6]));
}
public static void On<T1, T2, T3, T4, T5, T6, T7, T8>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6, T7, T8> handler)
public static IDisposable On<T1, T2, T3, T4, T5, T6, T7, T8>(this HubConnection hubConnection, string methodName, Action<T1, T2, T3, T4, T5, T6, T7, T8> handler)
{
if (hubConnection == null)
{
throw new ArgumentNullException(nameof(hubConnection));
}
hubConnection.On(methodName,
return hubConnection.On(methodName,
new[] { typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8) },
args => handler((T1)args[0], (T2)args[1], (T3)args[2], (T4)args[3], (T5)args[4], (T6)args[5], (T7)args[6], (T8)args[7]));
}
public static IDisposable On(this HubConnection hubConnection, string methodName, Type[] parameterTypes, Func<object[], Task> handler)
{
return hubConnection.On(methodName, parameterTypes, (parameters, state) =>
{
var currentHandler = (Func<object[], Task>)state;
return currentHandler(parameters);
}, handler);
}
}
}

View File

@ -103,6 +103,9 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
private static readonly Action<ILogger, string, Exception> _streamItemOnNonStreamInvocation =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(4, nameof(StreamItemOnNonStreamInvocation)), "Invocation {invocationId} received stream item but was invoked as a non-streamed invocation.");
private static readonly Action<ILogger, string, Exception> _exceptionThrownFromCallback =
LoggerMessage.Define<string>(LogLevel.Error, new EventId(5, nameof(ExceptionThrownFromCallback)), "An exception was thrown from the '{callback}' callback");
public static void PreparingNonBlockingInvocation(this ILogger logger, string invocationId, string target, int count)
{
_preparingNonBlockingInvocation(logger, invocationId, target, count, null);
@ -260,5 +263,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.Internal
{
_streamItemOnNonStreamInvocation(logger, invocationId, null);
}
public static void ExceptionThrownFromCallback(this ILogger logger, string callbackName, Exception exception)
{
_exceptionThrownFromCallback(logger, callbackName, exception);
}
}
}

View File

@ -14,7 +14,8 @@ namespace Microsoft.AspNetCore.Sockets.Client
Task SendAsync(byte[] data, CancellationToken cancellationToken);
Task DisposeAsync();
event Func<byte[], Task> Received;
IDisposable OnReceived(Func<byte[], object, Task> callback, object state);
event Func<Exception, Task> Closed;
IFeatureCollection Features { get; }

View File

@ -2,6 +2,8 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Net.Http;
@ -37,15 +39,14 @@ namespace Microsoft.AspNetCore.Sockets.Client
private readonly ITransportFactory _transportFactory;
private string _connectionId;
private readonly TimeSpan _eventQueueDrainTimeout = TimeSpan.FromSeconds(5);
private ReadableChannel<byte[]> Input => _transportChannel.In;
private WritableChannel<SendMessage> Output => _transportChannel.Out;
private readonly List<ReceiveCallback> _callbacks = new List<ReceiveCallback>();
public Uri Url { get; }
public IFeatureCollection Features { get; } = new FeatureCollection();
public event Func<byte[], Task> Received;
public event Func<Exception, Task> Closed;
public HttpConnection(Uri url)
@ -186,7 +187,7 @@ namespace Microsoft.AspNetCore.Sockets.Client
}
catch (Exception ex)
{
_logger.ExceptionThrownFromHandler(_connectionId, nameof(Closed), ex);
_logger.ExceptionThrownFromCallback(_connectionId, nameof(Closed), ex);
}
}
});
@ -338,16 +339,23 @@ namespace Microsoft.AspNetCore.Sockets.Client
{
_logger.RaiseReceiveEvent(_connectionId);
var receivedHandler = Received;
if (receivedHandler != null)
// Copying the callbacks to avoid concurrency issues
ReceiveCallback[] callbackCopies;
lock (_callbacks)
{
callbackCopies = new ReceiveCallback[_callbacks.Count];
_callbacks.CopyTo(callbackCopies);
}
foreach (var callbackObject in callbackCopies)
{
try
{
await receivedHandler(buffer);
await callbackObject.InvokeAsync(buffer);
}
catch (Exception ex)
{
_logger.ExceptionThrownFromHandler(_connectionId, nameof(Received), ex);
_logger.ExceptionThrownFromCallback(_connectionId, nameof(OnReceived), ex);
}
}
});
@ -444,6 +452,52 @@ namespace Microsoft.AspNetCore.Sockets.Client
_httpClient.Dispose();
}
public IDisposable OnReceived(Func<byte[], object, Task> callback, object state)
{
var receiveCallback = new ReceiveCallback(callback, state);
lock (_callbacks)
{
_callbacks.Add(receiveCallback);
}
return new Subscription(receiveCallback, _callbacks);
}
private class ReceiveCallback
{
private readonly Func<byte[], object, Task> _callback;
private readonly object _state;
public ReceiveCallback(Func<byte[], object, Task> callback, object state)
{
_callback = callback;
_state = state;
}
public Task InvokeAsync(byte[] data)
{
return _callback(data, _state);
}
}
private class Subscription : IDisposable
{
private readonly ReceiveCallback _receiveCallback;
private readonly List<ReceiveCallback> _callbacks;
public Subscription(ReceiveCallback callback, List<ReceiveCallback> callbacks)
{
_receiveCallback = callback;
_callbacks = callbacks;
}
public void Dispose()
{
lock (_callbacks)
{
_callbacks.Remove(_receiveCallback);
}
}
}
private class ConnectionState
{
public const int Initial = 0;

View File

@ -0,0 +1,20 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets.Client
{
public static partial class HttpConnectionExtensions
{
public static IDisposable OnReceived(this HttpConnection connection, Func<byte[], Task> callback)
{
return connection.OnReceived((data, state) =>
{
var currentCallback = (Func<byte[], Task>)state;
return currentCallback(data);
}, callback);
}
}
}

View File

@ -150,8 +150,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
private static readonly Action<ILogger, DateTime, string, Exception> _stoppingClient =
LoggerMessage.Define<DateTime, string>(LogLevel.Information, new EventId(18, nameof(StoppingClient)), "{time}: Connection Id {connectionId}: Stopping client.");
private static readonly Action<ILogger, DateTime, string, string, Exception> _exceptionThrownFromHandler =
LoggerMessage.Define<DateTime, string, string>(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromHandler)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{eventHandlerName}' event handler.");
private static readonly Action<ILogger, DateTime, string, string, Exception> _exceptionThrownFromCallback =
LoggerMessage.Define<DateTime, string, string>(LogLevel.Error, new EventId(19, nameof(ExceptionThrownFromCallback)), "{time}: Connection Id {connectionId}: An exception was thrown from the '{callback}' callback");
public static void StartTransport(this ILogger logger, string connectionId, TransferMode transferMode)
@ -514,11 +514,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Internal
}
}
public static void ExceptionThrownFromHandler(this ILogger logger, string connectionId, string eventHandlerName, Exception exception)
public static void ExceptionThrownFromCallback(this ILogger logger, string connectionId, string callbackName, Exception exception)
{
if (logger.IsEnabled(LogLevel.Error))
{
_exceptionThrownFromHandler(logger, DateTime.Now, connectionId, eventHandlerName, exception);
_exceptionThrownFromCallback(logger, DateTime.Now, connectionId, callbackName, exception);
}
}
}

View File

@ -12,7 +12,6 @@ using Microsoft.AspNetCore.Sockets;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Testing;
using Newtonsoft.Json;
using Xunit;
using Xunit.Abstractions;
@ -164,6 +163,41 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task InvokeNonExistantClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path)
{
using (StartLog(out var loggerFactory))
{
var httpConnection = new HttpConnection(new Uri(_serverFixture.BaseUrl + path), transportType, loggerFactory);
var connection = new HubConnection(httpConnection, protocol, loggerFactory);
try
{
await connection.StartAsync().OrTimeout();
var closeTcs = new TaskCompletionSource<object>();
connection.Closed += ex =>
{
if (ex != null)
{
closeTcs.SetException(ex);
}
else
{
closeTcs.SetResult(null);
}
return Task.CompletedTask;
};
await connection.InvokeAsync("CallHandlerThatDoesntExist").OrTimeout();
await connection.DisposeAsync().OrTimeout();
await closeTcs.Task.OrTimeout();
}
finally
{
await connection.DisposeAsync().OrTimeout();
}
}
}
[Theory]
[MemberData(nameof(HubProtocolsAndTransportsAndHubPaths))]
public async Task CanStreamClientMethodFromServer(IHubProtocol protocol, TransportType transportType, string path)

View File

@ -24,6 +24,11 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
await Clients.Client(Context.ConnectionId).InvokeAsync("Echo", message);
}
public async Task CallHandlerThatDoesntExist()
{
await Clients.Client(Context.ConnectionId).InvokeAsync("NoClientHandler");
}
public IObservable<int> Stream(int count)
{
return Observable.Interval(TimeSpan.FromMilliseconds(1))
@ -60,6 +65,12 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
return Clients.All.Send(message);
}
public async Task CallHandlerThatDoesntExist()
{
await Clients.Client(Context.ConnectionId).NoClientHandler();
}
}
public class TestHubT : Hub<ITestHub>
@ -90,12 +101,18 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests
{
return Clients.All.Send(message);
}
public async Task CallHandlerThatDoesntExist()
{
await Clients.Client(Context.ConnectionId).NoClientHandler();
}
}
public interface ITestHub
{
Task Echo(string message);
Task Send(string message);
Task NoClientHandler();
}
}

View File

@ -255,7 +255,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
}
[Fact]
public async Task ReceivedEventNotRaisedAfterConnectionIsDisposed()
public async Task ReceivedCallbackNotRaisedAfterConnectionIsDisposed()
{
var mockHttpHandler = new Mock<HttpMessageHandler>();
mockHttpHandler.Protected()
@ -289,16 +289,16 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
var receivedInvoked = false;
connection.Received += m =>
var onReceivedInvoked = false;
connection.OnReceived( _ =>
{
receivedInvoked = true;
onReceivedInvoked = true;
return Task.CompletedTask;
};
});
await connection.StartAsync();
await connection.DisposeAsync();
Assert.False(receivedInvoked);
Assert.False(onReceivedInvoked);
}
[Fact]
@ -336,12 +336,11 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
connection.Received +=
async m =>
connection.OnReceived(_ =>
{
callbackInvokedTcs.SetResult(null);
await closedTcs.Task;
};
return closedTcs.Task;
});
await connection.StartAsync();
channel.Out.TryWrite(Array.Empty<byte>());
@ -392,11 +391,8 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var closedTcs = new TaskCompletionSource<object>();
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
connection.Received +=
async m =>
{
await blockReceiveCallbackTcs.Task;
};
connection.OnReceived(_ => blockReceiveCallbackTcs.Task);
connection.Closed += _ => {
closedTcs.SetResult(null);
return Task.CompletedTask;
@ -445,11 +441,10 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var closedTcs = new TaskCompletionSource<object>();
var connection = new HttpConnection(new Uri("http://fakeuri.org/"), new TestTransportFactory(mockTransport.Object), loggerFactory: null, httpMessageHandler: mockHttpHandler.Object);
connection.Received +=
m =>
connection.OnReceived( _ =>
{
throw new OperationCanceledException();
};
});
await connection.StartAsync();
channel.Out.TryWrite(Array.Empty<byte>());
@ -642,11 +637,12 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
try
{
var receiveTcs = new TaskCompletionSource<string>();
connection.Received += data =>
connection.OnReceived((data, state) =>
{
receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
var tcs = ((TaskCompletionSource<string>)state);
tcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};
}, receiveTcs);
connection.Closed += e =>
{
@ -699,7 +695,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var receiveTcs = new TaskCompletionSource<string>();
var receivedRaised = false;
connection.Received += data =>
connection.OnReceived(data =>
{
if (!receivedRaised)
{
@ -709,7 +705,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};
});
connection.Closed += e =>
{
@ -762,7 +758,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
var receiveTcs = new TaskCompletionSource<string>();
var receivedRaised = false;
connection.Received += data =>
connection.OnReceived((data) =>
{
if (!receivedRaised)
{
@ -772,7 +768,7 @@ namespace Microsoft.AspNetCore.Sockets.Client.Tests
receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};
});
connection.Closed += e =>
{

View File

@ -12,7 +12,6 @@ using Microsoft.AspNetCore.SignalR.Tests.Common;
using Microsoft.AspNetCore.Sockets.Client;
using Microsoft.Extensions.Logging;
using Moq;
using Newtonsoft.Json;
using Xunit;
namespace Microsoft.AspNetCore.SignalR.Client.Tests
@ -120,23 +119,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
Assert.Same(exception, thrown);
}
[Fact]
public async Task DoesNotThrowWhenClientMethodCalledButNoInvocationHandlerHasBeenSetUp()
{
var mockConnection = new Mock<IConnection>();
mockConnection.SetupGet(p => p.Features).Returns(new FeatureCollection());
var invocation = new InvocationMessage(Guid.NewGuid().ToString(), nonBlocking: true, target: "NonExistingMethod123", arguments: new object[] { true, "arg2", 123 });
var mockProtocol = MockHubProtocol.ReturnOnParse(invocation);
var hubConnection = new HubConnection(mockConnection.Object, mockProtocol, null);
await hubConnection.StartAsync();
mockConnection.Raise(c => c.Received += null, new object[] { new byte[] { } });
Assert.Equal(1, mockProtocol.ParseCalls);
}
// Moq really doesn't handle out parameters well, so to make these tests work I added a manual mock -anurse
private class MockHubProtocol : IHubProtocol
{

View File

@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO;
using System.Text;
using System.Threading;
@ -30,7 +31,6 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
private TransferMode? _transferMode;
public event Func<Task> Connected;
public event Func<byte[], Task> Received;
public event Func<Exception, Task> Closed;
public Task Started => _started.Task;
@ -38,6 +38,8 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
public ReadableChannel<byte[]> SentMessages => _sentMessages.In;
public WritableChannel<byte[]> ReceivedMessages => _receivedMessages.Out;
private readonly List<ReceiveCallback> _callbacks = new List<ReceiveCallback>();
public IFeatureCollection Features { get; } = new FeatureCollection();
public TestConnection(TransferMode? transferMode = null)
@ -120,7 +122,16 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
{
while (_receivedMessages.In.TryRead(out var message))
{
await Received?.Invoke(message);
ReceiveCallback[] callbackCopies;
lock (_callbacks)
{
callbackCopies = _callbacks.ToArray();
}
foreach (var callback in callbackCopies)
{
await callback.InvokeAsync(message);
}
}
}
}
@ -136,5 +147,51 @@ namespace Microsoft.AspNetCore.SignalR.Client.Tests
Closed?.Invoke(ex);
}
}
public IDisposable OnReceived(Func<byte[], object, Task> callback, object state)
{
var receiveCallBack = new ReceiveCallback(callback, state);
lock (_callbacks)
{
_callbacks.Add(receiveCallBack);
}
return new Subscription(receiveCallBack, _callbacks);
}
private class ReceiveCallback
{
private readonly Func<byte[], object, Task> _callback;
private readonly object _state;
public ReceiveCallback(Func<byte[], object, Task> callback, object state)
{
_callback = callback;
_state = state;
}
public Task InvokeAsync(byte[] data)
{
return _callback(data, _state);
}
}
private class Subscription : IDisposable
{
private readonly ReceiveCallback _callback;
private readonly List<ReceiveCallback> _callbacks;
public Subscription(ReceiveCallback callback, List<ReceiveCallback> callbacks)
{
_callback = callback;
_callbacks = callbacks;
}
public void Dispose()
{
lock (_callbacks)
{
_callbacks.Remove(_callback);
}
}
}
}
}

View File

@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
var receiveTcs = new TaskCompletionSource<string>();
var closeTcs = new TaskCompletionSource<object>();
connection.Received += data =>
connection.OnReceived((data, state) =>
{
logger.LogInformation("Received {length} byte message", data.Length);
@ -124,10 +124,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests
{
data = Convert.FromBase64String(Encoding.UTF8.GetString(data));
}
receiveTcs.TrySetResult(Encoding.UTF8.GetString(data));
var tcs = (TaskCompletionSource<string>)state;
tcs.TrySetResult(Encoding.UTF8.GetString(data));
return Task.CompletedTask;
};
}, receiveTcs);
connection.Closed += e =>
{
logger.LogInformation("Connection closed");
@ -224,12 +225,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests
try
{
var receiveTcs = new TaskCompletionSource<byte[]>();
connection.Received += data =>
connection.OnReceived((data, state) =>
{
logger.LogInformation("Received {length} byte message", data.Length);
receiveTcs.TrySetResult(data);
var tcs = (TaskCompletionSource<byte[]>)state;
tcs.TrySetResult(data);
return Task.CompletedTask;
};
}, receiveTcs);
logger.LogInformation("Starting connection to {url}", url);
await connection.StartAsync().OrTimeout();