using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Reflection; using System.Threading.Tasks; using Channels; using Microsoft.AspNetCore.Sockets; using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Newtonsoft.Json.Linq; namespace SocketsSample { // This end point implementation is used for framing JSON objects from the stream public class RpcEndpoint : EndPoint { private readonly Dictionary> _callbacks = new Dictionary>(StringComparer.OrdinalIgnoreCase); private readonly ILogger _logger; private readonly IServiceProvider _serviceProvider; public RpcEndpoint(ILogger logger, IServiceProvider serviceProvider) { // TODO: Discover end points _logger = logger; _serviceProvider = serviceProvider; DiscoverEndpoints(); } protected virtual void DiscoverEndpoints() { RegisterRPCEndPoint(typeof(Echo)); } public override async Task OnConnected(Connection connection) { // TODO: Dispatch from the caller await Task.Yield(); var formatter = _serviceProvider.GetRequiredService() .GetFormatter((string)connection.Metadata["formatType"]); while (true) { // JSON.NET doesn't handle async reads so we wait for data here var result = await connection.Channel.Input.ReadAsync(); // Don't advance the buffer so we parse sync connection.Channel.Input.Advance(result.Buffer.Start); while (!reader.Read()) { break; } JObject request; try { request = JObject.Load(reader); } catch (Exception) { if (result.IsCompleted) { break; } throw; } if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("Received JSON RPC request: {request}", invocationDescriptor.ToString()); } InvocationResultDescriptor result; Func callback; if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback)) { result = callback(invocationDescriptor); } else { // If there's no method then return a failed response for this request result = new InvocationResultDescriptor { Id = invocationDescriptor.Id, Error = $"Unknown method '{invocationDescriptor.Method}'" }; } var resultFormatter = _serviceProvider.GetRequiredService(). GetFormatter((string)connection.Metadata["formatType"]); await resultFormatter.WriteAsync(result, connection.Channel.GetStream()); } } protected virtual void Initialize(object endpoint) { } protected void RegisterRPCEndPoint(Type type) { var methods = new List(); foreach (var m in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic)) { var methodName = type.FullName + "." + m.Name; methods.Add(methodName); var parameters = m.GetParameters(); if (_callbacks.ContainsKey(methodName)) { throw new NotSupportedException(String.Format("Duplicate definitions of {0}. Overloading is not supported.", m.Name)); } if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("RPC method '{methodName}' is bound", methodName); } _callbacks[methodName] = invocationDescriptor => { var invocationResult = new InvocationResultDescriptor(); invocationResult.Id = invocationDescriptor.Id; var scopeFactory = _serviceProvider.GetRequiredService(); // Scope per call so that deps injected get disposed using (var scope = scopeFactory.CreateScope()) { object value = scope.ServiceProvider.GetService(type) ?? Activator.CreateInstance(type); Initialize(value); try { var args = invocationDescriptor.Arguments .Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType)) .ToArray(); invocationResult.Result = m.Invoke(value, args); } catch (TargetInvocationException ex) { invocationResult.Error = ex.InnerException.Message; } catch (Exception ex) { invocationResult.Error = ex.Message; } } return invocationResult; }; }; } } }