Hubs are more fleshed out (#11)
- `HubEndPoint<T>` : `RpcEndPoint<T>` where T is the Hub type. Optimizing for a single hub per connection here. - Hubs get OnConnectedAsync and OnDisconnectedAsync methods that are invoked at the right time and with the right scope. - Introduced HubLifetimeManager<THub> (naming TBD) which is the center of the universe for Hub behaviors.
This commit is contained in:
parent
50e5827414
commit
53858495dc
|
|
@ -8,9 +8,12 @@ namespace SocketsSample
|
|||
{
|
||||
public class ChatEndPoint : EndPoint
|
||||
{
|
||||
public ConnectionList Connections { get; } = new ConnectionList();
|
||||
|
||||
public override async Task OnConnected(Connection connection)
|
||||
{
|
||||
Connections.Add(connection);
|
||||
|
||||
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
|
||||
|
||||
while (true)
|
||||
|
|
@ -33,6 +36,8 @@ namespace SocketsSample
|
|||
}
|
||||
}
|
||||
|
||||
Connections.Remove(connection);
|
||||
|
||||
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
using System;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using SocketsSample.EndPoints.Hubs;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class HubEndPoint<THub> : RpcEndpoint<THub>, IHubConnectionContext where THub : Hub
|
||||
{
|
||||
private readonly AllClientProxy<THub> _all;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
|
||||
InvocationAdapterRegistry registry,
|
||||
ILoggerFactory loggerFactory,
|
||||
IServiceScopeFactory serviceScopeFactory)
|
||||
: base(registry, loggerFactory, serviceScopeFactory)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_all = new AllClientProxy<THub>(_lifetimeManager);
|
||||
}
|
||||
|
||||
public virtual IClientProxy All => _all;
|
||||
|
||||
public virtual IClientProxy Client(string connectionId)
|
||||
{
|
||||
return new SingleClientProxy<THub>(_lifetimeManager, connectionId);
|
||||
}
|
||||
|
||||
public virtual IClientProxy Group(string groupName)
|
||||
{
|
||||
return new GroupProxy<THub>(_lifetimeManager, groupName);
|
||||
}
|
||||
|
||||
public virtual IClientProxy User(string userId)
|
||||
{
|
||||
return new UserProxy<THub>(_lifetimeManager, userId);
|
||||
}
|
||||
|
||||
public override async Task OnConnected(Connection connection)
|
||||
{
|
||||
try
|
||||
{
|
||||
await _lifetimeManager.OnConnectedAsync(connection);
|
||||
|
||||
using (var scope = _serviceScopeFactory.CreateScope())
|
||||
{
|
||||
var hub = scope.ServiceProvider.GetService<THub>() ?? Activator.CreateInstance<THub>();
|
||||
Initialize(connection, hub);
|
||||
await hub.OnConnectedAsync();
|
||||
}
|
||||
|
||||
await base.OnConnected(connection);
|
||||
}
|
||||
finally
|
||||
{
|
||||
using (var scope = _serviceScopeFactory.CreateScope())
|
||||
{
|
||||
var hub = scope.ServiceProvider.GetService<THub>() ?? Activator.CreateInstance<THub>();
|
||||
Initialize(connection, hub);
|
||||
await hub.OnDisconnectedAsync();
|
||||
}
|
||||
|
||||
await _lifetimeManager.OnDisconnectedAsync(connection);
|
||||
}
|
||||
}
|
||||
|
||||
protected override void BeforeInvoke(Connection connection, THub endpoint)
|
||||
{
|
||||
Initialize(connection, endpoint);
|
||||
}
|
||||
|
||||
private void Initialize(Connection connection, THub endpoint)
|
||||
{
|
||||
var hub = endpoint;
|
||||
hub.Clients = this;
|
||||
hub.Context = new HubCallerContext(connection);
|
||||
hub.Groups = new GroupManager<THub>(connection, _lifetimeManager);
|
||||
}
|
||||
|
||||
protected override void AfterInvoke(Connection connection, THub endpoint)
|
||||
{
|
||||
// Poison the hub make sure it can't be used after invocation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,117 +0,0 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class HubEndpoint : RpcEndpoint, IHubConnectionContext
|
||||
{
|
||||
private readonly ILogger<HubEndpoint> _logger;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
|
||||
public HubEndpoint(ILogger<HubEndpoint> logger, ILogger<RpcEndpoint> jsonRpcLogger, IServiceProvider serviceProvider)
|
||||
: base(jsonRpcLogger, serviceProvider)
|
||||
{
|
||||
_logger = logger;
|
||||
_serviceProvider = serviceProvider;
|
||||
All = new AllClientProxy(this);
|
||||
}
|
||||
|
||||
public IClientProxy All { get; }
|
||||
|
||||
public IClientProxy Client(string connectionId)
|
||||
{
|
||||
return new SingleClientProxy(this, connectionId);
|
||||
}
|
||||
|
||||
protected override void Initialize(Connection connection, object endpoint)
|
||||
{
|
||||
var hub = (Hub)endpoint;
|
||||
hub.Clients = this;
|
||||
hub.Context = new HubCallerContext(connection.ConnectionId, connection.User);
|
||||
|
||||
base.Initialize(connection, endpoint);
|
||||
}
|
||||
|
||||
protected override void DiscoverEndpoints()
|
||||
{
|
||||
// Register the chat hub
|
||||
RegisterRPCEndPoint(typeof(Chat));
|
||||
}
|
||||
|
||||
private class AllClientProxy : IClientProxy
|
||||
{
|
||||
private readonly HubEndpoint _endPoint;
|
||||
|
||||
public AllClientProxy(HubEndpoint endPoint)
|
||||
{
|
||||
_endPoint = endPoint;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
// REVIEW: Thread safety
|
||||
var tasks = new List<Task>(_endPoint.Connections.Count);
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = method,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
// TODO: serialize once per format by providing a different stream?
|
||||
foreach (var connection in _endPoint.Connections)
|
||||
{
|
||||
|
||||
var invocationAdapter =
|
||||
_endPoint._serviceProvider
|
||||
.GetRequiredService<InvocationAdapterRegistry>()
|
||||
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
private class SingleClientProxy : IClientProxy
|
||||
{
|
||||
private readonly string _connectionId;
|
||||
private readonly HubEndpoint _endPoint;
|
||||
|
||||
public SingleClientProxy(HubEndpoint endPoint, string connectionId)
|
||||
{
|
||||
_endPoint = endPoint;
|
||||
_connectionId = connectionId;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
var connection = _endPoint.Connections[_connectionId];
|
||||
|
||||
var invocationAdapter =
|
||||
_endPoint._serviceProvider
|
||||
.GetRequiredService<InvocationAdapterRegistry>()
|
||||
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
if (_endPoint._logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
_endPoint._logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method);
|
||||
}
|
||||
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = method,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
||||
namespace SocketsSample.EndPoints.Hubs
|
||||
{
|
||||
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
||||
{
|
||||
private readonly ConnectionList _connections = new ConnectionList();
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
|
||||
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
|
||||
{
|
||||
_registry = registry;
|
||||
}
|
||||
|
||||
public override void AddGroup(Connection connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.GetOrAdd("groups", k => new HashSet<string>());
|
||||
|
||||
lock (groups)
|
||||
{
|
||||
groups.Add(groupName);
|
||||
}
|
||||
}
|
||||
|
||||
public override void RemoveGroup(Connection connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
||||
|
||||
lock (groups)
|
||||
{
|
||||
groups.Remove(groupName);
|
||||
}
|
||||
}
|
||||
|
||||
public override Task InvokeAll(string methodName, params object[] args)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, c => true);
|
||||
}
|
||||
|
||||
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
|
||||
{
|
||||
var tasks = new List<Task>(_connections.Count);
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
// TODO: serialize once per format by providing a different stream?
|
||||
foreach (var connection in _connections)
|
||||
{
|
||||
if (!include(connection))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
}
|
||||
|
||||
public override Task InvokeConnection(string connectionId, string methodName, params object[] args)
|
||||
{
|
||||
var connection = _connections[connectionId];
|
||||
|
||||
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream());
|
||||
}
|
||||
|
||||
public override Task InvokeGroup(string groupName, string methodName, params object[] args)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
{
|
||||
var groups = connection.Metadata.Get<HashSet<string>>("groups");
|
||||
return groups?.Contains(groupName) == true;
|
||||
});
|
||||
}
|
||||
|
||||
public override Task InvokeUser(string userId, string methodName, params object[] args)
|
||||
{
|
||||
return InvokeAllWhere(methodName, args, connection =>
|
||||
{
|
||||
return connection.User.Identity.Name == userId;
|
||||
});
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
_connections.Add(connection);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task OnDisconnectedAsync(Connection connection)
|
||||
{
|
||||
_connections.Remove(connection);
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
using System.Security.Claims;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public class HubCallerContext
|
||||
{
|
||||
public HubCallerContext(Connection connection)
|
||||
{
|
||||
ConnectionId = connection.ConnectionId;
|
||||
User = connection.User;
|
||||
Connection = connection;
|
||||
}
|
||||
|
||||
public Connection Connection { get; }
|
||||
|
||||
public ClaimsPrincipal User { get; }
|
||||
|
||||
public string ConnectionId { get; }
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,25 @@
|
|||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
|
||||
namespace SocketsSample.EndPoints.Hubs
|
||||
{
|
||||
public abstract class HubLifetimeManager<THub>
|
||||
{
|
||||
public abstract Task OnConnectedAsync(Connection connection);
|
||||
|
||||
public abstract Task OnDisconnectedAsync(Connection connection);
|
||||
|
||||
public abstract Task InvokeAll(string methodName, params object[] args);
|
||||
|
||||
public abstract Task InvokeConnection(string connectionId, string methodName, params object[] args);
|
||||
|
||||
public abstract Task InvokeGroup(string groupName, string methodName, params object[] args);
|
||||
|
||||
public abstract Task InvokeUser(string userId, string methodName, params object[] args);
|
||||
|
||||
public abstract void AddGroup(Connection connection, string groupName);
|
||||
|
||||
public abstract void RemoveGroup(Connection connection, string groupName);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
using System.Threading.Tasks;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public interface IClientProxy
|
||||
{
|
||||
/// <summary>
|
||||
/// Invokes a method on the connection(s) represented by the <see cref="IClientProxy"/> instance.
|
||||
/// </summary>
|
||||
/// <param name="method">name of the method to invoke</param>
|
||||
/// <param name="args">argumetns to pass to the client</param>
|
||||
/// <returns>A task that represents when the data has been sent to the client.</returns>
|
||||
Task Invoke(string method, params object[] args);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public interface IGroupManager
|
||||
{
|
||||
void Add(string groupName);
|
||||
void Remove(string groupName);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public interface IHubConnectionContext
|
||||
{
|
||||
IClientProxy All { get; }
|
||||
|
||||
IClientProxy Client(string connectionId);
|
||||
|
||||
IClientProxy Group(string groupName);
|
||||
|
||||
IClientProxy User(string userId);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading.Tasks;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample.EndPoints.Hubs
|
||||
{
|
||||
public class UserProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly string _userId;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public UserProxy(HubLifetimeManager<THub> lifetimeManager, string userId)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_userId = userId;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
return _lifetimeManager.InvokeUser(_userId, method, args);
|
||||
}
|
||||
}
|
||||
|
||||
public class GroupProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly string _groupName;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public GroupProxy(HubLifetimeManager<THub> lifetimeManager, string groupName)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_groupName = groupName;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
return _lifetimeManager.InvokeGroup(_groupName, method, args);
|
||||
}
|
||||
}
|
||||
|
||||
public class AllClientProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public AllClientProxy(HubLifetimeManager<THub> lifetimeManager)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
// TODO: More than just chat
|
||||
return _lifetimeManager.InvokeAll(method, args);
|
||||
}
|
||||
}
|
||||
|
||||
public class SingleClientProxy<THub> : IClientProxy
|
||||
{
|
||||
private readonly string _connectionId;
|
||||
private readonly HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
|
||||
public SingleClientProxy(HubLifetimeManager<THub> lifetimeManager, string connectionId)
|
||||
{
|
||||
_lifetimeManager = lifetimeManager;
|
||||
_connectionId = connectionId;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
return _lifetimeManager.InvokeConnection(_connectionId, method, args);
|
||||
}
|
||||
}
|
||||
|
||||
public class GroupManager<THub> : IGroupManager
|
||||
{
|
||||
private readonly Connection _connection;
|
||||
private HubLifetimeManager<THub> _lifetimeManager;
|
||||
|
||||
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
|
||||
{
|
||||
_connection = connection;
|
||||
_lifetimeManager = lifetimeManager;
|
||||
}
|
||||
|
||||
public void Add(string groupName)
|
||||
{
|
||||
_lifetimeManager.AddGroup(_connection, groupName);
|
||||
}
|
||||
|
||||
public void Remove(string groupName)
|
||||
{
|
||||
_lifetimeManager.RemoveGroup(_connection, groupName);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample.EndPoints.Hubs
|
||||
{
|
||||
public class PubSubHubLifetimeManager<THub> : HubLifetimeManager<THub>
|
||||
{
|
||||
private readonly IPubSub _bus;
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
|
||||
public PubSubHubLifetimeManager(IPubSub bus, InvocationAdapterRegistry registry)
|
||||
{
|
||||
_bus = bus;
|
||||
_registry = registry;
|
||||
}
|
||||
|
||||
public override Task InvokeAll(string methodName, params object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return _bus.Publish(typeof(THub).Name, message);
|
||||
}
|
||||
|
||||
public override Task InvokeConnection(string connectionId, string methodName, params object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return _bus.Publish(typeof(THub) + "." + connectionId, message);
|
||||
}
|
||||
|
||||
public override Task InvokeGroup(string groupName, string methodName, params object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return _bus.Publish(typeof(THub) + "." + groupName, message);
|
||||
}
|
||||
|
||||
public override Task InvokeUser(string userId, string methodName, params object[] args)
|
||||
{
|
||||
var message = new InvocationDescriptor
|
||||
{
|
||||
Method = methodName,
|
||||
Arguments = args
|
||||
};
|
||||
|
||||
return _bus.Publish(typeof(THub) + "." + userId, message);
|
||||
}
|
||||
|
||||
public override Task OnConnectedAsync(Connection connection)
|
||||
{
|
||||
var subs = connection.Metadata.GetOrAdd("subscriptions", k => new List<IDisposable>());
|
||||
|
||||
subs.Add(Subscribe(typeof(THub).Name, connection));
|
||||
subs.Add(Subscribe(typeof(THub).Name + "." + connection.ConnectionId, connection));
|
||||
subs.Add(Subscribe(typeof(THub).Name + "." + connection.User.Identity.Name, connection));
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override Task OnDisconnectedAsync(Connection connection)
|
||||
{
|
||||
var subs = connection.Metadata.Get<IList<IDisposable>>("subscriptions");
|
||||
|
||||
if (subs != null)
|
||||
{
|
||||
foreach (var sub in subs)
|
||||
{
|
||||
sub.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public override void AddGroup(Connection connection, string groupName)
|
||||
{
|
||||
var groups = connection.Metadata.GetOrAdd("groups", k => new ConcurrentDictionary<string, IDisposable>());
|
||||
var key = typeof(THub).Name + "." + groupName;
|
||||
groups.TryAdd(key, Subscribe(key, connection));
|
||||
}
|
||||
|
||||
public override void RemoveGroup(Connection connection, string groupName)
|
||||
{
|
||||
var key = typeof(THub) + "." + groupName;
|
||||
var groups = connection.Metadata.Get<ConcurrentDictionary<string, IDisposable>>("groups");
|
||||
|
||||
IDisposable subscription;
|
||||
if (groups != null && groups.TryRemove(key, out subscription))
|
||||
{
|
||||
subscription.Dispose();
|
||||
}
|
||||
}
|
||||
|
||||
private IDisposable Subscribe(string signal, Connection connection)
|
||||
{
|
||||
return _bus.Subscribe(signal, message =>
|
||||
{
|
||||
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
return invocationAdapter.WriteInvocationDescriptor((InvocationDescriptor)message, connection.Channel.GetStream());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.IO;
|
||||
using System.Linq;
|
||||
using System.Reflection;
|
||||
using System.Threading.Tasks;
|
||||
|
|
@ -8,34 +7,26 @@ using Channels;
|
|||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using SocketsSample.Protobuf;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class RpcEndpoint : EndPoint
|
||||
public class RpcEndpoint<T> : EndPoint where T : class
|
||||
{
|
||||
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, InvocationResultDescriptor>> _callbacks
|
||||
= new Dictionary<string, Func<Connection, InvocationDescriptor, InvocationResultDescriptor>>(StringComparer.OrdinalIgnoreCase);
|
||||
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
|
||||
|
||||
private readonly ILogger<RpcEndpoint> _logger;
|
||||
private readonly IServiceProvider _serviceProvider;
|
||||
private readonly ILogger _logger;
|
||||
private readonly InvocationAdapterRegistry _registry;
|
||||
protected readonly IServiceScopeFactory _serviceScopeFactory;
|
||||
|
||||
|
||||
public RpcEndpoint(ILogger<RpcEndpoint> logger, IServiceProvider serviceProvider)
|
||||
public RpcEndpoint(InvocationAdapterRegistry registry, ILoggerFactory loggerFactory, IServiceScopeFactory serviceScopeFactory)
|
||||
{
|
||||
// TODO: Discover end points
|
||||
_logger = logger;
|
||||
_serviceProvider = serviceProvider;
|
||||
_logger = loggerFactory.CreateLogger<RpcEndpoint<T>>();
|
||||
_registry = registry;
|
||||
_serviceScopeFactory = serviceScopeFactory;
|
||||
|
||||
DiscoverEndpoints();
|
||||
}
|
||||
|
||||
protected virtual void DiscoverEndpoints()
|
||||
{
|
||||
RegisterRPCEndPoint(typeof(Echo));
|
||||
RegisterRPCEndPoint();
|
||||
}
|
||||
|
||||
public override async Task OnConnected(Connection connection)
|
||||
|
|
@ -44,16 +35,14 @@ namespace SocketsSample
|
|||
await Task.Yield();
|
||||
|
||||
var stream = connection.Channel.GetStream();
|
||||
var invocationAdapter =
|
||||
_serviceProvider
|
||||
.GetRequiredService<InvocationAdapterRegistry>()
|
||||
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
|
||||
|
||||
while (true)
|
||||
{
|
||||
var invocationDescriptor =
|
||||
await invocationAdapter.ReadInvocationDescriptor(
|
||||
stream, methodName => {
|
||||
stream, methodName =>
|
||||
{
|
||||
Type[] types;
|
||||
// TODO: null or throw?
|
||||
return _paramTypes.TryGetValue(methodName, out types) ? types : null;
|
||||
|
|
@ -90,12 +79,19 @@ namespace SocketsSample
|
|||
}
|
||||
}
|
||||
|
||||
protected virtual void Initialize(Connection connection, object endpoint)
|
||||
protected virtual void BeforeInvoke(Connection connection, T endpoint)
|
||||
{
|
||||
}
|
||||
|
||||
protected void RegisterRPCEndPoint(Type type)
|
||||
protected virtual void AfterInvoke(Connection connection, T endpoint)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
protected void RegisterRPCEndPoint()
|
||||
{
|
||||
var type = typeof(T);
|
||||
|
||||
foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic))
|
||||
{
|
||||
var methodName = type.FullName + "." + methodInfo.Name;
|
||||
|
|
@ -115,21 +111,22 @@ namespace SocketsSample
|
|||
|
||||
_callbacks[methodName] = (connection, invocationDescriptor) =>
|
||||
{
|
||||
var invocationResult = new InvocationResultDescriptor();
|
||||
invocationResult.Id = invocationDescriptor.Id;
|
||||
|
||||
var scopeFactory = _serviceProvider.GetRequiredService<IServiceScopeFactory>();
|
||||
|
||||
// Scope per call so that deps injected get disposed
|
||||
using (var scope = scopeFactory.CreateScope())
|
||||
var invocationResult = new InvocationResultDescriptor()
|
||||
{
|
||||
object value = scope.ServiceProvider.GetService(type) ?? Activator.CreateInstance(type);
|
||||
Id = invocationDescriptor.Id
|
||||
};
|
||||
|
||||
Initialize(connection, value);
|
||||
using (var scope = _serviceScopeFactory.CreateScope())
|
||||
{
|
||||
var value = scope.ServiceProvider.GetService<T>() ?? Activator.CreateInstance<T>();
|
||||
|
||||
BeforeInvoke(connection, value);
|
||||
|
||||
try
|
||||
{
|
||||
var args = invocationDescriptor.Arguments
|
||||
var arguments = invocationDescriptor.Arguments ?? Array.Empty<object>();
|
||||
|
||||
var args = arguments
|
||||
.Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType))
|
||||
.ToArray();
|
||||
|
||||
|
|
@ -137,12 +134,18 @@ namespace SocketsSample
|
|||
}
|
||||
catch (TargetInvocationException ex)
|
||||
{
|
||||
_logger.LogError(0, ex, "Failed to invoke RPC method");
|
||||
invocationResult.Error = ex.InnerException.Message;
|
||||
}
|
||||
catch (Exception ex)
|
||||
{
|
||||
_logger.LogError(0, ex, "Failed to invoke RPC method");
|
||||
invocationResult.Error = ex.Message;
|
||||
}
|
||||
finally
|
||||
{
|
||||
AfterInvoke(connection, value);
|
||||
}
|
||||
}
|
||||
|
||||
return invocationResult;
|
||||
|
|
|
|||
|
|
@ -1,20 +1,23 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public class Chat : Hub
|
||||
{
|
||||
public void Send(string message)
|
||||
public override async Task OnConnectedAsync()
|
||||
{
|
||||
Clients.All.Invoke("Send", message);
|
||||
await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " joined");
|
||||
}
|
||||
|
||||
public Person EchoPerson(Person p)
|
||||
public override async Task OnDisconnectedAsync()
|
||||
{
|
||||
return p;
|
||||
await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " left");
|
||||
}
|
||||
|
||||
public Task Send(string message)
|
||||
{
|
||||
return Clients.All.Invoke("Send", Context.ConnectionId + ": " + message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,44 +1,23 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Security.Claims;
|
||||
using System.Threading.Tasks;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public class Hub
|
||||
{
|
||||
public virtual Task OnConnectedAsync()
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public virtual Task OnDisconnectedAsync()
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
}
|
||||
|
||||
public IHubConnectionContext Clients { get; set; }
|
||||
|
||||
public HubCallerContext Context { get; set; }
|
||||
}
|
||||
|
||||
public interface IHubConnectionContext
|
||||
{
|
||||
IClientProxy All { get; }
|
||||
|
||||
IClientProxy Client(string connectionId);
|
||||
}
|
||||
|
||||
public interface IClientProxy
|
||||
{
|
||||
/// <summary>
|
||||
/// Invokes a method on the connection(s) represented by the <see cref="IClientProxy"/> instance.
|
||||
/// </summary>
|
||||
/// <param name="method">name of the method to invoke</param>
|
||||
/// <param name="args">argumetns to pass to the client</param>
|
||||
/// <returns>A task that represents when the data has been sent to the client.</returns>
|
||||
Task Invoke(string method, params object[] args);
|
||||
}
|
||||
|
||||
public class HubCallerContext
|
||||
{
|
||||
public HubCallerContext(string connectionId, ClaimsPrincipal user)
|
||||
{
|
||||
ConnectionId = connectionId;
|
||||
User = user;
|
||||
}
|
||||
|
||||
public ClaimsPrincipal User { get; }
|
||||
|
||||
public string ConnectionId { get; }
|
||||
public IGroupManager Groups { get; set; }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,57 @@
|
|||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace SocketsSample.Hubs
|
||||
{
|
||||
public interface IPubSub
|
||||
{
|
||||
IDisposable Subscribe(string topic, Func<object, Task> callback);
|
||||
Task Publish(string topic, object data);
|
||||
}
|
||||
|
||||
public class Bus : IPubSub
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, List<Func<object, Task>>> _subscriptions = new ConcurrentDictionary<string, List<Func<object, Task>>>();
|
||||
|
||||
public IDisposable Subscribe(string key, Func<object, Task> observer)
|
||||
{
|
||||
var subscriptions = _subscriptions.GetOrAdd(key, _ => new List<Func<object, Task>>());
|
||||
subscriptions.Add(observer);
|
||||
|
||||
return new DisposableAction(() =>
|
||||
{
|
||||
subscriptions.Remove(observer);
|
||||
});
|
||||
}
|
||||
|
||||
public async Task Publish(string key, object data)
|
||||
{
|
||||
List<Func<object, Task>> subscriptions;
|
||||
if (_subscriptions.TryGetValue(key, out subscriptions))
|
||||
{
|
||||
foreach (var c in subscriptions)
|
||||
{
|
||||
await c(data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class DisposableAction : IDisposable
|
||||
{
|
||||
private Action _action;
|
||||
|
||||
public DisposableAction(Action action)
|
||||
{
|
||||
_action = action;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Interlocked.Exchange(ref _action, () => { }).Invoke();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,7 +1,4 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading.Tasks;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
|
|
@ -12,5 +9,10 @@ namespace SocketsSample
|
|||
public string Method { get; set; }
|
||||
|
||||
public object[] Arguments { get; set; }
|
||||
|
||||
public override string ToString()
|
||||
{
|
||||
return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,9 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Builder;
|
||||
using Microsoft.AspNetCore.Hosting;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.DependencyInjection;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using SocketsSample.EndPoints.Hubs;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
|
|
@ -16,9 +15,13 @@ namespace SocketsSample
|
|||
{
|
||||
services.AddRouting();
|
||||
|
||||
services.AddSingleton<HubEndpoint>();
|
||||
services.AddSingleton<RpcEndpoint>();
|
||||
services.AddSingleton<IPubSub, Bus>();
|
||||
services.AddSingleton(typeof(HubLifetimeManager<>), typeof(PubSubHubLifetimeManager<>));
|
||||
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
|
||||
services.AddSingleton(typeof(RpcEndpoint<>), typeof(RpcEndpoint<>));
|
||||
|
||||
services.AddSingleton<ChatEndPoint>();
|
||||
services.AddSingleton<Chat>();
|
||||
|
||||
services.AddSingleton<ProtobufSerializer>();
|
||||
services.AddSingleton<InvocationAdapterRegistry>();
|
||||
|
|
@ -36,12 +39,11 @@ namespace SocketsSample
|
|||
app.UseDeveloperExceptionPage();
|
||||
}
|
||||
|
||||
|
||||
app.UseSockets(routes =>
|
||||
{
|
||||
routes.MapSocketEndpoint<HubEndpoint>("/hubs");
|
||||
routes.MapSocketEndpoint<HubEndPoint<Chat>>("/hubs");
|
||||
routes.MapSocketEndpoint<ChatEndPoint>("/chat");
|
||||
routes.MapSocketEndpoint<RpcEndpoint>("/jsonrpc");
|
||||
routes.MapSocketEndpoint<RpcEndpoint<Echo>>("/jsonrpc");
|
||||
});
|
||||
|
||||
app.UseRpc(invocationAdapters =>
|
||||
|
|
|
|||
|
|
@ -42,11 +42,11 @@
|
|||
|
||||
delete calls[response.Id];
|
||||
|
||||
if (response.error) {
|
||||
cb.error(response.error);
|
||||
if (response.Error) {
|
||||
cb.error(response.Error);
|
||||
}
|
||||
else {
|
||||
cb.success(response.result);
|
||||
cb.success(response.Result);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
|
|
|||
|
|
@ -16,15 +16,15 @@
|
|||
ws.onmessage = function (event) {
|
||||
var response = JSON.parse(event.data);
|
||||
|
||||
var cb = calls[response.id];
|
||||
var cb = calls[response.Id];
|
||||
|
||||
delete calls[response.id];
|
||||
delete calls[response.Id];
|
||||
|
||||
if (response.error) {
|
||||
cb.error(response.error);
|
||||
if (response.Error) {
|
||||
cb.error(response.Error);
|
||||
}
|
||||
else {
|
||||
cb.success(response.result);
|
||||
cb.success(response.Result);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -35,7 +35,7 @@
|
|||
this.invoke = function (method, args) {
|
||||
return new Promise((resolve, reject) => {
|
||||
calls[id] = { success: resolve, error: reject };
|
||||
ws.send(JSON.stringify({ method: method, params: args, id: id }));
|
||||
ws.send(JSON.stringify({ method: method, arguments: args, id: id }));
|
||||
id++;
|
||||
});
|
||||
};
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
|
||||
using System.Collections.Generic;
|
||||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class ConnectionMetadata
|
||||
{
|
||||
private IDictionary<string, object> _metadata = new Dictionary<string, object>();
|
||||
private ConcurrentDictionary<string, object> _metadata = new ConcurrentDictionary<string, object>();
|
||||
|
||||
public Format Format { get; set; } = Format.Text;
|
||||
|
||||
|
|
@ -13,12 +14,24 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
get
|
||||
{
|
||||
return _metadata[key];
|
||||
object value;
|
||||
_metadata.TryGetValue(key, out value);
|
||||
return value;
|
||||
}
|
||||
set
|
||||
{
|
||||
_metadata[key] = value;
|
||||
}
|
||||
}
|
||||
|
||||
public T GetOrAdd<T>(string key, Func<string, T> factory)
|
||||
{
|
||||
return (T)_metadata.GetOrAdd(key, k => factory(k));
|
||||
}
|
||||
|
||||
public T Get<T>(string key)
|
||||
{
|
||||
return (T)this[key];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,11 +7,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
/// </summary>
|
||||
public abstract class EndPoint
|
||||
{
|
||||
/// <summary>
|
||||
/// Live list of connections for this <see cref="EndPoint"/>
|
||||
/// </summary>
|
||||
public ConnectionList Connections { get; } = new ConnectionList();
|
||||
|
||||
/// <summary>
|
||||
/// Called when a new connection is accepted to the endpoint
|
||||
/// </summary>
|
||||
|
|
|
|||
|
|
@ -94,12 +94,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
state.Connection.Channel.Dispose();
|
||||
|
||||
await endpointTask;
|
||||
|
||||
endpoint.Connections.Remove(state.Connection);
|
||||
};
|
||||
|
||||
endpoint.Connections.Add(state.Connection);
|
||||
|
||||
endpointTask = endpoint.OnConnected(state.Connection);
|
||||
state.Connection.Metadata["endpoint"] = endpointTask;
|
||||
}
|
||||
|
|
@ -124,8 +120,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
state.Connection.Channel.Dispose();
|
||||
|
||||
await transportTask;
|
||||
|
||||
endpoint.Connections.Remove(state.Connection);
|
||||
}
|
||||
|
||||
// Mark the connection as inactive
|
||||
|
|
@ -143,8 +137,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connection);
|
||||
|
||||
endpoint.Connections.Add(connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connection);
|
||||
|
||||
|
|
@ -159,8 +151,6 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
|
||||
// Wait for both
|
||||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
endpoint.Connections.Remove(connection);
|
||||
}
|
||||
|
||||
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)
|
||||
|
|
|
|||
Loading…
Reference in New Issue