Added Connection property to EndPoint
- Exposes a list of connections for user code to act on - The connection list is thread safe (uses a concurrent dictionary under the hood) - Removed the Bus and just used the connection list in the samples
This commit is contained in:
parent
4cd4ddfad5
commit
1647432ef6
|
|
@ -1,71 +0,0 @@
|
|||
using System;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Threading;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class Message
|
||||
{
|
||||
public byte[] Payload { get; set; }
|
||||
}
|
||||
|
||||
public class Bus
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, List<Func<Message, Task>>> _subscriptions = new ConcurrentDictionary<string, List<Func<Message, Task>>>();
|
||||
|
||||
public IDisposable Subscribe(string key, Func<Message, Task> observer)
|
||||
{
|
||||
var connections = _subscriptions.GetOrAdd(key, _ => new List<Func<Message, Task>>());
|
||||
lock (connections)
|
||||
{
|
||||
connections.Add(observer);
|
||||
}
|
||||
|
||||
return new DisposableAction(() =>
|
||||
{
|
||||
lock (connections)
|
||||
{
|
||||
connections.Remove(observer);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public async Task Publish(string key, Message message)
|
||||
{
|
||||
List<Func<Message, Task>> connections;
|
||||
if (_subscriptions.TryGetValue(key, out connections))
|
||||
{
|
||||
Task[] tasks = null;
|
||||
lock (connections)
|
||||
{
|
||||
tasks = new Task[connections.Count];
|
||||
for (int i = 0; i < connections.Count; i++)
|
||||
{
|
||||
tasks[i] = connections[i](message);
|
||||
}
|
||||
}
|
||||
|
||||
await Task.WhenAll(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
private class DisposableAction : IDisposable
|
||||
{
|
||||
private Action _action;
|
||||
|
||||
public DisposableAction(Action action)
|
||||
{
|
||||
_action = action;
|
||||
}
|
||||
|
||||
public void Dispose()
|
||||
{
|
||||
Interlocked.Exchange(ref _action, () => { }).Invoke();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -10,51 +10,48 @@ namespace SocketsSample
|
|||
{
|
||||
public class ChatEndPoint : EndPoint
|
||||
{
|
||||
private Bus bus = new Bus();
|
||||
|
||||
public override async Task OnConnected(Connection connection)
|
||||
{
|
||||
await bus.Publish(nameof(ChatEndPoint), new Message
|
||||
{
|
||||
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})")
|
||||
});
|
||||
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
|
||||
|
||||
using (bus.Subscribe(nameof(ChatEndPoint), message => OnMessage(message, connection)))
|
||||
|
||||
while (true)
|
||||
{
|
||||
while (true)
|
||||
var input = await connection.Channel.Input.ReadAsync();
|
||||
try
|
||||
{
|
||||
var input = await connection.Channel.Input.ReadAsync();
|
||||
try
|
||||
if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted)
|
||||
{
|
||||
if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted)
|
||||
{
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
await bus.Publish(nameof(ChatEndPoint), new Message()
|
||||
{
|
||||
Payload = input.ToArray()
|
||||
});
|
||||
}
|
||||
finally
|
||||
{
|
||||
connection.Channel.Input.Advance(input.End);
|
||||
}
|
||||
// We can avoid the copy here but we'll deal with that later
|
||||
await Broadcast(input.ToArray());
|
||||
}
|
||||
finally
|
||||
{
|
||||
connection.Channel.Input.Advance(input.End);
|
||||
}
|
||||
}
|
||||
|
||||
await bus.Publish(nameof(ChatEndPoint), new Message
|
||||
{
|
||||
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})")
|
||||
});
|
||||
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
|
||||
}
|
||||
|
||||
private async Task OnMessage(Message message, Connection connection)
|
||||
private Task Broadcast(string text)
|
||||
{
|
||||
var buffer = connection.Channel.Output.Alloc();
|
||||
var payload = message.Payload;
|
||||
buffer.Write(payload);
|
||||
await buffer.FlushAsync();
|
||||
return Broadcast(Encoding.UTF8.GetBytes(text));
|
||||
}
|
||||
|
||||
private Task Broadcast(byte[] payload)
|
||||
{
|
||||
var tasks = new List<Task>(Connections.Count);
|
||||
|
||||
foreach (var c in Connections)
|
||||
{
|
||||
tasks.Add(c.Channel.Output.WriteAsync(payload));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,55 +1,34 @@
|
|||
using System;
|
||||
using System.Collections.Generic;
|
||||
using System.Linq;
|
||||
using System.Text;
|
||||
using System.Threading.Tasks;
|
||||
using Channels;
|
||||
using Microsoft.AspNetCore.Sockets;
|
||||
using Microsoft.Extensions.Logging;
|
||||
using Newtonsoft.Json.Linq;
|
||||
using SocketsSample.Hubs;
|
||||
|
||||
namespace SocketsSample
|
||||
{
|
||||
public class HubEndpoint : JsonRpcEndpoint
|
||||
public class HubEndpoint : JsonRpcEndpoint, IHubConnectionContext
|
||||
{
|
||||
private readonly ILogger<HubEndpoint> _logger;
|
||||
private readonly Bus _bus = new Bus();
|
||||
|
||||
public HubEndpoint(ILogger<HubEndpoint> logger, ILogger<JsonRpcEndpoint> jsonRpcLogger, IServiceProvider serviceProvider)
|
||||
: base(jsonRpcLogger, serviceProvider)
|
||||
{
|
||||
_logger = logger;
|
||||
All = new AllClientProxy(this);
|
||||
}
|
||||
|
||||
public override Task OnConnected(Connection connection)
|
||||
public IClientProxy All { get; }
|
||||
|
||||
public IClientProxy Client(string connectionId)
|
||||
{
|
||||
// TODO: Get the list of hubs and signals over the connection
|
||||
|
||||
// Subscribe to the hub
|
||||
_bus.Subscribe(nameof(Chat), message => OnMessage(connection, message));
|
||||
|
||||
// Subscribe to the connection id
|
||||
_bus.Subscribe(connection.ConnectionId, message => OnMessage(connection, message));
|
||||
|
||||
return base.OnConnected(connection);
|
||||
return new SingleClientProxy(this, connectionId);
|
||||
}
|
||||
|
||||
protected override bool HandleResponse(string connectionId, JObject response)
|
||||
{
|
||||
var ignore = _bus.Publish(connectionId, new Message
|
||||
{
|
||||
Payload = Encoding.UTF8.GetBytes(response.ToString())
|
||||
});
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
private Task OnMessage(Connection connection, Message message)
|
||||
{
|
||||
return connection.Channel.Output.WriteAsync(message.Payload);
|
||||
}
|
||||
|
||||
public Task Invoke(string key, string method, object[] args)
|
||||
private byte[] Pack(string method, object[] args)
|
||||
{
|
||||
var obj = new JObject();
|
||||
obj["method"] = method;
|
||||
|
|
@ -57,18 +36,15 @@ namespace SocketsSample
|
|||
|
||||
if (_logger.IsEnabled(LogLevel.Debug))
|
||||
{
|
||||
_logger.LogDebug("Outgoing RPC invocation method '{methodName}' to {signal}", method, key);
|
||||
_logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method);
|
||||
}
|
||||
|
||||
return _bus.Publish(key, new Message
|
||||
{
|
||||
Payload = Encoding.UTF8.GetBytes(obj.ToString())
|
||||
});
|
||||
return Encoding.UTF8.GetBytes(obj.ToString());
|
||||
}
|
||||
|
||||
protected override void Initialize(object endpoint)
|
||||
{
|
||||
((Hub)endpoint).Clients = new HubConnectionContext(endpoint.GetType().Name, this);
|
||||
((Hub)endpoint).Clients = this;
|
||||
base.Initialize(endpoint);
|
||||
}
|
||||
|
||||
|
|
@ -77,5 +53,53 @@ namespace SocketsSample
|
|||
// Register the chat hub
|
||||
RegisterJsonRPCEndPoint(typeof(Chat));
|
||||
}
|
||||
|
||||
private class AllClientProxy : IClientProxy
|
||||
{
|
||||
private readonly HubEndpoint _endPoint;
|
||||
|
||||
public AllClientProxy(HubEndpoint endPoint)
|
||||
{
|
||||
_endPoint = endPoint;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
// REVIEW: Thread safety
|
||||
var tasks = new List<Task>(_endPoint.Connections.Count);
|
||||
|
||||
byte[] message = null;
|
||||
|
||||
foreach (var connection in _endPoint.Connections)
|
||||
{
|
||||
if (message == null)
|
||||
{
|
||||
message = _endPoint.Pack(method, args);
|
||||
}
|
||||
|
||||
tasks.Add(connection.Channel.Output.WriteAsync(message));
|
||||
}
|
||||
|
||||
return Task.WhenAll(tasks);
|
||||
}
|
||||
}
|
||||
|
||||
private class SingleClientProxy : IClientProxy
|
||||
{
|
||||
private readonly string _connectionId;
|
||||
private readonly HubEndpoint _endPoint;
|
||||
|
||||
public SingleClientProxy(HubEndpoint endPoint, string connectionId)
|
||||
{
|
||||
_endPoint = endPoint;
|
||||
_connectionId = connectionId;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
var connection = _endPoint.Connections[_connectionId];
|
||||
return connection?.Channel.Output.WriteAsync(_endPoint.Pack(method, args));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,22 +86,14 @@ namespace SocketsSample
|
|||
response["error"] = string.Format("Unknown method '{0}'", request.Value<string>("method"));
|
||||
}
|
||||
|
||||
if (!HandleResponse(connection.ConnectionId, response))
|
||||
{
|
||||
_logger.LogDebug("Sending JSON RPC response: {data}", response);
|
||||
_logger.LogDebug("Sending JSON RPC response: {data}", response);
|
||||
|
||||
var writer = new JsonTextWriter(new StreamWriter(stream));
|
||||
response.WriteTo(writer);
|
||||
writer.Flush();
|
||||
}
|
||||
var writer = new JsonTextWriter(new StreamWriter(stream));
|
||||
response.WriteTo(writer);
|
||||
writer.Flush();
|
||||
}
|
||||
}
|
||||
|
||||
protected virtual bool HandleResponse(string connectionId, JObject response)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
protected virtual void Initialize(object endpoint)
|
||||
{
|
||||
|
||||
|
|
|
|||
|
|
@ -16,41 +16,6 @@ namespace SocketsSample.Hubs
|
|||
IClientProxy Client(string connectionId);
|
||||
}
|
||||
|
||||
public class HubConnectionContext : IHubConnectionContext
|
||||
{
|
||||
private readonly HubEndpoint _endPoint;
|
||||
|
||||
public HubConnectionContext(string hubName, HubEndpoint endpoint)
|
||||
{
|
||||
_endPoint = endpoint;
|
||||
All = new HubClientProxy(endpoint, hubName);
|
||||
}
|
||||
|
||||
public IClientProxy All { get; }
|
||||
|
||||
public IClientProxy Client(string connectionId)
|
||||
{
|
||||
return new HubClientProxy(_endPoint, connectionId);
|
||||
}
|
||||
}
|
||||
|
||||
public class HubClientProxy : IClientProxy
|
||||
{
|
||||
private readonly HubEndpoint _endPoint;
|
||||
private readonly string _key;
|
||||
|
||||
public HubClientProxy(HubEndpoint endPoint, string key)
|
||||
{
|
||||
_endPoint = endPoint;
|
||||
_key = key;
|
||||
}
|
||||
|
||||
public Task Invoke(string method, params object[] args)
|
||||
{
|
||||
return _endPoint.Invoke(_key, method, args);
|
||||
}
|
||||
}
|
||||
|
||||
public interface IClientProxy
|
||||
{
|
||||
/// <summary>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,51 @@
|
|||
using System;
|
||||
using System.Collections;
|
||||
using System.Collections.Concurrent;
|
||||
using System.Collections.Generic;
|
||||
|
||||
namespace Microsoft.AspNetCore.Sockets
|
||||
{
|
||||
public class ConnectionList : IReadOnlyCollection<Connection>
|
||||
{
|
||||
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
|
||||
|
||||
public Connection this[string connectionId]
|
||||
{
|
||||
get
|
||||
{
|
||||
Connection connection;
|
||||
if (_connections.TryGetValue(connectionId, out connection))
|
||||
{
|
||||
return connection;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
public int Count => _connections.Count;
|
||||
|
||||
public void Add(Connection connection)
|
||||
{
|
||||
_connections.TryAdd(connection.ConnectionId, connection);
|
||||
}
|
||||
|
||||
public void Remove(Connection connection)
|
||||
{
|
||||
Connection dummy;
|
||||
_connections.TryRemove(connection.ConnectionId, out dummy);
|
||||
}
|
||||
|
||||
public IEnumerator<Connection> GetEnumerator()
|
||||
{
|
||||
foreach (var item in _connections)
|
||||
{
|
||||
yield return item.Value;
|
||||
}
|
||||
}
|
||||
|
||||
IEnumerator IEnumerable.GetEnumerator()
|
||||
{
|
||||
return GetEnumerator();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
ConnectionState s;
|
||||
if (_connections.TryRemove(c.Key, out s))
|
||||
{
|
||||
s.Connection.Channel.Dispose();
|
||||
s?.Close();
|
||||
}
|
||||
else
|
||||
{
|
||||
|
|
|
|||
|
|
@ -4,8 +4,11 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
{
|
||||
public class ConnectionState
|
||||
{
|
||||
public Connection Connection { get; set; }
|
||||
|
||||
// These are used for long polling mostly
|
||||
public Action Close { get; set; }
|
||||
public DateTimeOffset LastSeen { get; set; }
|
||||
public bool Active { get; set; } = true;
|
||||
public Connection Connection { get; set; }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,16 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
/// </summary>
|
||||
public class EndPoint
|
||||
{
|
||||
// This is a stream based API, we might just want to change to a message based API or invent framing
|
||||
// over this stream based API to do a message based API
|
||||
/// <summary>
|
||||
/// Live list of connections for this <see cref="EndPoint"/>
|
||||
/// </summary>
|
||||
public ConnectionList Connections { get; } = new ConnectionList();
|
||||
|
||||
/// <summary>
|
||||
/// Called when a new connection is accepted to the endpoint
|
||||
/// </summary>
|
||||
/// <param name="connection">The new <see cref="Connection"/></param>
|
||||
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
|
||||
public virtual Task OnConnected(Connection connection)
|
||||
{
|
||||
return Task.CompletedTask;
|
||||
|
|
|
|||
|
|
@ -64,6 +64,9 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
|
||||
// Add the connection to the list
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
|
|
@ -80,6 +83,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
|
||||
{
|
||||
|
|
@ -93,6 +98,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
// Register this transport for disconnect
|
||||
RegisterDisconnect(context, connectionState.Connection);
|
||||
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
|
||||
// Call into the end point passing the connection
|
||||
var endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
|
||||
|
|
@ -109,6 +116,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
await Task.WhenAll(endpointTask, transportTask);
|
||||
|
||||
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
}
|
||||
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
|
||||
{
|
||||
|
|
@ -144,6 +153,19 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
connectionState.Connection.Metadata["transport"] = "poll";
|
||||
connectionState.Connection.Metadata.Format = format;
|
||||
connectionState.Connection.User = context.User;
|
||||
|
||||
// REVIEW: This is super gross, this all needs to be cleaned up...
|
||||
connectionState.Close = async () =>
|
||||
{
|
||||
connectionState.Connection.Channel.Dispose();
|
||||
|
||||
await endpointTask;
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
};
|
||||
|
||||
endpoint.Connections.Add(connectionState.Connection);
|
||||
|
||||
endpointTask = endpoint.OnConnected(connectionState.Connection);
|
||||
connectionState.Connection.Metadata["endpoint"] = endpointTask;
|
||||
}
|
||||
|
|
@ -168,6 +190,8 @@ namespace Microsoft.AspNetCore.Sockets
|
|||
connectionState.Connection.Channel.Dispose();
|
||||
|
||||
await transportTask;
|
||||
|
||||
endpoint.Connections.Remove(connectionState.Connection);
|
||||
}
|
||||
|
||||
// Mark the connection as inactive
|
||||
|
|
|
|||
Loading…
Reference in New Issue