Very hacky protobuff support

This commit is contained in:
moozzyk 2016-10-10 16:25:41 -07:00
parent e893f0c6d7
commit a8c831bad6
15 changed files with 561 additions and 113 deletions

View File

@ -78,11 +78,11 @@ namespace SocketsSample
foreach (var connection in _endPoint.Connections)
{
// TODO: separate serialization from writing to stream
var formatter = _endPoint._serviceProvider.GetRequiredService<SocketFormatters>()
.GetFormatter<InvocationDescriptor>((string)connection.Metadata["formatType"]);
tasks.Add(formatter.WriteAsync(message, connection.Channel.GetStream()));
var invocationAdapter = _endPoint._serviceProvider.GetRequiredService<SocketFormatters>()
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
tasks.Add(invocationAdapter.InvokeClientMethod(connection.Channel.GetStream(), message));
}
return Task.WhenAll(tasks);

View File

@ -10,6 +10,7 @@ using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using SocketsSample.Protobuf;
namespace SocketsSample
{
@ -18,10 +19,12 @@ namespace SocketsSample
{
private readonly Dictionary<string, Func<InvocationDescriptor, InvocationResultDescriptor>> _callbacks
= new Dictionary<string, Func<InvocationDescriptor, InvocationResultDescriptor>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly ILogger<RpcEndpoint> _logger;
private readonly IServiceProvider _serviceProvider;
public RpcEndpoint(ILogger<RpcEndpoint> logger, IServiceProvider serviceProvider)
{
// TODO: Discover end points
@ -41,40 +44,34 @@ namespace SocketsSample
// TODO: Dispatch from the caller
await Task.Yield();
/*
var formatter = _serviceProvider.GetRequiredService<SocketFormatters>()
.GetFormatter<InvocationDescriptor>((string)connection.Metadata["formatType"]);
*/
var stream = connection.Channel.GetStream();
var invocationAdapter = _serviceProvider.GetRequiredService<SocketFormatters>()
.GetInvocationAdapter((string)connection.Metadata["formatType"]);
while (true)
{
// JSON.NET doesn't handle async reads so we wait for data here
var result = await connection.Channel.Input.ReadAsync();
var invocationDescriptor =
await invocationAdapter.CreateInvocationDescriptor(
stream, methodName => {
Type[] types;
// TODO: null or throw?
return _paramTypes.TryGetValue(methodName, out types) ? types : null;
});
// Don't advance the buffer so we parse sync
connection.Channel.Input.Advance(result.Buffer.Start);
while (!reader.Read())
{
break;
}
JObject request;
try
{
request = JObject.Load(reader);
}
catch (Exception)
{
if (result.IsCompleted)
{
break;
}
throw;
}
/* TODO: ?? */
//if (((Channel)connection.Channel).Reading.IsCompleted)
//{
// break;
//}
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("Received JSON RPC request: {request}", invocationDescriptor.ToString());
_logger.LogDebug("Received RPC request: {request}", invocationDescriptor.ToString());
}
InvocationResultDescriptor result;
@ -93,9 +90,7 @@ namespace SocketsSample
};
}
var resultFormatter = _serviceProvider.GetRequiredService<SocketFormatters>().
GetFormatter<InvocationResultDescriptor>((string)connection.Metadata["formatType"]);
await resultFormatter.WriteAsync(result, connection.Channel.GetStream());
await invocationAdapter.WriteInvocationResult(stream, result);
}
}
@ -105,21 +100,18 @@ namespace SocketsSample
protected void RegisterRPCEndPoint(Type type)
{
var methods = new List<string>();
foreach (var m in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic))
foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic))
{
var methodName = type.FullName + "." + m.Name;
methods.Add(methodName);
var parameters = m.GetParameters();
var methodName = type.FullName + "." + methodInfo.Name;
if (_callbacks.ContainsKey(methodName))
{
throw new NotSupportedException(String.Format("Duplicate definitions of {0}. Overloading is not supported.", m.Name));
throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported.");
}
var parameters = methodInfo.GetParameters();
_paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray();
if (_logger.IsEnabled(LogLevel.Debug))
{
_logger.LogDebug("RPC method '{methodName}' is bound", methodName);
@ -145,7 +137,7 @@ namespace SocketsSample
.Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType))
.ToArray();
invocationResult.Result = m.Invoke(value, args);
invocationResult.Result = methodInfo.Invoke(value, args);
}
catch (TargetInvocationException ex)
{

View File

@ -32,5 +32,10 @@ namespace SocketsSample
{
_socketFormatters.RegisterFormatter<T, TFormatterType>(format);
}
public void AddInvocationAdapter(string format, IInvocationAdapter adapter)
{
_socketFormatters.RegisterInvocationAdapter(format, adapter);
}
}
}

View File

@ -0,0 +1,15 @@
using System;
using System.IO;
using System.Threading.Tasks;
namespace SocketsSample
{
public interface IInvocationAdapter
{
Task<InvocationDescriptor> CreateInvocationDescriptor(Stream stream, Func<string, Type[]> getParams);
Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor);
Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor);
}
}

View File

@ -1,11 +0,0 @@
using System;
using System.IO;
using System.Threading.Tasks;
namespace SocketsSample
{
interface InvocationDescriptorBuilder
{
Task<InvocationDescriptor> CreateInvocationDescriptor(Stream stream, Func<string, Type[]> getParams);
}
}

View File

@ -0,0 +1,41 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Newtonsoft.Json;
namespace SocketsSample
{
public class JSonInvocationAdapter : IInvocationAdapter
{
IServiceProvider _serviceProvider;
private JsonSerializer _serializer = new JsonSerializer();
public JSonInvocationAdapter(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider;
}
public async Task<InvocationDescriptor> CreateInvocationDescriptor(Stream stream, Func<string, Type[]> getParams)
{
// TODO: use a formatter (?)
var reader = new JsonTextReader(new StreamReader(stream));
return await Task.Run(() => _serializer.Deserialize<InvocationDescriptor>(reader));
}
public Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor)
{
var writer = new JsonTextWriter(new StreamWriter(stream));
_serializer.Serialize(writer, resultDescriptor);
writer.Flush();
return Task.FromResult(0);
}
public Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor)
{
var writer = new JsonTextWriter(new StreamWriter(stream));
_serializer.Serialize(writer, invocationDescriptor);
writer.Flush();
return Task.FromResult(0);
}
}
}

View File

@ -0,0 +1,116 @@
using System;
using System.IO;
using System.Threading.Tasks;
using Google.Protobuf;
namespace SocketsSample.Protobuf
{
public class ProtobufInvocationAdapter : IInvocationAdapter
{
public async Task<InvocationDescriptor> CreateInvocationDescriptor(Stream stream, Func<string, Type[]> getParams)
{
return await Task.Run(() => CreateInvocationDescriptorInt(stream, getParams));
}
private static Task<InvocationDescriptor> CreateInvocationDescriptorInt(Stream stream, Func<string, Type[]> getParams)
{
var inputStream = new CodedInputStream(stream, leaveOpen: true);
var invocationHeader = new RpcInvocationHeader();
inputStream.ReadMessage(invocationHeader);
var argumentTypes = getParams(invocationHeader.Name);
var invocationDescriptor = new InvocationDescriptor();
invocationDescriptor.Method = invocationHeader.Name;
invocationDescriptor.Id = invocationHeader.Id.ToString();
invocationDescriptor.Arguments = new object[argumentTypes.Length];
var primitiveParser = PrimitiveValue.Parser;
for (var i = 0; i < argumentTypes.Length; i++)
{
var value = new PrimitiveValue();
inputStream.ReadMessage(value);
if (typeof(int) == argumentTypes[i])
{
invocationDescriptor.Arguments[i] = value.Int32Value;
}
else if (typeof(string) == argumentTypes[i])
{
invocationDescriptor.Arguments[i] = value.StringValue;
}
else
{
throw new InvalidOperationException();
}
}
return Task.FromResult(invocationDescriptor);
}
public async Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor)
{
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result });
var resultHeader = new RpcInvocationResultHeader
{
Id = int.Parse(resultDescriptor.Id),
HasResult = resultDescriptor.Result != null
};
if (resultDescriptor.Error != null)
{
resultHeader.Error = resultDescriptor.Error;
}
outputStream.WriteMessage(resultHeader);
if (resultHeader.Error == null && resultDescriptor.Result != null)
{
var result = resultDescriptor.Result;
if (result.GetType() == typeof(int))
{
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result });
}
else if (result.GetType() == typeof(string))
{
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result });
}
}
outputStream.Flush();
await stream.FlushAsync();
}
public async Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor)
{
var outputStream = new CodedOutputStream(stream, leaveOpen: true);
outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation });
var invocationHeader = new RpcInvocationHeader()
{
Id = 0,
Name = invocationDescriptor.Method,
NumArgs = invocationDescriptor.Arguments.Length
};
outputStream.WriteMessage(invocationHeader);
foreach (var arg in invocationDescriptor.Arguments)
{
if (arg.GetType() == typeof(int))
{
outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg });
}
else if (arg.GetType() == typeof(string))
{
outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg });
}
}
outputStream.Flush();
await stream.FlushAsync();
}
}
}

View File

@ -1,49 +0,0 @@

using System;
using System.IO;
using System.Reflection;
using System.Threading.Tasks;
using Google.Protobuf;
namespace SocketsSample.Protobuf
{
public class ProtobufInvocationDescriptorBuilder : InvocationDescriptorBuilder
{
public Task<InvocationDescriptor> CreateInvocationDescriptor(Stream stream, Func<string, Type[]> getParams)
{
var invocationDescriptor = new InvocationDescriptor();
var inputStream = new CodedInputStream(stream, leaveOpen: true);
var invocationHeader = new RpcInvocationHeader();
inputStream.ReadMessage(invocationHeader);
var argumentTypes = getParams(invocationHeader.Name);
invocationDescriptor.Method = invocationHeader.Name;
invocationDescriptor.Id = invocationHeader.Id.ToString();
invocationDescriptor.Arguments = new object[argumentTypes.Length];
var primitiveValueParser = PrimitiveValue.Parser;
for (var i = 0; i < argumentTypes.Length; i++)
{
if (argumentTypes[i] == typeof(int))
{
invocationDescriptor.Arguments[i] = primitiveValueParser.ParseFrom(inputStream).Int32Value;
}
else if (argumentTypes[i] == typeof(int))
{
invocationDescriptor.Arguments[i] = primitiveValueParser.ParseFrom(inputStream).StringValue;
}
else if (typeof(IMessage).IsAssignableFrom(argumentTypes[i]))
{
throw new NotImplementedException();
}
}
return Task.FromResult(invocationDescriptor);
}
public async Task WriteResult(Stream stream, InvocationResultDescriptor result)
{
throw new NotImplementedException();
}
}
}

View File

@ -20,14 +20,20 @@ public static partial class RpcInvocationReflection {
static RpcInvocationReflection() {
byte[] descriptorData = global::System.Convert.FromBase64String(
string.Concat(
"ChNScGNJbnZvY2F0aW9uLnByb3RvIkAKE1JwY0ludm9jYXRpb25IZWFkZXIS",
"DAoETmFtZRgBIAEoCRIKCgJJZBgCIAEoBRIPCgdOdW1BcmdzGAMgASgFIkcK",
"DlByaW1pdGl2ZVZhbHVlEhQKCkludDMyVmFsdWUYASABKAVIABIVCgtTdHJp",
"bmdWYWx1ZRgCIAEoCUgAQggKBm9uZW9mX2IGcHJvdG8z"));
"ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l",
"c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k",
"EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u",
"SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD",
"IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF",
"EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp",
"dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY",
"AiABKAlIAEIICgZvbmVvZl9iBnByb3RvMw=="));
descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData,
new pbr::FileDescriptor[] { },
new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] {
new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null),
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null),
new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null)
}));
}
@ -35,6 +41,135 @@ public static partial class RpcInvocationReflection {
}
#region Messages
public sealed partial class RpcMessageKind : pb::IMessage<RpcMessageKind> {
private static readonly pb::MessageParser<RpcMessageKind> _parser = new pb::MessageParser<RpcMessageKind>(() => new RpcMessageKind());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<RpcMessageKind> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind(RpcMessageKind other) : this() {
messageKind_ = other.messageKind_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcMessageKind Clone() {
return new RpcMessageKind(this);
}
/// <summary>Field number for the "MessageKind" field.</summary>
public const int MessageKindFieldNumber = 1;
private global::RpcMessageKind.Types.Kind messageKind_ = 0;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public global::RpcMessageKind.Types.Kind MessageKind {
get { return messageKind_; }
set {
messageKind_ = value;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as RpcMessageKind);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(RpcMessageKind other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (MessageKind != other.MessageKind) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (MessageKind != 0) hash ^= MessageKind.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (MessageKind != 0) {
output.WriteRawTag(8);
output.WriteEnum((int) MessageKind);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (MessageKind != 0) {
size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(RpcMessageKind other) {
if (other == null) {
return;
}
if (other.MessageKind != 0) {
MessageKind = other.MessageKind;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 8: {
messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum();
break;
}
}
}
}
#region Nested types
/// <summary>Container for nested types declared in the RpcMessageKind message type.</summary>
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static partial class Types {
public enum Kind {
[pbr::OriginalName("Result")] Result = 0,
[pbr::OriginalName("Invocation")] Invocation = 1,
}
}
#endregion
}
public sealed partial class RpcInvocationHeader : pb::IMessage<RpcInvocationHeader> {
private static readonly pb::MessageParser<RpcInvocationHeader> _parser = new pb::MessageParser<RpcInvocationHeader>(() => new RpcInvocationHeader());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
@ -42,7 +177,7 @@ public sealed partial class RpcInvocationHeader : pb::IMessage<RpcInvocationHead
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; }
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
@ -208,6 +343,179 @@ public sealed partial class RpcInvocationHeader : pb::IMessage<RpcInvocationHead
}
public sealed partial class RpcInvocationResultHeader : pb::IMessage<RpcInvocationResultHeader> {
private static readonly pb::MessageParser<RpcInvocationResultHeader> _parser = new pb::MessageParser<RpcInvocationResultHeader>(() => new RpcInvocationResultHeader());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pb::MessageParser<RpcInvocationResultHeader> Parser { get { return _parser; } }
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
pbr::MessageDescriptor pb::IMessage.Descriptor {
get { return Descriptor; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader() {
OnConstruction();
}
partial void OnConstruction();
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() {
id_ = other.id_;
hasResult_ = other.hasResult_;
error_ = other.error_;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public RpcInvocationResultHeader Clone() {
return new RpcInvocationResultHeader(this);
}
/// <summary>Field number for the "Id" field.</summary>
public const int IdFieldNumber = 1;
private int id_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int Id {
get { return id_; }
set {
id_ = value;
}
}
/// <summary>Field number for the "HasResult" field.</summary>
public const int HasResultFieldNumber = 2;
private bool hasResult_;
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool HasResult {
get { return hasResult_; }
set {
hasResult_ = value;
}
}
/// <summary>Field number for the "Error" field.</summary>
public const int ErrorFieldNumber = 3;
private string error_ = "";
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public string Error {
get { return error_; }
set {
error_ = pb::ProtoPreconditions.CheckNotNull(value, "value");
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override bool Equals(object other) {
return Equals(other as RpcInvocationResultHeader);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public bool Equals(RpcInvocationResultHeader other) {
if (ReferenceEquals(other, null)) {
return false;
}
if (ReferenceEquals(other, this)) {
return true;
}
if (Id != other.Id) return false;
if (HasResult != other.HasResult) return false;
if (Error != other.Error) return false;
return true;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override int GetHashCode() {
int hash = 1;
if (Id != 0) hash ^= Id.GetHashCode();
if (HasResult != false) hash ^= HasResult.GetHashCode();
if (Error.Length != 0) hash ^= Error.GetHashCode();
return hash;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public override string ToString() {
return pb::JsonFormatter.ToDiagnosticString(this);
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void WriteTo(pb::CodedOutputStream output) {
if (Id != 0) {
output.WriteRawTag(8);
output.WriteInt32(Id);
}
if (HasResult != false) {
output.WriteRawTag(16);
output.WriteBool(HasResult);
}
if (Error.Length != 0) {
output.WriteRawTag(26);
output.WriteString(Error);
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public int CalculateSize() {
int size = 0;
if (Id != 0) {
size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id);
}
if (HasResult != false) {
size += 1 + 1;
}
if (Error.Length != 0) {
size += 1 + pb::CodedOutputStream.ComputeStringSize(Error);
}
return size;
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(RpcInvocationResultHeader other) {
if (other == null) {
return;
}
if (other.Id != 0) {
Id = other.Id;
}
if (other.HasResult != false) {
HasResult = other.HasResult;
}
if (other.Error.Length != 0) {
Error = other.Error;
}
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public void MergeFrom(pb::CodedInputStream input) {
uint tag;
while ((tag = input.ReadTag()) != 0) {
switch(tag) {
default:
input.SkipLastField();
break;
case 8: {
Id = input.ReadInt32();
break;
}
case 16: {
HasResult = input.ReadBool();
break;
}
case 26: {
Error = input.ReadString();
break;
}
}
}
}
}
public sealed partial class PrimitiveValue : pb::IMessage<PrimitiveValue> {
private static readonly pb::MessageParser<PrimitiveValue> _parser = new pb::MessageParser<PrimitiveValue>(() => new PrimitiveValue());
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
@ -215,7 +523,7 @@ public sealed partial class PrimitiveValue : pb::IMessage<PrimitiveValue> {
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]
public static pbr::MessageDescriptor Descriptor {
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; }
get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; }
}
[global::System.Diagnostics.DebuggerNonUserCodeAttribute]

View File

@ -1,11 +1,22 @@
syntax = "proto3";
message RpcMessageKind {
enum Kind { Result = 0; Invocation = 1; }
Kind MessageKind = 1;
}
message RpcInvocationHeader {
string Name = 1;
int32 Id = 2;
int32 NumArgs = 3;
}
message RpcInvocationResultHeader {
int32 Id = 1;
bool HasResult = 2;
string Error = 3;
}
message PrimitiveValue {
oneof oneof_ {
int32 Int32Value = 1;

View File

@ -9,6 +9,7 @@ namespace SocketsSample
{
private IServiceProvider _serviceProvider;
private Dictionary<string, Dictionary<Type, Type>> _formatters = new Dictionary<string, Dictionary<Type, Type>>();
private Dictionary<string, IInvocationAdapter> _invocationAdapters = new Dictionary<string, IInvocationAdapter>();
public SocketFormatters(IServiceProvider serviceProvider)
{
@ -37,7 +38,20 @@ namespace SocketsSample
return (IFormatter<T>)_serviceProvider.GetRequiredService(targetFormatterType);
}
throw new InvalidOperationException($"No formatter register for format '{format}' and type '{typeof(T).GetType().FullName}'");
return null;
// throw new InvalidOperationException($"No formatter register for format '{format}' and type '{typeof(T).GetType().FullName}'");
}
public void RegisterInvocationAdapter(string format, IInvocationAdapter adapter)
{
_invocationAdapters[format] = adapter;
}
public IInvocationAdapter GetInvocationAdapter(string format)
{
IInvocationAdapter value;
return _invocationAdapters.TryGetValue(format, out value) ? value : null;
}
}
}

View File

@ -53,6 +53,9 @@ namespace SocketsSample
formatters.MapFormatter<InvocationResultDescriptor, InvocationResultDescriptorLineFormatter>("line");
formatters.MapFormatter<InvocationDescriptor, RpcJSonFormatter<InvocationDescriptor>>("json");
formatters.MapFormatter<InvocationResultDescriptor, RpcJSonFormatter<InvocationResultDescriptor>>("json");
formatters.AddInvocationAdapter("protobuf", new Protobuf.ProtobufInvocationAdapter());
formatters.AddInvocationAdapter("json", new JSonInvocationAdapter(app.ApplicationServices));
});
}
}

View File

@ -13,7 +13,8 @@
"Microsoft.AspNetCore.Server.IISIntegration": "1.1.0-*",
"Microsoft.AspNetCore.Server.WebListener": "0.1.0",
"Microsoft.AspNetCore.Server.Kestrel": "1.1.0-*",
"Microsoft.Extensions.Logging.Console": "1.1.0-*"
"Microsoft.Extensions.Logging.Console": "1.1.0-*",
"Google.Protobuf": "3.1.0"
},
"tools": {

View File

@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.Sockets
var formatType = (string)context.Request.Query["formatType"];
state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType;
var ws = new WebSockets(state.Connection);
var ws = new WebSockets(state.Connection, format);
await DoPersistentConnection(endpoint, ws, context, state.Connection);

View File

@ -11,11 +11,13 @@ namespace Microsoft.AspNetCore.Sockets
{
private readonly HttpChannel _channel;
private readonly Connection _connection;
private readonly WebSocketMessageType _messageType;
public WebSockets(Connection connection)
public WebSockets(Connection connection, Format format)
{
_connection = connection;
_channel = (HttpChannel)connection.Channel;
_messageType = format == Format.Binary ? WebSocketMessageType.Binary : WebSocketMessageType.Text;
}
public async Task ProcessRequest(HttpContext context)
@ -93,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets
break;
}
await ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None);
await ws.SendAsync(data, _messageType, endOfMessage: true, cancellationToken: CancellationToken.None);
}
}