diff --git a/samples/SocketsSample/EndPoints/HubEndpoint.cs b/samples/SocketsSample/EndPoints/HubEndpoint.cs index fe0abc9044..c92fa439d5 100644 --- a/samples/SocketsSample/EndPoints/HubEndpoint.cs +++ b/samples/SocketsSample/EndPoints/HubEndpoint.cs @@ -78,11 +78,11 @@ namespace SocketsSample foreach (var connection in _endPoint.Connections) { - // TODO: separate serialization from writing to stream - var formatter = _endPoint._serviceProvider.GetRequiredService() - .GetFormatter((string)connection.Metadata["formatType"]); - tasks.Add(formatter.WriteAsync(message, connection.Channel.GetStream())); + var invocationAdapter = _endPoint._serviceProvider.GetRequiredService() + .GetInvocationAdapter((string)connection.Metadata["formatType"]); + + tasks.Add(invocationAdapter.InvokeClientMethod(connection.Channel.GetStream(), message)); } return Task.WhenAll(tasks); diff --git a/samples/SocketsSample/EndPoints/RpcEndpoint.cs b/samples/SocketsSample/EndPoints/RpcEndpoint.cs index 18b8bc8d16..602f9ef8e2 100644 --- a/samples/SocketsSample/EndPoints/RpcEndpoint.cs +++ b/samples/SocketsSample/EndPoints/RpcEndpoint.cs @@ -10,6 +10,7 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Newtonsoft.Json.Linq; +using SocketsSample.Protobuf; namespace SocketsSample { @@ -18,10 +19,12 @@ namespace SocketsSample { private readonly Dictionary> _callbacks = new Dictionary>(StringComparer.OrdinalIgnoreCase); + private readonly Dictionary _paramTypes = new Dictionary(); private readonly ILogger _logger; private readonly IServiceProvider _serviceProvider; + public RpcEndpoint(ILogger logger, IServiceProvider serviceProvider) { // TODO: Discover end points @@ -41,40 +44,34 @@ namespace SocketsSample // TODO: Dispatch from the caller await Task.Yield(); + /* var formatter = _serviceProvider.GetRequiredService() .GetFormatter((string)connection.Metadata["formatType"]); + */ + + var stream = connection.Channel.GetStream(); + var invocationAdapter = _serviceProvider.GetRequiredService() + .GetInvocationAdapter((string)connection.Metadata["formatType"]); while (true) { - // JSON.NET doesn't handle async reads so we wait for data here - var result = await connection.Channel.Input.ReadAsync(); + var invocationDescriptor = + await invocationAdapter.CreateInvocationDescriptor( + stream, methodName => { + Type[] types; + // TODO: null or throw? + return _paramTypes.TryGetValue(methodName, out types) ? types : null; + }); - // Don't advance the buffer so we parse sync - connection.Channel.Input.Advance(result.Buffer.Start); - - while (!reader.Read()) - { - break; - } - - JObject request; - try - { - request = JObject.Load(reader); - } - catch (Exception) - { - if (result.IsCompleted) - { - break; - } - - throw; - } + /* TODO: ?? */ + //if (((Channel)connection.Channel).Reading.IsCompleted) + //{ + // break; + //} if (_logger.IsEnabled(LogLevel.Debug)) { - _logger.LogDebug("Received JSON RPC request: {request}", invocationDescriptor.ToString()); + _logger.LogDebug("Received RPC request: {request}", invocationDescriptor.ToString()); } InvocationResultDescriptor result; @@ -93,9 +90,7 @@ namespace SocketsSample }; } - var resultFormatter = _serviceProvider.GetRequiredService(). - GetFormatter((string)connection.Metadata["formatType"]); - await resultFormatter.WriteAsync(result, connection.Channel.GetStream()); + await invocationAdapter.WriteInvocationResult(stream, result); } } @@ -105,21 +100,18 @@ namespace SocketsSample protected void RegisterRPCEndPoint(Type type) { - var methods = new List(); - - foreach (var m in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic)) + foreach (var methodInfo in type.GetTypeInfo().DeclaredMethods.Where(m => m.IsPublic)) { - var methodName = type.FullName + "." + m.Name; - - methods.Add(methodName); - - var parameters = m.GetParameters(); + var methodName = type.FullName + "." + methodInfo.Name; if (_callbacks.ContainsKey(methodName)) { - throw new NotSupportedException(String.Format("Duplicate definitions of {0}. Overloading is not supported.", m.Name)); + throw new NotSupportedException($"Duplicate definitions of '{methodInfo.Name}'. Overloading is not supported."); } + var parameters = methodInfo.GetParameters(); + _paramTypes[methodName] = parameters.Select(p => p.ParameterType).ToArray(); + if (_logger.IsEnabled(LogLevel.Debug)) { _logger.LogDebug("RPC method '{methodName}' is bound", methodName); @@ -145,7 +137,7 @@ namespace SocketsSample .Zip(parameters, (a, p) => Convert.ChangeType(a, p.ParameterType)) .ToArray(); - invocationResult.Result = m.Invoke(value, args); + invocationResult.Result = methodInfo.Invoke(value, args); } catch (TargetInvocationException ex) { diff --git a/samples/SocketsSample/FormatterExtensions.cs b/samples/SocketsSample/FormatterExtensions.cs index 687fee2ffe..094ecbaada 100644 --- a/samples/SocketsSample/FormatterExtensions.cs +++ b/samples/SocketsSample/FormatterExtensions.cs @@ -32,5 +32,10 @@ namespace SocketsSample { _socketFormatters.RegisterFormatter(format); } + + public void AddInvocationAdapter(string format, IInvocationAdapter adapter) + { + _socketFormatters.RegisterInvocationAdapter(format, adapter); + } } } diff --git a/samples/SocketsSample/IInvocationAdapter.cs b/samples/SocketsSample/IInvocationAdapter.cs new file mode 100644 index 0000000000..8900fea3e3 --- /dev/null +++ b/samples/SocketsSample/IInvocationAdapter.cs @@ -0,0 +1,15 @@ +using System; +using System.IO; +using System.Threading.Tasks; + +namespace SocketsSample +{ + public interface IInvocationAdapter + { + Task CreateInvocationDescriptor(Stream stream, Func getParams); + + Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor); + + Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor); + } +} diff --git a/samples/SocketsSample/InvocationDescriptorBuilder.cs b/samples/SocketsSample/InvocationDescriptorBuilder.cs deleted file mode 100644 index 8ca8bb7ea3..0000000000 --- a/samples/SocketsSample/InvocationDescriptorBuilder.cs +++ /dev/null @@ -1,11 +0,0 @@ -using System; -using System.IO; -using System.Threading.Tasks; - -namespace SocketsSample -{ - interface InvocationDescriptorBuilder - { - Task CreateInvocationDescriptor(Stream stream, Func getParams); - } -} diff --git a/samples/SocketsSample/JSonInvocationAdapter.cs b/samples/SocketsSample/JSonInvocationAdapter.cs new file mode 100644 index 0000000000..49262b075b --- /dev/null +++ b/samples/SocketsSample/JSonInvocationAdapter.cs @@ -0,0 +1,41 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Newtonsoft.Json; + +namespace SocketsSample +{ + public class JSonInvocationAdapter : IInvocationAdapter + { + IServiceProvider _serviceProvider; + private JsonSerializer _serializer = new JsonSerializer(); + + public JSonInvocationAdapter(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider; + } + + public async Task CreateInvocationDescriptor(Stream stream, Func getParams) + { + // TODO: use a formatter (?) + var reader = new JsonTextReader(new StreamReader(stream)); + return await Task.Run(() => _serializer.Deserialize(reader)); + } + + public Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor) + { + var writer = new JsonTextWriter(new StreamWriter(stream)); + _serializer.Serialize(writer, resultDescriptor); + writer.Flush(); + return Task.FromResult(0); + } + + public Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor) + { + var writer = new JsonTextWriter(new StreamWriter(stream)); + _serializer.Serialize(writer, invocationDescriptor); + writer.Flush(); + return Task.FromResult(0); + } + } +} diff --git a/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs b/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs new file mode 100644 index 0000000000..2ff60774af --- /dev/null +++ b/samples/SocketsSample/Protobuf/ProtobufInvocationAdapter.cs @@ -0,0 +1,116 @@ +using System; +using System.IO; +using System.Threading.Tasks; +using Google.Protobuf; + +namespace SocketsSample.Protobuf +{ + public class ProtobufInvocationAdapter : IInvocationAdapter + { + public async Task CreateInvocationDescriptor(Stream stream, Func getParams) + { + return await Task.Run(() => CreateInvocationDescriptorInt(stream, getParams)); + } + + private static Task CreateInvocationDescriptorInt(Stream stream, Func getParams) + { + var inputStream = new CodedInputStream(stream, leaveOpen: true); + var invocationHeader = new RpcInvocationHeader(); + inputStream.ReadMessage(invocationHeader); + var argumentTypes = getParams(invocationHeader.Name); + + var invocationDescriptor = new InvocationDescriptor(); + invocationDescriptor.Method = invocationHeader.Name; + invocationDescriptor.Id = invocationHeader.Id.ToString(); + invocationDescriptor.Arguments = new object[argumentTypes.Length]; + + var primitiveParser = PrimitiveValue.Parser; + + for (var i = 0; i < argumentTypes.Length; i++) + { + var value = new PrimitiveValue(); + inputStream.ReadMessage(value); + if (typeof(int) == argumentTypes[i]) + { + invocationDescriptor.Arguments[i] = value.Int32Value; + } + else if (typeof(string) == argumentTypes[i]) + { + invocationDescriptor.Arguments[i] = value.StringValue; + } + else + { + throw new InvalidOperationException(); + } + } + + return Task.FromResult(invocationDescriptor); + } + + public async Task WriteInvocationResult(Stream stream, InvocationResultDescriptor resultDescriptor) + { + var outputStream = new CodedOutputStream(stream, leaveOpen: true); + outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Result }); + + var resultHeader = new RpcInvocationResultHeader + { + Id = int.Parse(resultDescriptor.Id), + HasResult = resultDescriptor.Result != null + }; + + if (resultDescriptor.Error != null) + { + resultHeader.Error = resultDescriptor.Error; + } + + outputStream.WriteMessage(resultHeader); + + if (resultHeader.Error == null && resultDescriptor.Result != null) + { + var result = resultDescriptor.Result; + + if (result.GetType() == typeof(int)) + { + outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)result }); + } + else if (result.GetType() == typeof(string)) + { + outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)result }); + } + } + + outputStream.Flush(); + await stream.FlushAsync(); + } + + public async Task InvokeClientMethod(Stream stream, InvocationDescriptor invocationDescriptor) + { + var outputStream = new CodedOutputStream(stream, leaveOpen: true); + outputStream.WriteMessage(new RpcMessageKind() { MessageKind = RpcMessageKind.Types.Kind.Invocation }); + + var invocationHeader = new RpcInvocationHeader() + { + Id = 0, + Name = invocationDescriptor.Method, + NumArgs = invocationDescriptor.Arguments.Length + }; + + outputStream.WriteMessage(invocationHeader); + + foreach (var arg in invocationDescriptor.Arguments) + { + if (arg.GetType() == typeof(int)) + { + outputStream.WriteMessage(new PrimitiveValue { Int32Value = (int)arg }); + } + else if (arg.GetType() == typeof(string)) + { + outputStream.WriteMessage(new PrimitiveValue { StringValue = (string)arg }); + } + } + + outputStream.Flush(); + await stream.FlushAsync(); + } + } +} diff --git a/samples/SocketsSample/Protobuf/ProtobufInvocationDescriptorBuilder.cs b/samples/SocketsSample/Protobuf/ProtobufInvocationDescriptorBuilder.cs deleted file mode 100644 index f61a4bcf4a..0000000000 --- a/samples/SocketsSample/Protobuf/ProtobufInvocationDescriptorBuilder.cs +++ /dev/null @@ -1,49 +0,0 @@ - -using System; -using System.IO; -using System.Reflection; -using System.Threading.Tasks; -using Google.Protobuf; - -namespace SocketsSample.Protobuf -{ - public class ProtobufInvocationDescriptorBuilder : InvocationDescriptorBuilder - { - public Task CreateInvocationDescriptor(Stream stream, Func getParams) - { - var invocationDescriptor = new InvocationDescriptor(); - var inputStream = new CodedInputStream(stream, leaveOpen: true); - var invocationHeader = new RpcInvocationHeader(); - inputStream.ReadMessage(invocationHeader); - var argumentTypes = getParams(invocationHeader.Name); - - invocationDescriptor.Method = invocationHeader.Name; - invocationDescriptor.Id = invocationHeader.Id.ToString(); - invocationDescriptor.Arguments = new object[argumentTypes.Length]; - - var primitiveValueParser = PrimitiveValue.Parser; - for (var i = 0; i < argumentTypes.Length; i++) - { - if (argumentTypes[i] == typeof(int)) - { - invocationDescriptor.Arguments[i] = primitiveValueParser.ParseFrom(inputStream).Int32Value; - } - else if (argumentTypes[i] == typeof(int)) - { - invocationDescriptor.Arguments[i] = primitiveValueParser.ParseFrom(inputStream).StringValue; - } - else if (typeof(IMessage).IsAssignableFrom(argumentTypes[i])) - { - throw new NotImplementedException(); - } - } - - return Task.FromResult(invocationDescriptor); - } - - public async Task WriteResult(Stream stream, InvocationResultDescriptor result) - { - throw new NotImplementedException(); - } - } -} diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.cs b/samples/SocketsSample/Protobuf/RpcInvocation.cs index 6b5202dcce..1a4ed9ed19 100644 --- a/samples/SocketsSample/Protobuf/RpcInvocation.cs +++ b/samples/SocketsSample/Protobuf/RpcInvocation.cs @@ -20,14 +20,20 @@ public static partial class RpcInvocationReflection { static RpcInvocationReflection() { byte[] descriptorData = global::System.Convert.FromBase64String( string.Concat( - "ChNScGNJbnZvY2F0aW9uLnByb3RvIkAKE1JwY0ludm9jYXRpb25IZWFkZXIS", - "DAoETmFtZRgBIAEoCRIKCgJJZBgCIAEoBRIPCgdOdW1BcmdzGAMgASgFIkcK", - "DlByaW1pdGl2ZVZhbHVlEhQKCkludDMyVmFsdWUYASABKAVIABIVCgtTdHJp", - "bmdWYWx1ZRgCIAEoCUgAQggKBm9uZW9mX2IGcHJvdG8z")); + "ChNScGNJbnZvY2F0aW9uLnByb3RvIl8KDlJwY01lc3NhZ2VLaW5kEikKC01l", + "c3NhZ2VLaW5kGAEgASgOMhQuUnBjTWVzc2FnZUtpbmQuS2luZCIiCgRLaW5k", + "EgoKBlJlc3VsdBAAEg4KCkludm9jYXRpb24QASJAChNScGNJbnZvY2F0aW9u", + "SGVhZGVyEgwKBE5hbWUYASABKAkSCgoCSWQYAiABKAUSDwoHTnVtQXJncxgD", + "IAEoBSJJChlScGNJbnZvY2F0aW9uUmVzdWx0SGVhZGVyEgoKAklkGAEgASgF", + "EhEKCUhhc1Jlc3VsdBgCIAEoCBINCgVFcnJvchgDIAEoCSJHCg5QcmltaXRp", + "dmVWYWx1ZRIUCgpJbnQzMlZhbHVlGAEgASgFSAASFQoLU3RyaW5nVmFsdWUY", + "AiABKAlIAEIICgZvbmVvZl9iBnByb3RvMw==")); descriptor = pbr::FileDescriptor.FromGeneratedCode(descriptorData, new pbr::FileDescriptor[] { }, new pbr::GeneratedClrTypeInfo(null, new pbr::GeneratedClrTypeInfo[] { + new pbr::GeneratedClrTypeInfo(typeof(global::RpcMessageKind), global::RpcMessageKind.Parser, new[]{ "MessageKind" }, null, new[]{ typeof(global::RpcMessageKind.Types.Kind) }, null), new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationHeader), global::RpcInvocationHeader.Parser, new[]{ "Name", "Id", "NumArgs" }, null, null, null), + new pbr::GeneratedClrTypeInfo(typeof(global::RpcInvocationResultHeader), global::RpcInvocationResultHeader.Parser, new[]{ "Id", "HasResult", "Error" }, null, null, null), new pbr::GeneratedClrTypeInfo(typeof(global::PrimitiveValue), global::PrimitiveValue.Parser, new[]{ "Int32Value", "StringValue" }, new[]{ "Oneof" }, null, null) })); } @@ -35,6 +41,135 @@ public static partial class RpcInvocationReflection { } #region Messages +public sealed partial class RpcMessageKind : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcMessageKind()); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::RpcInvocationReflection.Descriptor.MessageTypes[0]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcMessageKind() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcMessageKind(RpcMessageKind other) : this() { + messageKind_ = other.messageKind_; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcMessageKind Clone() { + return new RpcMessageKind(this); + } + + /// Field number for the "MessageKind" field. + public const int MessageKindFieldNumber = 1; + private global::RpcMessageKind.Types.Kind messageKind_ = 0; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public global::RpcMessageKind.Types.Kind MessageKind { + get { return messageKind_; } + set { + messageKind_ = value; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RpcMessageKind); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RpcMessageKind other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (MessageKind != other.MessageKind) return false; + return true; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (MessageKind != 0) hash ^= MessageKind.GetHashCode(); + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (MessageKind != 0) { + output.WriteRawTag(8); + output.WriteEnum((int) MessageKind); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (MessageKind != 0) { + size += 1 + pb::CodedOutputStream.ComputeEnumSize((int) MessageKind); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RpcMessageKind other) { + if (other == null) { + return; + } + if (other.MessageKind != 0) { + MessageKind = other.MessageKind; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + input.SkipLastField(); + break; + case 8: { + messageKind_ = (global::RpcMessageKind.Types.Kind) input.ReadEnum(); + break; + } + } + } + } + + #region Nested types + /// Container for nested types declared in the RpcMessageKind message type. + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static partial class Types { + public enum Kind { + [pbr::OriginalName("Result")] Result = 0, + [pbr::OriginalName("Invocation")] Invocation = 1, + } + + } + #endregion + +} + public sealed partial class RpcInvocationHeader : pb::IMessage { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationHeader()); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -42,7 +177,7 @@ public sealed partial class RpcInvocationHeader : pb::IMessage { + private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new RpcInvocationResultHeader()); + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pb::MessageParser Parser { get { return _parser; } } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public static pbr::MessageDescriptor Descriptor { + get { return global::RpcInvocationReflection.Descriptor.MessageTypes[2]; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + pbr::MessageDescriptor pb::IMessage.Descriptor { + get { return Descriptor; } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcInvocationResultHeader() { + OnConstruction(); + } + + partial void OnConstruction(); + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcInvocationResultHeader(RpcInvocationResultHeader other) : this() { + id_ = other.id_; + hasResult_ = other.hasResult_; + error_ = other.error_; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public RpcInvocationResultHeader Clone() { + return new RpcInvocationResultHeader(this); + } + + /// Field number for the "Id" field. + public const int IdFieldNumber = 1; + private int id_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int Id { + get { return id_; } + set { + id_ = value; + } + } + + /// Field number for the "HasResult" field. + public const int HasResultFieldNumber = 2; + private bool hasResult_; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool HasResult { + get { return hasResult_; } + set { + hasResult_ = value; + } + } + + /// Field number for the "Error" field. + public const int ErrorFieldNumber = 3; + private string error_ = ""; + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public string Error { + get { return error_; } + set { + error_ = pb::ProtoPreconditions.CheckNotNull(value, "value"); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override bool Equals(object other) { + return Equals(other as RpcInvocationResultHeader); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public bool Equals(RpcInvocationResultHeader other) { + if (ReferenceEquals(other, null)) { + return false; + } + if (ReferenceEquals(other, this)) { + return true; + } + if (Id != other.Id) return false; + if (HasResult != other.HasResult) return false; + if (Error != other.Error) return false; + return true; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override int GetHashCode() { + int hash = 1; + if (Id != 0) hash ^= Id.GetHashCode(); + if (HasResult != false) hash ^= HasResult.GetHashCode(); + if (Error.Length != 0) hash ^= Error.GetHashCode(); + return hash; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public override string ToString() { + return pb::JsonFormatter.ToDiagnosticString(this); + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void WriteTo(pb::CodedOutputStream output) { + if (Id != 0) { + output.WriteRawTag(8); + output.WriteInt32(Id); + } + if (HasResult != false) { + output.WriteRawTag(16); + output.WriteBool(HasResult); + } + if (Error.Length != 0) { + output.WriteRawTag(26); + output.WriteString(Error); + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public int CalculateSize() { + int size = 0; + if (Id != 0) { + size += 1 + pb::CodedOutputStream.ComputeInt32Size(Id); + } + if (HasResult != false) { + size += 1 + 1; + } + if (Error.Length != 0) { + size += 1 + pb::CodedOutputStream.ComputeStringSize(Error); + } + return size; + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(RpcInvocationResultHeader other) { + if (other == null) { + return; + } + if (other.Id != 0) { + Id = other.Id; + } + if (other.HasResult != false) { + HasResult = other.HasResult; + } + if (other.Error.Length != 0) { + Error = other.Error; + } + } + + [global::System.Diagnostics.DebuggerNonUserCodeAttribute] + public void MergeFrom(pb::CodedInputStream input) { + uint tag; + while ((tag = input.ReadTag()) != 0) { + switch(tag) { + default: + input.SkipLastField(); + break; + case 8: { + Id = input.ReadInt32(); + break; + } + case 16: { + HasResult = input.ReadBool(); + break; + } + case 26: { + Error = input.ReadString(); + break; + } + } + } + } + +} + public sealed partial class PrimitiveValue : pb::IMessage { private static readonly pb::MessageParser _parser = new pb::MessageParser(() => new PrimitiveValue()); [global::System.Diagnostics.DebuggerNonUserCodeAttribute] @@ -215,7 +523,7 @@ public sealed partial class PrimitiveValue : pb::IMessage { [global::System.Diagnostics.DebuggerNonUserCodeAttribute] public static pbr::MessageDescriptor Descriptor { - get { return global::RpcInvocationReflection.Descriptor.MessageTypes[1]; } + get { return global::RpcInvocationReflection.Descriptor.MessageTypes[3]; } } [global::System.Diagnostics.DebuggerNonUserCodeAttribute] diff --git a/samples/SocketsSample/Protobuf/RpcInvocation.proto b/samples/SocketsSample/Protobuf/RpcInvocation.proto index ddd5f331b7..7cb9ca3361 100644 --- a/samples/SocketsSample/Protobuf/RpcInvocation.proto +++ b/samples/SocketsSample/Protobuf/RpcInvocation.proto @@ -1,11 +1,22 @@ syntax = "proto3"; +message RpcMessageKind { + enum Kind { Result = 0; Invocation = 1; } + Kind MessageKind = 1; +} + message RpcInvocationHeader { string Name = 1; int32 Id = 2; int32 NumArgs = 3; } +message RpcInvocationResultHeader { + int32 Id = 1; + bool HasResult = 2; + string Error = 3; +} + message PrimitiveValue { oneof oneof_ { int32 Int32Value = 1; diff --git a/samples/SocketsSample/SocketFormatters.cs b/samples/SocketsSample/SocketFormatters.cs index 08175fd8c8..52bbe94f35 100644 --- a/samples/SocketsSample/SocketFormatters.cs +++ b/samples/SocketsSample/SocketFormatters.cs @@ -9,6 +9,7 @@ namespace SocketsSample { private IServiceProvider _serviceProvider; private Dictionary> _formatters = new Dictionary>(); + private Dictionary _invocationAdapters = new Dictionary(); public SocketFormatters(IServiceProvider serviceProvider) { @@ -37,7 +38,20 @@ namespace SocketsSample return (IFormatter)_serviceProvider.GetRequiredService(targetFormatterType); } - throw new InvalidOperationException($"No formatter register for format '{format}' and type '{typeof(T).GetType().FullName}'"); + return null; + // throw new InvalidOperationException($"No formatter register for format '{format}' and type '{typeof(T).GetType().FullName}'"); + } + + public void RegisterInvocationAdapter(string format, IInvocationAdapter adapter) + { + _invocationAdapters[format] = adapter; + } + + public IInvocationAdapter GetInvocationAdapter(string format) + { + IInvocationAdapter value; + + return _invocationAdapters.TryGetValue(format, out value) ? value : null; } } } \ No newline at end of file diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index 4695f136be..dc88274411 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -53,6 +53,9 @@ namespace SocketsSample formatters.MapFormatter("line"); formatters.MapFormatter>("json"); formatters.MapFormatter>("json"); + + formatters.AddInvocationAdapter("protobuf", new Protobuf.ProtobufInvocationAdapter()); + formatters.AddInvocationAdapter("json", new JSonInvocationAdapter(app.ApplicationServices)); }); } } diff --git a/samples/SocketsSample/project.json b/samples/SocketsSample/project.json index ecfbf94513..4b87fd5135 100644 --- a/samples/SocketsSample/project.json +++ b/samples/SocketsSample/project.json @@ -13,7 +13,8 @@ "Microsoft.AspNetCore.Server.IISIntegration": "1.1.0-*", "Microsoft.AspNetCore.Server.WebListener": "0.1.0", "Microsoft.AspNetCore.Server.Kestrel": "1.1.0-*", - "Microsoft.Extensions.Logging.Console": "1.1.0-*" + "Microsoft.Extensions.Logging.Console": "1.1.0-*", + "Google.Protobuf": "3.1.0" }, "tools": { diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index 1da06ac2de..6203fd833a 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -65,7 +65,7 @@ namespace Microsoft.AspNetCore.Sockets var formatType = (string)context.Request.Query["formatType"]; state.Connection.Metadata["formatType"] = string.IsNullOrEmpty(formatType) ? "json" : formatType; - var ws = new WebSockets(state.Connection); + var ws = new WebSockets(state.Connection, format); await DoPersistentConnection(endpoint, ws, context, state.Connection); diff --git a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs index de3b7e4cc5..fdec59ed28 100644 --- a/src/Microsoft.AspNetCore.Sockets/WebSockets.cs +++ b/src/Microsoft.AspNetCore.Sockets/WebSockets.cs @@ -11,11 +11,13 @@ namespace Microsoft.AspNetCore.Sockets { private readonly HttpChannel _channel; private readonly Connection _connection; + private readonly WebSocketMessageType _messageType; - public WebSockets(Connection connection) + public WebSockets(Connection connection, Format format) { _connection = connection; _channel = (HttpChannel)connection.Channel; + _messageType = format == Format.Binary ? WebSocketMessageType.Binary : WebSocketMessageType.Text; } public async Task ProcessRequest(HttpContext context) @@ -93,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets break; } - await ws.SendAsync(data, WebSocketMessageType.Text, endOfMessage: true, cancellationToken: CancellationToken.None); + await ws.SendAsync(data, _messageType, endOfMessage: true, cancellationToken: CancellationToken.None); } }