Hubs are more fleshed out (#11)

- `HubEndPoint<T>` : `RpcEndPoint<T>` where T is the Hub type. Optimizing for a single hub per connection here.
- Hubs get OnConnectedAsync and OnDisconnectedAsync methods that are invoked at the right time and with the right scope.
- Introduced HubLifetimeManager<THub> (naming TBD) which is the center of the universe for Hub behaviors.
This commit is contained in:
David Fowler 2016-11-01 23:15:31 -07:00 committed by GitHub
parent 50e5827414
commit 53858495dc
22 changed files with 670 additions and 231 deletions

View File

@ -8,9 +8,12 @@ namespace SocketsSample
{
public class ChatEndPoint : EndPoint
{
public ConnectionList Connections { get; } = new ConnectionList();
public override async Task OnConnected(Connection connection)
{
Connections.Add(connection);
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
while (true)
@ -33,6 +36,8 @@ namespace SocketsSample
}
}
Connections.Remove(connection);
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
}

View File

@ -0,0 +1,89 @@
using System;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints.Hubs;
using SocketsSample.Hubs;
namespace SocketsSample
{
public class HubEndPoint<THub> : RpcEndpoint<THub>, IHubConnectionContext where THub : Hub
{
private readonly AllClientProxy<THub> _all;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public HubEndPoint(HubLifetimeManager<THub> lifetimeManager,
InvocationAdapterRegistry registry,
ILoggerFactory loggerFactory,
IServiceScopeFactory serviceScopeFactory)
: base(registry, loggerFactory, serviceScopeFactory)
{
_lifetimeManager = lifetimeManager;
_all = new AllClientProxy<THub>(_lifetimeManager);
}
public virtual IClientProxy All => _all;
public virtual IClientProxy Client(string connectionId)
{
return new SingleClientProxy<THub>(_lifetimeManager, connectionId);
}
public virtual IClientProxy Group(string groupName)
{
return new GroupProxy<THub>(_lifetimeManager, groupName);
}
public virtual IClientProxy User(string userId)
{
return new UserProxy<THub>(_lifetimeManager, userId);
}
public override async Task OnConnected(Connection connection)
{
try
{
await _lifetimeManager.OnConnectedAsync(connection);
using (var scope = _serviceScopeFactory.CreateScope())
{
var hub = scope.ServiceProvider.GetService<THub>() ?? Activator.CreateInstance<THub>();
Initialize(connection, hub);
await hub.OnConnectedAsync();
}
await base.OnConnected(connection);
}
finally
{
using (var scope = _serviceScopeFactory.CreateScope())
{
var hub = scope.ServiceProvider.GetService<THub>() ?? Activator.CreateInstance<THub>();
Initialize(connection, hub);
await hub.OnDisconnectedAsync();
}
await _lifetimeManager.OnDisconnectedAsync(connection);
}
}
protected override void BeforeInvoke(Connection connection, THub endpoint)
{
Initialize(connection, endpoint);
}
private void Initialize(Connection connection, THub endpoint)
{
var hub = endpoint;
hub.Clients = this;
hub.Context = new HubCallerContext(connection);
hub.Groups = new GroupManager<THub>(connection, _lifetimeManager);
}
protected override void AfterInvoke(Connection connection, THub endpoint)
{
// Poison the hub make sure it can't be used after invocation
}
}
}

View File

@ -1,117 +0,0 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.Hubs;
namespace SocketsSample
{
public class HubEndpoint : RpcEndpoint, IHubConnectionContext
{
private readonly ILogger<HubEndpoint> _logger;
private readonly IServiceProvider _serviceProvider;
public HubEndpoint(ILogger<HubEndpoint> logger, ILogger<RpcEndpoint> jsonRpcLogger, IServiceProvider serviceProvider)
: base(jsonRpcLogger, serviceProvider)
{
_logger = logger;
_serviceProvider = serviceProvider;
All = new AllClientProxy(this);
}
public IClientProxy All { get; }
public IClientProxy Client(string connectionId)
{
return new SingleClientProxy(this, connectionId);
}
protected override void Initialize(Connection connection, object endpoint)
{
var hub = (Hub)endpoint;
hub.Clients = this;
hub.Context = new HubCallerContext(connection.ConnectionId, connection.User);
base.Initialize(connection, endpoint);
}
protected override void DiscoverEndpoints()
{
// Register the chat hub
RegisterRPCEndPoint(typeof(Chat));
}
private class AllClientProxy : IClientProxy
{
private readonly HubEndpoint _endPoint;
public AllClientProxy(HubEndpoint endPoint)
{
_endPoint = endPoint;
}
public Task Invoke(string method, params object[] args)
{
// REVIEW: Thread safety
var tasks = new List<Task>(_endPoint.Connections.Count);
var message = new InvocationDescriptor
{
Method = method,
Arguments = args
};
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _endPoint.Connections)
{
var invocationAdapter =
_endPoint._serviceProvider
.GetRequiredService<InvocationAdapterRegistry>()
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()));
}
return Task.WhenAll(tasks);
}
}
private class SingleClientProxy : IClientProxy
{
private readonly string _connectionId;
private readonly HubEndpoint _endPoint;
public SingleClientProxy(HubEndpoint endPoint, string connectionId)
{
_endPoint = endPoint;
_connectionId = connectionId;
}
public Task Invoke(string method, params object[] args)
{
var connection = _endPoint.Connections[_connectionId];
var invocationAdapter =
_endPoint._serviceProvider
.GetRequiredService<InvocationAdapterRegistry>()
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
if (_endPoint._logger.IsEnabled(LogLevel.Debug))
{
_endPoint._logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method);
}
var message = new InvocationDescriptor
{
Method = method,
Arguments = args
};
return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream());
}
}
}
}

View File

@ -0,0 +1,115 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints.Hubs
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly ConnectionList _connections = new ConnectionList();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
{
_registry = registry;
}
public override void AddGroup(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", k => new HashSet<string>());
lock (groups)
{
groups.Add(groupName);
}
}
public override void RemoveGroup(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
lock (groups)
{
groups.Remove(groupName);
}
}
public override Task InvokeAll(string methodName, params object[] args)
{
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
// TODO: serialize once per format by providing a different stream?
foreach (var connection in _connections)
{
if (!include(connection))
{
continue;
}
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
tasks.Add(invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream()));
}
return Task.WhenAll(tasks);
}
public override Task InvokeConnection(string connectionId, string methodName, params object[] args)
{
var connection = _connections[connectionId];
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return invocationAdapter.WriteInvocationDescriptor(message, connection.Channel.GetStream());
}
public override Task InvokeGroup(string groupName, string methodName, params object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
return groups?.Contains(groupName) == true;
});
}
public override Task InvokeUser(string userId, string methodName, params object[] args)
{
return InvokeAllWhere(methodName, args, connection =>
{
return connection.User.Identity.Name == userId;
});
}
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
return Task.CompletedTask;
}
}
}

View File

@ -0,0 +1,21 @@
using System.Security.Claims;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.Hubs
{
public class HubCallerContext
{
public HubCallerContext(Connection connection)
{
ConnectionId = connection.ConnectionId;
User = connection.User;
Connection = connection;
}
public Connection Connection { get; }
public ClaimsPrincipal User { get; }
public string ConnectionId { get; }
}
}

View File

@ -0,0 +1,25 @@
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints.Hubs
{
public abstract class HubLifetimeManager<THub>
{
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnDisconnectedAsync(Connection connection);
public abstract Task InvokeAll(string methodName, params object[] args);
public abstract Task InvokeConnection(string connectionId, string methodName, params object[] args);
public abstract Task InvokeGroup(string groupName, string methodName, params object[] args);
public abstract Task InvokeUser(string userId, string methodName, params object[] args);
public abstract void AddGroup(Connection connection, string groupName);
public abstract void RemoveGroup(Connection connection, string groupName);
}
}

View File

@ -0,0 +1,15 @@
using System.Threading.Tasks;
namespace SocketsSample.Hubs
{
public interface IClientProxy
{
/// <summary>
/// Invokes a method on the connection(s) represented by the <see cref="IClientProxy"/> instance.
/// </summary>
/// <param name="method">name of the method to invoke</param>
/// <param name="args">argumetns to pass to the client</param>
/// <returns>A task that represents when the data has been sent to the client.</returns>
Task Invoke(string method, params object[] args);
}
}

View File

@ -0,0 +1,8 @@
namespace SocketsSample.Hubs
{
public interface IGroupManager
{
void Add(string groupName);
void Remove(string groupName);
}
}

View File

@ -0,0 +1,13 @@
namespace SocketsSample.Hubs
{
public interface IHubConnectionContext
{
IClientProxy All { get; }
IClientProxy Client(string connectionId);
IClientProxy Group(string groupName);
IClientProxy User(string userId);
}
}

View File

@ -0,0 +1,98 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
using SocketsSample.Hubs;
namespace SocketsSample.EndPoints.Hubs
{
public class UserProxy<THub> : IClientProxy
{
private readonly string _userId;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public UserProxy(HubLifetimeManager<THub> lifetimeManager, string userId)
{
_lifetimeManager = lifetimeManager;
_userId = userId;
}
public Task Invoke(string method, params object[] args)
{
return _lifetimeManager.InvokeUser(_userId, method, args);
}
}
public class GroupProxy<THub> : IClientProxy
{
private readonly string _groupName;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public GroupProxy(HubLifetimeManager<THub> lifetimeManager, string groupName)
{
_lifetimeManager = lifetimeManager;
_groupName = groupName;
}
public Task Invoke(string method, params object[] args)
{
return _lifetimeManager.InvokeGroup(_groupName, method, args);
}
}
public class AllClientProxy<THub> : IClientProxy
{
private readonly HubLifetimeManager<THub> _lifetimeManager;
public AllClientProxy(HubLifetimeManager<THub> lifetimeManager)
{
_lifetimeManager = lifetimeManager;
}
public Task Invoke(string method, params object[] args)
{
// TODO: More than just chat
return _lifetimeManager.InvokeAll(method, args);
}
}
public class SingleClientProxy<THub> : IClientProxy
{
private readonly string _connectionId;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public SingleClientProxy(HubLifetimeManager<THub> lifetimeManager, string connectionId)
{
_lifetimeManager = lifetimeManager;
_connectionId = connectionId;
}
public Task Invoke(string method, params object[] args)
{
return _lifetimeManager.InvokeConnection(_connectionId, method, args);
}
}
public class GroupManager<THub> : IGroupManager
{
private readonly Connection _connection;
private HubLifetimeManager<THub> _lifetimeManager;
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
{
_connection = connection;
_lifetimeManager = lifetimeManager;
}
public void Add(string groupName)
{
_lifetimeManager.AddGroup(_connection, groupName);
}
public void Remove(string groupName)
{
_lifetimeManager.RemoveGroup(_connection, groupName);
}
}
}

View File

@ -0,0 +1,123 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Sockets;
using SocketsSample.Hubs;
namespace SocketsSample.EndPoints.Hubs
{
public class PubSubHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly IPubSub _bus;
private readonly InvocationAdapterRegistry _registry;
public PubSubHubLifetimeManager(IPubSub bus, InvocationAdapterRegistry registry)
{
_bus = bus;
_registry = registry;
}
public override Task InvokeAll(string methodName, params object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return _bus.Publish(typeof(THub).Name, message);
}
public override Task InvokeConnection(string connectionId, string methodName, params object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return _bus.Publish(typeof(THub) + "." + connectionId, message);
}
public override Task InvokeGroup(string groupName, string methodName, params object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return _bus.Publish(typeof(THub) + "." + groupName, message);
}
public override Task InvokeUser(string userId, string methodName, params object[] args)
{
var message = new InvocationDescriptor
{
Method = methodName,
Arguments = args
};
return _bus.Publish(typeof(THub) + "." + userId, message);
}
public override Task OnConnectedAsync(Connection connection)
{
var subs = connection.Metadata.GetOrAdd("subscriptions", k => new List<IDisposable>());
subs.Add(Subscribe(typeof(THub).Name, connection));
subs.Add(Subscribe(typeof(THub).Name + "." + connection.ConnectionId, connection));
subs.Add(Subscribe(typeof(THub).Name + "." + connection.User.Identity.Name, connection));
return Task.CompletedTask;
}
public override Task OnDisconnectedAsync(Connection connection)
{
var subs = connection.Metadata.Get<IList<IDisposable>>("subscriptions");
if (subs != null)
{
foreach (var sub in subs)
{
sub.Dispose();
}
}
return Task.CompletedTask;
}
public override void AddGroup(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", k => new ConcurrentDictionary<string, IDisposable>());
var key = typeof(THub).Name + "." + groupName;
groups.TryAdd(key, Subscribe(key, connection));
}
public override void RemoveGroup(Connection connection, string groupName)
{
var key = typeof(THub) + "." + groupName;
var groups = connection.Metadata.Get<ConcurrentDictionary<string, IDisposable>>("groups");
IDisposable subscription;
if (groups != null && groups.TryRemove(key, out subscription))
{
subscription.Dispose();
}
}
private IDisposable Subscribe(string signal, Connection connection)
{
return _bus.Subscribe(signal, message =>
{
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
return invocationAdapter.WriteInvocationDescriptor((InvocationDescriptor)message, connection.Channel.GetStream());
});
}
}
}

View File

@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;
@ -8,34 +7,26 @@ using Channels;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using SocketsSample.Protobuf;
namespace SocketsSample
{
public class RpcEndpoint : EndPoint
public class RpcEndpoint<T> : EndPoint where T : class
{
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, InvocationResultDescriptor>> _callbacks
= new Dictionary<string, Func<Connection, InvocationDescriptor, InvocationResultDescriptor>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly ILogger<RpcEndpoint> _logger;
private readonly IServiceProvider _serviceProvider;
private readonly ILogger _logger;
private readonly InvocationAdapterRegistry _registry;
protected readonly IServiceScopeFactory _serviceScopeFactory;
public RpcEndpoint(ILogger<RpcEndpoint> logger, IServiceProvider serviceProvider)
public RpcEndpoint(InvocationAdapterRegistry registry, ILoggerFactory loggerFactory, IServiceScopeFactory serviceScopeFactory)
{
// TODO: Discover end points
_logger = logger;
_serviceProvider = serviceProvider;
_logger = loggerFactory.CreateLogger<RpcEndpoint<T>>();
_registry = registry;
_serviceScopeFactory = serviceScopeFactory;
DiscoverEndpoints();
}
protected virtual void DiscoverEndpoints()
{
RegisterRPCEndPoint(typeof(Echo));
RegisterRPCEndPoint();
}
public override async Task OnConnected(Connection connection)
@ -44,16 +35,14 @@ namespace SocketsSample
await Task.Yield();
var stream = connection.Channel.GetStream();
var invocationAdapter =
_serviceProvider
.GetRequiredService<InvocationAdapterRegistry>()
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
var invocationAdapter = _registry.GetInvocationAdapter((string)connection.Metadata["formatType"]);
while (true)
{
var invocationDescriptor =
await invocationAdapter.ReadInvocationDescriptor(
stream, methodName => {
stream, methodName =>
{
Type[] types;
// TODO: null or throw?
return _paramTypes.TryGetValue(methodName, out types) ? types : null;
@ -90,12 +79,19 @@ namespace SocketsSample
}
}
protected virtual void Initialize(Connection connection, object endpoint)
protected virtual void BeforeInvoke(Connection connection, T endpoint)
{
}
protected void RegisterRPCEndPoint(Type type)
protected virtual void AfterInvoke(Connection connection, T endpoint)
{
}
protected void RegisterRPCEndPoint()
{
var type = typeof(T);
foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic))
{
var methodName = type.FullName + "." + methodInfo.Name;
@ -115,21 +111,22 @@ namespace SocketsSample
_callbacks[methodName] = (connection, invocationDescriptor) =>
{
var invocationResult = new InvocationResultDescriptor();
invocationResult.Id = invocationDescriptor.Id;
var scopeFactory = _serviceProvider.GetRequiredService<IServiceScopeFactory>();
// Scope per call so that deps injected get disposed
using (var scope = scopeFactory.CreateScope())
var invocationResult = new InvocationResultDescriptor()
{
object value = scope.ServiceProvider.GetService(type) ?? Activator.CreateInstance(type);
Id = invocationDescriptor.Id
};
Initialize(connection, value);
using (var scope = _serviceScopeFactory.CreateScope())
{
var value = scope.ServiceProvider.GetService<T>() ?? Activator.CreateInstance<T>();
BeforeInvoke(connection, value);
try
{
var args = invocationDescriptor.Arguments
var arguments = invocationDescriptor.Arguments ?? Array.Empty<object>();
var args = arguments
.Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType))
.ToArray();
@ -137,12 +134,18 @@ namespace SocketsSample
}
catch (TargetInvocationException ex)
{
_logger.LogError(0, ex, "Failed to invoke RPC method");
invocationResult.Error = ex.InnerException.Message;
}
catch (Exception ex)
{
_logger.LogError(0, ex, "Failed to invoke RPC method");
invocationResult.Error = ex.Message;
}
finally
{
AfterInvoke(connection, value);
}
}
return invocationResult;

View File

@ -1,20 +1,23 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace SocketsSample.Hubs
{
public class Chat : Hub
{
public void Send(string message)
public override async Task OnConnectedAsync()
{
Clients.All.Invoke("Send", message);
await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " joined");
}
public Person EchoPerson(Person p)
public override async Task OnDisconnectedAsync()
{
return p;
await Clients.All.Invoke("Send", Context.Connection.ConnectionId + " left");
}
public Task Send(string message)
{
return Clients.All.Invoke("Send", Context.ConnectionId + ": " + message);
}
}
}

View File

@ -1,44 +1,23 @@
using System;
using System.Collections.Generic;
using System.Security.Claims;
using System.Threading.Tasks;
using System.Threading.Tasks;
namespace SocketsSample.Hubs
{
public class Hub
{
public virtual Task OnConnectedAsync()
{
return Task.CompletedTask;
}
public virtual Task OnDisconnectedAsync()
{
return Task.CompletedTask;
}
public IHubConnectionContext Clients { get; set; }
public HubCallerContext Context { get; set; }
}
public interface IHubConnectionContext
{
IClientProxy All { get; }
IClientProxy Client(string connectionId);
}
public interface IClientProxy
{
/// <summary>
/// Invokes a method on the connection(s) represented by the <see cref="IClientProxy"/> instance.
/// </summary>
/// <param name="method">name of the method to invoke</param>
/// <param name="args">argumetns to pass to the client</param>
/// <returns>A task that represents when the data has been sent to the client.</returns>
Task Invoke(string method, params object[] args);
}
public class HubCallerContext
{
public HubCallerContext(string connectionId, ClaimsPrincipal user)
{
ConnectionId = connectionId;
User = user;
}
public ClaimsPrincipal User { get; }
public string ConnectionId { get; }
public IGroupManager Groups { get; set; }
}
}

View File

@ -0,0 +1,57 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
namespace SocketsSample.Hubs
{
public interface IPubSub
{
IDisposable Subscribe(string topic, Func<object, Task> callback);
Task Publish(string topic, object data);
}
public class Bus : IPubSub
{
private readonly ConcurrentDictionary<string, List<Func<object, Task>>> _subscriptions = new ConcurrentDictionary<string, List<Func<object, Task>>>();
public IDisposable Subscribe(string key, Func<object, Task> observer)
{
var subscriptions = _subscriptions.GetOrAdd(key, _ => new List<Func<object, Task>>());
subscriptions.Add(observer);
return new DisposableAction(() =>
{
subscriptions.Remove(observer);
});
}
public async Task Publish(string key, object data)
{
List<Func<object, Task>> subscriptions;
if (_subscriptions.TryGetValue(key, out subscriptions))
{
foreach (var c in subscriptions)
{
await c(data);
}
}
}
private class DisposableAction : IDisposable
{
private Action _action;
public DisposableAction(Action action)
{
_action = action;
}
public void Dispose()
{
Interlocked.Exchange(ref _action, () => { }).Invoke();
}
}
}
}

View File

@ -1,7 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
namespace SocketsSample
{
@ -12,5 +9,10 @@ namespace SocketsSample
public string Method { get; set; }
public object[] Arguments { get; set; }
public override string ToString()
{
return $"{Id}: {Method}({(Arguments ?? new object[0]).Length})";
}
}
}

View File

@ -1,10 +1,9 @@
using System;
using System.Collections.Generic;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using SocketsSample.EndPoints.Hubs;
using SocketsSample.Hubs;
namespace SocketsSample
{
@ -16,9 +15,13 @@ namespace SocketsSample
{
services.AddRouting();
services.AddSingleton<HubEndpoint>();
services.AddSingleton<RpcEndpoint>();
services.AddSingleton<IPubSub, Bus>();
services.AddSingleton(typeof(HubLifetimeManager<>), typeof(PubSubHubLifetimeManager<>));
services.AddSingleton(typeof(HubEndPoint<>), typeof(HubEndPoint<>));
services.AddSingleton(typeof(RpcEndpoint<>), typeof(RpcEndpoint<>));
services.AddSingleton<ChatEndPoint>();
services.AddSingleton<Chat>();
services.AddSingleton<ProtobufSerializer>();
services.AddSingleton<InvocationAdapterRegistry>();
@ -36,12 +39,11 @@ namespace SocketsSample
app.UseDeveloperExceptionPage();
}
app.UseSockets(routes =>
{
routes.MapSocketEndpoint<HubEndpoint>("/hubs");
routes.MapSocketEndpoint<HubEndPoint<Chat>>("/hubs");
routes.MapSocketEndpoint<ChatEndPoint>("/chat");
routes.MapSocketEndpoint<RpcEndpoint>("/jsonrpc");
routes.MapSocketEndpoint<RpcEndpoint<Echo>>("/jsonrpc");
});
app.UseRpc(invocationAdapters =>

View File

@ -42,11 +42,11 @@
delete calls[response.Id];
if (response.error) {
cb.error(response.error);
if (response.Error) {
cb.error(response.Error);
}
else {
cb.success(response.result);
cb.success(response.Result);
}
}
else {

View File

@ -16,15 +16,15 @@
ws.onmessage = function (event) {
var response = JSON.parse(event.data);
var cb = calls[response.id];
var cb = calls[response.Id];
delete calls[response.id];
delete calls[response.Id];
if (response.error) {
cb.error(response.error);
if (response.Error) {
cb.error(response.Error);
}
else {
cb.success(response.result);
cb.success(response.Result);
}
};
@ -35,7 +35,7 @@
this.invoke = function (method, args) {
return new Promise((resolve, reject) => {
calls[id] = { success: resolve, error: reject };
ws.send(JSON.stringify({ method: method, params: args, id: id }));
ws.send(JSON.stringify({ method: method, arguments: args, id: id }));
id++;
});
};

View File

@ -1,11 +1,12 @@

using System.Collections.Generic;
using System;
using System.Collections.Concurrent;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionMetadata
{
private IDictionary<string, object> _metadata = new Dictionary<string, object>();
private ConcurrentDictionary<string, object> _metadata = new ConcurrentDictionary<string, object>();
public Format Format { get; set; } = Format.Text;
@ -13,12 +14,24 @@ namespace Microsoft.AspNetCore.Sockets
{
get
{
return _metadata[key];
object value;
_metadata.TryGetValue(key, out value);
return value;
}
set
{
_metadata[key] = value;
}
}
public T GetOrAdd<T>(string key, Func<string, T> factory)
{
return (T)_metadata.GetOrAdd(key, k => factory(k));
}
public T Get<T>(string key)
{
return (T)this[key];
}
}
}

View File

@ -7,11 +7,6 @@ namespace Microsoft.AspNetCore.Sockets
/// </summary>
public abstract class EndPoint
{
/// <summary>
/// Live list of connections for this <see cref="EndPoint"/>
/// </summary>
public ConnectionList Connections { get; } = new ConnectionList();
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>

View File

@ -94,12 +94,8 @@ namespace Microsoft.AspNetCore.Sockets
state.Connection.Channel.Dispose();
await endpointTask;
endpoint.Connections.Remove(state.Connection);
};
endpoint.Connections.Add(state.Connection);
endpointTask = endpoint.OnConnected(state.Connection);
state.Connection.Metadata["endpoint"] = endpointTask;
}
@ -124,8 +120,6 @@ namespace Microsoft.AspNetCore.Sockets
state.Connection.Channel.Dispose();
await transportTask;
endpoint.Connections.Remove(state.Connection);
}
// Mark the connection as inactive
@ -143,8 +137,6 @@ namespace Microsoft.AspNetCore.Sockets
// Register this transport for disconnect
RegisterDisconnect(context, connection);
endpoint.Connections.Add(connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connection);
@ -159,8 +151,6 @@ namespace Microsoft.AspNetCore.Sockets
// Wait for both
await Task.WhenAll(endpointTask, transportTask);
endpoint.Connections.Remove(connection);
}
private static void RegisterLongPollingDisconnect(HttpContext context, Connection connection)