Added Connection property to EndPoint

- Exposes a list of connections for user code to act on
- The connection list is thread safe (uses a concurrent dictionary under the hood)
- Removed the Bus and just used the connection list in the samples
This commit is contained in:
David Fowler 2016-10-03 22:58:56 -07:00
parent 4cd4ddfad5
commit 1647432ef6
10 changed files with 182 additions and 189 deletions

View File

@ -1,71 +0,0 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using Channels;
namespace Microsoft.AspNetCore.Sockets
{
public class Message
{
public byte[] Payload { get; set; }
}
public class Bus
{
private readonly ConcurrentDictionary<string, List<Func<Message, Task>>> _subscriptions = new ConcurrentDictionary<string, List<Func<Message, Task>>>();
public IDisposable Subscribe(string key, Func<Message, Task> observer)
{
var connections = _subscriptions.GetOrAdd(key, _ => new List<Func<Message, Task>>());
lock (connections)
{
connections.Add(observer);
}
return new DisposableAction(() =>
{
lock (connections)
{
connections.Remove(observer);
}
});
}
public async Task Publish(string key, Message message)
{
List<Func<Message, Task>> connections;
if (_subscriptions.TryGetValue(key, out connections))
{
Task[] tasks = null;
lock (connections)
{
tasks = new Task[connections.Count];
for (int i = 0; i < connections.Count; i++)
{
tasks[i] = connections[i](message);
}
}
await Task.WhenAll(tasks);
}
}
private class DisposableAction : IDisposable
{
private Action _action;
public DisposableAction(Action action)
{
_action = action;
}
public void Dispose()
{
Interlocked.Exchange(ref _action, () => { }).Invoke();
}
}
}
}

View File

@ -10,51 +10,48 @@ namespace SocketsSample
{
public class ChatEndPoint : EndPoint
{
private Bus bus = new Bus();
public override async Task OnConnected(Connection connection)
{
await bus.Publish(nameof(ChatEndPoint), new Message
{
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})")
});
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
using (bus.Subscribe(nameof(ChatEndPoint), message => OnMessage(message, connection)))
while (true)
{
while (true)
var input = await connection.Channel.Input.ReadAsync();
try
{
var input = await connection.Channel.Input.ReadAsync();
try
if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted)
{
if (input.IsEmpty && connection.Channel.Input.Reading.IsCompleted)
{
break;
}
break;
}
await bus.Publish(nameof(ChatEndPoint), new Message()
{
Payload = input.ToArray()
});
}
finally
{
connection.Channel.Input.Advance(input.End);
}
// We can avoid the copy here but we'll deal with that later
await Broadcast(input.ToArray());
}
finally
{
connection.Channel.Input.Advance(input.End);
}
}
await bus.Publish(nameof(ChatEndPoint), new Message
{
Payload = Encoding.UTF8.GetBytes($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})")
});
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
}
private async Task OnMessage(Message message, Connection connection)
private Task Broadcast(string text)
{
var buffer = connection.Channel.Output.Alloc();
var payload = message.Payload;
buffer.Write(payload);
await buffer.FlushAsync();
return Broadcast(Encoding.UTF8.GetBytes(text));
}
private Task Broadcast(byte[] payload)
{
var tasks = new List<Task>(Connections.Count);
foreach (var c in Connections)
{
tasks.Add(c.Channel.Output.WriteAsync(payload));
}
return Task.WhenAll(tasks);
}
}

View File

@ -1,55 +1,34 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Channels;
using Microsoft.AspNetCore.Sockets;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using SocketsSample.Hubs;
namespace SocketsSample
{
public class HubEndpoint : JsonRpcEndpoint
public class HubEndpoint : JsonRpcEndpoint, IHubConnectionContext
{
private readonly ILogger<HubEndpoint> _logger;
private readonly Bus _bus = new Bus();
public HubEndpoint(ILogger<HubEndpoint> logger, ILogger<JsonRpcEndpoint> jsonRpcLogger, IServiceProvider serviceProvider)
: base(jsonRpcLogger, serviceProvider)
{
_logger = logger;
All = new AllClientProxy(this);
}
public override Task OnConnected(Connection connection)
public IClientProxy All { get; }
public IClientProxy Client(string connectionId)
{
// TODO: Get the list of hubs and signals over the connection
// Subscribe to the hub
_bus.Subscribe(nameof(Chat), message => OnMessage(connection, message));
// Subscribe to the connection id
_bus.Subscribe(connection.ConnectionId, message => OnMessage(connection, message));
return base.OnConnected(connection);
return new SingleClientProxy(this, connectionId);
}
protected override bool HandleResponse(string connectionId, JObject response)
{
var ignore = _bus.Publish(connectionId, new Message
{
Payload = Encoding.UTF8.GetBytes(response.ToString())
});
return true;
}
private Task OnMessage(Connection connection, Message message)
{
return connection.Channel.Output.WriteAsync(message.Payload);
}
public Task Invoke(string key, string method, object[] args)
private byte[] Pack(string method, object[] args)
{
var obj = new JObject();
obj["method"] = method;
@ -57,18 +36,15 @@ namespace SocketsSample
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Outgoing RPC invocation method '{methodName}' to {signal}", method, key);
_logger.LogDebug("Outgoing RPC invocation method '{methodName}'", method);
}
return _bus.Publish(key, new Message
{
Payload = Encoding.UTF8.GetBytes(obj.ToString())
});
return Encoding.UTF8.GetBytes(obj.ToString());
}
protected override void Initialize(object endpoint)
{
((Hub)endpoint).Clients = new HubConnectionContext(endpoint.GetType().Name, this);
((Hub)endpoint).Clients = this;
base.Initialize(endpoint);
}
@ -77,5 +53,53 @@ namespace SocketsSample
// Register the chat hub
RegisterJsonRPCEndPoint(typeof(Chat));
}
private class AllClientProxy : IClientProxy
{
private readonly HubEndpoint _endPoint;
public AllClientProxy(HubEndpoint endPoint)
{
_endPoint = endPoint;
}
public Task Invoke(string method, params object[] args)
{
// REVIEW: Thread safety
var tasks = new List<Task>(_endPoint.Connections.Count);
byte[] message = null;
foreach (var connection in _endPoint.Connections)
{
if (message == null)
{
message = _endPoint.Pack(method, args);
}
tasks.Add(connection.Channel.Output.WriteAsync(message));
}
return Task.WhenAll(tasks);
}
}
private class SingleClientProxy : IClientProxy
{
private readonly string _connectionId;
private readonly HubEndpoint _endPoint;
public SingleClientProxy(HubEndpoint endPoint, string connectionId)
{
_endPoint = endPoint;
_connectionId = connectionId;
}
public Task Invoke(string method, params object[] args)
{
var connection = _endPoint.Connections[_connectionId];
return connection?.Channel.Output.WriteAsync(_endPoint.Pack(method, args));
}
}
}
}

View File

@ -86,22 +86,14 @@ namespace SocketsSample
response["error"] = string.Format("Unknown method '{0}'", request.Value<string>("method"));
}
if (!HandleResponse(connection.ConnectionId, response))
{
_logger.LogDebug("Sending JSON RPC response: {data}", response);
_logger.LogDebug("Sending JSON RPC response: {data}", response);
var writer = new JsonTextWriter(new StreamWriter(stream));
response.WriteTo(writer);
writer.Flush();
}
var writer = new JsonTextWriter(new StreamWriter(stream));
response.WriteTo(writer);
writer.Flush();
}
}
protected virtual bool HandleResponse(string connectionId, JObject response)
{
return false;
}
protected virtual void Initialize(object endpoint)
{

View File

@ -16,41 +16,6 @@ namespace SocketsSample.Hubs
IClientProxy Client(string connectionId);
}
public class HubConnectionContext : IHubConnectionContext
{
private readonly HubEndpoint _endPoint;
public HubConnectionContext(string hubName, HubEndpoint endpoint)
{
_endPoint = endpoint;
All = new HubClientProxy(endpoint, hubName);
}
public IClientProxy All { get; }
public IClientProxy Client(string connectionId)
{
return new HubClientProxy(_endPoint, connectionId);
}
}
public class HubClientProxy : IClientProxy
{
private readonly HubEndpoint _endPoint;
private readonly string _key;
public HubClientProxy(HubEndpoint endPoint, string key)
{
_endPoint = endPoint;
_key = key;
}
public Task Invoke(string method, params object[] args)
{
return _endPoint.Invoke(_key, method, args);
}
}
public interface IClientProxy
{
/// <summary>

View File

@ -0,0 +1,51 @@
using System;
using System.Collections;
using System.Collections.Concurrent;
using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionList : IReadOnlyCollection<Connection>
{
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
public Connection this[string connectionId]
{
get
{
Connection connection;
if (_connections.TryGetValue(connectionId, out connection))
{
return connection;
}
return null;
}
}
public int Count => _connections.Count;
public void Add(Connection connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(Connection connection)
{
Connection dummy;
_connections.TryRemove(connection.ConnectionId, out dummy);
}
public IEnumerator<Connection> GetEnumerator()
{
foreach (var item in _connections)
{
yield return item.Value;
}
}
IEnumerator IEnumerable.GetEnumerator()
{
return GetEnumerator();
}
}
}

View File

@ -83,7 +83,7 @@ namespace Microsoft.AspNetCore.Sockets
ConnectionState s;
if (_connections.TryRemove(c.Key, out s))
{
s.Connection.Channel.Dispose();
s?.Close();
}
else
{

View File

@ -4,8 +4,11 @@ namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionState
{
public Connection Connection { get; set; }
// These are used for long polling mostly
public Action Close { get; set; }
public DateTimeOffset LastSeen { get; set; }
public bool Active { get; set; } = true;
public Connection Connection { get; set; }
}
}

View File

@ -7,8 +7,16 @@ namespace Microsoft.AspNetCore.Sockets
/// </summary>
public class EndPoint
{
// This is a stream based API, we might just want to change to a message based API or invent framing
// over this stream based API to do a message based API
/// <summary>
/// Live list of connections for this <see cref="EndPoint"/>
/// </summary>
public ConnectionList Connections { get; } = new ConnectionList();
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
/// <param name="connection">The new <see cref="Connection"/></param>
/// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns>
public virtual Task OnConnected(Connection connection)
{
return Task.CompletedTask;

View File

@ -64,6 +64,9 @@ namespace Microsoft.AspNetCore.Sockets
// Register this transport for disconnect
RegisterDisconnect(context, connectionState.Connection);
// Add the connection to the list
endpoint.Connections.Add(connectionState.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
@ -80,6 +83,8 @@ namespace Microsoft.AspNetCore.Sockets
await Task.WhenAll(endpointTask, transportTask);
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
endpoint.Connections.Remove(connectionState.Connection);
}
else if (context.Request.Path.StartsWithSegments(path + "/ws"))
{
@ -93,6 +98,8 @@ namespace Microsoft.AspNetCore.Sockets
// Register this transport for disconnect
RegisterDisconnect(context, connectionState.Connection);
endpoint.Connections.Add(connectionState.Connection);
// Call into the end point passing the connection
var endpointTask = endpoint.OnConnected(connectionState.Connection);
@ -109,6 +116,8 @@ namespace Microsoft.AspNetCore.Sockets
await Task.WhenAll(endpointTask, transportTask);
_manager.RemoveConnection(connectionState.Connection.ConnectionId);
endpoint.Connections.Remove(connectionState.Connection);
}
else if (context.Request.Path.StartsWithSegments(path + "/poll"))
{
@ -144,6 +153,19 @@ namespace Microsoft.AspNetCore.Sockets
connectionState.Connection.Metadata["transport"] = "poll";
connectionState.Connection.Metadata.Format = format;
connectionState.Connection.User = context.User;
// REVIEW: This is super gross, this all needs to be cleaned up...
connectionState.Close = async () =>
{
connectionState.Connection.Channel.Dispose();
await endpointTask;
endpoint.Connections.Remove(connectionState.Connection);
};
endpoint.Connections.Add(connectionState.Connection);
endpointTask = endpoint.OnConnected(connectionState.Connection);
connectionState.Connection.Metadata["endpoint"] = endpointTask;
}
@ -168,6 +190,8 @@ namespace Microsoft.AspNetCore.Sockets
connectionState.Connection.Channel.Dispose();
await transportTask;
endpoint.Connections.Remove(connectionState.Connection);
}
// Mark the connection as inactive